# Diffusion Models

### Background

Forward Process
$$q(x_t| x_{t-1}) = \mathcal{N}(x_t, \sqrt{1-\beta}x_{t-1}, \beta I)$$
*The $\sqrt{1-\beta}x_{t-1}$ is to retain information about previous step to be able to reconstruct in reverse process, and ensure the contribution of $x_{t-1}$ to the variance scales properly*
Generally, $$q(x_t|x_0) = \mathcal{N}(x_t, \sqrt(\prod_0^t1-\beta_t)x_0, \beta_t I)$$


Reverse Process 

*Predict the noise distribution at timestep $t$ to get $x_t$ from $x_{t-1}$*
$$p_\theta (x_{t-1}|x_t) = \mathcal{N}(x_{t-1}|\mu_\theta(x_t,t), \Sigma_\theta(x_t, t))$$

Maximize expected probability of reconstructing the original image over your true distribution:
$$
\begin{aligned}
\argmax_\theta \mathbb{E}_{q(x_0)}[\log p_\theta(x_0)] \quad & \\
= \argmax_\theta \mathbb{E}_{q(x_0)}[\log \int p_\theta(x_{0:T}) \, dx_{1:T}] \quad & \text{Represent as marginal} \\
= \argmax_\theta \mathbb{E}_{q(x_0)}[\log \int q(x_{1:T}|x_0) \frac{p_\theta(x_{0:T}) }{q(x_{1:T}| x_0)} \, dx_{1:T}] \quad & \text{Importance sampling trick from known distribution, forward process $q(x_{1:T}| x_0)$} \\
= \argmax_\theta \mathbb{E}_{q(x_0)}[ \log \mathbb{E}_{q(x_{1:T}|x_0)}\left[\frac{p_\theta(x_{0:T})}{q(x_{1:T}| x_0)}\right]] \quad & \text{Definition of expectation} \\
= \argmin_\theta \mathbb{E}_{q(x_0)}\left[ -\log \mathbb{E}_{q(x_{1:T}|x_0)}\left[\frac{p_\theta(x_{0:T})}{q(x_{1:T}| x_0)}\right]\right] \quad &  \\
\leq \argmin_\theta \mathbb{E}_{q(x_{0:T})}\left[-\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}| x_0)}\right] \quad & \text{Jensen's Inequality, $-\log$ is convex} \\
\ldots \quad & \text{Simplifies to...} \\
\argmin_\theta \mathbb{E}_{t,x_0,\epsilon}\left[\left|\left|\epsilon - \epsilon_\theta(x_t, t)\right|\right|^2 \right]  \quad &  \\
\argmin_\theta \mathbb{E}_{t,x_0,\epsilon}\left[\left|\left|\epsilon - \epsilon_\theta\left(\sqrt{\alpha_{t}}x_0 + \sqrt{1-\alpha_{t}}\epsilon, t\right)\right|\right|^2 \right] \quad & 
\end{aligned} 
$$


In [1]:
import torch
import cv2
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# from PIL import Image
from torchvision.io import read_image
from tqdm.notebook import tqdm

In [2]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, file_list):
        super().__init__()
        self.file_list = file_list
        
    def __len__(self,):
        return len(self.file_list)
 
    def __getitem__(self, index):
        image_path = f"./data/celeb_faces/img_align_celeba/{self.file_list[index]}"
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (224,224))

        image = torch.tensor(image, dtype=torch.float32) / 255

        #scale from 0,1 to -1,1
        image = image * 2 - 1

        return image

In [3]:
class DownSample(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), padding=1)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=1)
        self.pool = torch.nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.bn = torch.nn.BatchNorm2d(out_channels) 

    def forward(self, x): # returns (feature map before pool, feature map after max pooling)
        pre_pool = self.conv1(x)
        pre_pool = self.conv2(pre_pool)
        # pre_pool = torch.nn.functional.batch_norm(pre_pool)
        pre_pool = torch.nn.functional.relu(pre_pool)


        post_pool = self.pool(pre_pool)

        return pre_pool, post_pool

