<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/train_lama.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LaMa Training Notebook

## Imports and Parameters

In [None]:
!pip install kornia

import os
import sys
import shutil
import glob

import torch
import torch.nn as nn
import torch.optim as optim

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/repos/GANime/src')
from model_helper_functions import gpu_memory, load_checkpoint, apply_mask, apply_comp, checkpoint
from model_data_loaders import create_dataloaders
from model_lama import Generator, Discriminator, weights_init

In [None]:
# network parameters
BATCH_SIZE = 15
LEARNING_RATE_GEN = 0.001
LEARNING_RATE_DISC = 0.0001
N_EPOCHS = 100
ALPHA_WEIGHT = 0.0004

# hardware
N_GPU = 1
N_WORKERS = 1

# image
IMG_HEIGHT = 288
IMG_WIDTH = 512
SINGLE_SIDE = 64

TEST_REFERENCES = [2800, 8000, 17850, 3000]

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
LOG_DIR = '/content/gdrive/My Drive/repos/GANime/data_out/logs/ffc/'
PREV_CHECKPOINT = '/content/gdrive/My Drive/repos/GANime/data_out/logs/ffc/checkpoint.pt' # set to None to not load and create a new log folder
PREV_CHECKPOINT = None # set to None to not load and create a new log folder

In [None]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [None]:
# finds the dataset size and the number of batches we'll have to process
dataset_size = len(glob.glob(f'{IMG_DIR_TRAIN}*'))
n_batches = dataset_size // BATCH_SIZE
print(f'Number of images: {dataset_size}')
print(f'Number of batches: {n_batches}')

In [None]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')
!nvidia-smi -L

## Data Loaders

In [None]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, dataset_size)

## Networks, Loss Functions, and Optimizers

In [None]:
gen = Generator().to(device)
gen.apply(weights_init)
disc = Discriminator(IMG_WIDTH, SINGLE_SIDE).to(device)
disc.apply(weights_init)
gpu_memory()

In [None]:
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
optimizer_gen = optim.Adadelta(gen.parameters())
optimizer_disc = optim.Adadelta(disc.parameters())

In [None]:
# loads the checkpoint
gen, optimizer_gen, disc, optimizer_disc, batch_counter = load_checkpoint(PREV_CHECKPOINT, LOG_DIR, gen, optimizer_gen, disc, optimizer_disc)

## Training Loop

In [None]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):
        batch = batch.to(device)

        # apply mask to the images
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)

        # only trains the discriminator every 9 batches
        if i % 9 == 0:
            #############################
            # Discriminator
            #############################
            disc.zero_grad()
            disc_output = disc(batch)
            disc_loss_real = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
            disc_accuracy = (disc_output.round() == torch.ones(disc_output.shape[0]).cuda()).sum()
            disc_loss_real.backward()

            # passes fake images to feed the discriminator
            gen_output = gen(batch_mask)
            gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
            disc_output = disc(gen_output) # try taking detach off
            disc_accuracy += (disc_output.round() == torch.zeros(disc_output.shape[0]).cuda()).sum()
            disc_accuracy = disc_accuracy / (BATCH_SIZE * 2)
            disc_loss_fake = loss_bce(disc_output, torch.zeros(disc_output.shape[0]).to(device))
            disc_loss_fake.backward()

            # optimized the discriminator
            disc_loss = (disc_loss_real + disc_loss_fake) / 200  # scale the loss between 0 and 1
            optimizer_disc.step()

        #############################
        # Generater
        #############################
        gen.zero_grad()
        gen_output = gen(batch_mask)

        # combines the sides from the generator with the 4:3 image and calculates the mse loss against the orginal full image
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)

        disc_output = disc(gen_output)
        
        # calculates the loss
        gen_train_loss_mse = loss_mse(gen_output, batch)
        gen_train_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
        gen_train_loss = (gen_train_loss_mse + gen_train_loss_bce*ALPHA_WEIGHT) / 2

        # error and optimize
        gen_train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 100 == 0:
            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch: {i}/{n_batches}, Total Images {batch_counter * BATCH_SIZE}, Gen Train Loss: {gen_train_loss:.4f}, Disc Accuracy: {disc_accuracy:.4f}, CUDA Memory: {(torch.cuda.memory_allocated() / 10**9):.4f}')

            if i % 1000 == 0:
                # gets the testing MSE
                batch = next(iter(dataloader_val))
                batch = batch.to(device)
                batch_mask = batch.clone()
                batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
                with torch.no_grad():
                    gen_output = gen(batch_mask)
                gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)

                # calculates the loss
                disc_output = disc(gen_output)
                val_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
                val_loss_mse = loss_mse(gen_output, batch)
                val_loss = (val_loss_mse + val_loss_bce*ALPHA_WEIGHT) / 2

                checkpoint(i,
                           batch_counter,
                           disc_loss.item(),
                           disc_accuracy,
                           gen_train_loss.item(),
                           gen_train_loss_mse.item(),
                           val_loss.item(),
                           val_loss_mse.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           disc,
                           optimizer_disc,
                           dataloader_test,
                           TEST_REFERENCES,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)
                torch.cuda.empty_cache()

        batch_counter += 1