In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

import math
import sys

# Mount Google Drive if executed on Google Colab
try:
    from google.colab import drive

    drive.mount('/content/gdrive/')
    sys.path.append('/content/gdrive/MyDrive/GenAI')
except:
    print("Not running on Google Colab")

from images import show_grid
from model import Model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

In [None]:
# Download the dataset
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

print("Classes:", *cifar10.classes)

In [None]:
# Extract a category of images
real_images = cifar10.data[ [i for i, t in enumerate(cifar10.targets) if t == cifar10.classes.index('automobile')] ] / 256

# Use floats
real_images = np.array(real_images, dtype=np.float32)

# Put the channel at the end
real_images = np.swapaxes(real_images, 1, 3)

show_grid(real_images)

In [None]:
# As defined in https://theaisummer.com/diffusion-models/#forward-diffusion
def add_noise(img, beta):
    return math.sqrt(1 - beta) * img + np.random.normal(scale=math.sqrt(beta), size=img.shape)

In [None]:
# Add noise to an image
noisy = [real_images[0]]

for t in range(25):
    noisy.append(add_noise(noisy[-1], 0.005 + t/24 * 0.05))

show_grid(np.array(noisy))

In [None]:
# Generate noise for all the images
pairs = []

for i in tqdm(range(len(real_images))):

    im = real_images[i] - 0.5

    for t in range(24):
        noised = add_noise(im, 0.005 + t/24 * 0.05)
        pairs.append((np.array(im, dtype=np.float32), np.array(noised, dtype=np.float32), np.array(t / 25, dtype=np.float32)))
        im = noised

## Training

In [None]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.L1Loss()

data_loader = torch.utils.data.DataLoader(pairs, batch_size=1024, shuffle=True)

In [None]:
model.train()

loss_hist = []

for epoch in tqdm(range(20)):
    
    for batch, (target, noised, time) in enumerate(data_loader):
        optimizer.zero_grad()

        target = target.to(device)
        noised = noised.to(device)
        time = time.to(device)

        pred = model(noised, time)

        loss = loss_fn(pred, target)
        loss_hist.append(loss.item())

        loss.backward()
        optimizer.step()

In [None]:
plt.plot(loss_hist)

## Evaluation

In [None]:
def generate_image():
    noisy = [np.zeros((3, 32, 32))]
    for t in range(25):
        noisy.append(add_noise(noisy[-1], 0.005 + t/24 * 0.05))
    
    denoised = [torch.Tensor([noisy[-1]]).to(device)]

    for i in range(25):
        denoised.append(model(denoised[-1], torch.Tensor([(24-i) / 25]).to(device)))
    
    show_grid(np.array([d.detach().cpu().numpy().squeeze(0) for d in denoised]))

generate_image()

In [None]:
pair = pairs[0]

fig, axes = plt.subplots(1, 3, figsize=(6, 6))

axes[0].imshow(np.swapaxes(pair[0], 0, 2) + 0.5)
axes[0].axis('off')

axes[1].imshow(np.swapaxes(pair[1], 0, 2) + 0.5)
axes[1].axis('off')

pred = model(torch.Tensor(pair[1]).to(device).unsqueeze(0), torch.Tensor(pair[2]).to(device).unsqueeze(0)).squeeze()
axes[2].imshow(np.swapaxes(pred.detach().cpu().numpy(), 0, 2))
axes[2].axis('off')

plt.tight_layout()
plt.show()