# The Diffusion Procedure

In [1]:
from google.colab import drive

drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [2]:
!ls drive

MyDrive


In [3]:
# Defining the diffusion procedure

from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn

import torch.utils.data


def gather(consts: torch.Tensor, t: torch.Tensor):
    """Gather consts for $t$ and reshape to feature map shape"""
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)


class DenoiseDiffusion:
    def __init__(self, latent_model: nn.Module, eps_model: nn.Module, n_steps: int, device: torch.device):
        # super().__init__()
        self.device = device
        self.latent_model = latent_model
        self.eps_model = eps_model
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sigma2 = self.beta

    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        ''' Note : x0 is in the latent space'''
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        var = 1 - gather(self.alpha_bar, t)
        return mean, var

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        ''' Note : x0 is in the latent space'''
        if eps is None:
            eps = torch.randn_like(x0)

        mean, var = self.q_xt_x0(x0, t)
        return mean + (var ** 0.5) * eps

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, map_latent=False):
        if map_latent:
            pass # removed the latent Encoder
        eps_theta = self.eps_model(xt, t)
        alpha_bar = gather(self.alpha_bar, t)
        alpha = gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = gather(self.sigma2, t)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** .5) * eps
    
    def reconstruct(self, xt: torch.Tensor, eps_theta: torch.Tensor, t: torch.Tensor):
        out = self.p_x0(xt, t, eps_theta)
        return out

    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None, reconstruct: bool = False):
        recon = None
        batch_size = x0.shape[0]
        x0 = x0.to(self.device)
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if noise is None:
            noise = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps=noise)
        eps_theta = self.eps_model(xt, t)
        if reconstruct:
            recon = self.reconstruct(xt, eps_theta, t)
        self.loss_ = F.mse_loss(noise, eps_theta)
        return self.loss_, recon
    
    def serialize_eps_model(self, iter_number, optimizer, loss):
        path = '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_'+str(iter_number)
        torch.save({
            'iteration_number': iter_number,
            'model_state_dict': self.eps_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, path)
    
    def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):

        alpha_bar = gather(self.alpha_bar, t)
        return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)


# Datasets

In [4]:
# Datasets for training

import torch
import torch.utils.data
import torchvision
from glob import glob
import os
import cv2
import numpy as np

def by_sample_number(path):
    num = int(path.split(os.path.sep)[-1].split('.')[0])
    return num


class MNISTDataset(torchvision.datasets.MNIST):
    """
    ### MNIST dataset
    """

    def __init__(self, image_size, data_path):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.ToTensor(),
        ])
        self.data_path = data_path
        super().__init__(self.data_path, train=True, download=True, transform=transform)

    def __getitem__(self, item):
        return super().__getitem__(item)[0]

class CarsDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path):
        self.imsize = 224
        self.dataset_path = dataset_path
        self._load_dataset_paths()
    
    def _load_dataset_paths(self):
        self.all_samples = sorted(glob(self.dataset_path + os.path.sep + '*.jpg'), key=by_sample_number)
    
    def __getitem__(self, index):
        img = cv2.imread(self.all_samples[index])
        img = cv2.resize(img, (self.imsize, self.imsize))
        img = img.transpose(2, 0, 1)
        assert img.shape == (3, self.imsize, self.imsize)
        assert np.max(img) <= 255
        img = torch.FloatTensor(img / 255.)
        return img
    
    def __len__(self):
        return len(self.all_samples)


# UNet Model Implementation

In [5]:
# Defining the UNet Model


"""
UNET implementation containing a bunch of modifications (residual blocks, multi-head attention, time-step embedding t)
"""

import math
from typing import Optional, Tuple, Union, List

import torch
from torch import nn
from torch.nn import Module


class Swish(Module):

    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):

    def __init__(self, n_channels: int):
        super().__init__()
        self.n_channels = n_channels
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        self.act = Swish()
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)
        return emb


class ResidualBlock(Module):

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
        super().__init__()
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()
        self.time_emb = nn.Linear(time_channels, out_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        h = self.conv1(self.act1(self.norm1(x)))
        h += self.time_emb(t)[:, :, None, None]
        h = self.conv2(self.act2(self.norm2(h)))
        return h + self.shortcut(x)


class AttentionBlock(Module):
    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):

        super().__init__()

        if d_k is None:
            d_k = n_channels
        self.norm = nn.GroupNorm(n_groups, n_channels)
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        self.output = nn.Linear(n_heads * d_k, n_channels)
        self.scale = d_k ** -0.5
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        _ = t
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        attn = attn.softmax(dim=2)
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        res = self.output(res)
        res += x
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        return res


