In [None]:
import os
import math

import tqdm
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {torch.cuda.get_device_name(device)}")

In [None]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
print(f"Total training images: {len(train_dataset)}")

In [None]:
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers = 8, pin_memory=True)
print(f"Total test images: {len(test_dataset)}")

In [None]:
# Training a digit classifier before we train the diffusion model 
class DigitClassifier(nn.Module):
    def __init__(self,):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
classifier = DigitClassifier().to(device)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)
classifier_loss = nn.CrossEntropyLoss()

In [None]:
scaler = GradScaler()
epochs = 100  # Enough for MNIST classification, ~10-15 minutes
checkpoint_path_classifier = "checkpoint_classifier.pth"

if os.path.exists(checkpoint_path_classifier):
    try:
        checkpoint = torch.load(checkpoint_path_classifier)
        classifier.load_state_dict(checkpoint['classifier_state_dict'])
        classifier_optimizer.load_state_dict(checkpoint['classifier_optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        losses_per_epoch = checkpoint['losses_per_epoch']
        accuracies_per_epoch = checkpoint['accuracies_per_epoch']
        print(f"Resumed classifier training from epoch {start_epoch}")
    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting fresh.")
        start_epoch = 0
        losses_per_epoch = []
        accuracies_per_epoch = []
else:
    start_epoch = 0
    losses_per_epoch = []
    accuracies_per_epoch = []

print("Starting classifier training...")
for epoch in range(start_epoch, epochs):
    classifier.train()
    pbar = tqdm.tqdm(train_dataloader)
    total_loss, correct_preds, total_preds = 0, 0, 0

    for step, (x, labels) in enumerate(pbar):
        try:
            x = x.to(device)  # (batch, 1, 28, 28)
            labels = labels.to(device)  # (batch), indices 0-9
            batch_size = x.shape[0]

            with autocast():
                flat_x = x.view(batch_size, -1)  # Flatten to (batch, 784)
                pred_labels = classifier(flat_x)  # Should be (batch, 10)
                loss = classifier_loss(pred_labels, labels)  # Ensure labels are indices

            classifier_optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(classifier_optimizer)
            scaler.update()

            _, predicted = torch.max(pred_labels, 1)  # Get predicted class indices
            correct_preds += (predicted == labels).sum().item()
            total_preds += batch_size
            total_loss += loss.item()
            accuracy = correct_preds / total_preds * 100
            pbar.set_description(f"Classifier Loss: {loss.item():.4f} | Acc: {accuracy:.2f}%")
        except RuntimeError as e:
            print(f"Runtime error at step {step}: {e}. Skipping batch. Check dimensions or reduce batch size to 32.")
            break
        except ValueError as e:
            print(f"Value error at step {step}: {e}. Likely dimension mismatch. pred_labels shape: {pred_labels.shape}, labels shape: {labels.shape}")
            break

    if total_preds > 0:
        avg_loss = total_loss / (step + 1)  # Adjust for possible early break
        accuracy = correct_preds / total_preds * 100
        losses_per_epoch.append(avg_loss)
        accuracies_per_epoch.append(accuracy)
        print(f"Epoch {epoch+1}/{epochs} | Average Loss: {avg_loss:.6f} | Accuracy: {accuracy:.2f}%")
    else:
        print(f"Epoch {epoch+1}/{epochs} | No valid batches processed. Check data loading.")
        losses_per_epoch.append(float('nan'))
        accuracies_per_epoch.append(0.0)

    torch.save({
        'epoch': epoch,
        'classifier_state_dict': classifier.state_dict(),
        'classifier_optimizer_state_dict': classifier_optimizer.state_dict(),
        'losses_per_epoch': losses_per_epoch,
        'accuracies_per_epoch': accuracies_per_epoch,
    }, checkpoint_path_classifier)
    print(f"Classifier checkpoint saved at epoch {epoch+1}")

    # Plotting
    clear_output(wait=True)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epoch + 2), losses_per_epoch, label='Loss', color='blue')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Classifier Loss vs Epoch")
    plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epoch + 2), accuracies_per_epoch, label='Accuracy (%)', color='green')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Classifier Accuracy vs Epoch")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
# Test DigitClassifier
classifier.eval()

# Get a batch for testing
dataiter = iter(test_dataloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
batch_size = images.shape[0]

with torch.no_grad():
    flat_images = images.view(batch_size, -1)  # Flatten to (batch, 784)
    outputs = classifier(flat_images)  # (batch, 10)
    _, predicted = torch.max(outputs, 1)  # Get predicted class indices

# Calculate accuracy
correct = (predicted == labels).sum().item()
accuracy = correct / batch_size * 100
print(f"Test Accuracy on batch: {accuracy:.2f}%")
print(f"Correct predictions: {correct} out of {batch_size}")

# Visualize some predictions
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 5, figsize=(12, 6))
axes = axes.ravel()
for idx in range(10):
    ax = axes[idx]
    img = images[idx].cpu().numpy().squeeze()  # (1, 28, 28) -> (28, 28)
    ax.imshow(img, cmap='gray')
    ax.set_title(f"True: {labels[idx].item()}\nPred: {predicted[idx].item()}")
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Diffusion schedules
T = 1000
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod).to(device)

