In [2]:
import torch
import pytorch_lightning as pl
from data.image_folder_dataset import ImageFolderDataset
from processing.XDoG import xdog
from processing.transforms import RandomSketch, Sketch
from data.image_data_module import ImageDataModule
from networks.modules import SketchColoringModule, PaintCorrectionModule
import matplotlib.pyplot as plt
import numpy as np

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize

import os

%load_ext autoreload
%autoreload

In [3]:
# Declare some variables
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

DATASET_PATH = 'portraits/'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Training Setup

In [None]:
data_module = ImageDataModule(data_dir=DATASET_PATH, batch_size=4, image_size=256, hatch_pattern_path="./processing/textures/")
data_module.setup()

In [14]:
# Defining network hpyeprapameters here. Hyperparameters can include anything and will be saved in model checkpoints
# These hyperparameters will be avaialable in the neural network module.

# UNet encoder and decoder dimensions
# !!! Output and bottleneck are automatically added in the code.
colorizer_params = {
  'encoder_blocks': [
    {'in_c': 1, 'out_c': 32, 'normalize': False, 'affine': False, 'p': 0.5},
    {'in_c': 32, 'out_c': 64, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 64, 'out_c': 128, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 128, 'out_c': 256, 'normalize': True, 'affine': False, 'p': 0},
  ],
  'decoder_blocks': [
    {'in_c': 512, 'out_c': 256, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 256, 'out_c': 128, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 128, 'out_c': 64, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 64, 'out_c': 32, 'normalize': True, 'affine': False, 'p': 0},
  ],
  'activation': 'sigmoid'
}

style_params = {
  'encoder_blocks': [
    {'in_c': 3, 'out_c': 32, 'normalize': False, 'affine': False, 'p': 0},
    {'in_c': 32, 'out_c': 64, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 64, 'out_c': 128, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 128, 'out_c': 256, 'normalize': True, 'affine': False, 'p': 0},
  ],
}

discriminator_params = {
  'encoder_blocks': [
    {'in_c': 4, 'out_c': 32, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 32, 'out_c': 64, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 64, 'out_c': 128, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 128, 'out_c': 256, 'normalize': True, 'affine': False, 'p': 0},
    {'in_c': 256, 'out_c': 512, 'normalize': True, 'affine': False, 'p': 0},
  ],
}

hparams = {
  "colorizer_params": colorizer_params,
  "discriminator_params": discriminator_params,
  "style_params": style_params,
  "num_exemplars": 1,
  "exemplar_method": "self",
  "train_gan": True,
  "generator_frequency": 1,
  "discriminator_frequency": 1,
  "g": 1,
  "rec": 0,
  "perc": 0,
  "color": 100,
  "l1_beta": 0.1,
  "colorizer_lr": 1e-4,
  "min_lr": 1e-3,
  "max_lr": 1e-2,
  "discriminator_lr": 4e-4,
  "b1": 0.5,
  "b2": 0.99,
  "disc_b1": 0,
  "disc_b2": 0,
  "weight_decay": 1e-5,
  "perceptual_layer": [17],
  "gan_loss": 'BCE',
  "texture": 10,
}

In [15]:
model = SketchColoringModule(device=device, **hparams).float()

In [None]:
# Load tensorboard to view visualisations realtime during training
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
''' 
Pytorch Lightning provides a Trainer class which handles training loop.
The trainer automatically performs validation and training steps while also
logging key metrics which enables effective training.

The trainer takes in data modules and a model and automatically sets up the training loop


The trainier automatically handles GPU or TPU training without the need of
manually casting tensors to device.

The trainer saves model checkpoints so that best models can later be recovered

The trainer supports wide variety of options which make model training more efficient
and fast (16 bit precision, debugging, early stopping, etc.)

Full manual can be found here: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html
'''
# Getting a warning about ambigous batch size
# The warning is nothing serious, it happens because pytorch lightning
# does not handle dictionaries as data inputs well. In reality training loop
# works properly

# Warning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10349 
import warnings
warnings.filterwarnings('ignore')

classifier_logger = TensorBoardLogger(save_dir='lightning_logs', log_graph = False)

checkpoint_callback = ModelCheckpoint(
  monitor="val_loss",
  dirpath="models/",
  filename="colorizer",
  save_top_k=1,
  mode="min",
)

trainer = pl.Trainer(
  overfit_batches=2, # debug option, overfits the given proportion of the whole data
  track_grad_norm=2, # debug option, tracks gradient norms in tensorboard
  default_root_dir=os.getcwd(), # The directory to save and log training results
  max_epochs=500,
  gpus=1 if torch.cuda.is_available() else None, # Uncomment to use GPU training when available
  val_check_interval=0.1, # validate 10 times per epoch, frequent validation is helpful
  logger=classifier_logger, # Logger options to track training
  callbacks=[checkpoint_callback],
  gradient_clip_val = 1.0,
  gradient_clip_algorithm="value",
)

trainer.fit(model, data_module)