In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
import math, os

# Hyperparameters
n_steps = 1000        # Number of diffusion timesteps (T)
beta_start = 1e-4
beta_end = 0.02
batch_size = 128
learning_rate = 2e-4
epochs = 50            # Note: more epochs (e.g., 50+) are typically needed for good results

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare Fashion-MNIST data
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create diffusion process constants (beta, alpha, etc.)
betas = torch.linspace(beta_start, beta_end, n_steps).to(device)            # schedule from 1e-4 to 0.02
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0).to(device)                        # cumulative product \bar{\alpha}_t

# Sinusoidal position embedding for timesteps (like in Transformer, or use nn.Embedding)
def sinusoidal_time_embedding(timesteps, embed_dim=256):
    """
    Create sinusoidal timestep embeddings (batch_size x embed_dim) for given timesteps.
    """
    # Timesteps shape: (batch,)
    half_dim = embed_dim // 2
    # Compute sinusoidal frequencies
    freq = torch.exp(-math.log(10000) * torch.arange(0, half_dim, device=device) / half_dim)
    # Outer product: (batch_size, half_dim)
    angles = timesteps[:, None].float() * freq[None, :]
    emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)
    return emb  # shape: (batch, embed_dim)

# Define the U-Net like model for epsilon_theta(x_t, t)
class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.embed_dim = 256
        # Define convolutional layers for downsampling
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)         # 28x28 -> 28x28 (no downsample yet)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 28x28 -> 14x14
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) # 14x14 -> 7x7
        # Bottleneck convolution
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)      # 7x7 -> 7x7
        # Define transposed conv layers for upsampling
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1) # 7x7 -> 14x14
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 14x14 -> 28x28
        # Final output conv (predict noise)
        self.conv_out = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        # Embedding transform layers for time conditioning (to match channels)
        self.time_fc1 = nn.Linear(self.embed_dim, 64)
        self.time_fc2 = nn.Linear(self.embed_dim, 128)
        self.time_fc3 = nn.Linear(self.embed_dim, 256)
        self.time_fc4 = nn.Linear(self.embed_dim, 256)
        self.time_fc5 = nn.Linear(self.embed_dim, 128)
        self.time_fc6 = nn.Linear(self.embed_dim, 64)

    def forward(self, x, t):
        # x shape: [batch, 1, 28, 28], t shape: [batch] (timestep indices)
        # Obtain time embeddings for t
        # We'll use sinusoidal embedding; alternatively, could use nn.Embedding learned embeddings
        time_emb = sinusoidal_time_embedding(t, self.embed_dim)  # shape: (batch, embed_dim)
        # Map time embedding to each required channel dimension
        emb1 = self.time_fc1(time_emb)[:, :, None, None]    # to shape (batch, 64, 1, 1)
        emb2 = self.time_fc2(time_emb)[:, :, None, None]    # to (batch, 128, 1, 1)
        emb3 = self.time_fc3(time_emb)[:, :, None, None]    # to (batch, 256, 1, 1)
        emb4 = self.time_fc4(time_emb)[:, :, None, None]    # to (batch, 256, 1, 1)
        emb5 = self.time_fc5(time_emb)[:, :, None, None]    # to (batch, 128, 1, 1)
        emb6 = self.time_fc6(time_emb)[:, :, None, None]    # to (batch, 64, 1, 1)

        # Downward (encoding) path with time conditioning
        h1 = F.relu(self.conv1(x) + emb1)            # add time embed to first conv output (broadcasted)
        h2 = F.relu(self.conv2(h1) + emb2)           # downsample to 14x14, add time embedding
        h3 = F.relu(self.conv3(h2) + emb3)           # downsample to 7x7, add time embedding
        h4 = F.relu(self.conv4(h3) + emb4)           # bottleneck conv at 7x7, add time embedding
        # Upward (decoding) path with skip connections
        u1 = F.relu(self.deconv1(h4) + emb5)         # upsample to 14x14, add time embedding
        # Skip connection from h2 (14x14 feature map)
        u1 = u1 + h2                                 # (both u1 and h2 have 128 channels)
        u2 = F.relu(self.deconv2(u1) + emb6)         # upsample to 28x28, add time embedding
        # Skip connection from h1 (28x28 feature map)
        u2 = u2 + h1                                 # (both u2 and h1 have 64 channels)
        out = self.conv_out(u2)                      # output noise prediction (no activation here)
        return out

# Initialize model and optimizer
model = DiffusionModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

os.makedirs('ddpm_outputs', exist_ok=True)

# Training loop
for epoch in range(1, 3):#epochs+1):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (real_imgs, _) in enumerate(train_loader):
        # print(batch_idx)
        real_imgs = real_imgs.to(device)
        # Sample random timesteps for each image in the batch
        t = torch.randint(0, n_steps, (real_imgs.size(0),), device=device).long()
        # Sample random noise
        noise = torch.randn_like(real_imgs)
        # Compute x_t from x_0 (real image) and noise: x_t = sqrt(alpha_bar[t])*x0 + sqrt(1-alpha_bar[t])*noise
        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)       # shape (batch,1,1,1)
        noisy_imgs = torch.sqrt(alpha_bar_t) * real_imgs + torch.sqrt(1 - alpha_bar_t) * noise
        # Predict the noise using the model
        pred_noise = model(noisy_imgs, t)
        # Loss: MSE between the predicted noise and true noise
        loss = F.mse_loss(pred_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * real_imgs.size(0)
        if batch_idx % 20 == 0:
            print(f"Train Epoch {epoch} [{batch_idx*len(real_imgs)}/{len(train_loader.dataset)}] Loss: {loss.item()/len(real_imgs):.4f}")
    avg_loss = epoch_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}/{epochs} - Training loss (MSE): {avg_loss:.6f}")
    # (Optional) save model or generate samples at intermediate epochs for monitoring
    # We will generate final samples after training below.
#%%
# Sampling loop (generate new images from the model)
model.eval()
with torch.no_grad():
    num_samples = 16
    # Start from pure noise
    x_t = torch.randn(num_samples, 1, 28, 28, device=device)  # x_T
    for t in range(n_steps-1, -1, -1):  # from T-1 down to 0
        t_batch = torch.tensor([t] * num_samples, device=device)
        # Predict noise at this timestep
        pred_noise = model(x_t, t_batch)
        # Compute parameters for reverse update
        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]
        # Formula for the mean of p(x_{t-1} | x_t):
        # (1/sqrt(alpha_t)) * (x_t - (1 - alpha_t)/sqrt(1 - alpha_bar_t) * pred_noise)
        mean = (1.0 / torch.sqrt(alpha_t)) * (x_t - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * pred_noise)
        if t > 0:
            # Sample z ~ N(0, I)
            z = torch.randn_like(x_t)
            sigma_t = torch.sqrt(betas[t])
            x_t = mean + sigma_t * z  # add noise for stochasticity
        else:
            x_t = mean  # at last step, no noise added
    generated = x_t.cpu()
# Save generated samples as an image grid
utils.save_image(generated, "ddpm_outputs/sample_generation.png", nrow=2)
print("DDPM sampling complete. Generated samples saved to 'ddpm_outputs/sample_generation.png'.")

Train Epoch 1 [0/60000] Loss: 0.0097
Train Epoch 1 [2560/60000] Loss: 0.0047
Train Epoch 1 [5120/60000] Loss: 0.0016


KeyboardInterrupt: 