## Data Setup

In [None]:
RESUME_ID = "1s3dxn44"

In [None]:
# Mount Google Drive.
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# Silly. We have to mount the drive as that's what the
# PyTorch API expects.  However, you can't delete the
# trash with a mounted drive. So, we use PyDrive, which
# requires a separate token. :facepalm
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
pdrive = GoogleDrive(gauth)


In [None]:
!pip install pytorch-ignite

In [None]:
# Install Weights and Biases for experiment tracking.  
!pip install wandb

In [None]:
from __future__ import print_function
#%matplotlib inlinea

import os

# Usual Suspects.
import numpy as np
import random
from time import sleep

# PyTorch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# Image Visualization.
from PIL import Image
from PIL import ImageEnhance

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import ImageGrid

# Ignite
from ignite.engine import Engine, Events
import ignite.distributed as idist

# Experiment tracking.
import wandb

# Set random seed for reproducibility
manualSeed = 42
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# Setup data directory.
!mkdir -p "/content/data/clean_bold_magic_symbols/"
!mkdir "/content/output/"

# You will need to upload your data file manually Google Drive.
!tar -xf  '/content/gdrive/MyDrive/datasets/clean_bold_magic_symbols.tar.gz' -C '/content/data/clean_bold_magic_symbols/'

# Uncomment to include icons for extra training data.
!mkdir -p "/content/data/icons/"
!tar -xf  '/content/gdrive/MyDrive/datasets/icons_light.tar.gz' -C '/content/data/icons/'


In [None]:
COUNT_OF_TRAINING_IMGS = len([file for _, _, file in os.walk("/content/data")][1])

In [None]:
!pip install gputil

In [None]:
import GPUtil

gpu = GPUtil.getGPUs()
gpu_name = gpu[0].name
gpu_memory = gpu[0].memoryTotal

print(f"GPU: {gpu_name}")
print(f"GPU Memory: {gpu_memory}")

## Batch Size vs. Learning Rate

| batch_size |      lr  |
|------------|----------|
| 4          |  0.002   |
| 16         |  0.0001  |


## Parameters

In [None]:
# Root directory for dataset
dataroot = f"{os.getcwd()}/data/"
output_folder = f"{os.getcwd()}/output/"

# Where to save models.
g_drive_models_dir = "dl_models/deep_arcane/"

# Model to load.
MODEL_DIR = "/content/gdrive/MyDrive/dl_models/deep_arcane_models/"

# Number of WORKERS for dataloader
WORKERS = 2

# Number of training epochs
EPOCHS = 5000
starting_epoch = 0

# Epoch to begin checking if we reached a new low loss.
EPOCH_TO_START_SAVING = 0

# Epoch to evaluate on.
EPOCH_TO_EVAL = 3

# How many epochs run before saving model
save_every = 10

# Batch size during training
BATCH_SIZE = 12

# Learning rate for optimizers
# GLR = 0.00001
# DLR = 0.00001
GLR = 0.0001
DLR = 0.0003

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
IMAGE_SIZE = 128

# Number of channels in the training images. For color images this is 3
NC = 1

# Size of z latent vector (i.e. size of generator input)
latent_dim = 18

# Size of feature maps in generator
NGF = 300

# Size of feature maps in discriminator
NDF = 300

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

G_DROPOUT = 0.0
D_DROPOUT = 0.0

# Labels
real_range = (0.0, 0.0)
fake_range = (0.90, 0.95)

sample_every = 1
samples = 8

# Image transformations.
convert_to_grayscale = True
resize_images = True
crop_images = True
# random_rotation_degrees = (-30, 30)
random_rotation_degrees = (-0, 0)


# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

id = wandb.util.generate_id()

experiment_settings = {
    "wandb_run_id": id,
    "dataroot": dataroot,
    "WORKERS": WORKERS,
    "BATCH_SIZE": BATCH_SIZE,
    "IMAGE_SIZE":IMAGE_SIZE,
    "NC":NC,
    "latent_dim":latent_dim,
    "NGF":NGF,
    "NDF":NDF,
    "EPOCHS":EPOCHS,
    "DLR":DLR,
    "GLR":GLR,
    "beta1":beta1,
    "G_DROPOUT":G_DROPOUT,
    "D_DROPOUT":D_DROPOUT,
    "real_range":real_range,
    "fake_range":fake_range,
    "ngpu":ngpu,
    "convert_to_grayscale": convert_to_grayscale,
    "resize_images": resize_images,
    "crop_images": crop_images,
    "random_rotation_degrees": random_rotation_degrees,
    "torch_version": torch.__version__,
    "gpu_name": gpu_name,
    "gpu_memory": gpu_memory,
    "cuda_ver": torch.version.cuda
}

