In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Configurations
batch_size = 1024
image_size = 28
timesteps = 1000  # Total diffusion steps

# Helper function to compute the linear noise schedule
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# Define the noise schedule
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

# Noise schedule parameters
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu()).reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
    return out

# Sampling from a normal distribution for generation
def sample_from_model(model, n_samples, image_size, device):
    model.eval()
    x = torch.randn(n_samples, 1, image_size, image_size).to(device)  # Initialize with random noise
    for t in range(timesteps - 1, -1, -1):
        t_tensor = torch.full((n_samples,), t, dtype=torch.long, device=device)
        predicted_noise = model(x, t_tensor)

        beta_t = extract(betas, t_tensor, x.shape)
        alpha_t = extract(alphas, t_tensor, x.shape)
        alpha_t_cumprod = extract(alphas_cumprod, t_tensor, x.shape)

        # If t > 0, calculate the previous alpha cumprod (alpha_{t-1})
        if t > 0:
            alpha_t_cumprod_prev = extract(alphas_cumprod, t_tensor - 1, x.shape)
            mean = (1 / torch.sqrt(alpha_t)) * (x - beta_t / torch.sqrt(1 - alpha_t_cumprod) * predicted_noise)
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(beta_t) * noise
        else:
            # At t = 0, there's no need for alpha_{t-1}
            mean = (1 / torch.sqrt(alpha_t)) * (x - beta_t / torch.sqrt(1 - alpha_t_cumprod) * predicted_noise)
            x = mean
    return x

# U-Net architecture for noise prediction
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
        )

    def forward(self, x, t):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

# Loss function (simple L2 loss between predicted and actual noise)
def diffusion_loss(model, x, t, noise):
    predicted_noise = model(x, t)
    return F.mse_loss(predicted_noise, noise)

# Data Loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory = True)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else device)

model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

n_epochs = 100

In [None]:
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    for step, (images, _) in enumerate(train_loader):
        images = images.to(device)
        batch_size = images.shape[0]

        # Sample random timesteps for each image in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

        # Sample random noise and corrupt the image
        noise = torch.randn_like(images)
        alpha_t = extract(alphas_cumprod, t, images.shape)
        noisy_images = torch.sqrt(alpha_t) * images + torch.sqrt(1 - alpha_t) * noise

        # Compute loss
        loss = diffusion_loss(model, noisy_images, t, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(train_loader)}')

In [None]:
# Generate some samples after training
n_samples = 16
samples = sample_from_model(model, n_samples, image_size, device).detach().cpu().numpy()

# Plot the generated samples
fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(samples[i].squeeze(), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()