Initialization

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from tqdm import tqdm
import numpy as np
from PIL import Image
from torchvision.models import resnet18
import torch.nn.functional as F

# ----------------------------
# 1. Configuration & Data Processing
# ----------------------------
CONFIG = {
    "image_size": 32,
    "batch_size_train": 32,
    "batch_size_eval": 32,
    "num_epochs": 5,  # "training loops"
    "lr": 1e-4,
    "timesteps": 1000,
    "cfg_scale": 7.5,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Data transforms: ToTensor ‚Üí Resize ‚Üí Normalize to [-1, 1]
transform = transforms.Compose([
    transforms.Resize((CONFIG["image_size"], CONFIG["image_size"])),
    transforms.ToTensor(),                      # [0,1]
    transforms.Normalize(mean=[0.5], std=[0.5])  # => [-1,1]
])
import swanlab

# Initialize experiment
swanlab.init(
    project="mnist-diffusion",
    experiment_name="ddpm-mnist-32x32-cfg7.5",
    config=CONFIG  # logs all hyperparameters
)

W1126 21:47:57.360000 24936 site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  torch.utils._pytree._register_pytree_node(


Output()

Output()

<swanlab.data.run.main.SwanLabRun at 0x21358f66190>

Load Dataset

In [2]:
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size_train"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size_eval"], shuffle=False)

Workflow Setup

In [3]:
class ClassConditionedUNet(nn.Module):
    def __init__(self, num_classes=10, image_size=32):
        super().__init__()
        self.label_dropout_prob = 0.1
        self.null_class_id = num_classes
        self.unet = UNet2DModel(
            sample_size=image_size,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(128, 256, 256),
            down_block_types=(
                "DownBlock2D", 
                "AttnDownBlock2D",
                "DownBlock2D"
            ),
            up_block_types=(
                "UpBlock2D", 
                "AttnUpBlock2D",
                "UpBlock2D"
            )
            ,
            class_embed_type="timestep",
            num_class_embeds=num_classes + 1,
        )
    
    def forward(self, x, t, class_labels=None):
        if class_labels is None:
            # In eval, caller should provide labels (even for unconditional)
            raise ValueError("class_labels must be provided.")
        
        if self.training:
            # Apply label dropout: replace some labels with null_class_id
            drop_mask = torch.rand(class_labels.shape[0], device=x.device) < self.label_dropout_prob
            class_labels = class_labels.clone()
            class_labels[drop_mask] = self.null_class_id  # e.g., 10
        
        return self.unet(x, t, class_labels=class_labels).sample

model = ClassConditionedUNet(num_classes=10, image_size=CONFIG["image_size"]).to(CONFIG["device"])

# Noise scheduler (DDPM)
noise_scheduler = DDPMScheduler(
    num_train_timesteps=CONFIG["timesteps"],
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear"
)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=CONFIG["lr"])

Training

In [4]:
def train_epoch(model, dataloader, optimizer, noise_scheduler, device,total_batch):
    model.train()
    total_loss = 0.0
    progress = tqdm(dataloader, desc="Training")
    
    for batch_idx, (images, labels) in enumerate(progress):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = images.shape[0]
        
        # Sample random timesteps
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, 
            (batch_size,), device=device
        ).long()

        # Add noise
        noise = torch.randn_like(images)
        noisy_images = noise_scheduler.add_noise(images, noise, timesteps)

        # Forward pass
        optimizer.zero_grad()
        noise_pred = model(noisy_images, timesteps, class_labels=labels)

        # MSE loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress.set_postfix({"loss": loss.item()})
        swanlab.log({
        "train_loss_batch": loss,
        "batch": total_batch
        })
        total_batch+=1
    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [5]:
print("Starting Training...\n")
total_batch=0
for epoch in range(1, CONFIG["num_epochs"] + 1):
    print(f"Epoch {epoch}/{CONFIG['num_epochs']}")
    train_loss = train_epoch(model, train_loader, optimizer, noise_scheduler, CONFIG["device"],total_batch)
    swanlab.log({
        "train_loss_epoch": train_loss,
        "epoch": epoch
    })
    print(f"Epoch {epoch} finished. Avg Loss: {train_loss:.5f}")
'''torch.save(model.state_dict(), "mnist_diffusion_model.pth")'''
print("Model saved as 'mnist_diffusion_model.pth'")

Starting Training...

Epoch 1/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1875/1875 [02:23<00:00, 13.02it/s, loss=0.0263] 


Epoch 1 finished. Avg Loss: 0.02817
Epoch 2/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1875/1875 [02:23<00:00, 13.03it/s, loss=0.0186] 


Epoch 2 finished. Avg Loss: 0.01825
Epoch 3/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1875/1875 [02:21<00:00, 13.24it/s, loss=0.0166] 


Epoch 3 finished. Avg Loss: 0.01675
Epoch 4/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1875/1875 [02:20<00:00, 13.39it/s, loss=0.0157] 


Epoch 4 finished. Avg Loss: 0.01622
Epoch 5/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1875/1875 [02:21<00:00, 13.26it/s, loss=0.0185] 

Epoch 5 finished. Avg Loss: 0.01564
Model saved as 'mnist_diffusion_model.pth'



