<a href="https://colab.research.google.com/github/CakeNuthep/Diffusion-Models-pytorch/blob/main/ddpm_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/CakeNuthep/Diffusion-Models-pytorch.git

fatal: destination path 'Diffusion-Models-pytorch' already exists and is not an empty directory.


In [2]:
!unzip "/content/Diffusion-Models-pytorch/data/Linnaeus 5 64X64.zip" -d "/content/Diffusion-Models-pytorch/"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/819_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/82_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/820_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/821_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/822_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/823_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/824_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/825_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/826_64.jpg  
  inflating: /content/Diffusion-Models-pytorch/Linnaeus 5 64X64/train/berry/827_64.jpg  
  inflating: /content/Diffusion-Models-pytorch

In [3]:
%cd "/content/Diffusion-Models-pytorch"

/content/Diffusion-Models-pytorch


In [4]:
class argument:
  def __init__(self,run_name=None,epochs=None,batch_size=None,image_size=None,dataset_path=None,device=None,lr=None):
    super().__init__()
    self.run_name = run_name
    self.epochs = epochs
    self.batch_size = batch_size
    self.image_size = image_size
    self.dataset_path = dataset_path
    self.device = device
    self.lr = lr

In [5]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from utils import *
from modules import UNet
import logging
from torch.utils.tensorboard import SummaryWriter

In [6]:
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

In [7]:
def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet(device=device).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))

In [None]:
args = argument()
args.run_name = "DDPM_Uncondtional"
args.epochs = 1        #500
args.batch_size = 5    #12
args.image_size = 64
args.dataset_path = "/content/Diffusion-Models-pytorch/Linnaeus 5 64X64"
args.device = "cpu"
args.lr = 3e-4
train(args)

  1%|          | 18/1600 [01:41<2:27:45,  5.60s/it, MSE=0.529]