<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/model_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANime GANs Model
This notebook tests the generator network.

## Imports and Parameters

In [None]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/GANime/src')
from helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint, checkpoint
from data_loaders import create_dataloaders
from networks import Generator, weights_init

In [None]:
# network parameters
BATCH_SIZE = 30
DATASET_SIZE = 100000
N_BATCHES = DATASET_SIZE // BATCH_SIZE
N_GPU = 1
N_WORKERS = 1
N_EPOCHS = 100
LEARNING_RATE = 0.0002

# image
IMG_HEIGHT = 256
IMG_WIDTH = 455
SINGLE_SIDE = 57

# tensorboard
TRAIN_REFERENCE_INDEX = 200
VAL_REFERENCE_INDEX = 100
TEST_REFERENCE_INDEX = 20

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/GANime/data_out/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/GANime/data_out/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/GANime/data_out/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
LOG_DIR = '/content/gdrive/My Drive/GANime/data_out/logs/temp/'

# checkpoint type
CHECKPOINT_TYPE = 'prev_checkpoint' # none or prev_checkpoint

In [None]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [None]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')

## Data Loaders

In [None]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, DATASET_SIZE)

## Networks, Loss Functions, and Optimizers

In [None]:
gen = Generator(N_GPU, IMG_WIDTH, SINGLE_SIDE).to(device)
gen.apply(weights_init)

In [None]:
loss = nn.BCELoss()
mse_loss = nn.MSELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9))

In [None]:
# loads the checkpoint
gen, optimizer_gen, batch_counter = load_checkpoint(CHECKPOINT_TYPE, LOG_DIR, gen, optimizer_gen)

### Training Loop

In [None]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):

        # generater
        gen.zero_grad()
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
        _, gen_output_global = gen(batch_mask.to(device))

        # keeps only the edges from the generator and calculates the loss
        gen_output_global = apply_comp(batch.to(device), gen_output_global, IMG_WIDTH, SINGLE_SIDE)
        train_loss = mse_loss(gen_output_global, batch.to(device))
        
        # error and optimize
        train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 10 == 0:
            # gets the testing MSE
            batch = next(iter(dataloader_val))
            batch_mask = batch.clone()
            batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
            with torch.no_grad():
                _, gen_output_global = gen(batch_mask.to(device))
            gen_output_global = apply_comp(batch.to(device), gen_output_global, IMG_WIDTH, SINGLE_SIDE)
            val_loss = mse_loss(gen_output_global, batch.to(device))

            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch in Epoch: {i}/{N_BATCHES}, Total Images {batch_counter * BATCH_SIZE}, Train Loss: {train_loss:.2f}, Valication Loss: {val_loss:.2f}')

            if i % 100 == 0:
                checkpoint(batch_counter,
                           0,
                           train_loss.item(),
                           val_loss.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           dataloader_train,
                           TRAIN_REFERENCE_INDEX,
                           dataloader_val,
                           VAL_REFERENCE_INDEX,
                           dataloader_test,
                           TEST_REFERENCE_INDEX,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)

        batch_counter += 1