In [None]:
import os
# Change directory to the package folder
wd = globals()['_dh'][0]
os.chdir(wd)

/content/drive/MyDrive/Project_CS7643
Data	  Main_Final.ipynb  prepare-data.ipynb
__MACOSX  output	    project-proposal-black_white.pdf


In [None]:
import numpy as np
import cv2
import tqdm
import re
import matplotlib.pyplot as plt
import time
import math
import random

# Pytorch package
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

In [None]:
#!unzip '/content/drive/MyDrive/Project_CS7643/Data/places-data.zip'

# Colorization part

We will experiment with **three models.** The first one will be the baseline, replicating Baldassarre et al. (2017) with the Inception V3 as our feature extractor. The second model will use the Tiny ConvNeXt as the feature extractor. Finally, the third one experiments with the fusion layer with the baseline feature extractor. Instead of the 1000 features, we will use a random sample of size 256.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class CustomImageDataset(Dataset):
    def __init__(self, dir, model_type):
        self.dir = dir
        self.model_type = model_type
        self.images = [image for image in os.listdir(self.dir)]
        print('Number of images:', len(self.images))


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # cv2 reads the image in BGR format - make sure to take this into account later
        color_img = cv2.imread(os.path.join(self.dir,self.images[idx]))

        # if we need to scale
        color_img = color_img.astype(np.float32)
        color_img = color_img / 255

        # according to the paper by Baldassarre et al (2017), encoder images need to be of size 224x224
        # inception images should be of size 299x299, do the same here
        encoder_img = cv2.resize(color_img, (224, 224))

        # feature extractor
        if self.model_type == 'inception_v3' or self.model_type == 'fusion_experiment' or self.model_type == 'inception_class':
          inception_img = cv2.resize(color_img, (299, 299))
        else:
          inception_img = cv2.resize(color_img, (224, 224))

        # convert from BGR format to CIE L*a*b* color space
        encoder_img_cie = cv2.cvtColor(encoder_img,cv2.COLOR_BGR2Lab)

        # take the luminance channel and normalize to have a scale [-1,1]  luminance has a range from 0 to 100
        encoder_img_lumin = encoder_img_cie[:,:,0]/50 - 1
        # repeat the luminance channel 3 times and convert to tensor
        encoder_img_lumin = torch.from_numpy(np.repeat(encoder_img_lumin[np.newaxis,...], 3, axis=0))

        # now take the a and b channels (range from -128 to 127) and normalize them, then concatenate and convert to tensor
        encoder_img_a = encoder_img_cie[:,:,1]/128
        encoder_img_b = encoder_img_cie[:,:,2]/128
        encoder_img_ab = torch.from_numpy(np.stack(([encoder_img_a, encoder_img_b]), axis=0))

        # convert from BGR format to CIE L*a*b* color space - inception
        inception_img_cie = cv2.cvtColor(inception_img,cv2.COLOR_BGR2Lab)

        # again, take the luminance channel and normalize
        inception_img_lumin = inception_img_cie[:,:,0]/50 - 1

        # repeat the luminance channel of the inception image 3 times and convert to tensor
        inception_img_lumin = torch.from_numpy(np.repeat(inception_img_lumin[np.newaxis,...], 3, axis=0))

        return encoder_img_lumin, encoder_img_ab, inception_img_lumin, torch.from_numpy(encoder_img), self.images[idx]

In [None]:
# model type
modeltype = 'inception_v3'
# parameters
n_epochs = 20
batch_size = 64
lr = 0.001
lr_decrease_rate = 0.5
output_dir = wd + '/output/' + modeltype + '/'

In [None]:
# make sure all the models shuffle the data the same way
torch.backends.cudnn.deterministic = True
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)

In [None]:
# get the datasets
# initialize the data loader BEFORE the model class to make sure that the training data is shuffled the same way for all the models
print('Training data')
train_dataset = CustomImageDataset('Data/places-data/train', model_type = modeltype)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print('Validation data')
valid_dataset = CustomImageDataset('Data/places-data/valid', model_type = modeltype)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Training data
Number of images: 84312
Validation data
Number of images: 10510


