<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/model_context_encoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANime - Context Encoder
This notebook applies the context encoder paper.

## Imports and Parameters

In [1]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/repos/GANime/src')
from helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint, checkpoint, gpu_memory
from data_loaders import create_dataloaders
from networks_context_encoders import Generator, Discriminator, weights_init

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
# network parameters
BATCH_SIZE = 15
DATASET_SIZE = 1000000
N_BATCHES = DATASET_SIZE // BATCH_SIZE
N_GPU = 1
N_WORKERS = 1
N_EPOCHS = 100
GEN_LEARNING_RATE = 0.001
DISC_LEARNING_RATE = 0.0001
ALPHA_WEIGHT = 0.9
DROPOUT_RATE = 0.1

# image
IMG_HEIGHT = 288
IMG_WIDTH = 512
SINGLE_SIDE = 64

# tensorboard
# TRAIN_REFERENCE_INDEX = 200
# VAL_REFERENCE_INDEX = 100
TEST_REFERENCES = [2800, 8000, 17850, 3000]

# cost weights
WEIGHT_DECAY = 0.05

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/test.zip'
IMG_DIR_TEST = '/content/frames/test/'
LOG_DIR = '/content/gdrive/My Drive/repos/GANime/data_out/logs/model_context_encoders/'
# PREV_CHECKPOINT = '/content/gdrive/My Drive/repos/GANime/data_out/logs/model_context_encoders/checkpoint.pt' # set to none to not load and create a new log folder
PREV_CHECKPOINT = None # set to none to not load and create a new log folder

In [3]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [4]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')
!nvidia-smi -L

Device: cuda:0
GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-ed9a6da3-733c-fcb8-5afb-96e175173c35)


In [5]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE, N_WORKERS, IMG_DIR_TRAIN, IMG_DIR_VAL, IMG_DIR_TEST, DATASET_SIZE)

Training Dataset
Number of images: 429979
Size of dataset: 429979
Validation Dataset
Number of images: 122851
Size of dataset: 122851
Testing Dataset
Number of images: 61426
Size of dataset: 61426


In [6]:
gen = Generator(IMG_WIDTH, SINGLE_SIDE).to(device)
gen.apply(weights_init)
disc = Discriminator().to(device)
disc.apply(weights_init)
gpu_memory()

Allocated memory: 2.3949


In [7]:
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=GEN_LEARNING_RATE, betas=(0.5, 0.9))
optimizer_disc = optim.Adam(disc.parameters(), lr=DISC_LEARNING_RATE, betas=(0.5, 0.9), weight_decay=WEIGHT_DECAY)

In [8]:
# loads the checkpoint
gen, optimizer_gen, disc, optimizer_disc, batch_counter = load_checkpoint(PREV_CHECKPOINT, LOG_DIR, gen, optimizer_gen, disc, optimizer_disc)
gpu_memory()

Folders removed
Allocated memory: 2.3949


