<a href="https://colab.research.google.com/github/MehrdadDastouri/diffusion_model_mnist/blob/main/diffusion_model_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create directory to save generated images
os.makedirs("generated_images", exist_ok=True)

# Hyperparameters
image_size = 28  # MNIST image size
channels = 1  # Grayscale images
batch_size = 128
epochs = 100
lr = 1e-4  # Learning rate
timesteps = 1000  # Number of diffusion steps

# Define the diffusion schedule (beta values)
def linear_beta_schedule(timesteps):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

beta = linear_beta_schedule(timesteps).to(device)
alpha = 1.0 - beta
alpha_hat = torch.cumprod(alpha, dim=0)

# Add noise to images
def add_noise(x, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x)
    sqrt_alpha_hat = torch.sqrt(alpha_hat[t]).view(-1, 1, 1, 1)
    sqrt_one_minus_alpha_hat = torch.sqrt(1 - alpha_hat[t]).view(-1, 1, 1, 1)
    return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise

# Define the U-Net architecture for noise estimation
class UNet(nn.Module):
    def __init__(self, channels):
        super(UNet, self).__init__()
        # Adjust input channels of the first convolution layer to handle time embedding
        self.conv1 = nn.Conv2d(channels + 1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(64, channels, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2)

    def forward(self, x, t):
        # Embed time step and concatenate with input
        t_embed = t.view(-1, 1, 1, 1).repeat(1, x.size(2), x.size(3), 1).permute(0, 3, 1, 2)
        x = torch.cat([x, t_embed], dim=1)

        # Down-sampling
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(self.pool(x1)))
        x3 = F.relu(self.conv3(self.pool(x2)))

        # Up-sampling
        x4 = F.relu(self.conv4(self.up(x3)))
        x5 = F.relu(self.conv5(self.up(x4)))
        x6 = self.conv6(x5)

        return x6

# Initialize the U-Net model
model = UNet(channels=channels).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Training loop
model.train()
for epoch in range(epochs):
    epoch_loss = 0
    for images, _ in train_loader:
        images = images.to(device)

        # Select random time steps
        t = torch.randint(0, timesteps, (images.size(0),), device=device).long()

        # Add noise to images
        noisy_images, noise = add_noise(images, t)

        # Predict noise using the U-Net
        predicted_noise = model(noisy_images, t)

        # Compute loss (MSE between predicted and true noise)
        loss = F.mse_loss(predicted_noise, noise)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(train_loader):.4f}")

    # Generate and save images every 10 epochs
    if (epoch + 1) % 10 == 0:
        model.eval()
        with torch.no_grad():
            x = torch.randn((16, channels, image_size, image_size), device=device)
            for t_step in reversed(range(timesteps)):
                t = torch.full((x.size(0),), t_step, device=device).long()
                predicted_noise = model(x, t)
                x = (1 / torch.sqrt(alpha[t])) * (x - (1 - alpha[t]) / torch.sqrt(1 - alpha_hat[t]) * predicted_noise)
                if t_step > 0:
                    x += torch.sqrt(beta[t]) * torch.randn_like(x)
            x = torch.clamp(x, -1, 1)
            save_image((x + 1) / 2, f"generated_images/epoch_{epoch+1}.png")
        model.train()

# Visualize generated images
generated_image_path = f"generated_images/epoch_{epochs}.png"
if os.path.exists(generated_image_path):
    generated_image = plt.imread(generated_image_path)
    plt.figure(figsize=(8, 8))
    plt.imshow(generated_image)
    plt.axis("off")
    plt.title("Generated Images")
    plt.show()