In [None]:
class FusionLayer(nn.Module):
    def __init__(self, model_type):
        super(FusionLayer,self).__init__()
        self.model_type = model_type

    def forward(self, enc_img_and_emb, mask=None):
        # img_after_enc is of shape batch_size x 256 x H/8 x W/8. H = 224, W = 224
        # the size of the embedding out of the inception model is batch_size x 1000, in the paper the size is 1001
        img_after_enc, embedding = enc_img_and_emb
        # after reshape the size of the embedding should be [batch_size, 1000, 1, 1]
        embedding=torch.reshape(embedding, (embedding.shape[0],embedding.shape[1],1,1))
        # repeat H/8 times to have a size [batch_size, 1000, 28, 28]
        embedding = embedding.repeat(1,1,img_after_enc.shape[2],img_after_enc.shape[3])
        if self.model_type == 'fusion_experiment':
          # take random 256 out of 1000
          x = random.sample(range(1000), 256)
          embedding = embedding[:,x,:,:]
        # concat to have a fusion size - [batch_size, 1256, 28, 28]
        fusion = torch.cat((img_after_enc,embedding),dim = 1)
        return fusion

In [None]:
class Colorize(nn.Module):
    def __init__(self, depth_after_fusion, model_type):
        super(Colorize,self).__init__()
        # the encoder part
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
        )

        in_chan =1256
        if model_type == 'fusion_experiment':
          in_chan = 512
        # the fusion part
        self.fusion = FusionLayer(model_type)
        # in case of inception v3 the output of the fusion is of shape [batch_size, 1256, 28, 28], hench input channels = 1256
        # another convolutional layer
        self.after_fusion = nn.Conv2d(in_channels=in_chan, out_channels=depth_after_fusion,kernel_size=1, stride=1,padding=0)
        # last Batch norm before the decoder
        self.lastnorm = nn.BatchNorm2d(depth_after_fusion)
        # the decoder part
        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels=depth_after_fusion, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2.0),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Upsample(scale_factor=2.0),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
            nn.Upsample(scale_factor=2.0),
        )
    # forward pass
    def forward(self, img_lumin, img_embed):
        img_after_encoder = self.encoder(img_lumin)
        fusion = self.fusion([img_after_encoder, img_embed])
        fusion = self.after_fusion(fusion)
        fusion = self.lastnorm(fusion)
        pred_ab = self.decoder(fusion)
        return pred_ab

In [None]:
# initialize the inception and transformer models
inception_v3 = models.inception_v3(weights='DEFAULT').float().to(device)
inception_v3.eval()
convnext = models.convnext_tiny(weights='DEFAULT').float().to(device)
convnext.eval()
#transformer = models.vit_l_16(weights = 'DEFAULT').float().to(device)
#transformer.eval()

In [None]:
# set the model, Adam will be the optimizer
model = Colorize(256, model_type = modeltype).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
# loss, decrease the learning rate every 2 epochs
if modeltype == 'inception_class':
  loss_criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
else:
  loss_criterion = torch.nn.MSELoss(reduction='mean').to(device)

# milestone to decrease
milestone_list  = list(range(0, n_epochs ,2))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestone_list, gamma=lr_decrease_rate)

In [None]:
# train and validate
train_loss_history = []
valid_loss_history = []

