<a href="https://colab.research.google.com/github/dulhara79/ddpm/blob/main/Diffusion_Models_%7C_PyTorch_Implementation_YT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch import optim
from tqdm import tqdm
import logging
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torchvision
from PIL import Image
import torch.nn.functional as F

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s",
                    level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


##DDPM

In [None]:
class Diffusion:
  def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
    self.noise_steps=noise_steps
    self.beta_start=beta_start
    self.beta_end=beta_end
    self.img_size=img_size
    self.device=device

    self.beta=self.prepare_noise_schedule().to(device)
    self.alpha=1.-self.beta
    self.alpha_hat=torch.cumprod(self.alpha, dim=0)

  def prepare_noise_schedule(self):
    return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

  def noise_images(self, x, t):
    sqrt_alpha_hat=torch.sqrt(self.alpha_hat[t])[:, None, None, None]
    sqrt_one_minus_alpha_hat=torch.sqrt(1-self.alpha_hat[t])[:, None, None, None]
    epsilon=torch.randn_like(x)
    return sqrt_alpha_hat*x+sqrt_one_minus_alpha_hat*epsilon,epsilon

  def sample_timesteps(self, n):
    return torch.randint(low=1, high=self.noise_steps, size=(n,))

  def sample(self, model, n):
    logging.info(f"Sampling {n} new images....")
    print(f"Sampling {n} new images....")
    model.eval()
    with torch.no_grad():
      x=torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
      for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
        t=(torch.ones(n)*i).long().to(self.device)
        predicted_noise=model(x, t)
        alpha=self.alpha[t][:, None, None, None]
        alpha_hat=self.alpha_hat[t][:, None, None, None]
        beta=self.beta[t][:, None, None, None]
        if i>1:
          noise=torch.randn_like(x)
        else:
          noise=torch.zeros_like(x)
        x=1/torch.sqrt(alpha)*(x-((1-alpha)/(torch.sqrt(1-alpha_hat)))*predicted_noise)+torch.sqrt(beta)*noise
    model.train()
    x=(x.clamp(-1, 1)+1)/2
    x=(x*255).type(torch.uint8)
    return x

In [None]:
# def train(args):
#   setup_logging(args.run_name)
#   device=args.device
#   dataloader=get_data(args)
#   model=UNet().to(device)
#   optimizer=optim.AdamW(model.parameters(), lr=args.lr)
#   mse=nn.MSELoss()
#   diffusion=Diffusion(img_size=args.image_size, device=device)
#   logger=SummaryWriter(os.path.join("runs", args.run_name))
#   l = len(dataloader)

#   for epoch in range(args.epochs):
#     logging.info(f"Starting epoch {epoch}:")
#     pbar=tqdm(dataloader)
#     for i, (images, _) in enumerate(pbar):
#       images=images.to(device)
#       t=diffusion.sample_timesteps(images.shape[0]).to(device)
#       x_t, noise=diffusion.noise_images(images, t)
#       predicted_noice=model(x_t, t)
#       loss=mse(noise, predicted_noice)

#       optimizer.zero_grad()
#       loss.backward()
#       optimizer.step()

#       pbar.set_postfix(MSE=loss.item())
#       logger.add_scalar("MSE", loss.item(), global_step=epoch*l+i)

#     sampled_images=diffusion.sample(model, n=images.shape[0])
#     save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
#     torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))

In [None]:
def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    model_path = os.path.join("models", args.run_name, "ckpt.pt")
    start_epoch = 0

    if os.path.exists(model_path):
        logging.info("Loading checkpoint...")
        print("Loading checkpoint...")
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        # Typo fixed: "Resuming"
        logging.info(f"Resuming training from epoch {start_epoch}")
        print(f"Resuming training from epoch {start_epoch}")

    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    # FIX #2: The loop now starts from the correct epoch
    for epoch in range(start_epoch, args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        print(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        # After the epoch is done, generate sample images
        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))

        # FIX #1 & #3: Save the checkpoint ONCE per epoch and remove the conflicting save
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
        }, model_path)
        logging.info(f"Epoch {epoch} finished and checkpoint saved.")
        print(f"Epoch {epoch} finished and checkpoint saved.")

In [None]:
from ast import arg
def launch():
  import argparse
  parser=argparse.ArgumentParser()
  args=parser.parse_args([])
  args.run_name="DDPM_Unconditional"
  args.epochs=500
  args.batch_size=8
  args.image_size=64
  # primary
  # args.dataset_path=r"/content/drive/MyDrive/Colab Notebooks/MLOM/MLOM_Assignment-1/landscape_pictures"
  # nwatch
  # args.dataset_path=r"/content/drive/MyDrive/Colab Notebooks/landscape_pictures"
  # args.dataset_path=r"/content/drive/MyDrive/MLOM/landscape/MLOM_Assignment-1"
  # kaushalyadulhara
  args.dataset_path=r"/content/drive/MyDrive/MLOM-Assignment/landscape_pictures"
  args.device="cuda" if torch.cuda.is_available() else "cpu"
  args.lr=3e-4

  if not os.path.exists(args.dataset_path):
        logging.error(f"Dataset path not found: {args.dataset_path}. Please upload a directory named 'dataset' with subfolders containing images.")

  train(args)

##Modules

