In [None]:
# DDPM: Using https://github.com/labmlai/annotated_deep_learning_paper_implementations
# DDPM: Specifically https://nn.labml.ai/diffusion/ddpm/index.html
# DDIM: Using https://github.com/ermongroup/ddim

from typing import Tuple, Optional, Union, List

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

import torchvision.utils as tvu
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.datasets import CIFAR10, STL10

import sys
import imageio
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### U-Net

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, n_channels: int):
        super().__init__()
        self.n_channels = n_channels
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        self.act = Swish()
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        half_dim = self.n_channels // 8
        emb = np.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)
        return emb


class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
        super().__init__()
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()
        self.time_emb = nn.Linear(time_channels, out_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        h = self.conv1(self.act1(self.norm1(x)))
        h += self.time_emb(t)[:, :, None, None]
        h = self.conv2(self.act2(self.norm2(h)))
        return h + self.shortcut(x)


class AttentionBlock(nn.Module):
    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
        super().__init__()
        if d_k is None: d_k = n_channels
        self.norm = nn.GroupNorm(n_groups, n_channels)
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        self.output = nn.Linear(n_heads * d_k, n_channels)
        self.scale = d_k ** -0.5
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        _ = t
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        attn = attn.softmax(dim=1)
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        res = self.output(res)
        res += x
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        return res


class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class MiddleBlock(nn.Module):
    def __init__(self, n_channels: int, time_channels: int):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x


class Upsample(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)


class Downsample(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
                 n_blocks: int = 2):
        super().__init__()
        n_resolutions = len(ch_mults)
        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
        self.time_emb = TimeEmbedding(n_channels * 4)

        down = []
        out_channels = in_channels = n_channels
        for i in range(n_resolutions):
            out_channels = in_channels * ch_mults[i]
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))
        self.down = nn.ModuleList(down)
        self.middle = MiddleBlock(out_channels, n_channels * 4, )

        up = []
        in_channels = out_channels
        for i in reversed(range(n_resolutions)):
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            if i > 0:
                up.append(Upsample(in_channels))        
        self.up = nn.ModuleList(up)
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        t = self.time_emb(t)
        x = self.image_proj(x)
        h = [x]
        for m in self.down:
            x = m(x, t)
            h.append(x)
        x = self.middle(x, t)
        
        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x, t)
            else:
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                x = m(x, t)
        return self.final(self.act(self.norm(x)))

### Denoise Diffusion

In [None]:
class DenoiseDiffusion(nn.Module):
    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        super().__init__()
        self.eps_model = eps_model
        self.n_steps = n_steps
        self.device = device

        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
    
    def gather(self, consts: torch.Tensor, t: torch.Tensor):
        """Gather consts for t and reshape to feature map shape"""
        c = consts.gather(-1, t).reshape(-1, 1, 1, 1)
        return c.reshape(-1, 1, 1, 1)

    def p_x0(self, xt: torch.Tensor, t: torch.Tensor):
        """Estimate x0"""
        eps_theta = self.eps_model(xt, t)
        alpha_bar = self.gather(self.alpha_bar, t)
        return (xt - (1 - alpha_bar) ** 0.5 * eps_theta) / (alpha_bar ** 0.5)

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        """Sample from q(xt|x0)"""
        if eps is None: eps = torch.randn_like(x0)
        mean = self.gather(self.alpha_bar, t) ** 0.5 * x0
        var = 1 - self.gather(self.alpha_bar, t)
        return mean + (var ** 0.5) * eps

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        """Sample from p_theta(xt_1|xt)"""
        eps_theta = self.eps_model(xt, t)
        alpha_bar = self.gather(self.alpha_bar, t)
        alpha = self.gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = self.gather(self.beta, t)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** .5) * eps

    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
        """Simplified loss"""
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if noise is None: 
            noise = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps=noise)
        eps_theta = self.eps_model(xt, t)
        return F.mse_loss(noise, eps_theta)

    def generalized_steps(self, x: torch.Tensor, seq: list):
        """From https://github.com/ermongroup/ddim"""
        alpha_bar = torch.cat([torch.ones(1).to(self.device), self.alpha_bar], dim=0)
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]

        for i, j in tqdm(zip(reversed(seq), reversed(seq_next)), total=len(seq), leave=None):
            t = (torch.ones(n) * i).to(self.device)
            next_t = (torch.ones(n) * j).to(self.device)
            at = alpha_bar.index_select(0, t.long()+1).view(-1, 1, 1, 1)
            at_next = alpha_bar.index_select(0, next_t.long()+1).view(-1, 1, 1, 1)

            xt = xs[-1].to(self.device)
            et = self.eps_model.forward(xt, t)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_preds.append(x0_t.cpu())

            c1 = (0 * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt())
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xs.append(xt_next.cpu())

        return xs, x0_preds