for epoch in range(n_epochs):
    print('Epoch:',epoch+1)

    # Train
    average_loss = 0.0
    start = time.time()
    model.train()

    for idx, (encoder_img_lumin, encoder_img_ab, inception_img_lumin, encoder_img, img_name) in enumerate(train_dataloader):

        # send to device
        encoder_img_lumin = encoder_img_lumin.to(device)
        encoder_img_ab = encoder_img_ab.to(device)
        inception_img_lumin = inception_img_lumin.to(device)

        # zero out the gradients
        optimizer.zero_grad()
        # get the features from the feature extractor model
        if modeltype == 'inception_v3' or modeltype == 'fusion_experiment' or modeltype == 'inception_class':
          img_embed = inception_v3(inception_img_lumin.float())
        else:
          img_embed = convnext(inception_img_lumin.float())

        pred_ab = model(encoder_img_lumin, img_embed)

        # Backpropagate
        loss = loss_criterion(pred_ab, encoder_img_ab.float())
        loss.backward()
        optimizer.step()

        #reduce the learning rate according to the milestones
        scheduler.step()

        # calculate the losses
        average_loss += loss.item()

        if idx % 1000 == 0:
          print('Batch: ', idx)

    # Print the loss
    train_loss = average_loss/len(train_dataloader)*batch_size
    train_loss_history.append(train_loss)
    print('Training Loss:',train_loss,'-------- Time passed ', round(time.time()-start, 2))

    # Validation
    average_loss = 0.0
    start = time.time()
    model.eval()
    for idx,(encoder_img_lumin, encoder_img_ab, inception_img_lumin, encoder_img, img_name) in enumerate(valid_dataloader):

        # send to device
        encoder_img_lumin = encoder_img_lumin.to(device)
        encoder_img_ab = encoder_img_ab.to(device)
        inception_img_lumin = inception_img_lumin.to(device)

        # get the features from inception v3
        if modeltype == 'inception_v3' or modeltype == 'fusion_experiment' or modeltype == 'inception_class':
          img_embed = inception_v3(inception_img_lumin.float())
        else:
          img_embed = convnext(inception_img_lumin.float())
        pred_ab = model(encoder_img_lumin,img_embed)

        # calculate the losses
        loss = loss_criterion(pred_ab, encoder_img_ab.float())
        average_loss += loss.item()

        if idx % 1000 == 0:
          print('Batch: ', idx)

    val_loss = average_loss/len(valid_dataloader)*batch_size
    valid_loss_history.append(val_loss)
    print('Validation Loss:', val_loss,'-------- Time passed ', round(time.time()-start,2))


In [None]:
plt.plot(train_loss_history, color='blue', label='Train loss')
plt.plot(valid_loss_history, color='orange', label='Valid loss')
plt.legend()
plt.savefig(output_dir + modeltype + '_TrainValidLoss.jpg')

In [None]:
test_dataset = CustomImageDataset('Data/places-data/test', model_type = modeltype)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

In [None]:
average_loss = 0.0
start = time.time()

for idx, (encoder_img_lumin, encoder_img_ab, inception_img_lumin, encoder_img, img_name) in enumerate(test_dataloader):

    # send to device
    encoder_img_lumin = encoder_img_lumin.to(device)
    encoder_img_ab = encoder_img_ab.to(device)
    inception_img_lumin = inception_img_lumin.to(device)
    #
    model.eval()

    if modeltype == 'inception_v3' or modeltype == 'fusion_experiment' or modeltype == 'inception_class':
      img_embed = inception_v3(inception_img_lumin.float())
    else:
      img_embed = convnext(inception_img_lumin.float())
    pred_ab = model(encoder_img_lumin, img_embed)

    # colorize some of the test images
    if idx % 500 == 0:
      # remember we have repeated the luminance channel 3 times, take only one and scale back
      lum = (encoder_img_lumin[:,0, :,:][0] + 1)*50
      # transform to have size 1x224x224
      lum = torch.stack([lum],dim=0)
      # take the ab channels and scale back again, size- 2x224x224
      ab = pred_ab[0]*128
      # concat to get 3x224x224
      cie = torch.cat([lum, ab], dim =0)
      # convert to numpy and transpose to have 224x224x3
      cie = cie.cpu().detach().numpy().transpose(1,2,0)
      # convert to rgb and scale back
      rgb = cv2.cvtColor(cie,cv2.COLOR_Lab2BGR) *255
      cv2.imwrite(output_dir + modeltype + '_' + img_name[0], rgb)

    # calculate the losses
    loss = loss_criterion(pred_ab, encoder_img_ab.float())
    average_loss += loss.item()
test_loss = average_loss/len(test_dataloader)*1
print('Test Loss:', test_loss,'-------- Time passed ', round(time.time()-start,2))


In [None]:
with open(output_dir + modeltype + '.txt', 'w') as f:
  f.write('The test loss: %f' % test_loss)

In [None]:
# take the ground truth images and make them grayscale, then save
gt_dir = wd + '/output/ground_truth/'
gr_dir = wd + '/output/grayscale/'

gt = os.listdir(gt_dir)
for image in gt:
  img = cv2.imread(gt_dir + image)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  cv2.imwrite(gr_dir + 'grayscale_' + image, img)