In [None]:
# Initialze Weights and Biases.  You'll need account.
if RESUME_ID != "":
  # config=experiment_settings,
    wandb.init(id=RESUME_ID, project="deep-arcane", resume="allow")
    if wandb.config["gpu_name"] != gpu_name:
      print("WOW! This is not the GPU you are looking for....")
      input()
else:
    wandb.init(config=experiment_settings, project="deep-arcane")

model_path = f"/content/gdrive/MyDrive/dl_models/deep_arcane_models/{wandb.run.id}/"

In [None]:
sample_image_path = "/content/data/clean_bold_magic_symbols/2.png"

def add_margins(image, image_size, random_rotation_degrees, color = (255, 255, 255)):
  
  # Get margin size.
  margin = image_size

  # Rotation
  rotation = random.randint(random_rotation_degrees[0], random_rotation_degrees[1])

  # Create a bigger image.
  tmp_img = Image.new("RGB", 
                      (image_size + margin, image_size + margin), 
                      color = color
  )

  # Paste the old image in the center
  cords = (
           round((tmp_img.size[0]-image.size[0])/2),
           round((tmp_img.size[1]-image.size[1])/2)
  )
  tmp_img.paste(image, cords)

  # Rotate the image.
  tmp_img = tmp_img.rotate(rotation)

  # Crop the image.
  crop_quarter_size = round(image_size / 2)
  crop_dims = (crop_quarter_size, 
               crop_quarter_size, 
               tmp_img.size[0] - crop_quarter_size, # Width - margin.
               tmp_img.size[1] - crop_quarter_size) # Height - margin
  tmp_img = tmp_img.crop(crop_dims)

  # Sharp
  factor = 50
  tmp_img = ImageEnhance.Sharpness(tmp_img).enhance(factor)
  
  if len(color) > 1:
    tmp_img = tmp_img.convert("1")

  return tmp_img

img = Image.open(sample_image_path)
img = add_margins(img, 128, random_rotation_degrees)

plt.imshow(img)
plt.axis("off")
plt.show()

def image_loader(path):
  image = Image.open(path)
  image = add_margins(image, IMAGE_SIZE, random_rotation_degrees)
  return image

def save_model(path, netG, netD, optimizerG, optimizerD, EPOCHS, criterion, prefix=""):
  try:
    os.makedirs(path)
  except:
    print("Folder exists")

  model_file_path = path + f"{prefix}deep_arcane.model"

  # Save Model
  torch.save(
      {
        "generator": netG.state_dict(),
        "discriminator": netD.state_dict(),
        "generator_optimizer": optimizerG.state_dict(),
        "discriminator_optimizer": optimizerD.state_dict(),
        "epoch": EPOCHS,
        "loss": criterion,
      }, model_file_path
  )

## Settings

In [None]:

# Normalize the images.
if NC == 1:
    normalize = transforms.Normalize((0.5), (0.5))
else:
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

# We can use an image folder dataset the way we have it setup.
# Create the dataset
transforms_to_apply = [
                       transforms.ToTensor(), 
                       normalize,
]

dataset = dset.ImageFolder(root=dataroot, 
                           loader=image_loader, 
                           transform=transforms.Compose(transforms_to_apply)
)
test_dataset = torch.utils.data.Subset(dataset, torch.arange(COUNT_OF_TRAINING_IMGS))

print(f"Found {len(dataset.imgs)} images in {dataroot}")

# Create the dataloader
dataloader = idist.auto_dataloader(
    dataset, 
    batch_size=BATCH_SIZE, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True,
)

test_dataloader = idist.auto_dataloader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    num_workers=2, 
    shuffle=False, 
    drop_last=True,
)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

## Display Methods

In [None]:
def display_generated_images(tensor, samples):
    images = []
    samples if tensor.shape[0] > samples else tensor.shape[0]
    for i in range(0, samples):
        image = tensor[i].detach().cpu().reshape(128, 128, 1)
        images.append(image)
        plt.axis("off")
        plt.imshow(image.reshape([128, 128]), cmap="gray")
        plt.show()
    return images

