Import the packages. In order to have all of them in the appropriate version, it is possible to create a conda environment from `environment.yml`.

In [None]:
import os
import random
import torch
import torchvision

from diffusion_model import Diffusion
from utils import set_device, plot_num, plot_process

Import the dataset from `torchvision` and save them in a local folder. We will just use the training data to treain our models; however, it is possible to merge training and test set or even to transform them applying slight rotations to expand the dataset size.

In [None]:
mnist_train = torchvision.datasets.MNIST('../data/', train= True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST('../data/', train= False, transform=torchvision.transforms.ToTensor(), download=True)

trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(mnist_test, batch_size=128, shuffle=True)

Set the appropriate device for the experiment. If the model has to be trained in-place, GPU acceleration is recommended. Alternatively, one can use a pretrained model; in such case, the CPU performance should be fine.

In [None]:
dvc = set_device()
dvc

With the following variables, it is possible to choose the size of the U-Net used to train the diffusion model and whether to train a new model or to use a pretrained one. In the latter case, models with `DIM` equal to 16 or 32 will be available.

The parameter `DIM` stands for the amount of channels at the top (i.e. initial and final) level of the U-Net. For the architecture used in this experiment, this implies that the bottom level will have `4*DIM` channels, each of shape 7x7.

This version of U-Net, despite being minimal, performs decently on a dataset with such small images. 

In [None]:
DIM = 32
USE_PRETRAINED = True

Let us initialize an instance from class `Diffusion`.

`T` is the number of steps of the Markov chain, `S` is the side size of the pictures.

In [None]:
diffusion = Diffusion(T = 100, S = 28, betas = torch.linspace(1e-04, 2e-02, 100), UNet_dim = DIM)

The parameter `beta` has to be chosen in such a way that as the time $t$ approaches $T=100$, the resulting transformed picture becomes unrecognizable, while for `t` close to $0$, the figure is similar to the original. 

These hyperparameters are crucial for the successful realization of the diffusion network. The following block of code can be used to get an idea of the range of values to be chosen. 

In [None]:
imgs, labs = next(iter(trainloader)) 

x0, lab = imgs[0].to(dvc), labs[0]
t = random.randint(0, diffusion.T)

eps = diffusion.sample_eps(torch.Size([x0.numel()])).reshape_as(x0)
                
alpha_hat = diffusion.alphas_hat[t]
x = x0*torch.sqrt(alpha_hat) + eps*torch.sqrt(1-alpha_hat)

print(f'original label {lab}\t time: {t}')

x = torch.squeeze(x.to('cpu'))
plot_num(x, cmap='viridis')

Load your model, or train it.

In [None]:
if USE_PRETRAINED:
    model_name = 'UNet_dim'+str(DIM)+'.pth'
    assert model_name in os.listdir(os.path.join(os.path.pardir, 'pretrained_models')), 'pretrained model not available for chosen DIM'
    diffusion.load_UNet(os.path.join(os.path.pardir, 'pretrained_models', model_name))

else:
    optimizer = torch.optim.Adam(diffusion.UNet.parameters(), 1e-03)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, 1e-04)
    diffusion.train(trainloader, 100, optimizer, scheduler)

    save = True
    if save:
        model_name = 'UNet_dim'+str(DIM)+'.pth'
        os.path.join(os.pardir, 'pretrained_models', model_name)
        torch.save(diffusion.UNet.state_dict(), os.path.join(os.pardir, 'pretrained_models', model_name))

Now, it is possible to generate new figures. Select the number as `num`. The additional parameter is list of integers: they are the steps for which the denoised image will be saved, in order to show how the figure of a number is progressively selected from noisy data.

In [None]:
num = 8

x0, process = diffusion.generate(num, [i*10 for i in range(1, 10)])

Then it is possible to display both the new and the process that led to its generation. Change `cmap` for a nicer visualization (default is b&w).

In [None]:
plot_num(x0, cmap='viridis'), plot_process(process, cmap='viridis')