class UpSample(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(2,2), stride=2)
        self.conv1 = torch.nn.Conv2d(out_channels*2, out_channels, kernel_size=(3,3), padding=1)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=1)
        self.bn = torch.nn.BatchNorm2d(out_channels)

    def forward(self, x, pre_pool): 
        x = self.upconv(x)

        # x and pre_pool are (B, C, H, W)
        # we concat across channels

        # # copy and crop
        # x_h, x_w = x.shape[2], x.shape[3]
        # pre_pool_h, pre_pool_w = pre_pool.skhape[2], pre_pool.shape[3]
        # dh, dw = pre_pool_h - x_h, pre_pool_w - x_w
        # pre_pool_cropped = pre_pool[:,:, dh//2: dh//2 + x_h, dw//2: dw//2 + x_w]

        # pad
        x_h, x_w = x.shape[2], x.shape[3]
        pre_pool_h, pre_pool_w = pre_pool.shape[2], pre_pool.shape[3]
        dh, dw = pre_pool_h - x_h, pre_pool_w - x_w
        pad = [
            dh // 2, math.ceil(dh / 2),
            dw // 2, math.ceil(dw / 2),
        ]
        padded_x = torch.nn.functional.pad(x, pad)

        x = torch.cat([padded_x, pre_pool], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)

        x = self.bn(x)
        x = torch.nn.functional.relu(x)
        
        return x
    
class UNet(torch.nn.Module):
    def __init__(self, image_shape=(3,224,244)):
        super().__init__()
        im_channels, im_height, im_width = image_shape
        self.down1 = DownSample(im_channels, 64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
    
        self.middle_conv1 = torch.nn.Conv2d(512, 1024, kernel_size=(3,3), padding=0)
        self.middle_conv2 = torch.nn.Conv2d(1024, 1024, kernel_size=(3,3), padding=0)

        self.up1 = UpSample(1024, 512)
        self.up2 = UpSample(512, 256)
        self.up3 = UpSample(256, 128)
        self.up4 = UpSample(128, 64)

        self.final_layer = torch.nn.Conv2d(64, 3, kernel_size=(3,3), padding=1)

    def forward(self,x):
        down1_feat, down1 = self.down1(x)
        down2_feat, down2 = self.down2(down1)
        down3_feat, down3 = self.down3(down2)
        down4_feat, down4 = self.down4(down3)

        middle1 = self.middle_conv1(down4)
        middle2 = self.middle_conv2(middle1)

        up1 = self.up1(middle2, down4_feat)
        up2 = self.up2(up1, down3_feat)
        up3 = self.up3(up2, down2_feat)
        up4 = self.up4(up3, down1_feat)

        output = self.final_layer(up4)
        output = torch.nn.functional.relu(output)

        return output

class DiffusionModel(torch.nn.Module):
    BETA_START = 1e-4
    BETA_END = 0.02
    MAX_TIMESTEPS = 1000
    IMG_SIZE = 224

    def __init__(self, device):
        super().__init__()


        self.betas = torch.linspace(self.BETA_START, self.BETA_END, self.MAX_TIMESTEPS)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0).to(device)

        # architecture
        self.net = UNet(image_shape=(3, self.IMG_SIZE, self.IMG_SIZE))

    def forward(self, input):
        return self.net(input)

    def apply_noise(self, image_batch, noise_batch, timestep):
        # noised image
        # sqrt(alpha_t) * img + sqrt(1 - alpha_t) * noise
        return torch.sqrt(self.alpha_bars[timestep]) * image_batch + torch.sqrt(1 - self.alpha_bars[timestep]) * noise_batch


In [7]:
splits = pd.read_csv("./data/celeb_faces/list_eval_partition.csv")

train_list = splits[splits["partition"] == 0]["image_id"].values
val_list = splits[splits["partition"] == 1]["image_id"].values

train_ds = ImageDataset(train_list[:100])
val_ds = ImageDataset(val_list[:10])

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    # num_workers=4
)
val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=16,
    shuffle=True,
)


In [None]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    
    device = torch.device('cpu')
else:
    device = torch.device("mps")

model = DiffusionModel(device)
model.to(device)

optim = torch.optim.Adam(model.parameters(), lr = 1e-4)

