In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

from PIL import Image

import math
import sys
import os

# Mount Google Drive if executed on Google Colab
try:
    from google.colab import drive

    drive.mount('/content/gdrive/')
    sys.path.append('/content/gdrive/MyDrive/GenAI')

    ROOT_PATH = '/content/gdrive/MyDrive/GenAI/'
    onColab = True
except:
    print("Not running on Google Colab")
    onColab = False
    ROOT_PATH = './'

from images import show_grid
from model import Model, PositionalEmbedding

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

In [None]:
NB_STEPS = 1000 - 1
LEARNING_RATE = 1e-4
BATCH_SIZE = 128

DATASET = 'hands32x32' # 'cifar10', 'hands32x32', 'hands64x64', 'hands128x128'

NB_EPOCHS = {
    'cifar10': 1000,
    'hands32x32': 50,
}[DATASET]

SAVE_EVERY = 10
EXPERIMENT_NAME = 'hands32x32'

if device == 'cpu':
    NB_EPOCHS = 3

In [None]:
# Create the output folder
OUTPUT_PATH = ROOT_PATH + f'output/{EXPERIMENT_NAME}/'
os.system(f'mkdir -p {OUTPUT_PATH}')

In [None]:
if DATASET == 'cifar10':
    # Download the dataset
    cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

    print("Classes:", *cifar10.classes)

    # Extract a category of images
    real_images = cifar10.data[ [i for i, t in enumerate(cifar10.targets) if t == cifar10.classes.index('automobile')] ]

elif DATASET in ['hands32x32', 'hands64x64', 'hands128x128']:
    s = int(DATASET.split('x')[1])

    if os.path.isfile(ROOT_PATH + f'data/hands/{str(s)}x{str(s)}.npy'):
        real_images = np.load(ROOT_PATH + f'data/hands/{str(s)}x{str(s)}.npy')
    else:
        folder = f'data/hands/{str(s)}x{str(s)}/'
        files = os.listdir(folder)

        real_images = np.empty((len(files), s, s, 3), dtype=np.uint8)

        for i, f in enumerate(files):
            img = np.array(Image.open(os.path.join(folder, f)))
            real_images[i] = img
        
        np.save(f'data/hands/{str(s)}x{str(s)}.npy', real_images)

else:
    print('Error: dataset does not exist')

real_images = real_images / (255 / 2) - 1

In [None]:
# Allow testing without GPU
if not onColab:
    real_images = real_images[:2]

# Data augmentation: rotate and flip the images
rotated = []
for img in real_images:
    rotated.append(img)
    rotated.append(np.rot90(img, 1, (0, 1)))
    rotated.append(np.rot90(img, 2, (0, 1)))
    rotated.append(np.rot90(img, 3, (0, 1)))

    img = np.fliplr(img)

    rotated.append(img)
    rotated.append(np.rot90(img, 1, (0, 1)))
    rotated.append(np.rot90(img, 2, (0, 1)))
    rotated.append(np.rot90(img, 3, (0, 1)))

real_images = rotated

# Use floats
real_images = np.array(real_images, dtype=np.float32)

# Put the channel at the end
real_images = np.swapaxes(real_images, 1, 3)

show_grid(real_images[:30])

In [None]:
def get_beta(step):
    return 0.0001 + (step / NB_STEPS) * 0.02

# Adds one or several times noise to an image
def add_noise(img, first_step, last_step = -1):
    if last_step == -1:
        last_step = first_step + 1

    alpha = 1
    for k in range(first_step, last_step):
        alpha *= (1 - get_beta(k))

    return math.sqrt(alpha) * img + np.random.normal(scale=math.sqrt(1 - alpha), size=img.shape)

In [None]:
# Add noise to an image progressively
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[-1], k))

show_grid(np.array(noisy[::20]))

del noisy

In [None]:
# Add noise from the beginning each time
noisy = [real_images[0]]

for k in range(NB_STEPS):
    noisy.append(add_noise(noisy[0], 0, k + 1))

show_grid(np.array(noisy[::20]))

del noisy

## Time encoding

In [None]:
pos_emb = PositionalEmbedding(NB_STEPS, 64)()

