This notebook uses Diffusion  model trained on MNIST dataset to generate the images.

In [1]:
#Libraries importing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm


In [2]:
#implement a unet problem
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.middle(x1)
        x3 = self.decoder(x2)
        return x3


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters
epochs = 100  # Reduced for demonstration
batch_size = 64
learning_rate = 1e-3
timesteps = 1000

# Model, Loss Function, Optimizer
model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='.', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [5]:
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(timesteps)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

def forward_diffusion_sample(x_0, t, device):
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = torch.sqrt(alphas_cumprod[t.cpu()]).view(-1, 1, 1, 1).to(device)
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - alphas_cumprod[t.cpu()]).view(-1, 1, 1, 1).to(device)
    return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise


In [6]:
def train(model, train_loader, optimizer, criterion, epochs, timesteps):
    model.train()
    for epoch in range(epochs):
        loop = tqdm(train_loader, leave=False)
        for batch_idx, (data, _) in enumerate(loop):
            data = data.to(device)
            optimizer.zero_grad()

            # Sample a random timestep for each image
            t = torch.randint(0, timesteps, (data.size(0),), device=device).long()
            x_noisy, noise = forward_diffusion_sample(data, t, device)

            # Forward pass
            output = model(x_noisy)
            loss = criterion(output, noise)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update progress bar
            loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
            loop.set_postfix(loss=loss.item())

        # Save model checkpoint
        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch+1}.pth')
        # Save example generated images
        save_image(x_noisy[:25], f'noisy_epoch_{epoch+1}.png', nrow=5, normalize=True)

train(model, train_loader, optimizer, criterion, epochs, timesteps)




In [7]:
def denoise_step(model, x, t, betas, device):
    t = torch.tensor([t]).to(device)
    sqrt_recip_alphas_t = torch.sqrt(1 / alphas[t.cpu()]).to(device)
    sqrt_recipm1_alphas_t = torch.sqrt(1 / alphas[t.cpu()] - 1).to(device)
    pred_noise = model(x)
    x = sqrt_recip_alphas_t * (x - sqrt_recipm1_alphas_t * pred_noise)
    return x

def generate_images(model, num_images, steps, device):
    model.eval()
    with torch.no_grad():
        for i in range(num_images):
            x = torch.randn(1, 1, 28, 28).to(device)
            for step in range(steps, 0, -1):
                x = denoise_step(model, x, step-1, betas, device)
            save_image(x, f'generated_image_{i+1}.png', normalize=True)

# Load the trained model
model.load_state_dict(torch.load('checkpoint_epoch_100.pth'))
model = model.to(device)

# Generate new images
generate_images(model, num_images=10, steps=timesteps, device=device)
