In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision.transforms import Compose, ToTensor, RandomResizedCrop
import torch.optim as optim
from torchvision.datasets import CIFAR10, MNIST, Flowers102
from torch.utils.data import DataLoader

from functools import partial

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: %s' % device)

# Create Dataset

In [None]:
from utils.diffusionDataset import DiffusionDataset


image_size = (256, 256)
batch_size = 6
#variance_schedule = np.ones(20)*0.0011
variance_schedule = np.linspace(1e-3, 2e-2, 1000)       # from Ho et al. (2020)
#alpha_t = np.cos((t/T+s)/(1+s)*np.pi/2)**2             # from Nichol & Dhariwal (2021)


dataset_flowers = Flowers102(root='datasets',
                  download=True)

dataset = DiffusionDataset(data=dataset_flowers,
                           variance_schedule=variance_schedule,
                           transform=Compose([ToTensor(),
                                            RandomResizedCrop(image_size)]))

data_loader = DataLoader(dataset,
                         batch_size=batch_size,
                         shuffle=True)

## Have a look at some elements in the dataset

In [None]:
(im_n, t), noise = next(iter(data_loader))
im_n = im_n[0]
t = t[0]
noise = noise[0]

In [None]:
from utils.diffusionDataset import get_original_image

im = get_original_image(im_n, noise, t, variance_schedule)

print(f"Forward diffusion step at stage t = {t}.")

w, h, dpi = 1500, 500, 100
fig, ax = plt.subplots(ncols=3, figsize=(w/dpi, h/dpi), dpi=dpi)

ax[0].imshow(im_n.permute(1, 2, 0))
ax[0].set_title('Noisy image')
ax[1].imshow(noise.permute(1, 2, 0))
ax[1].set_title('Noise')
ax[2].imshow(im.permute(1, 2, 0))
ax[2].set_title('Actual image')
plt.show()

## Plot the variance schedule and $\bar{\alpha}$

In [None]:
from utils.diffusionDataset import get_alpha_bar

T = 1000
t = np.arange(T+1)
s = 1e-3

alpha_bar = get_alpha_bar(variance_schedule)

w, h, dpi = 1000, 500, 100
fig, ax1 = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)
ax2 = ax1.twinx()

ax1.plot(variance_schedule, color='C0')
ax2.plot(alpha_bar, color='C1')

ax1.set_xlabel('t/T')
ax1.set_ylabel('$beta_t$')
ax2.set_ylabel('$alpha_t$')

plt.show()

# MODEL

## Define Model

In [None]:
from models.firstModel import FirstModel

image_size = (128, 128)

net = FirstModel(img_shape=(3,)+image_size, device=device)
net.to(device)

training_loss_list = []

Test on random input to see if it works

In [None]:
input_size = (1, 3, 128, 128)
input = torch.randn(input_size).to(device)
output = net(input,1)

## Train Model

In [None]:
from utils.training import training

training_loss_list = training(net=net,
                              data_loader=data_loader,
                              loss_function=torch.nn.MSELoss(),
                              epochs=1,
                              device=device,
                              training_loss_list=training_loss_list)
plt.plot(training_loss_list)
plt.yscale('log')
plt.show()

## Test Performance

In [None]:
(noisy_image, t), noise = next(iter(data_loader))

pred_noise = net(noisy_image.to(device), t.to(device))

pred_noise = pred_noise.detach().cpu().numpy()[0].transpose(1, 2, 0)
noise = noise.detach().cpu().numpy()[0].transpose(1, 2, 0)
t = t.detach().cpu()[0]

noisy_image = noisy_image.numpy()[0].transpose(1, 2, 0)

In [None]:
pred_rec = get_original_image(noisy_image=noisy_image, noise=pred_noise, t=t, variance_schedule=variance_schedule)
rec = get_original_image(noisy_image=noisy_image, noise=noise, t=t, variance_schedule=variance_schedule)

print(t)

w, h, dpi = 1500, 500, 100
fig, ax = plt.subplots(ncols=3, figsize=(w/dpi, h/dpi), dpi=dpi)

ax[0].imshow(rec)
ax[0].set_title('Original image')
ax[1].imshow(noisy_image)
ax[1].set_title('Noisy input image')
ax[2].imshow(pred_rec)
ax[2].set_title('Reconstructed image')
plt.show()