<a href="https://colab.research.google.com/github/ShreySharma07/DDPM-Denoising-Diffusion-Probabilistic-Model-/blob/main/diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import math

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
image_size = 32 #for cifar
image_channels = 3  # RGB images
timesteps = 1000  # Increased for better quality
batch_size = 256   # Reduced for better training stability
num_epochs = 500
lr = 2e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    """
    Cosine schedule as proposed in Improved DDPMs.
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

In [None]:
betas = cosine_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

In [None]:
def forward_diffusion_sample(x_0, t, device):
    """Add noise to images"""
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod.to(device)[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod.to(device)[t].view(-1, 1, 1, 1)

    return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise


In [None]:
def get_sinusoidal_embeddings(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, num_groups=8):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )

        self.block1 = nn.Sequential(
            nn.GroupNorm(num_groups, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )

        self.block2 = nn.Sequential(
            nn.GroupNorm(num_groups, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.residual_conv = nn.Identity()

    def forward(self, x, time_emb):
      h = self.block1(x)
      time_emb = self.time_mlp(time_emb)
      h += time_emb[:, :, None, None]
      h = self.block2(h)
      return h + self.residual_conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_groups=8):
        super().__init__()
        self.channels = channels
        self.norm = nn.GroupNorm(num_groups, channels)
        self.q = nn.Conv2d(channels, channels, 1)
        self.k = nn.Conv2d(channels, channels, 1)
        self.v = nn.Conv2d(channels, channels, 1)
        self.out = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)

        q = self.q(h).view(B, C, H*W).transpose(1, 2)
        k = self.k(h).view(B, C, H*W)
        v = self.v(h).view(B, C, H*W).transpose(1, 2)

        attn = torch.softmax(torch.bmm(q, k) / math.sqrt(C), dim=-1)
        h = torch.bmm(attn, v).transpose(1, 2).view(B, C, H, W)

        return x + self.out(h)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128, base_channels=128, channel_mults=(1, 2, 4, 8)):
      super().__init__()

      self.time_emb_dim = time_emb_dim
      self.base_channels = base_channels # Increased from 64
      self.channel_mults = channel_mults

      # Time embedding
      self.time_mlp = nn.Sequential(
          nn.Linear(time_emb_dim, time_emb_dim * 4),
          nn.SiLU(),
          nn.Linear(time_emb_dim * 4, time_emb_dim)
      )

      # Initial convolution
      self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

      # Encoder
      self.downs = nn.ModuleList()
      in_chs = base_channels
      for i, mult in enumerate(channel_mults):
          out_chs = base_channels * mult
          is_last = (i == len(channel_mults) - 1)
          self.downs.append(nn.ModuleList([
              ResidualBlock(in_chs, out_chs, time_emb_dim),
              ResidualBlock(out_chs, out_chs, time_emb_dim),
              AttentionBlock(out_chs) if mult in [4, 8] else nn.Identity(), # Apply attention at more levels
              nn.Conv2d(out_chs, out_chs, 3, stride=2, padding=1) if not is_last else nn.Identity()
          ]))
          in_chs = out_chs

      # Middle
      mid_channels = base_channels * channel_mults[-1]
      self.middle = nn.ModuleList([
          ResidualBlock(mid_channels, mid_channels, time_emb_dim),
          AttentionBlock(mid_channels),
          ResidualBlock(mid_channels, mid_channels, time_emb_dim)
      ])

      # Decoder
      self.ups = nn.ModuleList()
      for i in reversed(range(len(channel_mults))):
          mult = channel_mults[i]
          out_chs = base_channels * mult

          prev_level_chs = base_channels * channel_mults[min(i + 1, len(channel_mults) - 1)]
          skip_conn_chs = base_channels * channel_mults[i]
          skip_conn_chs = base_channels * channel_mults[i]
          input_chs_resblock = prev_level_chs + skip_conn_chs

          is_first_decoder_block = (i == len(channel_mults) - 1)
          upsample_layer = nn.Identity() if is_first_decoder_block else nn.ConvTranspose2d(prev_level_chs, prev_level_chs, 2, stride=2)

          self.ups.append(nn.ModuleList([
              ResidualBlock(input_chs_resblock, out_chs, time_emb_dim),
              ResidualBlock(out_chs, out_chs, time_emb_dim),
              AttentionBlock(out_chs) if mult in [4, 8] else nn.Identity(), # Symmetrical attention
              upsample_layer
          ]))

      # Output
      self.out = nn.Sequential(
          nn.GroupNorm(8, base_channels),
          nn.SiLU(),
          nn.Conv2d(base_channels, out_channels, 3, padding=1)
      )

    # highlight-start
    def forward(self, x, t):
      # Time embedding
      if len(t.shape) == 0:
          t = t.unsqueeze(0)
      t_emb = get_sinusoidal_embeddings(t.float(), self.time_emb_dim).to(x.device)
      t_emb = self.time_mlp(t_emb)

      # Initial conv
      x = self.init_conv(x)

      # Encoder
      skips = []
      for resblock1, resblock2, attnblock, downsample in self.downs:
          x = resblock1(x, t_emb)
          x = resblock2(x, t_emb)
          x = attnblock(x)
          skips.append(x)
          x = downsample(x)


      # Middle
      for block in self.middle:
          if isinstance(block, ResidualBlock):
               x = block(x, t_emb)
          else:
               x = block(x)

      # Decoder
      for i, (resblock1, resblock2, attnblock, upsample) in enumerate(self.ups):
          skip = skips.pop()
          x = upsample(x)
          x = torch.cat([x, skip], dim=1) # Concatenate skip connection
          x = resblock1(x, t_emb)
          x = resblock2(x, t_emb)
          x = attnblock(x)

      # Output
      return self.out(x)
    # highlight-end

In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

In [None]:
@torch.no_grad()
def sample_timestep(x, t, model):
    """Sample single timestep"""
    if isinstance(t, int):
        t = torch.tensor([t] * x.shape[0], device=x.device)

    betas_t = betas.to(x.device)[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod.to(x.device)[t].view(-1, 1, 1, 1)
    sqrt_recip_alphas_t = sqrt_recip_alphas.to(x.device)[t].view(-1, 1, 1, 1)

    predicted_noise = model(x, t)

    model_mean = sqrt_recip_alphas_t * (x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)

    if t[0] == 0:
        return model_mean
    else:
        posterior_variance_t = posterior_variance.to(x.device)[t].view(-1, 1, 1, 1)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

In [None]:
@torch.no_grad()
def generate_samples(model, num_samples, device):
    model.eval()

    # Start with pure noise
    x = torch.randn(num_samples, image_channels, image_size, image_size, device=device)

    # Denoise step by step
    for i in tqdm(reversed(range(timesteps)), desc='Sampling'):
        x = sample_timestep(x, i, model)

    # Convert from [-1, 1] to [0, 1]
    x = (x.clamp(-1, 1) + 1) / 2
    return x

In [None]:
from torchvision import datasets, transforms
train_dataset = datasets.CIFAR10(
    root='./data',        # Directory to store the data
    train=True,           # Use the training set
    download=True,        # Download if not found
    transform=transform
)

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)




In [None]:
model = UNet(in_channels=image_channels, out_channels=image_channels)

if torch.cuda.device_count() > 1:
    print(f'lets use {torch.cuda.device_count()} gpus')
    model = nn.DataParallel(model)

model.to(device)

model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


In [None]:
from torch.cuda.amp import GradScaler, autocast
# Before the loop
scaler = GradScaler()


for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch_idx, (images, _) in enumerate(pbar):
        images = images.to(device)
        t = torch.randint(0, timesteps, (images.shape[0],), device=device)

        # highlight-start
        with autocast():
            # Forward pass runs in mixed precision
            noisy_images, noise = forward_diffusion_sample(images, t, device)
            predicted_noise = model(noisy_images, t)
            loss = F.mse_loss(predicted_noise, noise)
        # highlight-end

        # Backward pass
        optimizer.zero_grad()
        # highlight-start
        # Scales loss. Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        # Clip the unscaled gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Unscales the gradients of optimizer's assigned params.
        scaler.step(optimizer)
        # Updates the scale for next iteration.
        scaler.update()
        # highlight-end

        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    scheduler.step()
    avg_loss = total_loss / len(dataloader)

    # Generate samples every 25 epochs
    if (epoch + 1) % 25 == 0:
        print(f"\nEpoch {epoch+1}, Average Loss: {avg_loss:.4f}")
        samples = generate_samples(model, 8, device)

        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        for i, ax in enumerate(axes.flat):
            ax.imshow(samples[i].cpu().permute(1, 2, 0))
            ax.axis('off')

        plt.suptitle(f'Generated Samples - Epoch {epoch+1}')
        plt.tight_layout()
        plt.show()

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'landscape_diffusion_epoch_{epoch+1}.pth')

print("Training complete!")

  scaler = GradScaler()
  with autocast():
W0723 10:19:49.215000 584 torch/_inductor/utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode
  with autocast():
Epoch 1/500:  99%|█████████▉| 195/196 [04:29<00:00,  1.26it/s, loss=0.1238]