## Importing Libraries.

In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from torchvision.utils import save_image, make_grid
import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image

### Neural Network modules and utility functions for image processing and data handling.

In [None]:
from Unet import UNet  #U_Net class for building UNet Encoder-Decoder
from custom_dataset import CustomDataset #Functions for Dataset Handling
from utils import  plot_sample, unorm, norm_all # Utility functions for Animation Processing and Data Handling

In [None]:
def plot_single_sample_with_intermidiate_steps(steps):
    # Change to Numpy image format (h, w, channels) vs (channels, h, w) for a single sample
    sx_gen_store = np.moveaxis(steps, 2, 4)
    nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], 1)  # Unity normalization for np.imshow

    # plot all 32 images in 8*4 grid
    fig, axs = plt.subplots(nrows=8, ncols=4, sharex=True, sharey=True, figsize=(4, 8))
    for i in range(8):
        for j in range(4):
            axs[i, j].imshow(nsx_gen_store[i*4+j, 0])
            axs[i, j].set_xticks([])
            axs[i, j].set_yticks([])

    plt.show()
    return


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),                
    transforms.Normalize((0.5,), (0.5,))  # range [-1,1]
    
    # Augmentations for lesser sized dataset
    # transforms.Lambda(quantize_image),
    # transforms.RandomRotation(degrees=90),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2),
])

### Model Hyperparameters.


In [None]:
# hyperparameters
# data hyperparameters
num_sprites = 25000

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
context_size = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './model/'

# training hyperparameters
batch_size = 100
n_epoch = 150 
lrate=1e-3

### Different Noise schedulings

In [None]:
# linear schedule 
# b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
# a_t = 1 - b_t
# ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
# ab_t[0] = 1

# cosine schedule
def cosine_schedule(timesteps, s=0.008):
    x = torch.linspace(0, timesteps, timesteps + 1, device=device)
    return torch.cos((x / timesteps + s) / (1 + s) * torch.pi / 2) ** 2

b_t = beta1 + (beta2 - beta1) * (1 - cosine_schedule(timesteps))
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1

# quadratic schedule
# b_t = beta1 + (beta2 - beta1) * (torch.linspace(0, 1, timesteps + 1, device=device) ** 2)
# a_t = 1 - b_t
# ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# ab_t[0] = 1

# exponential schedule
# b_t = beta1 * (beta2 / beta1) ** torch.linspace(0, 1, timesteps + 1, device=device)
# a_t = 1 - b_t
# ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# ab_t[0] = 1

### Constructing Model.

In [None]:
nn_model = UNet(in_channels=3, num_feature_maps=n_feat, context_size=context_size, image_size=height).to(device)

### Loading Dataset.

In [None]:
dataset = CustomDataset("./data_dir/sprites.npy","./data_dir/sprites_labels.npy", transform, nums_sprites=num_sprites, null_context=False)

In [None]:
# Plot images
sample_idx = np.random.choice(len(dataset.sprites), 10)
sample = [dataset.sprites[i] for i in sample_idx]
fig, ax = plt.subplots(1, 10, figsize=(20, 2))
for i in range(10):
    ax[i].imshow(sample[i])
    ax[i].axis('off')
plt.show()


In [None]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

In [None]:
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

### Model Training.

In [None]:
# set into train mode
nn_model.train()

losses = []

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        losses.append(loss.item())
        loss.backward()
        
        optim.step()



In [None]:
# plot loss
plt.plot(losses)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title('Training loss')
plt.show()

### Save Model

In [None]:
# save model periodically
torch.save(nn_model.state_dict(), save_dir + f"model_{n_epoch}_{num_sprites}.pth")
print('saved model at ' + save_dir + f"model_{n_epoch}_{num_sprites}.pth")

In [None]:
# removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

## Sampling

In [None]:
# sample using standard algorithm
@torch.no_grad()
def sample_with_intermidiates_ddpm(n_sample, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # store intermediate samples
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_{n_epoch}_{num_sprites}.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

### Visualizing Samples

In [None]:
# Sample one image and plot the intermediate steps
n_sample = 1
nrows = 5
save_rate = 20
sample, intermediate = sample_with_intermidiates_ddpm(n_sample, save_rate)
samples_dir = 'samples/'

plot_single_sample_with_intermidiate_steps(intermediate)

In [None]:
plt.clf()
samples, intermediate_ddpm = sample_with_intermidiates_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

# Acknowledgments
Sprites by ElvGames, [FrootsnVeggies](https://zrghr.itch.io/froots-and-veggies-culinary-pixels) and  [kyrise](https://kyrise.itch.io/)   
This code is modified from, https://github.com/cloneofsimo/minDiffusion   
Diffusion model is based on [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)