def create_display_grid(images):
    grid_size = round(len(images) / 4)
    f, axarr = plt.subplots(grid_size, grid_size)
    
    index = 0
    for row in range(0, grid_size):
        for col in range(0, grid_size):
            if index <= len(images):
                axarr[row,col].imshow(images[index], cmap="gray")
                axarr[row,col].axis('off')
                index += 1
    plt.tight_layout()
    plt.show()
    
def convert_tensor_to_image(image, width, height):
    # We multiply by 255, as the image is normalized 0-1.
    image = Image.fromarray(image.detach().cpu().numpy().reshape([width, height]) * 255)
    plt.imshow(image)
    color_img = Image.new("RGB", image.size)
    color_img.paste(image)
    return color_img
    
def convert_tensor_to_image(image, width, height):
    image = Image.fromarray(image.detach().cpu().numpy().reshape([width, height]) * 255)
    plt.imshow(image)
    color_img = Image.new("RGB", image.size)
    color_img.paste(image)
    return color_img


## Models

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        # nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
        
# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu, dropout):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.dropout = dropout
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_dim, NGF * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(NGF * 16),
            nn.Dropout2d(p=dropout),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.GELU(),
            # state size. (NGF*16) x 4 x 4
            nn.ConvTranspose2d(NGF * 16, NGF * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 8),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.GELU(),
            # state size. (NGF*8) x 8 x 8
            nn.ConvTranspose2d(NGF * 8, NGF * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 4),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.GELU(),
            # state size. (NGF*4) x 16 x 16 
            nn.ConvTranspose2d(NGF * 4, NGF * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 2),
            nn.Dropout2d(p=dropout),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.GELU(),
            # state size. (NGF*2) x 32 x 32
            nn.ConvTranspose2d(NGF * 2,     NGF, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.GELU(),
            # state size. (NGF) x 64 x 64
            nn.ConvTranspose2d(NGF, NC, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (NC) x 128 x 128
        )

    def forward(self, input):
        return self.main(input)
    
    
# Create the generator
netG = Generator(ngpu, G_DROPOUT).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Load Model
# TODO:

# Print the model
print(netG)


class Discriminator(nn.Module):
    def __init__(self, ngpu, dropout):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.kernel_size = 216
        
        self.output_0 = 32
        self.output_1 = 64
        self.output_2 = 128
        self.output_3 = 256
        self.output_4 = 1

        self.input_1 = self.output_0
        self.input_2 = self.output_1
        self.input_3 = self.output_2
        self.input_4 = self.output_3
        
        self.stride_0 = 3
        self.stride_1 = 3
        self.stride_2 = 2
        self.stride_3 = 2
        self.stride_4 = 2
        
        self.padding_0 = 1
        self.padding_1 = 1
        self.padding_2 = 1
        self.padding_3 = 1
        self.padding_4 = 1
        
        self.main = nn.Sequential(

            # input is (NC) x 128 x 128
            nn.Conv2d(NC, NDF, 4, stride=2, padding=1, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF) x 64 x 64
            nn.Conv2d(NDF, NDF * 2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(NDF * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*2) x 32 x 32
            nn.Conv2d(NDF * 2, NDF * 4, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(NDF * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*4) x 16 x 16 
            nn.Conv2d(NDF * 4, NDF * 8, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(NDF * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*8) x 8 x 8
            nn.Conv2d(NDF * 8, NDF * 16, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(NDF * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*16) x 4 x 4
            nn.Conv2d(NDF * 16, 1, 4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
    
# Create the Discriminator
netD = Discriminator(ngpu, D_DROPOUT).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

In [None]:
# Log metrics with wandb
wandb.watch([netG, netD])

# Initialize BCELoss function
criterion = nn.BCELoss(reduction="mean")

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(IMAGE_SIZE, latent_dim, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=DLR, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=GLR, betas=(beta1, 0.999))


# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0


# Continue Training

In [None]:
MODEL_DIR

In [None]:
model_path_to_load = f"{MODEL_DIR}{wandb.run.id}/deep_arcane.model"
model_path_to_load

In [None]:
if wandb.run.resumed:

    checkpoint = torch.load(model_path_to_load)

    netG.load_state_dict(checkpoint["generator"])
    netD.load_state_dict(checkpoint["discriminator"])
    optimizerG.load_state_dict(checkpoint["generator_optimizer"])
    optimizerD.load_state_dict(checkpoint["discriminator_optimizer"])

    EPOCHS += checkpoint["epoch"]
    starting_epoch = checkpoint["epoch"]
    print("Resumed model")
    del checkpoint

In [None]:
# Static noise for generating images comparable outputs.
static_noise = torch.randn(samples, latent_dim, 1, 1, device=device)

## Training Step

In [None]:
 
 def train_step(engine, data):               
      # Smooth labels.
      real_label = random.uniform(real_range[0], real_range[1])
      fake_label = random.uniform(fake_range[0], fake_range[1])

      # Set the models for training
      netG.train()
      netD.train()

      ############################
      # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
      ###########################
      ## Train with all-real batch
      netD.zero_grad()
      # Format batch
      real = data[0].to(idist.device())
      b_size = real.size(0)
      label = torch.full((b_size,), real_label, dtype=torch.float, device=idist.device())
      # Forward pass real batch through D
      output1 = netD(real).view(-1)
      # Calculate loss on all-real batch
      errD_real = criterion(output1, label)
      # Calculate gradients for D in backward pass
      errD_real.backward()

      ## Train with all-fake batch
      # Generate batch of latent vectors
      noise = torch.randn(b_size, latent_dim, 1, 1, device=idist.device())
      # Generate fake image batch with G
      fake = netG(noise)
      label.fill_(fake_label)
      # Classify all fake batch with D
      output2 = netD(fake.detach()).view(-1)
      # Calculate D's loss on the all-fake batch
      errD_fake = criterion(output2, label)
      # Calculate the gradients for this batch, accumulated (summed) with previous gradients
      errD_fake.backward()
      # Compute error of D as sum over the fake and the real batches
      errD = errD_real + errD_fake
      # Update D
      optimizerD.step()

      ############################
      # (2) Update G network: maximize log(D(G(z)))
      ###########################
      netG.zero_grad()
      label.fill_(real_label)  # fake labels are real for generator cost
      # Since we just updated D, perform another forward pass of all-fake batch through D
      output3 = netD(fake).view(-1)
      # Calculate G's loss based on this output
      errG = criterion(output3, label)
      # Calculate gradients for G
      errG.backward()
      # Update G
      optimizerG.step()
      
      return {
        "Loss_G" : errG.item(),
        "Loss_D" : errD.item(),
        "D_x": output1.mean().item(),
        "D_G_z1": output2.mean().item(),
        "D_G_z2": output3.mean().item(),
      }

  

## Train Model

In [None]:
trainer = Engine(train_step)

# If continuining a training session, load state:
if wandb.run.resumed:
    trainer.load_state_dict({"epoch": starting_epoch, "max_epochs": EPOCHS, "epoch_length": COUNT_OF_TRAINING_IMGS, "rng_state": None})

In [None]:
# def initialize_fn(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('BatchNorm') != -1:
#         nn.init.normal_(m.weight.data, 1.0, 0.02)
#         nn.init.constant_(m.bias.data, 0)

# @trainer.on(Events.STARTED)
# def init_weights():
#     netD.apply(initialize_fn)
#     netG.apply(initialize_fn)

In [None]:
G_losses = []
D_losses = []


@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])
    D_losses.append(o["Loss_D"])

In [None]:
img_list = []


@trainer.on(Events.ITERATION_COMPLETED(every=500))
def store_images(engine):
    with torch.no_grad():
        fake = netG(fixed_noise).cpu()
    img_list.append(fake)

## Log Images
Using a static distribution of noise, log the images across the training session.

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_images():
      # Display samples 
      netG.eval()
      generated_images = netG(static_noise)
      netG.train()
      images = display_generated_images(generated_images, samples)
      pil_images = [convert_tensor_to_image(image, IMAGE_SIZE, IMAGE_SIZE) for image in images]   
      wandb.log({"example": [wandb.Image(img) for img in pil_images]})



In [None]:
from ignite.metrics import FID, InceptionScore
fid_metric = FID(device=idist.device())
is_metric = InceptionScore(device=idist.device(), output_transform=lambda x: x[0])

In [None]:
import PIL.Image as Image

def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        # If black and white, ensure it is converted to 3-channel for comparing
        # to Inception images.
        pil_img = pil_img.convert('RGB')
        # If the image is not 299x299, resize it to be comparable.
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        noise = torch.randn(BATCH_SIZE, latent_dim, 1, 1, device=idist.device())
        netG.eval()
        fake_batch = netG(noise)
        fake = interpolate(fake_batch)
        real = interpolate(batch[0])
        return fake, real

In [None]:
evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

In [None]:

fid_values = []
is_values = []

@trainer.on(Events.EPOCH_COMPLETED(every=(EPOCH_TO_EVAL)))
def log_training_results(engine):

    # Get current epoch.
    epoch = engine.state.epoch

    # Evaluate model.
    evaluator.run(test_dataloader, max_epochs=1)
    metrics = evaluator.state.metrics
    fid_score = metrics['fid']
    is_score = metrics['is']
    fid_values.append(fid_score)
    is_values.append(is_score)

    # Grab the NN losses.
    loss_g = engine.state.output["Loss_G"]
    loss_d = engine.state.output["Loss_D"]

    print(f"Epoch [{epoch}/{EPOCHS}] Metric Scores")
    print(f"*    FID : {fid_score:4f}")
    print(f"*     IS : {is_score:4f}")
    print(f"* Loss_G : {loss_g:4f}")
    print(f"* Loss_D : {loss_d:4f}")

    # Log to Weights and Biases
    wandb.log({"epoch": epoch, 
               "FID": fid_score, 
               "IS": is_score,
               "Loss_G": loss_g, 
               "Loss_D": loss_d})
    
    # Handle saving model if lowest FID score.
    if epoch > EPOCH_TO_START_SAVING and fid_score > min(fid_values):
        print("Saving model: Started")

        # To prevent curruption, we save a temporary model, and then move it in place
        # once we are assured it has completely saved.  This gets around sessions
        # dying during a model save.
        save_model(model_path, netG, netD, optimizerG, optimizerD, epoch, criterion, prefix="tmp_")

        # Seems like the code continues to execute before results are secure
        # in the Google drive. Let's wait a bit before trying to move it.
        sleep(15)

        # Save a "tmp" model, in case this model is interrupted during saving.
        if os.path.exists(f"{MODEL_DIR}{wandb.run.id}/tmp_deep_arcane.model"):
        
            # Since we're here, we assume the model saved correctly and can make 
            # it the recorded model.
            move_command = f"mv {MODEL_DIR}{wandb.run.id}/tmp_deep_arcane.model {MODEL_DIR}{wandb.run.id}/deep_arcane.model"
            os.system(move_command)

            if gauth.access_token_expired:
                gauth.Refresh()

            # After moving the files, we need to empty the trash to keep the 
            # GDrive from filling up.
            [print(f"Deleted: {gfile.Delete()}") for gfile in pdrive.ListFile({'q': "trashed = True"}).GetList()]

            print("Saving model: Complete")


In [None]:
from ignite.metrics import RunningAverage


RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, 'Loss_G')
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, 'Loss_D')

In [None]:
from ignite.contrib.handlers import ProgressBar


ProgressBar().attach(trainer, metric_names=['Loss_G','Loss_D'])
ProgressBar().attach(evaluator)


In [None]:
def training(*args):
    trainer.run(dataloader, max_epochs=EPOCHS)

with idist.Parallel(backend='nccl') as parallel:
    parallel.run(training)

# Save Model

In [None]:
# save_model(model_path, netG, netD, optimizerG, optimizerD, EPOCHS, criterion)

## Ignite
https://pytorch-ignite.ai/blog/gan-evaluation-with-fid-and-is/

In [None]:
# TODO:
# 1. Modify lr per generator and discriminator
# 2. Update image plotter to be neater.  I.e., image grid.
# 3. Use 
#       pip install pipreqs
#       pipreqs /path/to/project
#    To preserve the requirements.txt as a WandB object.
# 4. Modify sampler to use static input.
# 5. Pre-train with SVGs
# 6. Create data augumentation schedule.  I.e., for the first half of the training
#.   making augmentation heavy, and taper off closer to finalization.



## Good GAN Reads
https://towardsdatascience.com/why-do-gans-need-so-much-noise-1eae6c0fb177

## Great Article on Activation
https://himanshuxd.medium.com/activation-functions-sigmoid-relu-leaky-relu-and-softmax-basics-for-neural-networks-and-deep-8d9c70eed91e

## Adding FID
https://colab.research.google.com/github/pytorch-ignite/pytorch-ignite.ai/blob/gh-pages/blog/2021-08-11-GAN-evaluation-using-FID-and-IS.ipynb