In [None]:
class EMA:
  def __init__(self, beta):
    super().__init__()
    self.beta=beta
    self.step=0

  def update_model_avarage(self, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(),
                                         ma_model.parameters()):
      old_weight, up_weight=ma_params.data, current_params.data
      ma_params.data=self.update_avarage(old_weight, up_weight)

  def update_avarage(self, old, new):
    if old is None:
      return new
    return old*self.beta+(1-self.beta)*new

  def step_ema(self, ema_model, model, step_start_ema=2000):
    if self.step<step_start_ema:
      self.reset_parameters(ema_model, model)
      self.step+=1
      return

    self.update_model_avarage(ema_model, model)
    self.step+=1

  def reset_parameters(self, ema_model, model):
    ema_model.load_state_dict(model.state_dict())

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, channels, size):
    super(SelfAttention, self).__init__()
    self.channels=channels
    self.size=size
    self.mha=nn.MultiheadAttention(channels, 4, batch_first=True)
    self.ln=nn.LayerNorm([channels])
    self.ff_self=nn.Sequential(
        nn.LayerNorm([channels]),
        nn.Linear(channels, channels),
        nn.GELU(),
        nn.Linear(channels, channels),
    )

  def forward(self, x):
    x=x.view(-1, self.channels, self.size*self.size).swapaxes(1, 2)
    x_ln=self.ln(x)
    attention_value, _ = self.mha(x_ln, x_ln, x_ln)
    attention_value=attention_value+x
    attention_value=self.ff_self(attention_value)+attention_value
    return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size,
                                               self.size)

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
    super().__init__()
    self.residual=residual
    if not mid_channels:
      mid_channels=out_channels
    self.double_conv=nn.Sequential(
        nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
        nn.GroupNorm(1, mid_channels),
        nn.GELU(),
        nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.GroupNorm(1, out_channels)
    )

  def forward(self, x):
    if self.residual:
      return F.gelu(x+self.double_conv(x))
    else:
      return self.double_conv(x)

In [None]:
class Down(nn.Module):
  def __init__(self, in_channels, out_channels, emb_dim=256):
    super().__init__()
    self.maxpool_conv=nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channels, in_channels, residual=True),
        DoubleConv(in_channels, out_channels),
    )
    self.emb_layer=nn.Sequential(
        nn.SiLU(),
        nn.Linear(
            emb_dim,
            out_channels
        ),
    )

  def forward(self, x, t):
    x=self.maxpool_conv(x)
    emb=self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2],
                                                   x.shape[-1])
    return x+emb

In [None]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2],
                                                         x.shape[-1])
        return x + emb

In [None]:
class UNet(nn.Module):
  def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
    super().__init__()
    self.device=device
    self.time_dim=time_dim

    self.inc=DoubleConv(c_in, 64)

    self.down1=Down(64, 128)
    self.sa1=SelfAttention(128, 32)
    self.down2=Down(128, 256)
    self.sa2=SelfAttention(256, 16)
    self.down3=Down(256, 256)
    self.sa3=SelfAttention(256, 8)

    self.bot1=DoubleConv(256, 512)
    self.bot2=DoubleConv(512, 512)
    self.bot3=DoubleConv(512, 256)

    self.up1=Up(512, 128)
    self.sa4=SelfAttention(128, 16)

    self.up2=Up(256, 64)
    self.sa5=SelfAttention(64, 32)

    self.up3=Up(128, 64)
    self.sa6=SelfAttention(64, 64)

    self.outc=nn.Conv2d(64, c_out, kernel_size=1)

  def pos_encoding(self, t, channels):
    inv_freq=1.0/(
        10000
        **(torch.arange(0, channels, 2, device=self.device).float()/channels)
        )
    pos_enc_a=torch.sin(t.repeat(1, channels//2)*inv_freq)
    pos_enc_b=torch.cos(t.repeat(1, channels//2)*inv_freq)
    pos_enc=torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc

  def forward(self, x, t):
    t=t.unsqueeze(-1).type(torch.float)
    t=self.pos_encoding(t, self.time_dim)

    x1=self.inc(x)
    x2=self.down1(x1, t)
    x2=self.sa1(x2)
    x3=self.down2(x2, t)
    x3=self.sa2(x3)
    x4=self.down3(x3, t)
    x4=self.sa3(x4)

    x4=self.bot1(x4)
    x4=self.bot2(x4)
    x4=self.bot3(x4)

    x=self.up1(x4, x3, t)
    x=self.sa4(x)
    x=self.up2(x, x2, t)
    x=self.sa5(x)
    x=self.up3(x, x1, t)
    x=self.sa6(x)
    output=self.outc(x)
    return output

##Utils

In [None]:
def plot_images(images):
  plt.figure(figsize=(32, 32))
  plt.imshow(torch.cat([
      torch.cat([i for i in images.cpu()], dim=-1),
  ], dim=-2).permute(1, 2, 0).cpu())
  plt.show()

In [None]:
def save_images(images, path, **kwargs):
  grid=torchvision.utils.make_grid(images, **kwargs)
  ndarr=grid.permute(1, 2, 0).to("cpu").numpy()
  im=Image.fromarray(ndarr)
  im.save(path)

In [None]:
def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader

In [None]:
def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

In [None]:
if __name__ == '__main__':
    # Make sure to call launch() to start the process
    # The UNet is missing the implementation of 'Up' and 'F.gelu' is called without 'F' imported
    # Fixes are made in the code above.
    try:
        launch()
    except TypeError as e:
        print(f"An error occurred during launch: {e}")
        print("This may be due to the UNet's forward pass not exactly matching the expected channel sizes for the Up blocks after concatenation. Please double-check the channel math.")

Loading checkpoint...


UnpicklingError: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 198

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.