In [None]:
!pip install fastai --upgrade
!pip install pytorch-ignite --upgrade
!pip install scipy --upgrade
!pip install scikit-image --upgrade

In [None]:
# Run to mount Google Drive to Colab instance
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Sep  8 19:38:18 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# Library imports
import os
import random
import warnings
import datetime

import torch
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torch.utils.data import DataLoader

from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet

from ignite.engine import Events, Engine, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, PSNR, SSIM, FID

from skimage.color import rgb2lab, lab2rgb
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Hyperparameters
NET_IMG_SIZE = 256
LEARNING_RATE = 1e-3
BATCH_SIZE = 25

# Settings
CHECKPOINT_PATH = 'drive/MyDrive/DeepPaint/Checkpoints'
METRICS_LOG_PATH = 'drive/MyDrive/DeepPaint/Metrics'

In [None]:
# Auxiliary functions
def rgbfromlab(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)

    
    return torch.from_numpy(np.stack(rgb_imgs, axis=0)).permute(0,3,1,2)

def save_checkpoint(model, optimizer, epoch, model_name, path):
    print('Creating checkpoint...')
    
    timestamp_str = get_current_timestamp()
    filename = f'{model_name}_E{epoch}_{timestamp_str}.pth'

    fullpath = os.path.join(path, filename)

    torch.save({
       'epoch': epoch,
       'model_state_dict': model.state_dict(),
       'optimizer_state_dict': optimizer.state_dict()
    }, fullpath)

    print(f'Done! Created checkpoint file: {fullpath}\n')

def log_metrics_to_csv(epoch, iter, metrics, header, filename, path, ext = '.csv', verbose = False):
    if verbose: print('Saving metrics...')

    fullpath = os.path.join(path, filename + ext)

    line = np.insert(metrics, 0, [epoch, iter])

    if not os.path.exists(fullpath):
        f = open(fullpath, mode='w')
        f.write(f'{header}\n')
        f.close()

    with open(fullpath, 'ab') as f:
        np.savetxt(f, [line], fmt='%7.5f', delimiter=',')

    if verbose: print(f'Done! Saved to {fullpath}')

# TODO: find out how to always print timestamp for UTC-6
def get_current_timestamp():
    return datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S")


In [None]:
class VOCColorization(datasets.VOCDetection):
    def __init__(
        self, 
        root = 'data', 
        year = '2012', 
        image_set = 'train', 
        download = True, 
        transform = None, 
        target_transform = None, 
        transforms = None):

        super().__init__(root, year=year, image_set=image_set, download=download, transform=transform, target_transform=target_transform, transforms=transforms)

    def __getitem__(self, index):
        # For now we can discard the annotation/label, we can modify this method later should we need it
        # Note that the variable length of the annotations causes problems with the dataloader when retrieving
        # a batch
        target, label = super().__getitem__(index)

        lab_image = self.preprocess_image(target) # target is the original PIL RGB Image

        return lab_image, transforms.ToTensor()(target) # Convert PIL Image to Tensor to use a dataloader

    """
    Takes a PIL Image in RGB mode and transfers it to CIELab color space.

    """
    def preprocess_image(self, img):
        #resize_transform = transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE))

        #img = resize_transform(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Convert RGB to Lab color space
        img_lab = transforms.ToTensor()(img_lab)

        # Adjust all channels to range [-1,1]
        img_lab[[0], ...] = img_lab[[0], ...] / 50. - 1. # L
        img_lab[[1,2], ...] = img_lab[[1,2], ...] / 110. # ab

        return img_lab



In [None]:
# Download the Pascal VOC2012 datasets
# For now, we'll use the 'train' image subset as training data and 'val' as the testing set.
training_data = VOCColorization(
    'data', 
    year='2012', 
    image_set='train',
    transform=transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE)), 
    download=True)

test_data = VOCColorization(
    'data', 
    year='2012', 
    image_set='val',
    transform=transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE)),
    download=True)

print(f'Training dataset size = {len(training_data)}')
print(f'Testing dataset size = {len(test_data)}')

Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to data/VOCtrainval_11-May-2012.tar


  0%|          | 0/1999639040 [00:00<?, ?it/s]

Extracting data/VOCtrainval_11-May-2012.tar to data
Using downloaded and verified file: data/VOCtrainval_11-May-2012.tar
Extracting data/VOCtrainval_11-May-2012.tar to data
Training dataset size = 5717
Testing dataset size = 5823


In [None]:
# Sample code to visualize random Colorization dataset images
lab_img, rgb_img = training_data[random.randint(0, len(training_data))]
print(f'lab_img = {lab_img.shape}\nrgb_img={rgb_img.shape}')

# Slice off L and ab channels
L = lab_img[[0], ...]
ab = lab_img[[1,2], ...]

# Convert from 1xHxW array to HxW so we can display it with PyPlot
new_L = L[0, :, :]
print(new_L)

# Display our images using pyplot
rows, cols = 1, 2
fig = plt.figure(figsize=(12,12))

fig.add_subplot(rows, cols, 1)
plt.title(f'Original (Resized to {NET_IMG_SIZE}x{NET_IMG_SIZE})')
plt.axis("off")
plt.imshow(transforms.ToPILImage()(rgb_img))

fig.add_subplot(rows, cols, 2)
plt.title('L* Channel')
plt.axis("off")
plt.imshow(new_L, cmap='gray')

plt.show()

In [None]:
# Create dataloaders for our datasets
training_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
#lab_img, rgb_img = next(iter(training_dataloader))

#print(f'lab_img.shape = {lab_img.shape}\nrgb_img.shape = {rgb_img.shape}')
#print(f'lab_img = {lab_img}\n')
#print(f'rgb_img = {rgb_img}')