### Sampler

In [None]:
class Sampler(nn.Module):
    def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
        super().__init__()
        self.diffusion = diffusion
        self.n_steps = diffusion.n_steps
        
        self.device = device
        self.image_channels = image_channels
        self.image_size = image_size
    
    def make_grid(self, x: torch.Tensor, nrow: int):
        """Helper function to create a grid of images"""
        plt.imshow(tvu.make_grid(x, nrow=nrow, padding=2).permute(1, 2, 0))
        plt.axis('off')
        plt.show()

    def slerp(self, z1: torch.Tensor, z2: torch.Tensor, alpha: torch.Tensor):
        """Spherical linear interpolation"""
        theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
        return (torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 + torch.sin(alpha * theta) / torch.sin(theta) * z2)
    
    def DDPM_sampling(self, xt: torch.Tensor, n_steps: int=1000):
        """Generate bath of images using DDPM sampling"""
        for t_inv in tqdm(range(n_steps), leave=None):
            t_ = n_steps - t_inv - 1
            t = xt.new_full((xt.size(0),), t_, dtype=torch.long)
            xt = self.diffusion.p_sample(xt, t)

        return xt

    def DDPM_sequence(self, xt: torch.Tensor, n_frames: int):
        """Sample an image step-by-step using p_theta(xt_1|xt)"""
        # Calculate which diffusion indices to show
        intervals = [int(i) for i in np.linspace(0, self.n_steps-1, n_frames)]; intervals.pop(0)
        
        frames = []
        # Get diffusion sequence for each noise image in batch
        for i in tqdm(range(xt.size(0)), leave=None):
            # Add initial noise to sequence
            sequence = [torch.squeeze(xt[i]).cpu()]
            # Perform reverse (generative) process for all T steps
            for t_inv in tqdm(range(self.n_steps), leave=None):
                t_ = self.n_steps - t_inv - 1
                t = xt[i].new_full((1,), t_, dtype=torch.long)
                # Save intermediate image to sequence
                if t_inv in intervals:
                    x0 = self.diffusion.p_x0(xt[i][None, :, :, :], t)
                    sequence.append(x0[0].cpu())
                # Perform reverse process
                xt[i] = self.diffusion.p_sample(xt[i][None, :, :, :], t)
            # Add to frames
            frames += sequence

        return frames

    def DDPM_interpolation(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int, t_: int = 100):
        """Interpolate two images x1 and x2 step by step""" 
        t = torch.full((1,), t_, device=self.device)

        frames = []
        # Get inteprolation sequence for each noise image in batch
        for i in tqdm(range(len(x1)), leave=None):
            # Forward process for t steps (larger t = more coarse)
            x1t = self.diffusion.q_sample(x1[i][None, :, :, :], t)
            x2t = self.diffusion.q_sample(x2[i][None, :, :, :], t)

            # Add initial image1 to sequence
            sequence = [torch.tensor(x1[i]).cpu()]
            # Performe linear interpolation between the images
            for j in tqdm(range(1, n_frames-1), leave=None):
                # Perform linear interpolation 
                lambda_ = j / (n_frames-1)
                xt = (1 - lambda_) * x1t + lambda_ * x2t
                # Decode latents into image space
                x0 = self.DDPM_sampling(xt, t_)
                sequence.append(x0[0].cpu())
            # Add initial image2 to sequence
            sequence.append(torch.tensor(x2[i]).cpu())
            # Add to frames
            frames += sequence

        return frames

    def DDIM_sampling(self, xt: torch.Tensor, n_timesteps: int):
        """Based on the implementation in https://github.com/ermongroup/ddim"""
        seq = range(0, self.n_steps, self.n_steps // n_timesteps)
        x = self.diffusion.generalized_steps(xt, seq)
        return x

    def DDIM_sequence(self, xt: torch.Tensor, n_frames: int, n_timesteps: int):
        """Based on the implementation in https://github.com/ermongroup/ddim"""
        # Calculate which diffusion indices to show
        intervals = [int(i) for i in np.linspace(0, n_timesteps-1, n_frames)]

        frames = []
        # Get diffusion sequence for each noise image in batch
        for i in tqdm(range(xt.size(0)), leave=None):
            # Get entire diffusion sequence
            x0_preds = self.DDIM_sampling(xt[i][None, :, :, :], n_timesteps)[1]
            # Select images with indices from intervals
            sequence = [torch.squeeze(x0_preds[j]) for j in range(n_timesteps) if j in intervals]
            # Add initial noise instead of first frame
            sequence[0] = torch.squeeze(xt[i]).cpu()
            # Add to frames
            frames += sequence

        return frames

    def DDIM_interpolation(self, z1: torch.Tensor, z2: torch.Tensor, n_frames: int, n_timesteps: int):
        """Based on the implementation in https://github.com/ermongroup/ddim"""
        # Calculate interpolation points
        alpha = torch.linspace(0, 1, n_frames).to(self.device)

        frames = []
        # Get inteprolation sequence for each noise image in batch
        for i in tqdm(range(z1.size(0)), leave=None):
            # Interpolate noise using slerp
            x = torch.cat([self.slerp(z1[i][None, None, :, :, :], z2[i][None, None, :, :, :], alpha[j]) for j in range(n_frames)], dim=0)
            # Generate images from noise at interpolation points
            sequence = [torch.squeeze(self.DDIM_sampling(x[j], n_timesteps)[0][-1]) for j in tqdm(range(n_frames), leave=None)]
            # Add to frames
            frames += sequence
            
        return frames

### Configs

In [None]:
class Configs:
    device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    eps_model: UNet
    diffusion: DenoiseDiffusion
    sampler: Sampler
    dataset_name: str = 'cifar10'  # cifar10 or stl10
    image_channels: int = 3
    image_size: int = 32  # 32 or 96
    n_channels: int = 64
    channel_multipliers: List[int] = [1, 2, 2, 4]
    is_attention: List[int] = [False, False, False, True]
    n_steps: int = 1_000
    batch_size: int = 64  # 64 or 16
    learning_rate: float = 2e-5
    epoch: int = 1
    dataset: data.Dataset
    data_loader: data.DataLoader
    optimizer: optim.Adam

    def init(self):
        self.eps_model = UNet(image_channels=self.image_channels, n_channels=self.n_channels, ch_mults=self.channel_multipliers, is_attn=self.is_attention).to(self.device)
        self.diffusion = DenoiseDiffusion(eps_model=self.eps_model, n_steps=self.n_steps, device=self.device)
        self.sampler = Sampler(self.diffusion, image_channels=self.image_channels, image_size=self.image_size, device=self.device)

        if self.dataset_name == 'cifar10':
            self.dataset = CIFAR10(root='./data', download=True, transform=transforms.ToTensor())
        
        if self.dataset_name == 'stl10':
            self.dataset1 = STL10(root='./data', split='train', download=True, transform=transforms.ToTensor())
            self.dataset2 = STL10(root='./data', split='test', download=True, transform=transforms.ToTensor())
            self.dataset = data.ConcatDataset([self.dataset1, self.dataset2])  # Join test and train datasets
            
        self.data_loader = data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
        self.optimizer = optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

        # Load previous parameters from drive
        params = torch.load('drive/My Drive/Colab Notebooks/{}.chkpt'.format(self.dataset_name), map_location=self.device)
        self.diffusion.load_state_dict(params['diffusion'])
        self.optimizer.load_state_dict(params['optimizer'])
        self.epoch = params['epoch']
    
    def train(self):
        while True:
            print('epoch:', self.epoch)
            self.epoch += 1

            # Train
            for i, (x, y) in enumerate(tqdm(self.data_loader, leave=None)):
                data_ = x.to(self.device)
                self.optimizer.zero_grad()
                loss = self.diffusion.loss(data_)
                loss.backward()
                self.optimizer.step()

            # Sample
            with torch.no_grad():
                noise = torch.randn([64, self.image_channels, self.image_size, self.image_size], device=self.device)
                x = self.sampler.DDIM_sampling(noise, n_timesteps=10)[0][-1]
                self.sampler.make_grid(x, nrow=int(np.sqrt(x.size(0))))

            # Save parameters to drive
            torch.save({'diffusion': self.diffusion.state_dict(),
                        'optimizer': self.optimizer.state_dict(), 
                        'epoch': self.epoch},
                        'drive/My Drive/Colab Notebooks/{}.chkpt'.format(self.dataset_name))

    def sample(self):
        with torch.no_grad():
            # Generate batches of random noise for sampling
            xt1 = torch.randn([64, self.image_channels, self.image_size, self.image_size], device=self.device)
            xt2 = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
            z1 = torch.randn([8, self.image_channels, self.image_size, self.image_size], device=self.device)
            z2 = torch.randn([8, self.image_channels, self.image_size, self.image_size], device=self.device)

            # DDIM sampling
            x = self.sampler.DDIM_sampling(xt1, n_timesteps=10)[0][-1]
            self.sampler.make_grid(x, nrow=int(np.sqrt(x.size(0))))
            x = self.sampler.DDIM_sampling(xt1, n_timesteps=1000)[0][-1]
            self.sampler.make_grid(x, nrow=int(np.sqrt(x.size(0))))

            frames1 = self.sampler.DDIM_sequence(xt2, n_frames=20, n_timesteps=100)
            self.sampler.make_grid(frames1, nrow=len(frames1)//xt2.size(0))

            frames2 = self.sampler.DDIM_interpolation(z1, z2, n_frames=8, n_timesteps=100)
            self.sampler.make_grid(frames2, nrow=len(frames2)//z1.size(0))

            # Get generated images from DDIM interpolation for DDPM interpolation
            x1 = [frames2[i].to(self.device) for i in range(len(frames2)) if i%8==0]
            x2 = [frames2[i].to(self.device) for i in range(len(frames2)) if (i+1)%8==0]

            # DDPM sampling
            x = self.sampler.DDPM_sampling(xt1).cpu()
            self.sampler.make_grid(x, nrow=int(np.sqrt(x.size(0))))

            frames1 = self.sampler.DDPM_sequence(xt2, n_frames=20)
            self.sampler.make_grid(frames1, nrow=len(frames1)//xt2.size(0))

            frames2 = self.sampler.DDPM_interpolation(x1, x2 , n_frames=8, t_=100)
            self.sampler.make_grid(frames2, nrow=len(frames2)//len(x1))
            frames2 = self.sampler.DDPM_interpolation(x1, x2 , n_frames=8, t_=500)
            self.sampler.make_grid(frames2, nrow=len(frames2)//len(x1))

### Main

In [None]:
configs = Configs()
configs.init()
# configs.train()
configs.sample()