class DownBlock(Module):

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class UpBlock(Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class MiddleBlock(Module):
    def __init__(self, n_channels: int, time_channels: int):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x


class Upsample(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)


class Downsample(nn.Module):

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)


class UNet(Module):
    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
                 n_blocks: int = 2):
        super().__init__()
        n_resolutions = len(ch_mults)

        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

        self.time_emb = TimeEmbedding(n_channels * 4)

        down = []
        out_channels = in_channels = n_channels
        for i in range(n_resolutions):
            out_channels = in_channels * ch_mults[i]
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        self.down = nn.ModuleList(down)
        self.middle = MiddleBlock(out_channels, n_channels * 4, )
        up = []
        in_channels = out_channels
        for i in reversed(range(n_resolutions)):
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            if i > 0:
                up.append(Upsample(in_channels))

        self.up = nn.ModuleList(up)
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        t = self.time_emb(t)
        x = self.image_proj(x)
        h = [x]
        for m in self.down:
            x = m(x, t)
            h.append(x)

        x = self.middle(x, t)

        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x, t)
            else:
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                x = m(x, t)
        return self.final(self.act(self.norm(x)))


# Traning and Sampling

In [10]:
# Training and sampling

from typing import List
import torch
import torch.utils.data
try:
  from torchinfo import summary
except:
  !pip3 install torchinfo
  from torchinfo import summary
import numpy as np
import os
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
from glob import glob
from tqdm import trange
import cv2

warnings.filterwarnings('ignore')
plt.ion()

def by_iter_num(path):
    num = int(path.split(os.path.sep)[-1].split('_')[-1])
    return num


class Configs:
    device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    iter_counter : int = 0

    eps_model: UNet
    diffusion: DenoiseDiffusion
    image_channels: int = 3
    image_size: int = 224
    n_channels: int = 64
    channel_multipliers: List[int] = [1, 2, 4]
    is_attention: List[int] = [False, False, True]

    n_steps: int = 1_000 # number of diffusion steps
    batch_size: int = 4
    n_samples: int = 4
    learning_rate: float = 2e-5
    epochs: int = 1_000
    dataset_path: str = '/content/drive/MyDrive/diffusion_dataset/train/' 
    dataset: torch.utils.data.Dataset = CarsDataset(dataset_path)
    data_loader: torch.utils.data.DataLoader
    optimizer: torch.optim.Adam

    def init(self, debug=True, cold_start=True):
        if cold_start:
            self.eps_model = UNet(
                image_channels=self.image_channels,
                n_channels=self.n_channels,
                ch_mults=self.channel_multipliers,
                is_attn=self.is_attention,
            ).to(self.device)
            self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
        else:
            self.eps_model = UNet(
                image_channels=self.image_channels,
                n_channels=self.n_channels,
                ch_mults=self.channel_multipliers,
                is_attn=self.is_attention,
            )
            print(sorted(glob('/content/drive/MyDrive/diffusion_dataset/model_checkpoints/*'), key=by_iter_num))
            chkpt_path = sorted(glob('/content/drive/MyDrive/diffusion_dataset/model_checkpoints/*'), key=by_iter_num)[-1]
            state_dict = torch.load(chkpt_path)
            self.eps_model.load_state_dict(state_dict['model_state_dict'])
            self.eps_model = self.eps_model.to(self.device)
            self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
            self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
            self.iter_counter = int(state_dict['iteration_number'])
        if debug:
            summary(self.eps_model, input_data=(torch.from_numpy(np.random.uniform(size=(1, 3, 224, 224))).type(torch.float32).to(self.device),
                                                 torch.randint(0, self.n_steps, (1,), device=self.device, dtype=torch.long)) )
            os.system('nvidia-smi')
        self.diffusion = DenoiseDiffusion(
            latent_model=None,
            eps_model=self.eps_model,
            n_steps=self.n_steps,
            device=self.device,
        )
        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
        

    def show_output(self, output):
        show_image = np.moveaxis(output.detach().cpu().numpy()[0], 0, -1)
        show_image = np.clip(show_image, 0, 1)
        plt.cla()
        plt.imshow(show_image)
        plt.draw()
        plt.pause(0.001)
        return


    def sample(self):
        """
        ### Sample images
        """
            
        with torch.no_grad():
            x = torch.randn([1, self.image_channels, self.image_size, self.image_size],
                            device=self.device)

            for t_ in trange(self.n_steps):
                
                t = self.n_steps - t_ - 1
                x = self.diffusion.p_sample(x, x.new_full((1,), t, dtype=torch.long))

                if t_ % 50 == 0:
                    self._serialize_generations(x.detach().cpu(), prefix=str(t_))
            for _ in range(15):
                t = 0
                x = self.diffusion.p_sample(x, x.new_full((1,), t, dtype=torch.long))

                if t % 50 == 0:
                    self._serialize_generations(x.detach().cpu(), prefix=str(t_))
        




    def train(self, reconstruct=False):
        for _, data in enumerate(tqdm(self.data_loader)):
            self.iter_counter += 1
            data = data.to(self.device)
            self.optimizer.zero_grad()
            loss, recon = self.diffusion.loss(data, reconstruct=reconstruct)
            loss.backward()
            self.optimizer.step()
            if reconstruct:
                self.show_output(recon)
            # print('[INFO] Current Batch Loss : ', loss.item())
            # print('Iter counter : ', self.iter_counter)
            if self.iter_counter % 2_000 == 0: # saves the model every 1500 updates
                print('[INFO] Serializing model...')
                self.diffusion.serialize_eps_model(self.iter_counter, self.optimizer, loss)

    def _serialize_generations(self, generations, prefix=None):
        generations = generations.numpy()

        if not os.path.exists('./gen_images'):
          os.mkdir('./gen_images')

        for k, gen_img in enumerate(generations):
            show_image = np.moveaxis(gen_img, 0, -1)
            show_image = np.clip(show_image, 0, 1)
            show_image = (show_image * 255).astype(np.uint8)
            if not prefix is None:
                cv2.imwrite(f'./gen_images/{prefix}{k}.png', show_image)
            else:
                cv2.imwrite(f'./gen_images/{k}.png', show_image)
        


    def run(self):
        """
        ### Training loop
        """
        for _ in range(self.epochs):
            self.train(reconstruct=False)
    
    def generate_samples(self):
      """
      ### Generate Samples from Noise
      """
      self.sample()


