<a href="https://colab.research.google.com/github/MatthewYancey/GANime/blob/master/src/inference_all_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Globally and Locally Consistant Images Inference
This notebook is for reviewing batches of images through the global and local model

## Imports and Parameters

In [1]:
import os
import sys
import shutil
import glob
import random
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import itertools

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from google.colab import drive
drive.mount('/content/gdrive')

sys.path.append('/content/gdrive/MyDrive/repos/GANime/src')
from model_helper_functions import apply_mask, apply_padding, apply_comp, apply_scale, load_checkpoint_inference, checkpoint
from model_data_loaders import create_dataloaders
from model_gal import Generator as _gen_gal
from model_context_encoders import Generator as _gen_ce
from model_lama import Generator as _gen_lama
# from model_gal import gal_Generator, Discriminator, weights_init

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
# network parameters
BATCH_SIZE = 15
N_EPOCHS = 100
ALPHA_WEIGHT = 0.0004

# hardware
N_GPU = 1
N_WORKERS = 1

# image
IMG_HEIGHT = 288
IMG_WIDTH = 512
SINGLE_SIDE = 64

TEST_REFERENCES = [2800, 8000, 17850, 3000]

# directories
ZIP_PATH_TRAIN = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/train.zip'
IMG_DIR_TRAIN = '/content/frames/train/'
ZIP_PATH_VAL = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/validate.zip'
IMG_DIR_VAL = '/content/frames/validate/'
ZIP_PATH_TEST = '/content/gdrive/My Drive/repos/GANime/data_out/pokemon/test.zip'
IMG_DIR_TEST = '/content/frames/test/'

CHECKPOINT_CE = '/content/gdrive/My Drive/repos/GANime/data_out/logs/model_context_encoders/checkpoint.pt'
CHECKPOINT_GAL = '/content/gdrive/My Drive/repos/GANime/data_out/logs/global_and_local/checkpoint.pt'
CHECKPOINT_LAMA = '/content/gdrive/My Drive/repos/GANime/data_out/logs/ffc/checkpoint.pt'

TEMP_DIR = '/content/saved_frames/'
OUTPUT_DIR = '/content/gdrive/MyDrive/repos/GANime/data_out/test_output/'

In [3]:
# unzips images
if os.path.exists(IMG_DIR_TRAIN) == False:
    shutil.unpack_archive(ZIP_PATH_TRAIN, IMG_DIR_TRAIN, 'zip')
    shutil.unpack_archive(ZIP_PATH_VAL, IMG_DIR_VAL, 'zip')
    shutil.unpack_archive(ZIP_PATH_TEST, IMG_DIR_TEST, 'zip')

In [4]:
# sets what device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print(f'Device: {device}')

Device: cuda:0


In [5]:
dataloader_train, dataloader_val, dataloader_test = create_dataloaders(BATCH_SIZE,
                                                                       N_WORKERS,
                                                                       IMG_DIR_TRAIN,
                                                                       IMG_DIR_VAL,
                                                                       IMG_DIR_TEST,
                                                                       shuffle_images=False)

Training Dataset
Number of images: 429979
Size of dataset: 429979
Validation Dataset
Number of images: 122851
Size of dataset: 122851
Testing Dataset
Number of images: 61426
no transform
Size of dataset: 61426


In [6]:
gen_gal = _gen_gal(IMG_WIDTH, SINGLE_SIDE).to(device)
gen_gal = load_checkpoint_inference(CHECKPOINT_GAL, gen_gal)

gen_ce = _gen_ce(IMG_WIDTH, SINGLE_SIDE).to(device)
gen_ce = load_checkpoint_inference(CHECKPOINT_CE, gen_ce)

gen_lama = _gen_lama().to(device)
gen_lama = load_checkpoint_inference(CHECKPOINT_LAMA, gen_lama)

Loaded checkpoint from /content/gdrive/My Drive/repos/GANime/data_out/logs/global_and_local/checkpoint.pt
Loaded checkpoint from /content/gdrive/My Drive/repos/GANime/data_out/logs/model_context_encoders/checkpoint.pt
Loaded checkpoint from /content/gdrive/My Drive/repos/GANime/data_out/logs/ffc/checkpoint.pt


In [8]:
def save_images(gen, model_name):
    print(f'Processing {model_name}')

    if os.path.exists(TEMP_DIR):
        shutil.rmtree(TEMP_DIR)
    os.mkdir(TEMP_DIR)

    dataloader_iter = iter(dataloader_test)

    i = 0
    img_count = 0
    with torch.no_grad():
        while img_count < 10000:
            torch.cuda.empty_cache()
            batch = next(itertools.islice(dataloader_iter, i, None))
            gen_output = gen(apply_mask(batch.to(device), IMG_WIDTH, SINGLE_SIDE))
            batch = apply_comp(batch.to(device), gen_output, IMG_WIDTH, SINGLE_SIDE)
            batch = apply_scale(batch)
            batch = batch.detach()
            batch = batch[:64].cpu()
            batch  = batch.numpy() # make sure tensor is on cpu

            for img_i in range(batch.shape[0]):
                img = batch[img_i, :, :, :]
                img = np.transpose(img, (1, 2, 0))
                img = (img * 255).astype(np.uint8)
                img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                img_name = f'{TEMP_DIR}{model_name}_{img_count}.jpg'
                cv2.imwrite(img_name, img)
                img_count += 1

            i += 1
            batch = None


    shutil.make_archive(f'{OUTPUT_DIR}{model_name}', 'zip', TEMP_DIR)


gens = [[gen_gal, 'gal'], [gen_ce, 'ce'], [gen_lama, 'lama']]
for g in gens:
    save_images(g[0], g[1])

Processing gal
min batch: -1.0
max batch: 1.0
min gen: -0.9995692372322083
max gen: 1.0
min batch: -1.0
max batch: 1.0
