# DDPM (Denoising Diffusion probabilistic Model)

https://arxiv.org/abs/2006.11239

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


## Prepare Dataset

In [None]:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=128, shuffle=True)


## Forward Process

In [None]:

def forward_process(x_0, t, beta):
    """
    Simulate the forward process of the diffusion model.
    """
    noise = torch.randn_like(x_0)
    alpha = 1 - beta
    alpha_cumprod = alpha.cumprod(dim=0)
    x_t = (alpha_cumprod[t]**0.5) * x_0 + (1 - alpha_cumprod[t])**0.5 * noise
    return x_t, noise

T = 1000  # Total time steps
beta = torch.linspace(1e-4, 0.02, T)


## Reverse Process Model

In [None]:

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, 3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = UNet()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


## Training Loop

In [None]:

def loss_fn(pred, target):
    return ((pred - target)**2).mean()

num_epochs = 10

for epoch in range(num_epochs):
    for x_batch, _ in dataloader:
        t = torch.randint(0, T, (x_batch.size(0),), device=x_batch.device).long()
        x_batch = x_batch.to(torch.float32).to(model.device)
        x_t, noise = forward_process(x_batch, t, beta)
        optimizer.zero_grad()
        noise_pred = model(x_t)
        loss = loss_fn(noise_pred, noise)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