def main():
    configs = Configs()
    configs.init(cold_start=False) # setting cold_start to False loads the most recent checkpoint
    # configs.run() # trains the diffusion model
    configs.generate_samples() # for running inference


if __name__ == '__main__':
    main()


['/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_102000', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_103500', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_103600', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_103800', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_104000', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_104200', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_104400', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_104600', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_104800', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_105000', '/content/drive/MyDrive/diffusion_dataset/model_checkpoints/model_checkpoint_105200', '/content/drive/MyDrive/diffusion_dataset/model_check

100%|██████████| 1000/1000 [01:55<00:00,  8.64it/s]


In [None]:
!ls /content/drive/MyDrive/diffusion_dataset/model_checkpoints

In [None]:
torch.cuda.empty_cache()

In [12]:
from google.colab import files
!zip -r ./gen_images1.zip ./gen_images
!files.download('./gen_images1.zip')

  adding: gen_images/ (stored 0%)
  adding: gen_images/9501.png (deflated 0%)
  adding: gen_images/8000.png (deflated 0%)
  adding: gen_images/2502.png (deflated 0%)
  adding: gen_images/1001.png (deflated 0%)
  adding: gen_images/501.png (deflated 0%)
  adding: gen_images/8003.png (deflated 0%)
  adding: gen_images/4500.png (deflated 0%)
  adding: gen_images/7000.png (deflated 0%)
  adding: gen_images/9990.png (deflated 0%)
  adding: gen_images/3501.png (deflated 0%)
  adding: gen_images/8501.png (deflated 0%)
  adding: gen_images/4503.png (deflated 0%)
  adding: gen_images/4001.png (deflated 0%)
  adding: gen_images/3001.png (deflated 0%)
  adding: gen_images/5002.png (deflated 0%)
  adding: gen_images/3503.png (deflated 0%)
  adding: gen_images/3003.png (deflated 0%)
  adding: gen_images/4501.png (deflated 0%)
  adding: gen_images/6002.png (deflated 0%)
  adding: gen_images/8001.png (deflated 0%)
  adding: gen_images/5001.png (deflated 0%)
  adding: gen_images/7001.png (deflated 0%)