In [None]:
# Forward diffusion function
def forward_diffusion_sample(x_0, t, device=device):
    noise = torch.randn_like(x_0).to(device)
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t][:, None, None, None]
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    return x_t, noise

In [None]:
# Guided UNet model with label conditioning
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if emb.shape[1] < self.dim:
            emb = F.pad(emb, (0, self.dim - emb.shape[1]))
        return emb

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    def forward(self, x):
        identity = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x += self.shortcut(identity)
        return F.relu(x)

class UNet(nn.Module):
    def __init__(self, time_dim=256, label_dim=10):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim)
        )
        self.label_mlp = nn.Sequential(
            nn.Linear(label_dim, time_dim),  # Map 10-class one-hot to time_dim
            nn.ReLU()
        )
        self.enc1 = ResidualBlock(1, 64)
        self.enc2 = ResidualBlock(64, 128)
        self.enc3 = ResidualBlock(128, 256)
        self.time_proj1 = nn.Linear(time_dim, 64)
        self.time_proj2 = nn.Linear(time_dim, 128)
        self.time_proj3 = nn.Linear(time_dim, 256)
        self.bottleneck = ResidualBlock(256, 256)
        self.dec1 = ResidualBlock(512, 128)  # 256 + 256 from skip
        self.dec2 = ResidualBlock(256, 64)   # 128 + 128 from skip
        self.dec3 = nn.Conv2d(128, 1, 3, padding=1)  # 64 + 64 from skip

    def forward(self, x, t, labels):
        t_emb = self.time_mlp(t)
        l_emb = self.label_mlp(labels)  # (batch, 10) -> (batch, time_dim)
        combined_emb = t_emb + l_emb
        x1 = self.enc1(x) + self.time_proj1(combined_emb).unsqueeze(-1).unsqueeze(-1)
        x2 = self.enc2(x1) + self.time_proj2(combined_emb).unsqueeze(-1).unsqueeze(-1)
        x3 = self.enc3(x2) + self.time_proj3(combined_emb).unsqueeze(-1).unsqueeze(-1)
        x = self.bottleneck(x3)
        x = self.dec1(torch.cat([x, x3], dim=1))
        x = self.dec2(torch.cat([x, x2], dim=1))
        x = self.dec3(torch.cat([x, x1], dim=1))
        return x

In [None]:
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.5e-4)
mse_loss = nn.MSELoss()
scaler = GradScaler()

In [None]:
# Using FID as a metric
from torcheval.metrics.image import FrechetInceptionDistance

fid = FrechetInceptionDistance(device=device)

In [None]:
checkpoint_path_diffusion = 'checkpoint_diffusion.pth'


# Load pre-trained classifier
checkpoint_classifier = torch.load("checkpoint_classifier.pth")
classifier = DigitClassifier().to(device)
classifier.load_state_dict(checkpoint_classifier['classifier_state_dict'])
classifier.eval()  # Freeze classifier, no training

if os.path.exists(checkpoint_path_diffusion):
    checkpoint = torch.load(checkpoint_path_diffusion)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    losses_per_epoch = checkpoint['losses_per_epoch']
    fids_per_epoch = checkpoint['fids_per_epoch']
    print(f"Resumed diffusion training from epoch {start_epoch}")
else:
    start_epoch = 0
    losses_per_epoch = []
    fids_per_epoch = []
    accuracies_per_epoch = []

In [None]:
epochs = 200

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

