In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

import math

from images import show_grid
from model import Model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Download the dataset
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

print("Classes:", *cifar10.classes)

In [None]:
# Extract a category of images
real_images = cifar10.data[ [i for i, t in enumerate(cifar10.targets) if t == cifar10.classes.index('automobile')] ] / 256

# Use floats
real_images = np.array(real_images)

# Put the channel at the end
real_images = np.swapaxes(real_images, 1, 3)

show_grid(real_images)

In [None]:
# As defined in https://theaisummer.com/diffusion-models/#forward-diffusion
def add_noise(img, beta):
    return math.sqrt(1 - beta) * img + np.random.normal(scale=math.sqrt(beta), size=img.shape)

In [None]:
# Add noise to an image
noisy = [real_images[0]]

for t in range(25):
    noisy.append(add_noise(noisy[-1], 0.005 + t/24 * 0.05))

show_grid(np.array(noisy))

In [None]:
# Generate noise for all the images
pairs = []

for i in tqdm(range(len(real_images))):
    im = real_images[i]

    for t in range(24):
        noised = add_noise(im, 0.005 + t/24 * 0.05)
        pairs.append((im, noised))
        im = noised
    
    # Todel
    if i > 10:
        break

pairs = np.array(pairs, dtype=np.float32)

## Training

In [None]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# TODO batch size
data_loader = torch.utils.data.DataLoader(pairs, batch_size=4096, shuffle=True)

In [None]:
model.train()

loss_hist = []

for epoch in range(10):
    
    for batch, data in enumerate(data_loader):
        optimizer.zero_grad()

        noised = data[:, 1, :, :].to(device)
        target = data[:, 0, :, :].to(device)

        pred = model(noised)

        loss = loss_fn(pred, target)
        loss_hist.append(loss.item())

        loss.backward()
        optimizer.step()

In [None]:
plt.plot(loss_hist[10:])

## Evaluation