train_losses = []
val_losees = []
for epoch in range(50):
    print(f"Training Epoch {epoch}")
    losses = []
    model.train()
    print(len(train_loader))
    for i, input in tqdm(enumerate(train_loader)):
        optim.zero_grad()

        input = input.to(device)

        noise = torch.randn_like(input)
        noise = noise.to(device)
        noise_steps = torch.randint(0, model.MAX_TIMESTEPS, (1,))
        noised_input = model.apply_noise(input, noise, noise_steps)

        pred_noise = model(noised_input.permute(0, 3, 1, 2))
        loss = torch.nn.functional.mse_loss(pred_noise, noise.permute(0, 3, 1, 2))
        losses.append(loss.item())

        loss.backward()
        optim.step()

    train_losses.append(np.mean(losses))

    with torch.no_grad():
        print(f"Validation Epoch {epoch}")
        losses = []
        model.eval()
        for i, input in tqdm(enumerate(val_loader)):
            input = input.to(device)

            noise = torch.randn_like(input)
            noise = noise.to(device)
            noise_steps = torch.randint(0, model.MAX_TIMESTEPS, (1,))
            noised_input = model.apply_noise(input, noise, noise_steps)

            pred_noise = model(noised_input.permute(0, 3, 1, 2))
            loss = torch.nn.functional.mse_loss(pred_noise, noise.permute(0, 3, 1, 2))
            losses.append(loss.item())

        val_losees.append(np.mean(losses))
    
    print(f"Epoch {epoch} | Train Loss: {train_losses[-1]} Validation Loss: {val_losees[-1]}")
        




    # pred = model(noised_input.permute(0, 3, 1, 2))
    # pred = pred.permute(0, 2,3, 1)

    # original = input[0] / 2 + 0.5
    # noised = noised_input[0] / 2  + 0.5
    # pred_noise = pred[0] / 2 + 0.5

    # original, noised, pred_noise = original.detach().cpu(), noised.detach().cpu(), pred_noise.detach().cpu()
    # plt.imshow(torch.cat([original, noised, pred_noise], dim=1))
    # break



Training Epoch 0
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 0


0it [00:00, ?it/s]

Epoch 0 | Train Loss: 1.0030863285064697 Validation Loss: 0.9863555431365967
Training Epoch 1
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 1


0it [00:00, ?it/s]

Epoch 1 | Train Loss: 0.9033703207969666 Validation Loss: 0.9540436267852783
Training Epoch 2
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 2


0it [00:00, ?it/s]

Epoch 2 | Train Loss: 0.790685134274619 Validation Loss: 0.8901197910308838
Training Epoch 3
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 3


0it [00:00, ?it/s]

Epoch 3 | Train Loss: 0.6810774803161621 Validation Loss: 0.7975717186927795
Training Epoch 4
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 4


0it [00:00, ?it/s]

Epoch 4 | Train Loss: 0.6363945347922189 Validation Loss: 0.705755889415741
Training Epoch 5
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 5


0it [00:00, ?it/s]

Epoch 5 | Train Loss: 0.6135824918746948 Validation Loss: 0.6366470456123352
Training Epoch 6
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 6


0it [00:00, ?it/s]

Epoch 6 | Train Loss: 0.5735103743416923 Validation Loss: 0.5951613187789917
Training Epoch 7
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 7


0it [00:00, ?it/s]

Epoch 7 | Train Loss: 0.5747976132801601 Validation Loss: 0.8374192714691162
Training Epoch 8
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 8


0it [00:00, ?it/s]

Epoch 8 | Train Loss: 0.5642282792500087 Validation Loss: 0.5428347587585449
Training Epoch 9
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 9


0it [00:00, ?it/s]

Epoch 9 | Train Loss: 0.5427798203059605 Validation Loss: 0.5397014021873474
Training Epoch 10
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 10


0it [00:00, ?it/s]

Epoch 10 | Train Loss: 0.5370939203671047 Validation Loss: 0.5942836999893188
Training Epoch 11
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 11


0it [00:00, ?it/s]

Epoch 11 | Train Loss: 0.5342784779412406 Validation Loss: 0.5460073947906494
Training Epoch 12
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 12


0it [00:00, ?it/s]

