In [1]:
# Library imports
import random

import torch
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torch.utils.data import DataLoader

from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet

from ignite.engine import Events, Engine, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, PSNR, SSIM, FID

from skimage.color import rgb2lab, lab2rgb
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Hyperparameters
NET_IMG_SIZE = 256
LEARNING_RATE = 1e-3
BATCH_SIZE = 16

In [3]:
def rgbfromlab(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

In [4]:
class VOCColorization(datasets.VOCDetection):
    def __init__(
        self, 
        root = 'data', 
        year = '2012', 
        image_set = 'train', 
        download = True, 
        transform = None, 
        target_transform = None, 
        transforms = None):

        super().__init__(root, year=year, image_set=image_set, download=download, transform=transform, target_transform=target_transform, transforms=transforms)

    def __getitem__(self, index):
        # For now we can discard the annotation/label, we can modify this method later should we need it
        # Note that the variable length of the annotations causes problems with the dataloader when retrieving
        # a batch
        target, label = super().__getitem__(index)

        lab_image = self.preprocess_image(target) # target is the original PIL RGB Image

        return lab_image, transforms.ToTensor()(target) # Convert PIL Image to Tensor to use a dataloader

    """
    Takes a PIL Image in RGB mode and transfers it to CIELab color space.

    """
    def preprocess_image(self, img):
        #resize_transform = transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE))

        #img = resize_transform(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Convert RGB to Lab color space
        img_lab = transforms.ToTensor()(img_lab)

        # Adjust all channels to range [-1,1]
        img_lab[[0], ...] = img_lab[[0], ...] / 50. - 1. # L
        img_lab[[1,2], ...] = img_lab[[1,2], ...] / 110. # ab

        return img_lab



In [5]:
# Download the Pascal VOC2012 datasets
# For now, we'll use the 'train' image subset as training data and 'val' as the testing set.
training_data = VOCColorization(
    'data', 
    year='2012', 
    image_set='train',
    transform=transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE)), 
    download=True)

test_data = VOCColorization(
    'data', 
    year='2012', 
    image_set='val',
    transform=transforms.Resize((NET_IMG_SIZE, NET_IMG_SIZE)),
    download=True)

print(f'Training dataset size = {len(training_data)}')
print(f'Testing dataset size = {len(test_data)}')

Using downloaded and verified file: data\VOCtrainval_11-May-2012.tar
Extracting data\VOCtrainval_11-May-2012.tar to data
Using downloaded and verified file: data\VOCtrainval_11-May-2012.tar
Extracting data\VOCtrainval_11-May-2012.tar to data
Training dataset size = 5717
Testing dataset size = 5823


In [None]:
# Sample code to visualize random Colorization dataset images
lab_img, rgb_img = training_data[random.randint(0, len(training_data))]
print(f'lab_img = {lab_img.shape}\nrgb_img={rgb_img.shape}')

# Slice off L and ab channels
L = lab_img[[0], ...]
ab = lab_img[[1,2], ...]

# Convert from 1xHxW array to HxW so we can display it with PyPlot
new_L = L[0, :, :]
print(new_L)

# Display our images using pyplot
rows, cols = 1, 2
fig = plt.figure(figsize=(12,12))

fig.add_subplot(rows, cols, 1)
plt.title(f'Original (Resized to {NET_IMG_SIZE}x{NET_IMG_SIZE})')
plt.axis("off")
plt.imshow(transforms.ToPILImage()(rgb_img))

fig.add_subplot(rows, cols, 2)
plt.title('L* Channel')
plt.axis("off")
plt.imshow(new_L, cmap='gray')

plt.show()

In [6]:
# Create dataloaders for our datasets
training_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
#lab_img, rgb_img = next(iter(training_dataloader))

#print(f'lab_img.shape = {lab_img.shape}\nrgb_img.shape = {rgb_img.shape}')
#print(f'lab_img = {lab_img}\n')
#print(f'rgb_img = {rgb_img}')

In [7]:
body = create_body(resnet18, pretrained=True, n_in=1, cut=-2)
model = DynamicUnet(body, n_out=2, img_size=(NET_IMG_SIZE, NET_IMG_SIZE))
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using device: {device}')

print(model)
model.to(device)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Using device: cuda
DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchN

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [None]:
# Test to ensure that our model accepts inputs and returns outputs of the correct shape
lab_img, rgb_img = next(iter(training_dataloader))

L_channel = lab_img[:,[0], ...]
ab_channels = lab_img[:,[1,2], ...]

print(f'L_channel.shape = {L_channel.shape}\n')
print(f'ab_channel.shape = {ab_channels.shape}\n')


with torch.no_grad():
    L_channel = L_channel.to(device)

    ab_hat = model(L_channel)

    print(f'ab_hat = {ab_hat.shape}')

    colorized_img = rgbfromlab(L_channel, ab_hat)

    #pil_img = transforms.ToPILImage()(colorized_img[0])

    plt.imshow(colorized_img[0])


In [8]:
# Train a Unet (ResNet18 backbone) with L2 Loss
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.MSELoss()

# Define training loop step and training Engine
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    lab_img, rgb_img = batch[0].to(device), batch[1]

    L_channel = lab_img[:,[0], ...]
    ab_channels = lab_img[:,[1,2], ...]

    ab_prediction = model(L_channel)

    #print(f'L_channel.shape = {L_channel.shape}\tab_channels.shape = {ab_channels.shape}\tab_pred.shape = {ab_prediction.shape}\n')

    loss = loss_func(ab_prediction, ab_channels)
    loss.backward()

    optimizer.step()
    
    return loss.item()

trainer = Engine(train_step)

# Define validation loop step and validation Engine
def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        lab_img, rgb_img = batch[0].to(device), batch[1]

        L_channel = lab_img[:,[0], ...]
        ab_channels = lab_img[:,[1,2], ...]

        ab_prediction = model(L_channel)

        return ab_prediction, ab_channels

evaluator = Engine(validation_step)

# Define and attach metrics to engines
#accuracy = Accuracy()
#l2_loss = Loss(loss_func)

#accuracy.attach(trainer, "accuracy")
#l2_loss.attach(trainer, "l2_loss")

#accuracy.attach(evaluator, "accuracy")
#l2_loss.attach(evaluator, "l2_loss")

# Add event handlers to trainer engine
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def log_training_loss(engine):
    print(f"Epoch[{engine.state.epoch}] Iter[{engine.state.iteration}] Loss: {engine.state.output:.5f}")

#@trainer.on(Events.EPOCH_COMPLETED)
#def log_training_results(engine):
 #   engine.run(training_dataloader)
  #  metrics = engine.state.metrics
   # print(f"Training Results - Epoch: {engine.state.epoch} Avg loss: {metrics['l2_loss']:.2f}")

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(test_dataloader)
    metrics = evaluator.state.metrics
    print(f"Validation Results - Epoch: {evaluator.state.epoch} Avg loss: {evaluator.state.output:.5f}")

In [None]:
trainer.run(training_dataloader, max_epochs=3)