In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

from ddpm.data import CustomMnistDataset
from ddpm.config import CONFIG
from ddpm.unet_utils import Unet
from ddpm.process import ForwardProcess

In [None]:
print("Loading MNIST dataset...")
train_csv = CONFIG.data_dir / "train.csv"
mnist_ds = CustomMnistDataset(str(train_csv))
mnist_dl = DataLoader(mnist_ds, batch_size=128, shuffle=False)

print(f"Dataset size: {len(mnist_ds)} images")
print(f"Number of batches: {len(mnist_dl)}")

In [None]:
print("Calculating mean and std...\n")

all_data = []

for batch in tqdm(mnist_dl, desc="Loading batches"):
    all_data.append(batch)

all_data = torch.cat(all_data, dim=0)
print(f"Total data shape: {all_data.shape}\n")

mean = torch.mean(all_data).item()
std = torch.std(all_data).item()

print(f"{'='*50}")
print(f"MNIST Statistics (after standardization to [-1, 1]):")
print(f"{'='*50}")
print(f"Mean: {mean:.6f}")
print(f"Std:  {std:.6f}")
print(f"{'='*50}")

print(f"\nData min value: {all_data.min().item():.4f}")
print(f"Data max value: {all_data.max().item():.4f}")

del all_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

model = Unet().to(device)
print("Created untrained UNet with random weights")

num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}\n")

fp = ForwardProcess()

fp.betas = fp.betas.to(device)
fp.sqrt_betas = fp.sqrt_betas.to(device)
fp.alphas = fp.alphas.to(device)
fp.sqrt_alphas = fp.sqrt_alphas.to(device)
fp.alpha_bars = fp.alpha_bars.to(device)
fp.sqrt_alpha_bars = fp.sqrt_alpha_bars.to(device)
fp.sqrt_one_minus_alpha_bars = fp.sqrt_one_minus_alpha_bars.to(device)

criterion = torch.nn.MSELoss()

In [None]:
imgs = next(iter(mnist_dl)).to(device)
print(f"Batch shape: {imgs.shape}")
print(f"Batch size: {imgs.shape[0]}\n")

noise = torch.randn_like(imgs).to(device)
t = torch.randint(0, CONFIG.num_timesteps, (imgs.shape[0],)).to(device)
noisy_imgs = fp.add_noise(imgs, noise, t)

model.eval()
with torch.no_grad():
    noise_pred = model(noisy_imgs, t)
    loss = criterion(noise_pred, noise)

print(f"{'='*50}")
print(f"Initial Loss (Untrained UNet):")
print(f"{'='*50}")
print(f"MSE Loss: {loss.item():.6f}")
print(f"{'='*50}")

In [None]:
print("\n" + "="*50)
print("Setting up training for gradient flow check")
print("="*50 + "\n")

mnist_dl_train = DataLoader(mnist_ds, batch_size=128, shuffle=True)

model_train = Unet().to(device)
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-4)
model_train.train()

In [None]:
import matplotlib.pyplot as plt

print(f"Training for 3 epochs to monitor gradients\n")

# Store stats across all epochs
all_epoch_losses = []
all_mean_abs_grads = []
all_max_abs_grads = []
all_grad_norms = []
all_zero_grad_pcts = []