plt.figure(figsize=(8, 2.5))
plt.pcolormesh(pos_emb, cmap='viridis')
plt.xlabel("Embedding dimension")
plt.ylabel("Time")
plt.title("Positional Encoding")
plt.colorbar(label='Embedding value')
plt.show()

pos_emb = pos_emb.to(device)

## Training

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
model = Model(32).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=4, min_lr=1e-6)
loss_fn = nn.MSELoss()

# pin_memory improves performance on GPU
data_loader = torch.utils.data.DataLoader(real_images, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [None]:
def generate_image():
    model.eval()

    img = torch.normal(mean=torch.zeros(1, 3, 32, 32), std=torch.ones(1, 3, 32, 32)).to(device)

    hist = [] # The tensors are on the CPU

    for k in tqdm(range(NB_STEPS-1, -1, -1)):
        with torch.no_grad():
            #pred = model(img, pos_emb[k:k+1, :])
            pred = 0.98 * img

        alpha = 1
        for i in range(k+1):
            alpha *= (1 - get_beta(i))
        
        noise = torch.normal(mean=torch.zeros(1, 3, 32, 32), std=torch.ones(1, 3, 32, 32)).to(device)

        img = (img - get_beta(k) / math.sqrt(1 - alpha) * pred) / math.sqrt(1 - get_beta(k)) + \
            math.sqrt(get_beta(k)) * noise

        if k % ((NB_STEPS+1) // 40) == 0:
            hist.append(img.detach().cpu())

    return hist

In [None]:
model.train()

loss_hist, lr_hist = [], []
scaler = torch.cuda.amp.GradScaler()

for epoch in (pbar := tqdm(range(NB_EPOCHS))):
    epoch_str = str(epoch).zfill(5)

    lr = optimizer.param_groups[0]['lr']
    lr_hist.append(lr)
    pbar.set_description(f"lr = {lr}")

    sum_loss = 0

    for batch, images in enumerate(data_loader):
        images = images.to(device)

        # Generate noisy images
        err = torch.normal(mean=torch.zeros(images.shape), std=torch.ones(images.shape)).to(device)

        steps = torch.randint(0, NB_STEPS, size=(len(images), 1, 1, 1))

        alphas = torch.ones(steps.shape)
        for k in range(NB_STEPS):
            alphas = (steps > 0) * alphas * (1 - get_beta(k)) + (steps <= 0) * alphas
            steps -= 1
        alphas = alphas.to(device).repeat(1, *images.shape[1:])

        noisy_images = torch.sqrt(alphas) * images + torch.sqrt(1 - alphas) * err

        # Train the model on them
        optimizer.zero_grad(set_to_none=True)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            pred_err = model(noisy_images, pos_emb[steps.squeeze(1).squeeze(1).squeeze(1)])
            loss = loss_fn(pred_err, err)
            sum_loss += loss.item()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        del images, err, steps, alphas, noisy_images

    # Reduce the learning rate if needed
    scheduler.step(sum_loss)
    loss_hist.append(sum_loss)

    if epoch % SAVE_EVERY == SAVE_EVERY - 1:
        # Create an output folder
        os.system(f'mkdir -p {OUTPUT_PATH}{epoch_str}')

        # Save the model
        torch.save(model.state_dict(), OUTPUT_PATH + epoch_str + '/model')

        # Generate a few images
        generated_images = []

        for i in range(12):
            hist = generate_image()
            show_grid(np.array([img.numpy().squeeze(0) for img in hist]), 10, f'{OUTPUT_PATH}{epoch_str}/hist-' + str(i).zfill(2))
            np.save(f'{OUTPUT_PATH}{epoch_str}/{str(i).zfill(2)}.npy', hist[-1])
            generated_images.append(hist[-1])

        show_grid(np.array([img.numpy().squeeze(0) for img in generated_images]), 4, f'{OUTPUT_PATH}{epoch_str}/all')

        with open(OUTPUT_PATH + 'metrics', 'w') as f:
            f.write(str(loss_hist) + '\n')
            f.write(str(lr_hist) + '\n')

In [None]:
plt.figure()
plt.plot(loss_hist, label="Train loss")
plt.legend()
plt.show()

## Evaluation

In [None]:
# model.load_state_dict(torch.load("00079", map_location=torch.device('cpu')))