In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader

import os, math
import sys
import shutil
import random
import numpy as np
import skfmm

import GAN as GAN
from GAN import Generator
from GAN import Critic

import wandb

import data_manager as dm

## Initialize Weights and Biases

#### Configure the run

In [None]:
RECORD_METRICS = True

# Inputs
DATASET = 'random_40_density'
SUBSET = 'train'
BATCH_SIZE = 10


# Structure
NUM_LAYERS_CRIT = 6
KERNEL_CRIT = ["all 4"]
STRIDE_CRIT = [2,2,2,2,1,1]
PAD_CRIT = ["all 'same' (as seen in Keras)"]
FEATURES_CRIT = [3,64,128,256,512,512,1]

NUM_LAYERS_GEN = 16
KERNEL_GEN = ["all 4"]
STRIDE_GEN = ["all 2"]
PAD_GEN = ["all 'same' (as seen in Keras)"]
FEATURES_GEN = [2,64,128,256,512,512,512,512,512,512,512,512,512,256,128,64,1]


# Hyperparameters
LR_CRIT = 2e-4
LR_GEN = 2e-4
CRIT_ITERATIONS = 5
LAMBDA = 10


# Internal Data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAP_SHAPE = (64,64)
# NOISE_SHAPE = (BATCH_SIZE, 1, MAP_SHAPE[0], MAP_SHAPE[1])

NUM_EPOCHS = 20
START_EPOCH = 0

Initialize WandB

In [None]:
GROUP=''

CONFIG = dict(
    dataset = DATASET,
    subset = SUBSET,

    layers_crit = NUM_LAYERS_CRIT,
    kernels_crit = KERNEL_CRIT,
    stride_crit = STRIDE_CRIT,
    padding_crit = PAD_CRIT,
    features_crit = FEATURES_CRIT,

    layers_gen = NUM_LAYERS_GEN,
    kernels_gen = KERNEL_GEN,
    stride_gen = STRIDE_GEN,
    padding_gen = PAD_GEN,
    features_gen = FEATURES_GEN,

    batch_size = BATCH_SIZE,
    learning_rate_crit = LR_CRIT,
    learning_rate_gen = LR_GEN,
    crit_iterations = CRIT_ITERATIONS,
    gp_coefficient = LAMBDA
)

if RECORD_METRICS:
    run = wandb.init(project='wgan-gp', entity='aicv-lab', config=CONFIG, group=GROUP)

## Define The GAN's Structure

In [None]:
# Save the GAN's definitions and hyperparams
if RECORD_METRICS:
    dm.save_gan(run.name, CONFIG)

## Define Essential Functions

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(coeff, critic, real, fake, device="cpu"):
    # sample x_hat from P(x_hat)
    rand = torch.randn((real.shape[0], 1, 1, 1), device=device) # generate a random number from 0 to 1 for each matrix in the batch
    x_hat = rand*real + (1-rand)*fake

    critic_output = critic(x_hat)
    grad_ones = torch.ones_like(critic_output, device=device)

    gp = torch.autograd.grad(                                   # find magnitude of critic's resulting gradient
        inputs = x_hat,
        outputs = critic_output,
        grad_outputs = grad_ones,
        create_graph = True,
        retain_graph = True
    )[0]

    gp = torch.norm(gp, p=2, dim=(1,2,3))    # vector norm of each gradient
    gp = (gp - 1)**2
    gp = coeff * torch.mean(gp)

    return gp

In [None]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, subset, device='cpu'):
        self.device = device
        self.paths = dm.load_input(dataset, subset) # Load all of the paths in the specified set

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x = self.paths[idx]
        x = x.to(self.device)
        return x

## Initialize Model & Data

In [None]:
train_dataset = PathsDataset(DATASET, SUBSET, device=device)
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
curr_epoch = START_EPOCH

gen = Generator(FEATURES_GEN, KERNEL_GEN, STRIDE_GEN, PAD_GEN, device=device)
critic = Critic(FEATURES_CRIT, KERNEL_CRIT, STRIDE_CRIT, PAD_CRIT, device=device)

opt_gen = optim.Adam(gen.parameters(), lr=LR_GEN, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LR_CRIT, betas = (0.0, 0.9))

initialize_weights(gen)
initialize_weights(critic)

In [None]:
# fixed_noise = torch.rand(NOISE_SHAPE, device=device)  # Uncomment when giving gen input with random noise

gen.train()
critic.train()

## Train the Model

In [None]:
for epoch in range(NUM_EPOCHS):
    curr_epoch += 1
    for batch_idx, real in enumerate(dataloader):
        real = F.interpolate(real, size=(256,256))  # pix2pix requires 256x256 inputs

        initial_path = real[:,1:2,:,:]
        # fixed_input = torch.concat((fixed_noise, real[:,1:,:,:]), axis=1)  # Uncomment when giving gen input with random noise
        fixed_input = real[:,1:,:,:]  # Uncomment when giving gen input without random noise

        for _ in range(CRIT_ITERATIONS):
            # noise = torch.randn_like(real[:,-1:,:,:], device=device).abs()  # Uncomment when giving gen input with random noise
            # noise = torch.concat((noise, real[:,1:,:,:]), axis=1)  # Uncomment when giving gen input with random noise
            noise = real[:,1:,:,:]  # Uncomment when giving gen input without random noise

            fake = gen(noise)
            fake = torch.concat((fake, real[:,1:,:,:]), axis=1)

            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(LAMBDA, critic, real, fake, device=device) # compute the gradient penalty
            loss_critic = (
                torch.mean(critic_fake) - torch.mean(critic_real) + gp
            )

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{curr_epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} " +
                  f"Loss D: {loss_critic:.4f}, Lambda GP: {gp:.4f}, loss G: {loss_gen:.4f}"
            )

            if RECORD_METRICS:
                dm.save_checkpoint(run.name, run.step, gen, critic)

            if BATCH_SIZE > 8:
                outputs = gen(fixed_input[:8,:,:,:])
                inputs = real[:8,:,:,:]
                # outputs = torch.concat((outputs, fixed_input[:8,1:,:,:]), axis=1)  # Uncomment when giving gen input with random noise
                outputs = torch.concat((outputs, fixed_input[:8,:,:,:]), axis=1)  # Uncomment when giving gen input without random noise
            else:
                outputs = gen(fixed_input)
                inputs = real
                # outputs = torch.concat((outputs, fixed_input[:,1:,:,:]), axis=1)  # Uncomment when giving gen input with random noise
                outputs = torch.concat((outputs, fixed_input), axis=1)  # Uncomment when giving gen input without random noise

            if RECORD_METRICS:
                wandb.log({
                    'epoch': curr_epoch,
                    'generator loss': loss_gen,
                    'critic loss': loss_critic,
                    'gradient penalty': gp,
                    'fake': wandb.Image(outputs),
                    'real': wandb.Image(inputs)
                })

In [None]:
if RECORD_METRICS:
    wandb.finish()