In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision.transforms import Compose, ToTensor, RandomResizedCrop
import torch.optim as optim
from torchvision.datasets import CIFAR10, MNIST, Flowers102
from torch.utils.data import DataLoader

from functools import partial

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: %s' % device)

# Create Dataset

In [None]:
from utils.diffusionDataset import DiffusionDataset


image_size = (128, 128)
batch_size = 64
#variance_schedule = np.ones(20)*0.0011
variance_schedule = np.linspace(1e-3, 2e-2, 1000)       # from Ho et al. (2020)
#alpha_t = np.cos((t/T+s)/(1+s)*np.pi/2)**2             # from Nichol & Dhariwal (2021)


dataset_flowers = Flowers102(root='datasets',
                  download=True)

dataset = DiffusionDataset(data=dataset_flowers,
                           variance_schedule=variance_schedule,
                           transform=Compose([ToTensor(),
                                            RandomResizedCrop(image_size)]))

data_loader = DataLoader(dataset,
                         batch_size=batch_size,
                         pin_memory=True,
                         shuffle=True)

## Have a look at some elements in the dataset

In [None]:
(im_n, t), noise = next(iter(data_loader))
im_n = im_n[0]
t = t[0]
noise = noise[0]

In [None]:
from utils.diffusionDataset import get_original_image

im = get_original_image(im_n, noise, t, variance_schedule)

print(f"Forward diffusion step at stage t = {t}.")

w, h, dpi = 1500, 500, 100
fig, ax = plt.subplots(ncols=3, figsize=(w/dpi, h/dpi), dpi=dpi)

ax[0].imshow(im_n.permute(1, 2, 0))
ax[0].set_title('Noisy image')
ax[1].imshow(noise.permute(1, 2, 0))
ax[1].set_title('Noise')
ax[2].imshow(im.permute(1, 2, 0))
ax[2].set_title('Actual image')
plt.show()

## Plot the variance schedule and $\bar{\alpha}$

In [None]:
from utils.diffusionDataset import get_alpha_bar

T = 1000
t = np.arange(T+1)
s = 1e-3

alpha_bar = get_alpha_bar(variance_schedule)

w, h, dpi = 1000, 500, 100
fig, ax1 = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)
ax2 = ax1.twinx()

ax1.plot(variance_schedule, color='C0')
ax2.plot(alpha_bar, color='C1')

ax1.set_xlabel('t/T')
ax1.set_ylabel('$beta_t$')
ax2.set_ylabel('$alpha_t$')

plt.show()

# Train Model

## Define Model
And load previously saved weights if necessary

In [None]:
from models.firstModel import FirstModel
from models.secondModel import SecondModelSum, SecondModelConcat

net = FirstModel(img_shape=(3,)+image_size, device=device)
net.to(device)

#net.load_state_dict(torch.load('weights/weights_FirstModel.pth'))

training_loss_list = []

Test on random input to see if it works

In [None]:
input_size = (1, 3, 128, 128)
input = torch.randn(input_size).to(device)
output = net(input, torch.Tensor(1).to(device))

## Training

In [None]:
from utils.training import training

training_loss_list = training(net=net,
                              data_loader=data_loader,
                              loss_function=torch.nn.MSELoss(),
                              epochs=5,
                              device=device,
                              training_loss_list=training_loss_list)
plt.plot(training_loss_list)
plt.yscale('log')
plt.show()

Save weights if necessary

In [None]:
#torch.save(net.state_dict(), 'weights/weights_SecondModelSum.pth')

## Test Performance

Get an image from the dataset

In [None]:
from utils.reconstruction import reconstruct_image_from_noise
from utils.eval import visualize_single_reconstruction

(noisy_image_batch, t_batch), noise_batch = next(iter(data_loader))
i = np.argmin(t_batch.numpy())

noisy_image = noisy_image_batch.numpy()[i]
t = t_batch.numpy()[i]
noise = noise_batch.numpy()[i]

# true original image
original_image = reconstruct_image_from_noise(noisy_image=noisy_image,
                                              noise=noise,
                                              t=t,
                                              variance_schedule=variance_schedule)

print(f't = {t}')

Visualize results on a single image

In [None]:
visualize_single_reconstruction(net, original_image, noisy_image, t, variance_schedule, device)


# Model Comparisons on single images

## FirstModel

In [None]:
from models.firstModel import FirstModel

net = FirstModel(img_shape=(3,)+image_size, device=device)
net.load_state_dict(torch.load('weights/weights_FirstModel.pth'))
net.to(device)

visualize_single_reconstruction(net, original_image, noisy_image, t, variance_schedule, device)

## SecondModelSum

In [None]:
from models.secondModel import SecondModelSum

net = SecondModelSum(img_shape=(3,)+image_size, device=device)
net.load_state_dict(torch.load('weights/weights_SecondModelSum.pth'))
net.to(device)

visualize_single_reconstruction(net, original_image, noisy_image, t, variance_schedule, device)

## SecondModelConcat

In [None]:
from models.secondModel import SecondModelConcat

net = SecondModelConcat(img_shape=(3,)+image_size, device=device)
net.load_state_dict(torch.load('weights/weights_SecondModelConcat.pth'))
net.to(device)

visualize_single_reconstruction(net, original_image, noisy_image, t, variance_schedule, device)