In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

import math
import sys

# Mount Google Drive if executed on Google Colab
try:
    from google.colab import drive

    drive.mount('/content/gdrive/')
    sys.path.append('/content/gdrive/MyDrive/GenAI')
except:
    print("Not running on Google Colab")

from images import show_grid
from model import Model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

In [None]:
NB_STEPS = 100 - 1
NB_EPOCHS = 10

In [None]:
# Download the dataset
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

print("Classes:", *cifar10.classes)

In [None]:
# Extract a category of images
real_images = cifar10.data[ [i for i, t in enumerate(cifar10.targets) if t == cifar10.classes.index('automobile')] ] / (255 / 2) - 1

# Use floats
real_images = np.array(real_images, dtype=np.float32)

# Put the channel at the end
real_images = np.swapaxes(real_images, 1, 3)

show_grid(real_images[:25])

In [None]:
def get_beta(step):
    #return 0.0001 + (step / NB_STEPS) * 0.02
    return 0.0001 + (step / NB_STEPS) * 0.15

# Adds one or several times noise to an image
def add_noise(img, first_step, last_step = -1):
    if last_step == -1:
        last_step = first_step + 1

    alpha = 1
    for k in range(first_step, last_step):
        alpha *= (1 - get_beta(k))
    
    return math.sqrt(alpha) * img + np.random.normal(scale=math.sqrt(1 - alpha), size=img.shape)

In [None]:
# Add noise to an image progressively
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[-1], k))

show_grid(np.array(noisy))

del noisy

In [None]:
# Add noise from the beginning each time
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[0], 0, k + 1))

show_grid(np.array(noisy))

del noisy

In [None]:
# Generate noise for all the images
pairs = []

for i in tqdm(range(len(real_images))):
    im = real_images[i]

    for k in range(NB_STEPS):
        noised_k = add_noise(im, 0, k)
        next_noised = add_noise(noised_k, k, k+1)
        pairs.append((np.array(noised_k, dtype=np.float32), np.array(next_noised, dtype=np.float32), np.array(k / NB_STEPS, dtype=np.float32)))
        break
    break

## Training

In [None]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.L1Loss()

data_loader = torch.utils.data.DataLoader(pairs, batch_size=1024, shuffle=True)

In [None]:
model.train()

loss_hist = []

for epoch in tqdm(range(NB_EPOCHS)):
    
    for batch, (target, noised, time) in enumerate(data_loader):
        optimizer.zero_grad()

        target = target.to(device)
        noised = noised.to(device)
        time = time.to(device)

        pred = model(noised, time)

        loss = loss_fn(pred, target)
        loss_hist.append(loss.item())

        loss.backward()
        optimizer.step()

In [None]:
plt.plot(loss_hist)

## Evaluation

In [None]:
def generate_image():
    noisy = np.random.normal(size=(3, 32, 32))
    
    denoised = [torch.Tensor([noisy]).to(device)]

    for k in range(NB_STEPS):
        t = (NB_STEPS - k - 1) / NB_STEPS
        denoised.append(model(denoised[-1], torch.Tensor([t]).to(device)))
    
    show_grid(np.array([d.detach().cpu().numpy().squeeze(0) for d in denoised]))

generate_image()

In [None]:
im1, im2, t = pairs[0]

pred = model(torch.Tensor(im2).to(device).unsqueeze(0), torch.Tensor(t).to(device).unsqueeze(0)).squeeze()
pred = pred.detach().cpu().numpy()

show_grid(np.array([im1, im2, pred]))