In [None]:
# train_models_dual.ipynb
%load_ext autoreload
%autoreload 2

import torch
import matplotlib.pyplot as plt
import os
# Import your modules
from data import generate_all_datasets
from model import ScoreNet  # your ScoreNet architecture
from trainer import train_diffusion_model
   

# Set up the device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")


# Create checkpoint directory
os.makedirs("checkpoints", exist_ok=True)

# Hyperparameters
num_samples = 100000
num_epochs = 10000
batch_size = 5000
lr = 1e-3
num_diffusion_timesteps = 1000  # (this can be used inside your trainer if needed)
 # Weight for the regularization term

# Generate all datasets
all_datasets = generate_all_datasets(num_samples)
print(f"Available datasets: {list(all_datasets.keys())}")

# Choose the two datasets to train on
dataset_name_x = "Rectangle"


if dataset_name_x not in all_datasets:
    raise ValueError(f"Datasets not found. Available datasets: {list(all_datasets.keys())}")

data_x = all_datasets[dataset_name_x].to(device)

print(f"Dataset X '{dataset_name_x}' shape: {data_x.shape}")


# Initialize models
score_net1 = ScoreNet().to(device)


# Setup optimizer
params = list(score_net1.parameters())
optimizer = torch.optim.Adam(params, lr=lr)


# Train the models
print(f"Starting training with {dataset_name_x}  dataset")

losses = train_diffusion_model(
    data_x,
    score_net1,
    optimizer,
    num_diffusion_timesteps=num_diffusion_timesteps,
    batch_size=batch_size,
    num_epochs=num_epochs,
    device=device,
    checkpoint_path=f"checkpoints/{dataset_name_x}.pt",
    save_every=100
)

# Plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.title(f'Training Loss: {dataset_name_x} ')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')  # Log scale often works better for loss visualization
plt.grid(True)
plt.savefig(f"dual_training_{dataset_name_x}_loss.png")
plt.show()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: mps
Available datasets: ['Gaussian Mixture', 'Rectangle']
Dataset X 'Rectangle' shape: torch.Size([100000, 2])
Starting training with Rectangle  dataset


Training Progress:   0%|          | 0/10000 [00:00<?, ?it/s]

Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint saved to checkpoints/Rectangle.pt
Checkpoint