In [2]:
GOOGLE_COLAB = True

if GOOGLE_COLAB:
  # Mount Google Drive
  from google.colab import drive
  drive.mount('/content/gdrive/')
  # Note: you need to put the path to whatever folder you have here
  path_prefix = '/content/gdrive/My Drive/CS182_Sketch2Img/Sketch2img/final/'
else:
  path_prefix = ''

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [0]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL
from timeit import default_timer as timer
import sys

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

if GOOGLE_COLAB:
  os.chdir(path_prefix) # set working directory to the one inside of My Drive

In [0]:
from generator import Generator
from discriminator import Discriminator
from dataset import load_sketchygan_dataset
from loss import discriminator_loss, generator_loss

In [0]:
# Train parameters: don't change the class size but you can change anything else.
batch_size = 10
num_classes = 125
init_in_channels = 3
max_epochs = 10
image_save_freq = 500
model_save_freq = None # none -> save at end. o.w. -> save every this many iters
write_freq = 50

In [17]:
# adapted from https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
def xavier_init(model):
  for param in model.parameters():
    if len(param.size()) == 2:
      torch.nn.init.xavier_normal(param)

ds, dl = load_sketchygan_dataset(8)
discriminator = Discriminator(num_classes, init_in_channels)
generator = Generator(num_classes)

xavier_init(generator)
xavier_init(discriminator)

opt_g = torch.optim.Adam(generator.parameters())
opt_d = torch.optim.Adam(discriminator.parameters())

  """


In [0]:
## NOTE: These two paramaters should be edited so you have reasonable filenames
##       When saving, or if you want to load a model that has been saved at
##       path_of_run_to_load.
run_name = 'test_run'
path_of_run_to_load = None

models_dir = os.path.join('saved_models', run_name)
images_dir = os.path.join('saved_images', run_name)
if not os.path.exists(models_dir):
  os.makedirs(models_dir)
if not os.path.exists(images_dir):
  os.makedirs(images_dir)
  
if path_of_run_to_load:
  if os.path.exists(path_of_run_to_load):
    checkpoint = torch.load(path_of_run_to_load)
    discriminator.load_state_dict(checkpoint['discriminator'])
    generator.load_state_dict(checkpoint['generator'])
    count = checkpoint['count']
    discriminator.eval()
    generator.eval()
  else:
    raise ValueError("Path of run to load is invalid.")
else:
  count = 0

In [0]:
for epoch in range(max_epochs):
    for batch_idx, (real_images, sketches, class_labels) in enumerate(dl):
      # Update discriminator
      discriminator.zero_grad()
      fake_images, noise = generator.forward(class_labels, sketches).detach()
      loss_d = discriminator_loss(discriminator, 
                                 real_images,
                                 fake_images,
                                 class_labels)
      loss_d.backward()
      opt_d.step()
      
      # Update generator
      generator.zero_grad()
      loss_g = discriminator_loss(discriminator,
                                 generator,
                                 real_images,
                                 class_labels)
      loss_g.backward()
      opt_g.step()

      # Print out progress periodically
      if count % write_freq == 0:
        template = 'Epoch [%d/%d] Batch [%d/%d]:\n\tDiscriminator Loss = %.4f, \n\t Generator Loss = %.4f'
        status = template % (epoch, max_epochs, batch_idx, len(dl), loss_d.data[0], loss_g.data[0])
        print(status)

      # Save real and fake images periodically
      if count % image_save_freq == 0:
        real_image_path = os.path.join(images_dir, 'real_sample_e%d_b%d.png' % (epoch, batch_idx))
        real_image = real_images[0]
        real_image = real_image[:, :64, :64]
        real_image = real_image.view(1, 3, 64, 64)
        torch.utils.save_image(real_image, real_image_path)
        
        fake_image_path = os.path.join(images_dir, 'fake_sample_e%d_b%d.png' % (epoch, batch_idx))
        fake_image = fake_images[0]
        fake_image = fake_image[:, :64, :64]
        fake_image = fake_image.view(1, 3, 64, 64)
        torch.utils.save_image(fake_image, fake_image_path)
        
      # Save model periodically
      if model_save_freq and count % model_save_freq:
        model_path = os.path.join(model_dir, '{}_c{}'.format(run_name, count))
        torch.save({
            'discriminator': discriminator.state_dict(),
            'generator': generator.state_dict(),
            'count': count,
            }, model_path)
                            
      count += 1

In [0]:
# Save model afterwards if you wish by running this cell
model_path = os.path.join(model_dir, '{}_final'.format(run_name, count))
torch.save({
    'discriminator': discriminator.state_dict(),
    'generator': generator.state_dict(),
    'count': count,
    }, model_path)