<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/model_GANsl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANime GANs Model
This notebook tests the generator network.

## Imports and Parameters

In [10]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/GANime/src')
from helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint, checkpoint
from data_loaders import create_dataloaders
from networks import Generator, Discriminator, weights_init

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [11]:
# network parameters
BATCH_SIZE = 15
DATASET_SIZE = 100000
N_BATCHES = DATASET_SIZE // BATCH_SIZE
N_GPU = 1
N_WORKERS = 1
N_EPOCHS = 100
LEARNING_RATE = 0.0002
DROPOUT_RATE = 0.2

# image
IMG_HEIGHT = 256
IMG_WIDTH = 456
SINGLE_SIDE = 57

# tensorboard
TRAIN_REFERENCE_INDEX = 200
VAL_REFERENCE_INDEX = 100
TEST_REFERENCE_INDEX = 20

# cost weights
WEIGHT_DECAY = 0.05

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/GANime/data_out/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/GANime/data_out/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/GANime/data_out/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
LOG_DIR = '/content/gdrive/My Drive/GANime/data_out/logs/model_gans/'
PREV_CHECKPOINT = '/content/gdrive/My Drive/GANime/data_out/logs/model_gans/checkpoint.pt' # set to none to not load and create a new log folder
# PREV_CHECKPOINT = None # set to none to not load and create a new log folder

In [12]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [13]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')

Device: cuda:0


## Data Loaders

In [14]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, DATASET_SIZE)

Training Dataset
Number of images: 114808
Size of dataset: 100000
Validation Dataset
Number of images: 36734
Size of dataset: 36734
Testing Dataset
Number of images: 2210
Size of dataset: 2210


## Networks, Loss Functions, and Optimizers

In [15]:
gen = Generator(IMG_WIDTH, SINGLE_SIDE).to(device)
gen.apply(weights_init)
disc = Discriminator(IMG_WIDTH, SINGLE_SIDE, DROPOUT_RATE).to(device)
disc.apply(weights_init)

Discriminator(
  (global_disc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)
    (5): Dropout(p=0.2, inplace=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)
    (9): Dropout(p=0.2, inplace=False)
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Dropout(p=0.2, inplace=False)
    (14): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): ReLU()
    (16): Conv2d(51

In [16]:
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9))
optimizer_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9), weight_decay=WEIGHT_DECAY)

In [17]:
# loads the checkpoint
gen, optimizer_gen, disc, optimizer_disc, batch_counter = load_checkpoint(PREV_CHECKPOINT, LOG_DIR, gen, optimizer_gen, disc, optimizer_disc)

Folders removed


### Training Loop

In [9]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):
        batch = batch.to(device)

        #############################
        # Discriminator
        #############################
        disc.zero_grad()
        disc_output = disc(batch)
        disc_loss_real = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
        disc_loss_real.backward()

        # apply mask to the images
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)

        # passes fake images to feed the discriminator
        gen_output = gen(batch_mask)
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output) # try taking detach off
        disc_loss_fake = loss_bce(disc_output, torch.zeros(disc_output.shape[0]).to(device))
        disc_loss_fake.backward()

        # optimized the discriminator
        disc_loss = (disc_loss_real + disc_loss_fake) / 200  # scale the loss between 0 and 1
        optimizer_disc.step()

        #############################
        # Generater
        #############################
        gen.zero_grad()
        gen_output = gen(batch_mask)

        # combines the sides from the generator with the 4:3 image and calculates the mse loss against the orginal full image
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output)
        
        # calculates the loss
        gen_train_loss_mse = loss_mse(gen_output, batch)
        gen_train_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
        gen_train_loss = (gen_train_loss_mse + gen_train_loss_bce * 0.01) / 2

        # error and optimize
        gen_train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 10 == 0:
            # gets the testing MSE
            batch = next(iter(dataloader_val))
            batch = batch.to(device)
            batch_mask = batch.clone()
            batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
            with torch.no_grad():
                gen_output = gen(batch_mask)
            gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
            val_loss = loss_mse(gen_output, batch)
            
            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch in Epoch: {i}/{N_BATCHES}, Total Images {batch_counter * BATCH_SIZE}, Gen Train Loss: {gen_train_loss:.4f}, Gen Val Loss: {val_loss:.4f}, Disc Train Loss: {disc_loss:.4f}')

            if i % 100 == 0:
                checkpoint(batch_counter,
                           disc_loss.item(),
                           gen_train_loss.item(),
                           val_loss.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           disc,
                           optimizer_disc,
                           dataloader_train,
                           TRAIN_REFERENCE_INDEX,
                           dataloader_val,
                           VAL_REFERENCE_INDEX,
                           dataloader_test,
                           TEST_REFERENCE_INDEX,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)

        batch_counter += 1

Epoch: 0/100, Batch in Epoch: 0/6666, Total Images 0, Gen Train Loss: 0.0691, Gen Val Loss: 0.0760, Disc Train Loss: 0.0071
Saving reference images
Saved checkpoint
Epoch: 0/100, Batch in Epoch: 10/6666, Total Images 150, Gen Train Loss: 0.0250, Gen Val Loss: 0.0727, Disc Train Loss: 0.0064
Epoch: 0/100, Batch in Epoch: 20/6666, Total Images 300, Gen Train Loss: 0.0205, Gen Val Loss: 0.0374, Disc Train Loss: 0.0065
Epoch: 0/100, Batch in Epoch: 30/6666, Total Images 450, Gen Train Loss: 0.0180, Gen Val Loss: 0.0533, Disc Train Loss: 0.0059
Epoch: 0/100, Batch in Epoch: 40/6666, Total Images 600, Gen Train Loss: 0.0229, Gen Val Loss: 0.0309, Disc Train Loss: 0.0057
Epoch: 0/100, Batch in Epoch: 50/6666, Total Images 750, Gen Train Loss: 0.0179, Gen Val Loss: 0.0317, Disc Train Loss: 0.0052
Epoch: 0/100, Batch in Epoch: 60/6666, Total Images 900, Gen Train Loss: 0.0231, Gen Val Loss: 0.0320, Disc Train Loss: 0.0048
Epoch: 0/100, Batch in Epoch: 70/6666, Total Images 1050, Gen Train Loss:

ValueError: ignored