for epoch in range(3):
    epoch_losses = []
    epoch_grad_stats = {
        "mean_abs_grad": [],
        "max_abs_grad": [],
        "grad_norm": [],
        "zero_grad_percentage": []
    }
    
    for imgs in tqdm(mnist_dl_train, desc=f"Epoch {epoch+1}/3"):
        imgs = imgs.to(device)
        noise = torch.randn_like(imgs).to(device)
        t = torch.randint(0, CONFIG.num_timesteps, (imgs.shape[0],)).to(device)
        noisy_imgs = fp.add_noise(imgs, noise, t)
        
        optimizer.zero_grad()
        noise_pred = model_train(noisy_imgs, t)
        loss = criterion(noise_pred, noise)
        loss.backward()
        
        all_grads = []
        for param in model_train.parameters():
            if param.grad is not None:
                all_grads.append(param.grad.flatten())
        
        if len(all_grads) > 0:
            all_grads = torch.cat(all_grads)
            
            mean_abs_grad = torch.mean(torch.abs(all_grads)).item()
            max_abs_grad = torch.max(torch.abs(all_grads)).item()
            grad_norm = torch.norm(all_grads).item()
            zero_grad_pct = (torch.sum(torch.abs(all_grads) < 1e-7).item() / all_grads.numel()) * 100
            
            epoch_grad_stats["mean_abs_grad"].append(mean_abs_grad)
            epoch_grad_stats["max_abs_grad"].append(max_abs_grad)
            epoch_grad_stats["grad_norm"].append(grad_norm)
            epoch_grad_stats["zero_grad_percentage"].append(zero_grad_pct)
        
        optimizer.step()
        epoch_losses.append(loss.item())
    
    # Store epoch averages
    all_epoch_losses.append(np.mean(epoch_losses))
    all_mean_abs_grads.append(np.mean(epoch_grad_stats["mean_abs_grad"]))
    all_max_abs_grads.append(np.mean(epoch_grad_stats["max_abs_grad"]))
    all_grad_norms.append(np.mean(epoch_grad_stats["grad_norm"]))
    all_zero_grad_pcts.append(np.mean(epoch_grad_stats["zero_grad_percentage"]))
    
    print(f"\nEpoch {epoch+1}: Loss={np.mean(epoch_losses):.6f}")

# Plot gradient statistics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

epochs = range(1, 4)

# Loss plot
axes[0, 0].plot(epochs, all_epoch_losses, 'b-o', linewidth=2, markersize=8)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)

# Mean absolute gradient
axes[0, 1].plot(epochs, all_mean_abs_grads, 'g-o', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Mean |grad|', fontsize=12)
axes[0, 1].set_title('Mean Absolute Gradient', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].ticklabel_format(style='scientific', axis='y', scilimits=(0,0))

# Gradient norm
axes[1, 0].plot(epochs, all_grad_norms, 'r-o', linewidth=2, markersize=8)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Gradient Norm', fontsize=12)
axes[1, 0].set_title('Gradient Norm', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].ticklabel_format(style='scientific', axis='y', scilimits=(0,0))

# Zero gradient percentage
axes[1, 1].plot(epochs, all_zero_grad_pcts, 'm-o', linewidth=2, markersize=8)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Zero Grad %', fontsize=12)
axes[1, 1].set_title('Zero Gradient Percentage', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n{'='*50}")
print(f"Training Summary")
print(f"{'='*50}")
print(f"Final Loss: {all_epoch_losses[-1]:.6f}")
print(f"Final Mean |grad|: {all_mean_abs_grads[-1]:.6e}")
print(f"Final Grad Norm: {all_grad_norms[-1]:.6e}")
print(f"{'='*50}")

In [None]:
import matplotlib.pyplot as plt

print("\n" + "="*50)
print("Forward Process Visualization")
print("="*50 + "\n")

sample_img = mnist_ds[0].unsqueeze(0).to(device)

timesteps = [1, 2, 5, 10, 20, 50, 100, 500, 999]

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

axes[0].imshow(sample_img[0, 0].cpu().numpy(), cmap="gray", vmin=-1, vmax=1)
axes[0].set_title("Original (t=0)")
axes[0].axis("off")

for i, t_val in enumerate(timesteps):
    noise = torch.randn_like(sample_img)
    t_tensor = torch.tensor([t_val], device=device)
    noisy_img = fp.add_noise(sample_img, noise, t_tensor)
    
    axes[i+1].imshow(noisy_img[0, 0].cpu().numpy(), cmap="gray", vmin=-1, vmax=1)
    axes[i+1].set_title(f"t={t_val}")
    axes[i+1].axis("off")

plt.tight_layout()
plt.show()