In [2]:
import os
import torch
import torch.optim as optim
import torchvision
from tqdm import tqdm
import sys

from ddpm import config as _config
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_dataloaders
from ddpm.diffusion_model import DiffusionModel

In [None]:
_config.DEBUG = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

cifar10_config.res_net_config.initial_pad = 0
batch_size = cifar10_config.batch_size
max_epochs = 250  
learning_rate = 1e-4

Using device: cuda


In [4]:
os.makedirs("samples", exist_ok=True)

In [5]:
train_loader, test_loader = get_cifar10_dataloaders(batch_size=batch_size)

In [6]:
model = DiffusionModel(cifar10_config).to(device)

# Create an EMA model (exact copy of the original model)
model_ema = DiffusionModel(cifar10_config).to(device)
model_ema.load_state_dict(model.state_dict())
model_ema.eval()

# Utility function to update EMA weights
def update_ema(model, ema_model, alpha=0.9999):
    """EMA update for each parameter."""
    with torch.no_grad():
        for p, p_ema in zip(model.parameters(), ema_model.parameters()):
            p_ema.data = alpha * p_ema.data + (1 - alpha) * p.data


In [7]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# CosineAnnealingLR will decay the LR smoothly over max_epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)


In [8]:
for epoch in range(max_epochs):
    model.train()
    batch_progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, (images, labels) in enumerate(batch_progress):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass & loss
        loss = model(images, labels)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update EMA after each optimizer step
        update_ema(model, model_ema)

        # Progress bar info
        batch_progress.set_postfix(loss=loss.item())
        sys.stdout.flush()

    # Step the LR scheduler after each epoch
    scheduler.step()

    # Print the final loss & current LR
    current_lr = scheduler.get_last_lr()[0]
    tqdm.write(f"Epoch {epoch}, loss={loss.item():.4f}, LR={current_lr}")

    # ----------------------------------
    # Periodically sample with EMA model
    # ----------------------------------
    model_ema.eval()
    with torch.no_grad():
        # Let's pick random class labels for 16 samples
        labels_for_sampling = torch.randint(low=0, high=10, size=(16,), device=device)
        samples = model_ema.sample(
            shape=(16, 3, 32, 32),
            device=device,
            y=labels_for_sampling
        )

    # Convert from [-1,1] to [0,1] if your model outputs are in [-1,1]
    samples = (samples.clamp(-1, 1) + 1) / 2

    # Save image grid for this epoch to "samples/" folder
    img_name = f"samples/generated_samples_epoch_{epoch}.png"
    torchvision.utils.save_image(samples, img_name, nrow=4)
    # print(f"[INFO] Saved samples to {img_name}")

                                                                       

Epoch 0, loss=0.0668, LR=9.999975326009292e-05


                                                                       

Epoch 1, loss=0.0651, LR=9.999901304280685e-05


                                                                       

KeyboardInterrupt: 

In [None]:
torch.save({
    'model_state': model.state_dict(),
    'model_ema_state': model_ema.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epoch
}, "diffusion_model_checkpoint.pth")

print("Training complete, model and EMA weights saved.")

Training complete, model and EMA weights saved.


: 