In [None]:
import os

from typing import Dict, Tuple
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from torchvision import models, transforms
import torchvision.utils as vutils
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML

from dataset import DatasetLoader
from models import ContextUnet
from diffusion_utilities import plot_sample, sample_ddpm, sample_ddim

## Parameters SetUP

In [None]:
# Root directory for artifacts
save_dir = "../../../artifacts/unconditioned/diffusion"

# Root directory for dataset
dataroot = "../../../cover_dataset"

# diffusion hyperparameters
timesteps = 1000
beta1 = 1e-4
beta2 = 0.02

# Number of hidden dimension feature
n_feat = 64

# Context vector is of size
n_cfeat = 1

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
height = 64  # 16x16 image

# Number of workers for dataloader
workers = 4

# Batch size during training
batch_size = 100

# Number of training epochs
n_epoch = 200

# Learning rate for optimizer
lrate = 1e-3

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 0

## Dataset Loading

In [None]:
# Create an instance of the DatasetLoader class
dataset_loader = DatasetLoader(dataroot)

# Use the dataset_loader object to create the dataset
dataloader = dataset_loader.create_dataset(height, batch_size, workers)

### Visualizing some images

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(device)[:64], padding=2, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

## Training

In [None]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(
    device
)

# re setup optimizer
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

In [None]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1


# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return (
        ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise
    )

In [None]:
# training without context code
# set into train mode
nn_model.train()

for ep in range(n_epoch):
    print(f"epoch {ep}")

    # linearly decay learning rate
    optim.param_groups[0]["lr"] = lrate * (1 - ep / n_epoch)

    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar:  # x: images
        optim.zero_grad()
        x = x.to(device)

        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device)
        x_pert = perturb_input(x, t, noise)

        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)

        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()

        optim.step()

    # save model periodically
    if ep % 4 == 0 or ep == int(n_epoch - 1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print("saved model at " + save_dir + f"model_{ep}.pth")

### Sampling

In [None]:
# load in model weights and set to eval mode
weights_path = f"{save_dir}/model_31.pth"
nn_model.load_state_dict(torch.load(weights_path, map_location=device))
nn_model.eval()
print("Loaded in Model")

In [None]:
# visualize samples (Slow approach)
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(
    intermediate_ddpm, 32, 4, save_dir, "ani_run", None, save=False
)
HTML(animation_ddpm.to_jshtml())

In [None]:
# visualize samples (Fast approach)
plt.clf()
samples, intermediate = sample_ddim(32, n=25)
animation_ddim = plot_sample(intermediate, 32, 4, save_dir, "ani_run", None, save=False)
HTML(animation_ddim.to_jshtml())