In [None]:
print("Starting diffusion training...")
for epoch in range(start_epoch, epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    pbar = tqdm.tqdm(train_dataloader)
    total_loss = 0
    step_count = 0

    model.train()
    for step, (x, labels) in enumerate(pbar):
        x = x.to(device)  # (batch, 1, 28, 28)
        labels_one_hot = F.one_hot(labels, num_classes=10).float().to(device)
        batch_size = x.shape[0]
        t = torch.randint(0, T, (batch_size,), device=device).long()

        with autocast():
            x_t, noise = forward_diffusion_sample(x, t)
            if x_t.shape != noise.shape or x_t.shape[1:] != (1, 28, 28):
                raise ValueError(f"Shape mismatch: x_t {x_t.shape}, noise {noise.shape}")

            noise_pred = model(x_t, t.float(), labels_one_hot)
            if noise_pred.shape != noise.shape:
                raise ValueError(f"Shape mismatch: noise_pred {noise_pred.shape}, noise {noise.shape}")
            
            loss = mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        alpha_bar = sqrt_alphas_cumprod[t][:, None, None, None]
        one_minus_alpha_bar = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        x0_pred = (x_t - one_minus_alpha_bar * noise_pred) / alpha_bar
        x0_pred = torch.clamp(x0_pred, -1, 1)

        total_loss += loss.item()
        step_count += 1
        pbar.set_description(f"Loss: {loss.item():.4f}")

    if step_count > 0:
        avg_loss = total_loss / step_count
        losses_per_epoch.append(avg_loss)
        print(f"Average Loss: {avg_loss:.6f}")
    else:
        print(f"Epoch {epoch+1}/{epochs} | No valid batches processed. Check data or reduce batch size.")
        losses_per_epoch.append(float('nan'))

    # Compute FID every 5 epochs or at the end
    if epoch % 5 == 0 or epoch == epochs - 1:
        print("Computing FID score...")
        model.eval()
        fid.reset()
        n_samples = 500
        batch_size_fid = 64

        # Match label distribution from real data
        real_labels = []
        for _, labels in train_dataloader:
            real_labels.extend(labels.numpy())
            if len(real_labels) > n_samples:
                break
        real_labels = np.array(real_labels)[:n_samples]
        label_dist = np.bincount(real_labels, minlength=10) / len(real_labels)
        labels_all = np.random.choice(10, n_samples, p=label_dist)
        labels_all = torch.from_numpy(labels_all).to(device)
        labels_one_hot_all = F.one_hot(labels_all, num_classes=10).float()

        x_gen_all = []
        with torch.no_grad(), autocast():
            for i in range(0, n_samples, batch_size_fid):
                batch_labels = labels_one_hot_all[i:i + batch_size_fid]
                batch_size = min(batch_size_fid, n_samples - i)
                x_gen = torch.randn((batch_size, 1, 28, 28)).to(device)
                for t in reversed(range(0, T)):
                    if t % 100 == 0:
                        print(f"  Denoising step {t} for batch {i//batch_size_fid}/{n_samples//batch_size_fid}")
                    t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
                    beta_t = betas[t].to(device)
                    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].to(device)
                    sqrt_recip_alpha_t = (1.0 / torch.sqrt(alphas[t])).to(device)
                    epsilon_theta = model(x_gen, t_batch.float(), batch_labels[:batch_size])
                    model_mean = sqrt_recip_alpha_t * (x_gen - beta_t * epsilon_theta / sqrt_one_minus_alphas_cumprod_t)
                    if t > 1:
                        noise = torch.randn_like(x_gen).to(device)
                        sigma_t = torch.sqrt(beta_t)
                        x_gen = model_mean + sigma_t * noise
                    else:
                        x_gen = model_mean
                x_gen = torch.clamp(x_gen, -1, 1) * 0.5 + 0.5
                x_gen_all.append(x_gen.cpu())

        x_gen = torch.cat(x_gen_all, dim=0)[:n_samples]
        # Use multiple real batches
        real_data_all = []
        for _ in range(n_samples // train_dataloader.batch_size + 1):
            try:
                real_batch, _ = next(iter(train_dataloader))
                real_data_all.append(real_batch.to(device))
            except StopIteration:
                break
        real_data = torch.cat(real_data_all, dim=0)[:n_samples]

        torch.clamp(x_gen, 0, 1, out=x_gen)
        torch.clamp(real_data, 0, 1, out=real_data)

        # Try 1-channel FID (if supported by torcheval, otherwise revert to 3-channel)
        try:
            fid.update(real_data, is_real=True)
            fid.update(x_gen, is_real=False)
        except ValueError:
            # Fallback to 3-channel if 1-channel fails
            x_gen_rgb = x_gen.repeat(1, 3, 1, 1)
            real_data_rgb = real_data.repeat(1, 3, 1, 1)
            fid.update(real_data_rgb, is_real=True)
            fid.update(x_gen_rgb, is_real=False)

        fid_score = fid.compute().item()
        fids_per_epoch.append(fid_score)
        print(f"FID score at epoch {epoch+1}: {fid_score:.2f}")
        # Debug: Save a sample image
        sample_img = x_gen[0].cpu().numpy().squeeze()
        plt.imsave(f'sample_epoch_{epoch+1}.png', sample_img, cmap='gray')
        print(f"Saved sample image at epoch {epoch+1}")
        model.train()

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses_per_epoch': losses_per_epoch,
        'fids_per_epoch': fids_per_epoch,
    }, checkpoint_path_diffusion)
    print(f"Diffusion checkpoint saved at epoch {epoch+1}")

    # Plotting
    clear_output(wait=True)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    x_range = range(1, len(losses_per_epoch) + 1)
    plt.plot(x_range, losses_per_epoch, label='MSE Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss vs Epoch")
    plt.grid(True)

    plt.subplot(1, 2, 2)
    x_range_fid = range(1, len(fids_per_epoch) + 1)
    plt.plot(x_range_fid, fids_per_epoch, label='FID Score', color='green')
    plt.xlabel("Epoch")
    plt.ylabel("FID")
    plt.title("FID vs Epoch")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Optional: Add learning rate scheduling
    if 'scheduler' in locals():
        scheduler.step(avg_loss)
        print(f"Learning rate adjusted to {optimizer.param_groups[0]['lr']:.6f}")