In [12]:
## Imports

import torch
import torchvision
from PIL import Image, ImageMath
import glob
import cv2
from torch.utils.data.dataset import Dataset
import random
import numpy as np
import matplotlib.pyplot as plt 
%matplotlib inline
import torch.nn as nn
import torch.nn.functional as F

In [16]:
# Define device for Cuda
# useTanh - Flag to train & test model for optional credit 1(use tanh in output layer)
# By default, usetanh= False, meaning running this program trains and test the colorization(without Tanh)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
useTanh = False

In [14]:
# Network for main assignment
class ColorNetwork(nn.Module):
  def __init__(self):
    super(ColorNetwork, self).__init__()
    
    self.layers = nn.Sequential(
        nn.Conv2d(1,16,kernel_size=3,stride=2),
        nn.BatchNorm2d(16),
        nn.LeakyReLU(),
        nn.Conv2d(16,32,kernel_size=3,stride=2),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(),
        nn.Conv2d(32,64,kernel_size=3,stride=2),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
        nn.Conv2d(64,128,kernel_size=3,stride=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
        nn.Conv2d(128,256,kernel_size=3,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),
        # Deconvolution

        nn.ConvTranspose2d(256, 128, kernel_size=3,stride=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
        nn.ConvTranspose2d(128, 64, kernel_size=3,stride=2),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
        nn.ConvTranspose2d(64, 32, kernel_size=3,stride=2),
        nn.BatchNorm2d(32),
        nn.LeakyReLU(),
        nn.ConvTranspose2d(32, 16, kernel_size=3,stride=2),
        nn.BatchNorm2d(16),
        nn.LeakyReLU(),
        nn.ConvTranspose2d(16, 2, kernel_size=3,stride=2, output_padding=1),
    )

  def forward(self, inputVal):
      inputVal = self.layers(inputVal)       
      return inputVal

In [None]:
# load the model that will be used for testing
model = ColorNetwork().to(device)
if useTanh == False:
    model.load_state_dict(torch.load("Network3.pt"))
else:
    model.load_state_dict(torch.load("Tanh.pt"))
criterion = nn.MSELoss()
model.eval()

In [None]:
# Utility function to convert RGB to lab image
# Post conversion, returns normalised L, a & b channels

def convertRGBToLAB(img):
  lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB).astype("float32")
  l,a,b = cv2.split(lab)
  l = l/100
  # Optional credit(OC)
  if useTanh == True:
    a = a/128
    b = b/128
  return (l, a, b)

In [None]:
# Custom Dataset that applies random scale, random horizontal flip and
# random crop to each train image 10 times. This custom class is reused for test loader too.
# Resize each augmented/ test image to 128x128

class ColorDataset(Dataset):
  def __init__(self, folder_path, transform=None, ran=None):
    self.image_list = glob.glob(folder_path)
    self.transform = transform
    self.T = []
    for img in self.image_list:
      single_img = Image.open(img)
      #single_img = (torchvision.transforms.ToTensor()(single_img)).float()
      if self.transform !=None:
        if ran != None: 
          for i in range(10):
            tran = self.transform(single_img)
            l,a,b = convertRGBToLAB(np.asarray(tran.permute(1,2,0)))
            self.T.append((l,a,b))
        else:
          #single_img = (torchvision.transforms.ToTensor()(single_img)).float()
          tran = self.transform(single_img)
          l,a,b = convertRGBToLAB(np.asarray(tran.permute(1,2,0)))
          self.T.append((l,a,b))

    self.data_len = len(self.T)
       
  def __getitem__(self, index):
    return self.T[index]

  def __len__(self):
    return self.data_len

In [None]:
# Load the test data and resize it to 128x128

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(128),
    torchvision.transforms.ToTensor()
])
test_dir = "blue_cis6930/nghosh/Test/*"
dataset_test = ColorDataset(test_dir, test_transforms)

test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=10)


In [None]:
# Display Test predicted colored image after merging ab with rescaled L
# Plot grayscale L, predicted colored and original RGB from left to right

from skimage import color

def DisplayFinalImage(testgray, predValues, targetVals):
  lab_pred = torch.cat((testgray,predValues), dim=1).cpu().detach().numpy()
  lab_pred = lab_pred.transpose(0,2,3,1)
  lab_pred[:, :, :, 0:1] = lab_pred[:, :, :, 0:1] * 100
  lab_tar = torch.cat((testgray,targetVals), dim=1).cpu().detach().numpy()
  lab_tar = lab_tar.transpose(0,2,3,1)
  lab_tar[:, :, :, 0:1] = lab_tar[:, :, :, 0:1] * 100
  if useTanh == True:
    lab_pred[:, :, :, 1:3] = lab_pred[:, :, :, 1:3] * 128
    lab_tar[:, :, :, 1:3] = lab_tar[:, :, :, 1:3] * 128
  plt.axis('off')
  for i in range(lab_pred.shape[0]-1):
    pred_rgb=cv2.cvtColor(lab_pred[i], cv2.COLOR_LAB2BGR).astype("float64")
    tar_rgb=cv2.cvtColor(lab_tar[i], cv2.COLOR_LAB2BGR).astype("float64")
    plt.subplot(1,3, 1)
    plt.imshow(torch.squeeze(testgray[i]).cpu().detach().numpy(), cmap='gray')
    plt.title("Grayscale(L)")
    plt.subplot(1,3, 2)
    plt.imshow(pred_rgb)
    plt.title("Predicted Colorization")
    plt.subplot(1,3, 3)
    plt.imshow(tar_rgb)
    plt.title("Original image")
    plt.show()

In [None]:
# Function to test the colorization network

def validate(grayImage, targetVals_test, criterion, model):
  model.eval()
  predValues = model(grayImage)
  DisplayFinalImage(grayImage, predValues, targetVals_test)
  loss_test = criterion(predValues, targetVals_test)
  return loss_test

In [None]:
# For each loaded test image, run it through the trained medel to get the predicted ab 
# and report the total average incurred loss for all test images

total_loss = 0.0
for test_img,a_test,b_test in test_loader:
  a = torch.unsqueeze(a_test, 1)
  b = torch.unsqueeze(b_test, 1)
  targetValues = torch.cat((a, b), dim=1)
  test_img = torch.unsqueeze(test_img , 1)
  loss = validate(test_img.to(device),targetValues.to(device),criterion,model)
  total_loss += loss
avg_loss = total_loss/ len(test_loader)
print("Average loss incurred with Test data is ", avg_loss)