Epoch 12 | Train Loss: 0.5325356210981097 Validation Loss: 0.5426895618438721
Training Epoch 13
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 13


0it [00:00, ?it/s]

Epoch 13 | Train Loss: 0.5293259450367519 Validation Loss: 0.5291503071784973
Training Epoch 14
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 14


0it [00:00, ?it/s]

Epoch 14 | Train Loss: 0.578303439276559 Validation Loss: 0.5402179956436157
Training Epoch 15
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 15


0it [00:00, ?it/s]

Epoch 15 | Train Loss: 0.5330976503235954 Validation Loss: 0.5326470732688904
Training Epoch 16
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 16


0it [00:00, ?it/s]

Epoch 16 | Train Loss: 0.5390392712184361 Validation Loss: 0.5427546501159668
Training Epoch 17
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 17


0it [00:00, ?it/s]

Epoch 17 | Train Loss: 0.5364804353032794 Validation Loss: 0.5992576479911804
Training Epoch 18
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 18


0it [00:00, ?it/s]

Epoch 18 | Train Loss: 0.5282705681664603 Validation Loss: 0.5418761968612671
Training Epoch 19
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 19


0it [00:00, ?it/s]

Epoch 19 | Train Loss: 0.5399659276008606 Validation Loss: 0.5252231955528259
Training Epoch 20
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 20


0it [00:00, ?it/s]

Epoch 20 | Train Loss: 0.5302597624914986 Validation Loss: 0.5435231328010559
Training Epoch 21
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 21


0it [00:00, ?it/s]

Epoch 21 | Train Loss: 0.5250600320952279 Validation Loss: 0.5304575562477112
Training Epoch 22
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 22


0it [00:00, ?it/s]

Epoch 22 | Train Loss: 0.5233299221311297 Validation Loss: 0.5246754884719849
Training Epoch 23
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 23


0it [00:00, ?it/s]

Epoch 23 | Train Loss: 0.5713819180216108 Validation Loss: 0.5407687425613403
Training Epoch 24
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 24


0it [00:00, ?it/s]

Epoch 24 | Train Loss: 0.5318551148687091 Validation Loss: 0.5451647043228149
Training Epoch 25
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 25


0it [00:00, ?it/s]

Epoch 25 | Train Loss: 0.5516894204275948 Validation Loss: 0.5375844836235046
Training Epoch 26
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 26


0it [00:00, ?it/s]

Epoch 26 | Train Loss: 0.5661568301064628 Validation Loss: 0.5284305810928345
Training Epoch 27
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 27


0it [00:00, ?it/s]

Epoch 27 | Train Loss: 0.572435200214386 Validation Loss: 0.5477023720741272
Training Epoch 28
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 28


0it [00:00, ?it/s]

Epoch 28 | Train Loss: 0.529661238193512 Validation Loss: 0.5422180891036987
Training Epoch 29
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 29


0it [00:00, ?it/s]

Epoch 29 | Train Loss: 0.5260694537843976 Validation Loss: 0.5354257225990295
Training Epoch 30
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 30


0it [00:00, ?it/s]

Epoch 30 | Train Loss: 0.5582549401691982 Validation Loss: 0.5350945591926575
Training Epoch 31
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 31


0it [00:00, ?it/s]

Epoch 31 | Train Loss: 0.5331936393465314 Validation Loss: 0.5815355777740479
Training Epoch 32
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 32


0it [00:00, ?it/s]

Epoch 32 | Train Loss: 0.541022607258388 Validation Loss: 0.5330294370651245
Training Epoch 33
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 33


0it [00:00, ?it/s]

Epoch 33 | Train Loss: 0.524385963167463 Validation Loss: 0.5841611623764038
Training Epoch 34
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 34


0it [00:00, ?it/s]

Epoch 34 | Train Loss: 0.5233293005398342 Validation Loss: 0.9852771759033203
Training Epoch 35
7


0it [00:00, ?it/s]

0
1
2
3
4
5
6
Validation Epoch 35


0it [00:00, ?it/s]

Epoch 35 | Train Loss: 0.5756667000906808 Validation Loss: 0.5579149127006531
Training Epoch 36
7


0it [00:00, ?it/s]

0
1
2
3
4
5
