<a href="https://colab.research.google.com/github/abrham17/Diffusion_model-UNet-implementation/blob/main/Diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils as vutils
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

In [2]:
# Noise schedule
def get_noise_schedule(T=1000, beta_start=0.0001, beta_end=0.02):
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return betas, alphas, alphas_cumprod

# Forward diffusion
def forward_diffusion(x_0, t, alphas_cumprod, device):
    noise = torch.randn_like(x_0).to(device)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
    x_t = sqrt_alphas_cumprod * x_0 + sqrt_one_minus_alphas_cumprod * noise
    return x_t, noise

In [3]:
# Sinusoidal embedding
def get_sinusoidal_embedding(t, time_dim, device):
    half_dim = time_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    # Corrected line: Multiply t (batch_size, 1) with emb expanded to (1, half_dim)
    emb = t * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb

# Visualization
def visualize_samples(samples, step, epoch):
    samples = (samples.clamp(-1, 1) + 1) / 2
    grid = vutils.make_grid(samples, nrow=8, normalize=True)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f"Epoch {epoch}, Step {step}")
    plt.axis('off')
    plt.savefig(f"epoch_{epoch}_step_{step}.png")
    plt.close()

In [4]:
# Sampling
def sample(model, n_samples, in_channels, height, width, num_classes, T, betas, alphas, alphas_cumprod, device):
    model.eval()
    with torch.no_grad():
        x = torch.randn(n_samples, in_channels, height, width).to(device)
        c = torch.randint(0, num_classes, (n_samples,)).to(device)
        for t in reversed(range(T)):
            t_tensor = torch.full((n_samples,), t, dtype=torch.float, device=device)
            predicted_noise = model(x, t_tensor, c)
            alpha_t = alphas[t].view(-1, 1, 1, 1)
            beta_t = betas[t].view(-1, 1, 1, 1)
            alpha_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)
            x = (1 / torch.sqrt(alpha_t)) * (x - (beta_t / torch.sqrt(1 - alpha_cumprod_t)) * predicted_noise)
            if t > 0:
                x += torch.sqrt(beta_t) * torch.randn_like(x)
        return x, c


In [5]:
class UNet(nn.Module):
      """
      UNet architecture for conditional Denoising Diffusion Probabilistic Models (DDPMs).
      args:
          in_channels (int): Number of input channels.
          out_channels (int): Number of output channels.
          time_dim (int): Dimension of the time embedding.
          num_classes (int): Number of conditional generation.
      returns:
          out (Tensor): Predicted noise tensor for the reverse diffusion process.
      process:
          1. Down sampling: from in_channels to 64, 128, 256.
          2. Bottleneck: from 256 to 256.
          3. Up sampling: from 256 to 128, 64, out_channels.
      """
      def __init__(self, in_channels=1, out_channels=1, time_dim=128, num_classes=None):
          super(UNet, self).__init__()
          self.down1 = nn.Sequential(
              nn.Conv2d(in_channels, 64, 3, padding=1),
              nn.BatchNorm2d(64),
              nn.ReLU()
          )
          self.down2 = nn.Sequential(
              nn.Conv2d(64, 128, 3, stride=2, padding=1),
              nn.BatchNorm2d(128),
              nn.ReLU()
          )
          self.bottleneck = nn.Sequential(
              nn.Conv2d(128, 256, 3, padding=1),
              nn.BatchNorm2d(256),
              nn.ReLU()
          )
          self.up1 = nn.Sequential(
              nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
              nn.BatchNorm2d(128),
              nn.ReLU()
          )
          self.up2 = nn.Sequential(
              nn.Conv2d(128 + 64, 64, 3, padding=1),
              nn.BatchNorm2d(64),
              nn.ReLU()
          )
          self.out = nn.Conv2d(64, out_channels, 3, padding=1)
          self.time_dim = time_dim
          self.class_emb = nn.Embedding(num_classes, time_dim) if num_classes else None
          self.emb_projection = nn.Conv2d(time_dim, 64, 1)

      def forward(self, x, t, c=None):
          t_emb = get_sinusoidal_embedding(t, self.time_dim, x.device)
          c_emb = self.class_emb(c) if c is not None else torch.zeros_like(t_emb)
          emb = t_emb + c_emb
          emb = emb.view(-1, self.time_dim, 1, 1)
          emb = self.emb_projection(emb)
          d1 = self.down1(x) + emb
          d2 = self.down2(d1)
          b = self.bottleneck(d2)
          u1 = self.up1(b)
          u2 = self.up2(torch.cat([u1, d1], dim=1))
          out = self.out(u2)
          return out

In [None]:
# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = MNIST(root='./data', train=True, download=True, transform=all_transforms)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
model = UNet(in_channels=1, out_channels=1, time_dim=128, num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

T = 1000
betas, alphas, alphas_cumprod = get_noise_schedule(T)
betas, alphas, alphas_cumprod = betas.to(device), alphas.to(device), alphas_cumprod.to(device)

model.train()
for epoch in range(10):
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        t = torch.randint(0, T, (x.shape[0],), device=device).float()
        # Reshape t to (batch_size, 1)
        t = t.view(-1, 1)
        x_t, noise = forward_diffusion(x, t.long(), alphas_cumprod, device)
        predicted_noise = model(x_t, t, y)
        loss = criterion(predicted_noise, noise)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch: {epoch+1}, Loss: {total_loss / len(train_loader)}")
    # Visualize samples every epoch
    samples, classes = sample(model, 64, 1, 28, 28, 10, T, betas, alphas, alphas_cumprod, device)
    visualize_samples(samples, 0, epoch+1)

100%|██████████| 9.91M/9.91M [00:00<00:00, 40.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.11MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.79MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.1MB/s]


Epoch: 1, Loss: 0.06060605171495981
Epoch: 2, Loss: 0.054923035764395556
