The following example notebook implements standard diffusion
with a simple CNN model to generate realistic MNIST digits.

This is a modified implementation of `minDiffusion`
which implements [DDPM](https://arxiv.org/abs/2006.11239).

To run this example notebook,
install requirements as in `requirements.txt` (for example, `pip install -r requirements.txt`).
You may also wish to follow system-dependent PyTorch instructions
[here](https://pytorch.org/) to install accelerated
versions of PyTorch, but note they are not needed
(I am testing this on my laptop).

If you do use accelerated hardware, make sure that your code
is still compatible with CPU-only installs.

First, let's create a folder to store example images:

In [2]:
!mkdir -p contents

In [1]:
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
from utils import ddpm_schedules, CNNBlock, CNN, DDPM, save_pickle

We will run this on MNIST. We perform some basic preprocessing, and set up the data loader:

In [2]:
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True)

Here we define the train function.

In [18]:
def train_ddpm(
        ddpm: nn.Module, 
        optim: torch.optim.Optimizer, 
        dataloader: DataLoader, 
        accelerator: Accelerator, 
        save_folder: str,
        n_epoch: int = 100,
        start_epoch: int = 0
    ):
    losses = []
    
    for i in range(n_epoch):
        ddpm.train()
    
        pbar = tqdm(dataloader)  # Wrap our loop with a visual progress bar
        for x, _ in pbar:
            optim.zero_grad()
    
            loss = ddpm(x)
    
            loss.backward()
            # ^Technically should be `accelerator.backward(loss)` but not necessary for local training
    
            losses.append(loss.item())
            avg_loss = np.average(losses[min(len(losses)-100, 0):])
            pbar.set_description(f"loss: {avg_loss:.3g}")  # Show running average of loss in progress bar
    
            optim.step()
    
        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(16, (1, 28, 28), accelerator.device)  # Can get device explicitly with `accelerator.device`
            grid = make_grid(xh, nrow=4)
    
            # Save samples to `./contents` directory
            save_image(grid, f"./{save_folder}/ddpm_sample_{start_epoch + i:04d}.png")
    
            # save model
            torch.save(ddpm.state_dict(), f"./{save_folder}/ddpm_mnist_{start_epoch + i}.pth")
            
    save_pickle(losses, f"./{save_folder}/ddpm_mnist_losses_{start_epoch}.pkl")


Here we define each of the models we will be testing.

We essentially just vary the noise schedule but keep it linear in all of them. We keep the initial noise the same but change the final noise so that we add noise at a quicker or slower rate.

Here, we use HuggingFace's `accelerate` library, which abstracts away all the `.to(device)` calls for us.
This lets us focus on the model itself rather than data movement.
It also does a few other tricks to speed up calculations.

In [27]:
# this is the default model
accelerator_1 = Accelerator()
gt_1 = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)
ddpm_1 = DDPM(gt=gt_1, betas=(1e-4, 0.02), n_T=1000)
optim_1 = torch.optim.Adam(ddpm_1.parameters(), lr=2e-4)
ddpm_1, optim_1, dataloader_1 = accelerator_1.prepare(ddpm_1, optim_1, dataloader)

In [28]:
accelerator_2 = Accelerator()
gt_2 = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)
ddpm_2 = DDPM(gt=gt_2, betas=(1e-4, 0.1), n_T=200)
optim_2 = torch.optim.Adam(ddpm_2.parameters(), lr=2e-4)
ddpm_2, optim_2, dataloader_2 = accelerator_2.prepare(ddpm_2, optim_2, dataloader)

In [29]:
accelerator_3 = Accelerator()
gt_3 = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)
ddpm_3 = DDPM(gt=gt_3, betas=(1e-4, 0.004), n_T=5000)
optim_3 = torch.optim.Adam(ddpm_3.parameters(), lr=2e-4)
ddpm_3, optim_3, dataloader_3 = accelerator_3.prepare(ddpm_3, optim_3, dataloader)

Now we train the models.

In [30]:
!mkdir -p contents_2
train_ddpm(ddpm_2, optim_2, dataloader_2, accelerator_2, "contents_2")

loss: 0.114: 100%|██████████| 468/468 [00:29<00:00, 15.67it/s]
loss: 0.0804: 100%|██████████| 468/468 [00:28<00:00, 16.43it/s]
loss: 0.0666: 100%|██████████| 468/468 [00:28<00:00, 16.37it/s]
loss: 0.0586: 100%|██████████| 468/468 [00:28<00:00, 16.29it/s]
loss: 0.0532: 100%|██████████| 468/468 [00:28<00:00, 16.35it/s]
loss: 0.0492: 100%|██████████| 468/468 [00:28<00:00, 16.37it/s]
loss: 0.0461: 100%|██████████| 468/468 [00:28<00:00, 16.44it/s]
loss: 0.0437: 100%|██████████| 468/468 [00:28<00:00, 16.29it/s]
loss: 0.0417: 100%|██████████| 468/468 [00:28<00:00, 16.53it/s]
loss: 0.04: 100%|██████████| 468/468 [00:28<00:00, 16.43it/s]  
loss: 0.0385: 100%|██████████| 468/468 [00:28<00:00, 16.38it/s]
loss: 0.0373: 100%|██████████| 468/468 [00:28<00:00, 16.40it/s]
loss: 0.0362: 100%|██████████| 468/468 [00:28<00:00, 16.36it/s]
loss: 0.0352: 100%|██████████| 468/468 [00:28<00:00, 16.57it/s]
loss: 0.0343: 100%|██████████| 468/468 [00:28<00:00, 16.54it/s]
loss: 0.0335: 100%|██████████| 468/468 [0

In [31]:
!mkdir -p contents_3
train_ddpm(ddpm_3, optim_3, dataloader_3, accelerator_3, "contents_3")

loss: 0.124: 100%|██████████| 468/468 [00:28<00:00, 16.39it/s]
loss: 0.0841: 100%|██████████| 468/468 [00:27<00:00, 16.83it/s]
loss: 0.068: 100%|██████████| 468/468 [00:27<00:00, 16.78it/s] 
loss: 0.0589: 100%|██████████| 468/468 [00:27<00:00, 16.75it/s]
loss: 0.0528: 100%|██████████| 468/468 [00:30<00:00, 15.49it/s]
loss: 0.0486: 100%|██████████| 468/468 [00:29<00:00, 15.98it/s]
loss: 0.0452: 100%|██████████| 468/468 [00:28<00:00, 16.31it/s]
loss: 0.0427: 100%|██████████| 468/468 [00:28<00:00, 16.66it/s]
loss: 0.0405: 100%|██████████| 468/468 [00:29<00:00, 15.96it/s]
loss: 0.0387: 100%|██████████| 468/468 [00:29<00:00, 15.97it/s]
loss: 0.0372: 100%|██████████| 468/468 [00:29<00:00, 15.89it/s]
loss: 0.0359: 100%|██████████| 468/468 [00:29<00:00, 15.87it/s]
loss: 0.0348: 100%|██████████| 468/468 [00:28<00:00, 16.19it/s]
loss: 0.0338: 100%|██████████| 468/468 [00:28<00:00, 16.21it/s]
loss: 0.0329: 100%|██████████| 468/468 [00:29<00:00, 16.10it/s]
loss: 0.0321: 100%|██████████| 468/468 [0