In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from ddpm.ddpm_mnist.data import CustomMnistDataset
from ddpm.ddpm_mnist.config import CONFIG
from ddpm.unet_utils import Unet
from ddpm.process import ForwardProcess

In [None]:
print("Loading MNIST dataset...")
train_csv = CONFIG.data_dir / "train.csv"
mnist_ds = CustomMnistDataset(str(train_csv))
mnist_dl = DataLoader(mnist_ds, batch_size=128, shuffle=False)

print(f"Dataset size: {len(mnist_ds)} images")
print(f"Number of batches: {len(mnist_dl)}")

In [None]:
print("\n" + "="*50)
print("Sample Images from Dataset")
print("="*50 + "\n")

# Display 16 sample images in a 4x4 grid
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
axes = axes.flatten()

for i in range(16):
    img = mnist_ds[i]
    # Convert from [-1, 1] to [0, 255] for display
    img_display = ((img.squeeze() + 1) / 2 * 255).numpy().astype(np.uint8)
    
    axes[i].imshow(img_display, cmap='gray', vmin=0, vmax=255)
    axes[i].axis('off')
    axes[i].set_title(f"Sample {i+1}", fontsize=10)

plt.suptitle("MNIST Dataset Samples", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Images are preprocessed to range [-1, 1], displayed as [0, 255]")

In [None]:
print("\n" + "="*50)
print("Forward Process Visualization")
print("="*50 + "\n")

sample_img = mnist_ds[0].unsqueeze(0).to(device)

# Timesteps to visualize: clean image (0), then 1, 2, 5, 10, 25, 100, 500, 999
timesteps = [0, 1, 2, 5, 10, 25, 100, 500, 999]

fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

for i, t_val in enumerate(timesteps):
    if t_val == 0:
        # Show clean image
        img_display = ((sample_img[0, 0].cpu() + 1) / 2 * 255).numpy().astype(np.uint8)
        axes[i].imshow(img_display, cmap="gray", vmin=0, vmax=255)
        axes[i].set_title(f"Clean Image (t=0)", fontsize=12, fontweight='bold')
    else:
        # Apply forward diffusion
        noise = torch.randn_like(sample_img)
        t_tensor = torch.tensor([t_val], device=device)
        noisy_img = fp.add_noise(sample_img, noise, t_tensor)
        
        # Convert to [0, 255] for display
        img_display = ((noisy_img[0, 0].cpu() + 1) / 2 * 255).numpy().astype(np.uint8)
        axes[i].imshow(img_display, cmap="gray", vmin=0, vmax=255)
        axes[i].set_title(f"After t={t_val} steps", fontsize=12)
    
    axes[i].axis("off")

plt.suptitle("Forward Diffusion Process: Progressive Noise Addition", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"Visualization shows how noise is gradually added over {CONFIG.num_timesteps} timesteps")

In [None]:
print("Calculating mean and std...\n")

all_data = []

for batch in tqdm(mnist_dl, desc="Loading batches"):
    all_data.append(batch)

all_data = torch.cat(all_data, dim=0)
print(f"Total data shape: {all_data.shape}\n")

mean = torch.mean(all_data).item()
std = torch.std(all_data).item()

print(f"{'='*50}")
print(f"MNIST Statistics (after standardization to [-1, 1]):")
print(f"{'='*50}")
print(f"Mean: {mean:.6f}")
print(f"Std:  {std:.6f}")
print(f"{'='*50}")

print(f"\nData min value: {all_data.min().item():.4f}")
print(f"Data max value: {all_data.max().item():.4f}")

del all_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

model = Unet().to(device)
print("Created untrained UNet with random weights")

num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}\n")

fp = ForwardProcess()

fp.betas = fp.betas.to(device)
fp.sqrt_betas = fp.sqrt_betas.to(device)
fp.alphas = fp.alphas.to(device)
fp.sqrt_alphas = fp.sqrt_alphas.to(device)
fp.alpha_bars = fp.alpha_bars.to(device)
fp.sqrt_alpha_bars = fp.sqrt_alpha_bars.to(device)
fp.sqrt_one_minus_alpha_bars = fp.sqrt_one_minus_alpha_bars.to(device)

criterion = torch.nn.MSELoss()

In [None]:
imgs = next(iter(mnist_dl)).to(device)
print(f"Batch shape: {imgs.shape}")
print(f"Batch size: {imgs.shape[0]}\n")

noise = torch.randn_like(imgs).to(device)
t = torch.randint(0, CONFIG.num_timesteps, (imgs.shape[0],)).to(device)
noisy_imgs = fp.add_noise(imgs, noise, t)

model.eval()
with torch.no_grad():
    noise_pred = model(noisy_imgs, t)
    loss = criterion(noise_pred, noise)

print(f"{'='*50}")
print(f"Initial Loss (Untrained UNet):")
print(f"{'='*50}")
print(f"MSE Loss: {loss.item():.6f}")
print(f"{'='*50}")