In [116]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
import os
import torchvision
import pandas as pd

class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        '''
        process and downscale the image feature maps
        '''
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        '''
        process and upscale the image feature maps
        '''
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=10):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        self.up0 = nn.Sequential(
            # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):
        # x is (noisy) image, c is context label, t is timestep, 
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)

        # convert context to one hot embedding
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)
        
        # mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1,self.n_classes)
        context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1
        c = c * context_mask
        
        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        # could concatenate the context embedding here instead of adaGN
        # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)

        up1 = self.up0(hiddenvec)
        # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        up2 = self.up1(cemb1*up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out


def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }


class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))
        

In [117]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_shape = (3, 28, 28)
in_channels = 3
n_feat = 128
n_classes = 20


In [118]:
model = DDPM(nn_model=ContextUnet(in_channels=in_channels, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=400, device=device, drop_prob=0.1)


In [119]:
def min_max_normalize(img):
    out = img.clone()
    min_value = out.min()
    out = out - min_value
    max_value = out.max()
    out /= max_value

    return out

def sample(model, n_sample, c, guide_w, make_history=False):
    with torch.no_grad():
        x = torch.randn((n_sample, *img_shape), device=device)
        c = torch.tensor(c, device=device).repeat(n_sample)

        c = c.repeat(2)

        context_mask = torch.zeros_like(c, device=device)
        context_mask[n_sample:] = 1.

        history = []
        for i in tqdm(range(model.n_T, 0, -1)):
            t = torch.tensor([i / model.n_T], device=device).repeat(n_sample*2, 1, 1, 1)
            x = x.repeat(2,1,1,1)
            
            z = torch.randn(n_sample, *img_shape).to(device) if i > 1 else 0

            eps = model.nn_model(x, c, t, context_mask)
            eps = (1+guide_w) * eps[:n_sample] - guide_w * eps[n_sample:]
            x = x[:n_sample]
            x = (model.oneover_sqrta[i] * (x - eps * model.mab_over_sqrtmab[i]) + model.sqrt_beta_t[i] * z)

            if make_history and (i + 1) % 20 == 0:
                history.append(x.cpu())

        x = x.cpu()
        history.append(x)
        return (x, history) if make_history else x

In [120]:
model.load_state_dict(torch.load(f"outdir/model_199.pth"))
model = model.eval()

  model.load_state_dict(torch.load(f"outdir/model_199.pth"))


## 0-9

In [121]:
out0 = sample(model, 10, 0, 1.0, True)
change_fig = torch.cat([out0[1][0][:1], out0[1][5][:1], out0[1][10][:1], out0[1][13][:1], out0[1][16][:1], out0[1][20][:1]])

for idx, f in enumerate(change_fig):
    change_fig[idx] = min_max_normalize(f)

save_image(change_fig, 'p1_out/0.png')

out0 = out0[0]

100%|██████████| 400/400 [00:03<00:00, 112.39it/s]


In [122]:
out1_9 = []
for i in range(1,10):
    out1_9.append(sample(model, 10, i, 1.0, False))

100%|██████████| 400/400 [00:03<00:00, 117.96it/s]
100%|██████████| 400/400 [00:03<00:00, 121.46it/s]
100%|██████████| 400/400 [00:03<00:00, 119.79it/s]
100%|██████████| 400/400 [00:03<00:00, 120.41it/s]
100%|██████████| 400/400 [00:03<00:00, 118.59it/s]
100%|██████████| 400/400 [00:03<00:00, 119.20it/s]
100%|██████████| 400/400 [00:03<00:00, 118.09it/s]
100%|██████████| 400/400 [00:03<00:00, 122.34it/s]
100%|██████████| 400/400 [00:03<00:00, 120.61it/s]


In [123]:
out = torch.cat(out1_9)
out = torch.cat([out0, out])
grid = make_grid(out, nrow=10)
save_image(grid, 'p1_out/1.png')

## 10-19 

In [124]:
out0 = sample(model, 10, 10, 1.0, True)
change_fig = torch.cat([out0[1][0][:1], out0[1][5][:1], out0[1][10][:1], out0[1][13][:1], out0[1][16][:1], out0[1][20][:1]])

for idx, f in enumerate(change_fig):
    change_fig[idx] = min_max_normalize(f)

grid = make_grid(change_fig, nrow=6)
save_image(grid, 'p1_out/2.png')
out0 = out0[0]

100%|██████████| 400/400 [00:03<00:00, 118.15it/s]


In [125]:
out1_9 = []
for i in range(11,20):
    out1_9.append(sample(model, 10, i, 1.0, False))
out = torch.cat(out1_9)
out = torch.cat([out0, out])
grid = make_grid(out, nrow=10)
save_image(grid, 'p1_out/3.png')

100%|██████████| 400/400 [00:03<00:00, 117.79it/s]
100%|██████████| 400/400 [00:03<00:00, 120.11it/s]
100%|██████████| 400/400 [00:03<00:00, 117.86it/s]
100%|██████████| 400/400 [00:03<00:00, 117.42it/s]
100%|██████████| 400/400 [00:03<00:00, 118.27it/s]
100%|██████████| 400/400 [00:03<00:00, 118.74it/s]
100%|██████████| 400/400 [00:03<00:00, 118.87it/s]
100%|██████████| 400/400 [00:03<00:00, 119.31it/s]
100%|██████████| 400/400 [00:03<00:00, 118.01it/s]