In [None]:
for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader_train, 0):
        batch = batch.to(device)
        #############################
        # Discriminator
        #############################
        disc.zero_grad()
        disc_output = disc(batch)
        disc_loss_real = loss_bce(disc_output, torch.ones(disc_output.shape[0], 1).cuda())
        disc_accuracy = (disc_output.round() == torch.ones(disc_output.shape[0], 1).cuda()).sum()
        disc_loss_real.backward()

        # apply mask to the images
        batch_mask = batch.clone()
        batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)

        # passes fake images to feed the discriminator
        gen_output = gen(batch_mask)
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output)

        # calcualtes accuracy and loss on the fake images
        disc_accuracy += (disc_output.round() == torch.zeros(disc_output.shape[0], 1).cuda()).sum()
        disc_accuracy = disc_accuracy / (BATCH_SIZE * 2)
        disc_loss_fake = loss_bce(disc_output, torch.zeros(disc_output.shape[0], 1).to(device))
        disc_loss_fake.backward()

        # optimized the discriminator
        disc_loss = (disc_loss_real + disc_loss_fake) / 200  # scale the loss between 0 and 1
        optimizer_disc.step()

        #############################
        # Generater
        #############################
        gen.zero_grad()
        gen_output = gen(batch_mask)

        # combines the sides from the generator with the 4:3
        gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)
        disc_output = disc(gen_output)
        
        # calculates the loss
        gen_train_loss_l2 = loss_mse(gen_output, batch)
        gen_train_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0], 1).cuda())
        gen_train_loss = gen_train_loss_l2*ALPHA_WEIGHT + gen_train_loss_bce*(1-ALPHA_WEIGHT)

        # error and optimize
        gen_train_loss.backward()
        optimizer_gen.step()

        # prints the status and checkpoints every so often
        if i % 10 == 0:
            # gets the validation loss
            batch = next(iter(dataloader_val))
            batch = batch.to(device)
            batch_mask = batch.clone()
            batch_mask = apply_mask(batch_mask, IMG_WIDTH, SINGLE_SIDE)
            with torch.no_grad():
                gen_output = gen(batch_mask)
            gen_output = apply_comp(batch, gen_output, IMG_WIDTH, SINGLE_SIDE)

            # calcuates the validation loss
            disc_output = disc(gen_output)
            gen_val_loss_bce = loss_bce(disc_output, torch.ones(disc_output.shape[0], 1).cuda())
            gen_val_loss_l2 = loss_mse(gen_output, batch)
            gen_val_loss = gen_val_loss_l2*ALPHA_WEIGHT + gen_val_loss_bce*(1-ALPHA_WEIGHT)

            print(f'Epoch: {epoch}/{N_EPOCHS}, Batch: {i}/{N_BATCHES}, Total Images {batch_counter * BATCH_SIZE}, Gen Train Loss: {gen_train_loss:.4f}, Gen Val Loss: {gen_val_loss:.4f}, Disc Accuracy: {disc_accuracy:.4f}, CUDA Memory: {(torch.cuda.memory_allocated() / 10**9):.4f}')

            if i % 200 == 0:
                torch.cuda.empty_cache()
                checkpoint(i,
                           batch_counter,
                           disc_loss.item(),
                           disc_accuracy,
                           gen_train_loss.item(),
                           gen_train_loss_l2.item(),
                           gen_val_loss.item(),
                           gen_val_loss_l2.item(),
                           LOG_DIR,
                           gen,
                           optimizer_gen,
                           disc,
                           optimizer_disc,
                           dataloader_test,
                           TEST_REFERENCES,
                           IMG_HEIGHT,
                           IMG_WIDTH,
                           SINGLE_SIDE)
                torch.cuda.empty_cache()

        batch_counter += 1

Epoch: 0/100, Batch: 0/66666, Total Images 0, Gen Train Loss: 0.1812, Gen Val Loss: 0.1725, Disc Accuracy: 0.5333, CUDA Memory: 10.1336
Saving reference images
Saving reference images
Saving reference images
Saving reference images
Saving checkpoint at new epoch
Saved to tensorboard
Epoch: 0/100, Batch: 10/66666, Total Images 150, Gen Train Loss: 0.1586, Gen Val Loss: 0.1311, Disc Accuracy: 0.3000, CUDA Memory: 10.4727
Epoch: 0/100, Batch: 20/66666, Total Images 300, Gen Train Loss: 0.1784, Gen Val Loss: 0.1699, Disc Accuracy: 0.6333, CUDA Memory: 10.5529
Epoch: 0/100, Batch: 30/66666, Total Images 450, Gen Train Loss: 0.1559, Gen Val Loss: 0.1562, Disc Accuracy: 0.5000, CUDA Memory: 10.5529
Epoch: 0/100, Batch: 40/66666, Total Images 600, Gen Train Loss: 0.1480, Gen Val Loss: 0.1683, Disc Accuracy: 0.5000, CUDA Memory: 10.5529
Epoch: 0/100, Batch: 50/66666, Total Images 750, Gen Train Loss: 0.1794, Gen Val Loss: 0.1553, Disc Accuracy: 0.7333, CUDA Memory: 10.5248
Epoch: 0/100, Batch: 