The following example notebook implements standard diffusion
with a simple CNN model to generate realistic MNIST digits.

This is a modified implementation of `minDiffusion`
which implements [DDPM](https://arxiv.org/abs/2006.11239).

To run this example notebook,
install requirements as in `requirements.txt` (for example, `pip install -r requirements.txt`).
You may also wish to follow system-dependent PyTorch instructions
[here](https://pytorch.org/) to install accelerated
versions of PyTorch, but note they are not needed
(I am testing this on my laptop).

If you do use accelerated hardware, make sure that your code
is still compatible with CPU-only installs.

First, let's create a folder to store example images:

In [2]:
!mkdir -p contents

In [1]:
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
from utils import ddpm_schedules, CNNBlock, CNN, DDPM

We will run this on MNIST. We perform some basic preprocessing, and set up the data loader:

In [2]:
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True)

We create our model with a given choice of hidden layers and activation function. We also choose a learning rate.

In [3]:
gt = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)
# For testing: (16, 32, 32, 16)
# For more capacity (for example): (64, 128, 256, 128, 64)
ddpm = DDPM(gt=gt, betas=(1e-4, 0.02), n_T=1000)
optim = torch.optim.Adam(ddpm.parameters(), lr=2e-4)

We could set up a GPU if we have one, which is done below.

Here, we use HuggingFace's `accelerate` library, which abstracts away all the `.to(device)` calls for us.
This lets us focus on the model itself rather than data movement.
It also does a few other tricks to speed up calculations.

PyTorch Lightning, which we discussed during the course, is another option that also handles a lot more, but is a bit heavyweight.
`accelerate` is a simpler option closer to raw PyTorch.
However, if you prefer, you could choose to use Lightning for the coursework!

In [4]:
accelerator = Accelerator()

# We wrap our model, optimizer, and dataloaders with `accelerator.prepare`,
# which lets HuggingFace's Accelerate handle the device placement and gradient accumulation.
ddpm, optim, dataloader = accelerator.prepare(ddpm, optim, dataloader)

First, let's just make sure this works:

In [5]:
for x, _ in dataloader:
    break

with torch.no_grad():
    ddpm(x)

Now, let's train it. You can exit early by interrupting the kernel. Images
are saved to the `contents` folder.

In [None]:
n_epoch = 100
losses = []

for i in range(n_epoch):
    ddpm.train()

    pbar = tqdm(dataloader)  # Wrap our loop with a visual progress bar
    for x, _ in pbar:
        optim.zero_grad()

        loss = ddpm(x)

        loss.backward()
        # ^Technically should be `accelerator.backward(loss)` but not necessary for local training

        losses.append(loss.item())
        avg_loss = np.average(losses[min(len(losses)-100, 0):])
        pbar.set_description(f"loss: {avg_loss:.3g}")  # Show running average of loss in progress bar

        optim.step()

    ddpm.eval()
    with torch.no_grad():
        xh = ddpm.sample(16, (1, 28, 28), accelerator.device)  # Can get device explicitly with `accelerator.device`
        grid = make_grid(xh, nrow=4)

        # Save samples to `./contents` directory
        save_image(grid, f"./contents/ddpm_sample_{i:04d}.png")

        # save model
        torch.save(ddpm.state_dict(), f"./ddpm_mnist.pth")


loss: 0.132: 100%|██████████| 468/468 [00:34<00:00, 13.69it/s]
loss: 0.089: 100%|██████████| 468/468 [00:30<00:00, 15.48it/s] 
loss: 0.072: 100%|██████████| 468/468 [00:29<00:00, 16.05it/s] 
loss: 0.0623: 100%|██████████| 468/468 [00:29<00:00, 15.96it/s]
loss: 0.056: 100%|██████████| 468/468 [00:29<00:00, 15.93it/s] 
loss: 0.0514: 100%|██████████| 468/468 [00:29<00:00, 15.94it/s]
loss: 0.0479: 100%|██████████| 468/468 [00:29<00:00, 15.73it/s]
loss: 0.0453: 100%|██████████| 468/468 [00:30<00:00, 15.53it/s]
loss: 0.0431: 100%|██████████| 468/468 [00:28<00:00, 16.34it/s]
loss: 0.0412: 100%|██████████| 468/468 [00:28<00:00, 16.24it/s]
loss: 0.0397: 100%|██████████| 468/468 [00:28<00:00, 16.37it/s]
