In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

import math
import sys

# Mount Google Drive if executed on Google Colab
try:
    from google.colab import drive

    drive.mount('/content/gdrive/')
    sys.path.append('/content/gdrive/MyDrive/GenAI')
except:
    print("Not running on Google Colab")

from images import show_grid
from model import Model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

In [None]:
NB_STEPS = 1000 - 1
NB_EPOCHS = 10
LEARNING_RATE = 1e-4
BATCH_SIZE = 512

In [None]:
# Download the dataset
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

print("Classes:", *cifar10.classes)

In [None]:
# Extract a category of images
real_images = cifar10.data[ [i for i, t in enumerate(cifar10.targets) if t == cifar10.classes.index('automobile')] ] / (255 / 2) - 1

real_images = real_images[:2] # TODEL

# Use floats
real_images = np.array(real_images, dtype=np.float32)

# Put the channel at the end
real_images = np.swapaxes(real_images, 1, 3)

show_grid(real_images[:25])

In [None]:
def get_beta(step):
    return 0.0001 + (step / NB_STEPS) * 0.02

# Adds one or several times noise to an image
def add_noise(img, first_step, last_step = -1):
    if last_step == -1:
        last_step = first_step + 1

    alpha = 1
    for k in range(first_step, last_step):
        alpha *= (1 - get_beta(k))

    return math.sqrt(alpha) * img + np.random.normal(scale=math.sqrt(1 - alpha), size=img.shape)

In [None]:
# Add noise to an image progressively
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[-1], k))

show_grid(np.array(noisy[::20]))

del noisy

In [None]:
# Add noise from the beginning each time
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[0], 0, k + 1))

show_grid(np.array(noisy[::20]))

del noisy

## Training

In [None]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

data_loader = torch.utils.data.DataLoader(real_images, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model.train()

loss_hist, diff_loss_hist = [], []

for epoch in tqdm(range(NB_EPOCHS)):
    
    for batch, images in enumerate(data_loader):
        images = images.to(device)

        # Generate noisy images
        err = torch.normal(mean=torch.zeros(images.shape), std=torch.ones(images.shape)).to(device)

        steps = torch.randint(0, NB_STEPS, size=(len(images), 1, 1, 1)).to(device)

        alphas = torch.ones(steps.shape)
        for k in range(NB_STEPS):
            alphas = (steps > 0) * alphas * (1 - get_beta(k)) + (steps <= 0) * alphas
        alphas = alphas.to(device).repeat(1, *images.shape[1:])

        noisy_images = torch.sqrt(alphas) * images + torch.sqrt(1 - alphas) * err

        times = steps.squeeze(1).squeeze(1).squeeze(1) / NB_STEPS

        # Train the model on them
        optimizer.zero_grad()

        pred_err = model(noisy_images, times)

        loss = loss_fn(pred_err, err)
        loss_hist.append(loss.item())

        loss.backward()
        optimizer.step()

        # Compute the 
        with torch.no_grad():
            loss_0 = loss_fn(torch.zeros(err.shape), err.cpu())

            # The value should progressively become negative
            diff_loss_hist.append(loss.item() - loss_0.item())

        del images, err, steps, alphas, noisy_images, times

In [None]:
plt.figure()
plt.plot(diff_loss_hist)
plt.show()

In [None]:
plt.figure()
plt.plot(loss_hist, label="Train loss")
plt.legend()
plt.show()

## Evaluation

In [None]:
def generate_image():
    # Avoids memory problems. TODO: improve
    model_cpu = model.cpu()
    model_cpu.eval()

    noisy = np.random.normal(size=(3, 32, 32))

    denoised = [torch.Tensor([noisy])]

    for k in tqdm(range(NB_STEPS-1, -1, -1)):
        t = k / NB_STEPS
        im = denoised[-1]
        pred = model_cpu(im, torch.Tensor([t]))

        alpha = 1
        for i in range(k+1):
            alpha *= (1 - get_beta(i))

        denoised.append((
            (im - get_beta(k) / math.sqrt(1 - alpha) * pred) / math.sqrt(1 - get_beta(k)) +
            math.sqrt(get_beta(k)) * torch.Tensor(np.random.normal(size=(3, 32, 32)))   
        ).cpu())

    show_grid(np.array([d.detach().cpu().numpy().squeeze(0) for d in denoised[::20]]))

generate_image()