In [None]:
import os
import torch
import torch.optim as optim
import torchvision
from tqdm import tqdm
import sys
import time
import matplotlib.pyplot as plt

from ddpm import config as _config
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_dataloaders
from ddpm.diffusion_model import DiffusionModel

In [None]:
_config.DEBUG = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("System info: ")
print("Device:", device)
print("Device count", torch.cuda.device_count())
print("GPU Device:", torch.cuda.get_device_name(0))
print("GPU RAM:", f"{(torch.cuda.get_device_properties(0).total_memory / 1e9).__round__(2)} GB")

cifar10_config.res_net_config.initial_pad = 0
batch_size = cifar10_config.batch_size

# max_epochs = 500
max_epochs = 5

learning_rate = 1e-4

In [None]:
train_loader, test_loader = get_cifar10_dataloaders(batch_size=batch_size)

In [None]:
model = DiffusionModel(cifar10_config).to(device)

# Create an EMA model (exact copy of the original model)
model_ema = DiffusionModel(cifar10_config).to(device)
model_ema.load_state_dict(model.state_dict())
model_ema.eval()

# Utility function to update EMA weights
def update_ema(model, ema_model, alpha=0.9999):
    """EMA update for each parameter."""
    with torch.no_grad():
        for p, p_ema in zip(model.parameters(), ema_model.parameters()):
            p_ema.data = alpha * p_ema.data + (1 - alpha) * p.data


In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# CosineAnnealingLR will decay the LR smoothly over max_epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

In [None]:
for epoch in range(max_epochs):
    t0 = time.time()
    model.train()
    batch_progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=True)

    for batch_idx, (images, labels) in enumerate(batch_progress):
        images = images.to(device)
        labels = labels.to(device)

        loss = model(images, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_ema(model, model_ema)
        
        batch_progress.set_postfix(loss=loss.item())
        sys.stdout.flush()
        
    scheduler.step()

    current_lr = scheduler.get_last_lr()[0]
    tqdm.write(f"Epoch {epoch}, loss={loss.item():.4f}, LR={current_lr}")

    model_ema.eval()