In [None]:
body = create_body(resnet18, pretrained=True, n_in=1, cut=-2)
model = DynamicUnet(body, n_out=2, img_size=(NET_IMG_SIZE, NET_IMG_SIZE))
model_name = 'dyunet' # Used for checkpointing
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using device: {device}')

#print(model)
model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Using device: cuda


DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [None]:
# Test to ensure that our model accepts inputs and returns outputs of the correct shape
lab_img, rgb_img = next(iter(training_dataloader))

L_channel = lab_img[:,[0], ...]
ab_channels = lab_img[:,[1,2], ...]

print(f'L_channel.shape = {L_channel.shape}\n')
print(f'ab_channel.shape = {ab_channels.shape}\n')


with torch.no_grad():
    L_channel = L_channel.to(device)

    ab_hat = model(L_channel)

    print(f'ab_hat = {ab_hat.shape}')

    colorized_img = rgbfromlab(L_channel, ab_hat)

    #pil_img = transforms.ToPILImage()(colorized_img[0])

    plt.imshow(colorized_img[0])


In [None]:
# Create some smaller datasets and loaders to test the following code
smaller_test_ds = [training_data[i] for i in range(100)]
smaller_test_dl = DataLoader(smaller_test_ds, batch_size=BATCH_SIZE, shuffle=True)

smaller_val_ds = [test_data[i] for i in range(100)]
smaller_val_dl = DataLoader(smaller_val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Train a Unet (ResNet18 backbone) with L2 Loss
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.MSELoss()

# Define training loop step and training Engine
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    lab_img, rgb_img = batch[0].to(device), batch[1]

    L_channel = lab_img[:,[0], ...]
    ab_channels = lab_img[:,[1,2], ...]

    ab_prediction = model(L_channel)

    #print(f'L_channel.shape = {L_channel.shape}\tab_channels.shape = {ab_channels.shape}\tab_pred.shape = {ab_prediction.shape}\n')

    loss = loss_func(ab_prediction, ab_channels)
    loss.backward()

    optimizer.step()
    
    return loss.item()

trainer = Engine(train_step)

# Define validation loop step and validation Engine
def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        lab_img, rgb_img = batch[0].to(device), batch[1]

        L_channel = lab_img[:,[0], ...]
        ab_channels = lab_img[:,[1,2], ...]

        ab_prediction = model(L_channel)

        # Disable skimage warnings when converting lab -> rgb due to out-of-range
        # values (this is expected)
        with warnings.catch_warnings():
          warnings.simplefilter('ignore')
          rgb_pred = rgbfromlab(L_channel, ab_prediction)

        return rgb_pred, rgb_img

evaluator = Engine(validation_step)

# Define and attach metrics to engines
l2_loss = Loss(loss_func)
l2_loss.attach(evaluator, 'l2_loss')

psnr = PSNR(data_range=255.0)
psnr.attach(evaluator, 'psnr')

ssim = SSIM(data_range=255.0)
ssim = ssim.attach(evaluator, 'ssim')

fid = FID()
fid.attach(evaluator, 'fid')

# Add event handlers to trainer engine
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    epoch = engine.state.Epoch
    iter = engine.state.iteration
    loss = engine.state.output

    metrics = np.array([loss])

    print(f"Epoch[{epoch}] Iter[{iter}] Loss: {loss:.5f}")
    log_metrics_to_csv(epoch, iter, metrics, header='epoch,iter,loss', filename='training_loss', path=METRICS_LOG_PATH)

@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint(engine):
    save_checkpoint(model, optimizer, engine.state.epoch, model_name, CHECKPOINT_PATH)

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(smaller_val_dl)
    metrics = evaluator.state.metrics
    print(f"Validation Results - Epoch: {evaluator.state.epoch} Avg loss: {metrics['l2_loss']:.5f} PSNR: {metrics['psnr']:.5f} SSIM: {metrics['ssim']:.5f} FID: {metrics['fid']:.5f}")


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [None]:
trainer.run(smaller_test_dl, max_epochs=3)

Epoch[1] Iter[1] Loss: 0.30342
Epoch[1] Iter[2] Loss: 16636.75781
Epoch[1] Iter[3] Loss: 0.46851
Epoch[1] Iter[4] Loss: 0.18212
Creating checkpoint...
Done! Created checkpoint file: drive/MyDrive/DeepPaint/Checkpoints/dyunet_E1_08-09-2021_16-22-38.pth

Validation Results - Epoch: 1 Avg loss: 0.01965 PSNR: 66.01184 SSIM: 0.99728 FID: 0.21401
Epoch[2] Iter[5] Loss: 0.11562
Epoch[2] Iter[6] Loss: 0.06530
Epoch[2] Iter[7] Loss: 0.04114
Epoch[2] Iter[8] Loss: 0.07822
Creating checkpoint...
Done! Created checkpoint file: drive/MyDrive/DeepPaint/Checkpoints/dyunet_E2_08-09-2021_16-23-43.pth

Validation Results - Epoch: 1 Avg loss: 0.00954 PSNR: 70.60732 SSIM: 0.99872 FID: 0.16551
Epoch[3] Iter[9] Loss: 0.02806
Epoch[3] Iter[10] Loss: 0.05620
Epoch[3] Iter[11] Loss: 0.04892
Epoch[3] Iter[12] Loss: 0.02630
Creating checkpoint...
Done! Created checkpoint file: drive/MyDrive/DeepPaint/Checkpoints/dyunet_E3_08-09-2021_16-24-47.pth

Validation Results - Epoch: 1 Avg loss: 0.00955 PSNR: 70.46869 SSI

State:
	iteration: 12
	epoch: 3
	epoch_length: 4
	max_epochs: 3
	output: 0.02630404382944107
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [None]:
# Clear GPU cache (useful after a crash during training)
torch.cuda.empty_cache()