In [None]:
import os
import torch
import torch.optim as optim
import torchvision
from tqdm import tqdm
import sys

from ddpm import config as _config
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_dataloaders
from ddpm.diffusion_model import DiffusionModel

In [None]:
_config.DEBUG = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cifar10_config.res_net_config.initial_pad = 0
batch_size = cifar10_config.batch_size
max_epochs = 100          # train for 100 epochs
learning_rate = 1e-4

Using device: cuda


In [None]:
os.makedirs("samples", exist_ok=True)

In [None]:
train_loader, test_loader = get_cifar10_dataloaders(batch_size=batch_size)



In [None]:
model = DiffusionModel(cifar10_config).to(device)

# Create an EMA model (exact copy of the original model)
model_ema = DiffusionModel(cifar10_config).to(device)
model_ema.load_state_dict(model.state_dict())
model_ema.eval()

# Utility function to update EMA weights
def update_ema(model, ema_model, alpha=0.9999):
    """EMA update for each parameter."""
    with torch.no_grad():
        for p, p_ema in zip(model.parameters(), ema_model.parameters()):
            p_ema.data = alpha * p_ema.data + (1 - alpha) * p.data


In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# CosineAnnealingLR will decay the LR smoothly over max_epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)


In [None]:
for epoch in range(max_epochs):
    model.train()
    batch_progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, (images, labels) in enumerate(batch_progress):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass & loss
        loss = model(images, labels)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update EMA after each optimizer step
        update_ema(model, model_ema)

        # Progress bar info
        batch_progress.set_postfix(loss=loss.item())
        sys.stdout.flush()

    # Step the LR scheduler after each epoch
    scheduler.step()

    # Print the final loss & current LR
    current_lr = scheduler.get_last_lr()[0]
    tqdm.write(f"Epoch {epoch}, loss={loss.item():.4f}, LR={current_lr}")

    # ----------------------------------
    # Periodically sample with EMA model
    # ----------------------------------
    model_ema.eval()
    with torch.no_grad():
        # Let's pick random class labels for 16 samples
        labels_for_sampling = torch.randint(low=0, high=10, size=(16,), device=device)
        samples = model_ema.sample(
            shape=(16, 3, 32, 32),
            device=device,
            y=labels_for_sampling
        )

    # Convert from [-1,1] to [0,1] if your model outputs are in [-1,1]
    samples = (samples.clamp(-1, 1) + 1) / 2

    # Save image grid for this epoch to "samples/" folder
    img_name = f"samples/generated_samples_epoch_{epoch}.png"
    torchvision.utils.save_image(samples, img_name, nrow=4)
    # print(f"[INFO] Saved samples to {img_name}")

In [None]:
torch.save({
    'model_state': model.state_dict(),
    'model_ema_state': model_ema.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epoch
}, "diffusion_model_checkpoint.pth")

print("Training complete, model and EMA weights saved.")

Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]

[DEBUG] Epoch 0, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:   0%|          | 0/10 [02:27<?, ?it/s]

Epoch 0, loss=0.0601, LR=9.779754323328192e-05


Training Progress:  10%|█         | 1/10 [02:42<24:25, 162.89s/it]

[INFO] Saved samples to generated_samples_epoch_0.png




[DEBUG] Epoch 1, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  10%|█         | 1/10 [05:16<24:25, 162.89s/it]

Epoch 1, loss=0.0448, LR=9.140576474687264e-05


Training Progress:  20%|██        | 2/10 [05:31<22:11, 166.49s/it]

[INFO] Saved samples to generated_samples_epoch_1.png




[DEBUG] Epoch 2, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  20%|██        | 2/10 [08:05<22:11, 166.49s/it]

Epoch 2, loss=0.0414, LR=8.14503363531613e-05


Training Progress:  30%|███       | 3/10 [08:22<19:37, 168.14s/it]

[INFO] Saved samples to generated_samples_epoch_2.png




[DEBUG] Epoch 3, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  30%|███       | 3/10 [10:48<19:37, 168.14s/it]

Epoch 3, loss=0.0424, LR=6.890576474687264e-05


Training Progress:  40%|████      | 4/10 [11:03<16:33, 165.56s/it]

[INFO] Saved samples to generated_samples_epoch_3.png




[DEBUG] Epoch 4, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  40%|████      | 4/10 [13:24<16:33, 165.56s/it]

Epoch 4, loss=0.0426, LR=5.500000000000001e-05


Training Progress:  50%|█████     | 5/10 [13:40<13:31, 162.36s/it]

[INFO] Saved samples to generated_samples_epoch_4.png




[DEBUG] Epoch 5, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  50%|█████     | 5/10 [16:07<13:31, 162.36s/it]

Epoch 5, loss=0.0259, LR=4.109423525312737e-05


Training Progress:  60%|██████    | 6/10 [16:23<10:50, 162.58s/it]

[INFO] Saved samples to generated_samples_epoch_5.png




[DEBUG] Epoch 6, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  60%|██████    | 6/10 [18:48<10:50, 162.58s/it]

Epoch 6, loss=0.0251, LR=2.8549663646838717e-05


Training Progress:  70%|███████   | 7/10 [19:04<08:06, 162.11s/it]

[INFO] Saved samples to generated_samples_epoch_6.png




[DEBUG] Epoch 7, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  70%|███████   | 7/10 [21:29<08:06, 162.11s/it]

Epoch 7, loss=0.0155, LR=1.8594235253127375e-05


Training Progress:  80%|████████  | 8/10 [21:45<05:23, 161.69s/it]

[INFO] Saved samples to generated_samples_epoch_7.png




[DEBUG] Epoch 8, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  80%|████████  | 8/10 [24:12<05:23, 161.69s/it]

Epoch 8, loss=0.0204, LR=1.2202456766718093e-05


Training Progress:  90%|█████████ | 9/10 [24:27<02:41, 161.84s/it]

[INFO] Saved samples to generated_samples_epoch_8.png




[DEBUG] Epoch 9, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  90%|█████████ | 9/10 [26:59<02:41, 161.84s/it]

Epoch 9, loss=0.0089, LR=1e-05


Training Progress: 100%|██████████| 10/10 [27:14<00:00, 163.48s/it]

[INFO] Saved samples to generated_samples_epoch_9.png



