<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/model_global_and_local_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANime Globally and Locally Consistent Images

## Imports and Parameters

In [1]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/repos/GANime/src')
from helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint, checkpoint, gpu_memory
from data_loaders import create_dataloaders
from networks_global_and_local import Generator, Discriminator, weights_init

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
# network parameters
BATCH_SIZE = 15
DATASET_SIZE = 100000
N_BATCHES = DATASET_SIZE // BATCH_SIZE
N_GPU = 1
N_WORKERS = 1
N_EPOCHS = 100
LEARNING_RATE = 0.0002
ALPHA_WEIGHT = 0.99
DROPOUT_RATE = 0.2

# image
IMG_HEIGHT = 288
IMG_WIDTH = 512
SINGLE_SIDE = 64

# tensorboard
TRAIN_REFERENCE_INDEX = 200
VAL_REFERENCE_INDEX = 100
TEST_REFERENCE_INDEX = 20

# cost weights
WEIGHT_DECAY = 0.05

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
# LOG_DIR = '/content/gdrive/My Drive/repos/GANime/data_out/logs/global_and_local/'
LOG_DIR = '/content/temp/global_and_local/'

# PREV_CHECKPOINT = '/content/gdrive/My Drive/repos/GANime/data_out/logs/global_and_local/checkpoint.pt' # set to none to not load and create a new log folder
PREV_CHECKPOINT = None # set to none to not load and create a new log folder

In [3]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [4]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')
!nvidia-smi -L

Device: cuda:0
GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-67eee1a6-f8fb-16fe-195b-c6ef33332e5d)


## Data Loaders

In [5]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, DATASET_SIZE)

Training Dataset
Number of images: 21623
Size of dataset: 21623
Validation Dataset
Number of images: 6178
Size of dataset: 6178
Testing Dataset
Number of images: 3089
Size of dataset: 3089


## Networks, Loss Functions, and Optimizers

In [6]:
gen = Generator(IMG_WIDTH, SINGLE_SIDE).to(device)
gen.apply(weights_init)
disc = Discriminator(IMG_WIDTH, SINGLE_SIDE, DROPOUT_RATE).to(device)
disc.apply(weights_init)
gpu_memory()

Allocated memory: 0.420852224


In [7]:
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
# optimizer_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.9))
optimizer_gen = optim.Adadelta(gen.parameters())
optimizer_disc = optim.Adadelta(disc.parameters())

In [8]:
# loads the checkpoint
gen, optimizer_gen, disc, optimizer_disc, batch_counter = load_checkpoint(PREV_CHECKPOINT, LOG_DIR, gen, optimizer_gen, disc, optimizer_disc)

No log folder found


### Training Loop

In [9]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):
        batch = batch.to(device)

        #############################
        # Discriminator
        #############################
        disc.zero_grad()
        disc_output = disc(batch)
        disc_loss_real = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
        disc_accuracy = (disc_output.round() == torch.ones(disc_output.shape[0]).cuda()).sum()
        disc_loss_real.backward()

        # apply mask to the images
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)

        # passes fake images to feed the discriminator
        gen_output = gen(batch_mask)
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output) # try taking detach off
        disc_accuracy += (disc_output.round() == torch.zeros(disc_output.shape[0]).cuda()).sum()
        disc_accuracy = disc_accuracy / (BATCH_SIZE * 2)
        disc_loss_fake = loss_bce(disc_output, torch.zeros(disc_output.shape[0]).to(device))
        disc_loss_fake.backward()

        # optimized the discriminator
        disc_loss = (disc_loss_real + disc_loss_fake) / 200  # scale the loss between 0 and 1
        optimizer_disc.step()

        #############################
        # Generater
        #############################
        gen.zero_grad()
        gen_output = gen(batch_mask)

        # combines the sides from the generator with the 4:3 image and calculates the mse loss against the orginal full image
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output)
        
        # calculates the loss
        gen_train_loss_mse = loss_mse(gen_output, batch)
        gen_train_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
        gen_train_loss = (gen_train_loss_mse + gen_train_loss_bce * 0.01) / 2

        # error and optimize
        gen_train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 10 == 0:
            # gets the testing MSE
            batch = next(iter(dataloader_val))
            batch = batch.to(device)
            batch_mask = batch.clone()
            batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
            with torch.no_grad():
                gen_output = gen(batch_mask)
            gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)

            # calculates the loss
            disc_output = disc(gen_output)
            val_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0]).cuda())
            val_loss_mse = loss_mse(gen_output, batch)
            val_loss = val_loss_mse*ALPHA_WEIGHT + val_loss_bce*(1-ALPHA_WEIGHT)

            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch: {i}/{N_BATCHES}, Total Images {batch_counter * BATCH_SIZE}, Gen Train Loss: {gen_train_loss:.4f}, Gen Val Loss: {val_loss:.4f}, Disc Accuracy: {disc_accuracy:.4f}, CUDA Memory: {(torch.cuda.memory_allocated() / 10**9):.4f}')

            if i % 100 == 0:
                checkpoint(batch_counter,
                           disc_loss.item(),
                           disc_accuracy,
                           gen_train_loss.item(),
                           gen_train_loss_mse.item(),
                           val_loss.item(),
                           val_loss_mse.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           disc,
                           optimizer_disc,
                           dataloader_train,
                           TRAIN_REFERENCE_INDEX,
                           dataloader_val,
                           VAL_REFERENCE_INDEX,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)

        batch_counter += 1

Epoch: 0/100, Batch: 0/6666, Total Images 0, Gen Train Loss: 0.0629, Gen Val Loss: 0.0908, Disc Accuracy: 0.5000, CUDA Memory: 2.534216192
Saving reference images
Saved checkpoint
Epoch: 0/100, Batch: 100/6666, Total Images 1500, Gen Train Loss: 0.0300, Gen Val Loss: 0.0505, Disc Accuracy: 0.5000, CUDA Memory: 3.332958208
Epoch: 0/100, Batch: 200/6666, Total Images 3000, Gen Train Loss: 0.0224, Gen Val Loss: 0.0462, Disc Accuracy: 0.5000, CUDA Memory: 3.329951744
Epoch: 0/100, Batch: 300/6666, Total Images 4500, Gen Train Loss: 0.0193, Gen Val Loss: 0.0318, Disc Accuracy: 0.5667, CUDA Memory: 3.329820672
Epoch: 0/100, Batch: 400/6666, Total Images 6000, Gen Train Loss: 0.0186, Gen Val Loss: 0.0392, Disc Accuracy: 0.7000, CUDA Memory: 3.328886784
Epoch: 0/100, Batch: 500/6666, Total Images 7500, Gen Train Loss: 0.0198, Gen Val Loss: 0.0285, Disc Accuracy: 0.6000, CUDA Memory: 3.327821824
Epoch: 0/100, Batch: 600/6666, Total Images 9000, Gen Train Loss: 0.0172, Gen Val Loss: 0.0302, Disc

KeyboardInterrupt: ignored