<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/model_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANime GANs Model
This notebook tests the generator network.

## Imports and Parameters

In [1]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/GANime/src')
from helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint, checkpoint
from data_loaders import create_dataloaders
from networks import Generator, GlobalDiscriminator, weights_init

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
# network parameters
BATCH_SIZE = 15
DATASET_SIZE = 100000
N_BATCHES = DATASET_SIZE // BATCH_SIZE
N_GPU = 1
N_WORKERS = 1
N_EPOCHS = 100
LEARNING_RATE = 0.0002

# image
IMG_HEIGHT = 256
IMG_WIDTH = 455
SINGLE_SIDE = 57

# tensorboard
TRAIN_REFERENCE_INDEX = 200
VAL_REFERENCE_INDEX = 100
TEST_REFERENCE_INDEX = 20

# cost weights
GAN_WEIGHT = 0.0004

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/GANime/data_out/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/GANime/data_out/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/GANime/data_out/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
LOG_DIR = '/content/gdrive/My Drive/GANime/data_out/logs/model_gans/'
PREV_CHECKPOINT = '/content/gdrive/My Drive/GANime/data_out/saved_models/model_mse/checkpoint.pt' # set to none to not load and create a new log folder
# PREV_CHECKPOINT = None # set to none to not load and create a new log folder

In [3]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [4]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')

Device: cuda:0


## Data Loaders

In [5]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, DATASET_SIZE)

Training Dataset
Number of images: 114808
Size of dataset: 100000
Validation Dataset
Number of images: 36734
Size of dataset: 36734
Testing Dataset
Number of images: 2210
Size of dataset: 2210


## Networks, Loss Functions, and Optimizers

In [6]:
gen = Generator(N_GPU, IMG_WIDTH, SINGLE_SIDE).to(device)
gen.apply(weights_init)
global_disc = GlobalDiscriminator(N_GPU).to(device)
global_disc.apply(weights_init)

GlobalDiscriminator(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv3): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv4): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv5): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv6): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv7): Conv2d(512, 1, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (batch64): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

In [7]:
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9))
optimizer_disc = optim.Adam(global_disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9))

In [8]:
# loads the checkpoint
gen, optimizer_gen, batch_counter = load_checkpoint(PREV_CHECKPOINT, LOG_DIR, gen, optimizer_gen)

Loaded checkpoint from /content/gdrive/My Drive/GANime/data_out/saved_models/model_mse/checkpoint.pt
No batch counter


### Training Loop

In [None]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):
        batch = batch.to(device)

        #############################
        # Discriminator
        #############################
        global_disc.zero_grad()
        output_global_disc = global_disc(batch)
        disc_loss_real = loss_bce(output_global_disc, torch.ones(output_global_disc.shape).cuda())
        disc_loss_real.backward()

        # apply mask to the images
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)

        # passes fake images to the Discriminator
        _, gen_output_global = gen(batch_mask)
        gen_output_global = apply_comp(batch, gen_output_global, IMG_WIDTH, SINGLE_SIDE)
        output_global_disc = global_disc(gen_output_global.detach())
        disc_loss_fake = loss_bce(output_global_disc, torch.zeros(output_global_disc.shape).to(device))
        disc_loss_fake.backward()

        # optimized the discriminator
        disc_loss = disc_loss_real + disc_loss_fake
        optimizer_disc.step()

        #############################
        # Generater
        #############################
        gen.zero_grad()
        _, gen_output_global = gen(batch_mask)

        # combines the sides from the generator with the 4:3 image and calculates the mse loss against the orginal full image
        gen_output_global = apply_comp(batch, gen_output_global, IMG_WIDTH, SINGLE_SIDE)
        output_global_disc = global_disc(gen_output_global)
        
        # calculates the loss
        gen_train_loss_mse = loss_mse(gen_output_global, batch)
        gen_train_loss_bce = loss_bce(output_global_disc, torch.ones(output_global_disc.shape).cuda())
        # gen_train_loss = gen_train_loss_bce
        gen_train_loss = gen_train_loss_mse +  (gen_train_loss_bce * GAN_WEIGHT)

        # error and optimize
        gen_train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 10 == 0:
            # gets the testing MSE
            batch = next(iter(dataloader_val))
            batch = batch.to(device)
            batch_mask = batch.clone()
            batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
            with torch.no_grad():
                _, gen_output_global = gen(batch_mask)
            gen_output_global = apply_comp(batch, gen_output_global, IMG_WIDTH, SINGLE_SIDE)
            val_loss = loss_mse(gen_output_global, batch)
            
            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch in Epoch: {i}/{N_BATCHES}, Total Images {batch_counter * BATCH_SIZE}, Gen Train Loss: {gen_train_loss:.4f}, Gen Val Loss: {val_loss:.4f}, Disc Train Loss: {disc_loss:.4f}')

            if i % 100 == 0:
                checkpoint(batch_counter,
                           disc_loss.item(),
                           gen_train_loss.item(),
                           val_loss.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           global_disc,
                           optimizer_disc,
                           dataloader_train,
                           TRAIN_REFERENCE_INDEX,
                           dataloader_val,
                           VAL_REFERENCE_INDEX,
                           dataloader_test,
                           TEST_REFERENCE_INDEX,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)

        batch_counter += 1