In [None]:
#200k training images(81% of total )
#45k test images (18% of total)

In [None]:
%%capture
!pip install pretrainedmodels

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler

import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb, rgb2gray
import pretrainedmodels
from pretrainedmodels import utils
from datetime import datetime
from pathlib import Path
from ipywidgets import FloatProgress


In [None]:
# https://modelzoo.co/model/pretrained-modelspytorch
se_resnet = pretrainedmodels.__dict__["se_resnet152"](
    num_classes=1000, 
    pretrained="imagenet"
)
print(pretrainedmodels.model_names)

In [None]:
# A torch.device is an object representing the device on which a torch.Tensor is or will be allocated.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
#hyper parameters

VALIDATION_SIZE = 0.2
# BATCH_SIZE defines the number of samples that will be propagated through the network.
BATCH_SIZE = 128 #64 #128 #250 #64

EPOCHS = 50
LEARNING_RATE = 0.001

In [None]:
# Load a "model" at checkpoint saved at "path" along with the optimizer and the epoch  to start with.
def load_checkpoint(path, model, optimizer):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return model, optimizer, checkpoint["epoch"]

In [None]:
# grayscale image is the L channel tensor, ab_inout are the AB channels tensor of LAB image, stack them to get LAB image.
def convert_to_LAB(grayscale_input, ab_input):
    color_image = torch.cat((grayscale_input, ab_input), axis=0).numpy()
    #torch shape is C,H,W
    # required for matplotlib = H,W,C
    color_image = color_image.transpose((1, 2, 0)) 
    color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
    color_image = lab2rgb(color_image.astype(np.float64))
    return color_image

In [None]:
# https://lukemelas.github.io/image-colorization.html
# Takes in grayscale, ABchannel tensors, and the AB channel groundtruth tensor and returns RGB numpy arrays of grayscale, ABPredicted and ABgroundthruth images.
def convert_to_rgb(grayscale_input, ab_input, ab_ground_truth):
    predicted_image = convert_to_LAB(grayscale_input, ab_input)
    ground_truth_image = convert_to_LAB(grayscale_input, ab_ground_truth)
    grayscale_input = grayscale_input.squeeze().numpy()
    return grayscale_input, predicted_image, ground_truth_image

In [None]:
#Helper method to display the images.Takes in the grayscale, predicted and groundtruth images as RGB numpy arrays, plots and returns the matplot figure.
def display_images(grayscale, pred_image, ground_truth_image):
    f, axarr = plt.subplots(1, 3, figsize=(20, 10))
    axarr[0].imshow(grayscale, cmap="gray")
    axarr[0].set_title("Grayscale Image (Model Input)", fontsize=20)
    axarr[1].imshow(pred_image)
    axarr[1].set_title("RGB Image (Model Output)", fontsize=20)
    axarr[2].imshow(ground_truth_image)
    axarr[2].set_title("RGB Image (Ground-truth)", fontsize=20)
    return f

In [None]:
#Encoder network designed as per the paper. Takes in a 224x224 x1 scaled grayscale image(L channel) a
class Encoder(nn.Module):
    # https://d2l.ai/chapter_convolutional-neural-networks/padding-and-strides.html
    def __init__(self):
        super(Encoder, self).__init__()    
        self.input_ = nn.Conv2d(1, 64, 3, padding=1, stride=2)
        self.conv1 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 128, 3, padding=1, stride=2)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1, stride=2)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv6 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv7 = nn.Conv2d(512, 256, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.input_(x))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))

        return x

In [None]:
class Decoder(nn.Module):
    # https://d2l.ai/chapter_convolutional-neural-networks/padding-and-strides.html
    def __init__(self):
        super(Decoder, self).__init__()
        self.input_1 = nn.Conv2d(1256, 256, 1)
        self.input_ = nn.Conv2d(256, 128, 3, padding=1)
        self.conv1 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv4 = nn.Conv2d(32, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.input_1(x))
        x = F.relu(self.input_(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv3(x))
        x = torch.tanh(self.conv4(x))
        x = F.interpolate(x, scale_factor=2)
        return x

In [None]:
se_resnet = se_resnet.to(device)
se_resnet.eval()

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.encoder = Encoder()
        self.encoder = self.encoder.to(device)

        self.decoder = Decoder()
        self.decoder = self.decoder.to(device)

    def forward(self, encoder_input, feature_input):
        encoded_img = self.encoder(encoder_input)

        with torch.no_grad():
            embedding = se_resnet(feature_input)

        embedding = embedding.view(-1, 1000, 1, 1)

        rows = torch.cat([embedding] * 28, dim=3)
        embedding_block = torch.cat([rows] * 28, dim=2)
        fusion_block = torch.cat([encoded_img, embedding_block], dim=1)

        return self.decoder(fusion_block)

In [None]:
# Initializing the pretrained.utils methods
load_img = utils.LoadImage()
tf_img = utils.TransformImage(se_resnet) 

# Encoder and se_resnet models take in different HxW images
encoder_transform = transforms.Compose([transforms.CenterCrop(224)])
se_resnet_transform = transforms.Compose([transforms.CenterCrop(224)])

class ImageDataset(datasets.ImageFolder):
    """
    Subclass of ImageFolder that separates LAB channels into L and AB channels.
    It also transforms the image into the correctly formatted input for se_resnet.
    """
    def __getitem__(self, index):
        img_path, _ = self.imgs[index]

        img_se_resnet = tf_img(se_resnet_transform(load_img(img_path)))
        img = self.loader(img_path)

        img_original = encoder_transform(img)
        img_original = np.asarray(img_original)

        img_lab = rgb2lab(img_original)
        img_lab = (img_lab + 128) / 255

        img_ab = img_lab[:, :, 1:3]
        
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()

        img_gray = rgb2gray(img_original)
        img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()

        return img_gray, img_ab, img_se_resnet

In [None]:
train_data = ImageDataset("../data/imagenet_images")


In [None]:
num_train = len(train_data)
indices = list(range(num_train))
print(num_train)
np.random.shuffle(indices)
split = int(np.floor(VALIDATION_SIZE * num_train))

train_idx, valid_idx = indices[split:], indices[:split]
# SubsetRandomSampler: Samples elements randomly from a given list of indices, without replacement.
train_samp = SubsetRandomSampler(train_idx)
valid_samp = SubsetRandomSampler(valid_idx)

# train_samp, valid_samp = retrieve_training_validation_samplers(
#     train_data, 
#     VALIDATION_SIZE
# )

train_dataloader = torch.utils.data.DataLoader(
  train_data, 
  batch_size=BATCH_SIZE, 
  sampler=train_samp,
  num_workers=4
)
print(len(train_dataloader))
valid_dataloader = torch.utils.data.DataLoader(
  train_data, 
  batch_size=BATCH_SIZE, 
  sampler=valid_samp,
  num_workers=4
)

In [None]:
model = Network()
model = model.to(device)

#Mean Squared Logarithmic Error (MSLE)
#Sometimes, one may not want to penalize the model too much for predicting unscaled quantities directly.
# Relaxing the penalty on huge differences can be done with the help of Mean Squared Logarithmic Error.
class MSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))


criterion = MSLELoss() #nn.MSELoss()
#MSLELoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [None]:
checkpoints=[]
checkpoint_path = "./checkpoints/SEnet/colorization_senet.pt"
checkpoints.append(checkpoint_path)
# model_name='colorization'
# model_name = "{}.pt".format(model_name)
# save_path = os.path.join(checkpoint_path, model_name)


# if not os.path.exists(save_path):
#     open(save_path, 'w').close()
    
# model, optimizer, epochs = load_checkpoint(save_path, model, optimizer)

In [None]:
test_root = "../data/test_data"

In [None]:
def run_test_without_groundtruth(location):
    test_data = ImageDataset(test_root)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)
    img_gray, img_ab, img_se_resnet = iter(test_dataloader).next()
    img_gray, img_ab, img_se_resnet = img_gray.to(device), img_ab.to(device), img_se_resnet.to(device)

    model.eval()
    with torch.no_grad():
        output = model(img_gray, img_se_resnet)
    
    for idx in range(1):
        grayscale, predicted_image, _ = convert_to_rgb(
          img_gray[idx].cpu(), 
          output[idx].cpu(), 
          img_ab[idx].cpu()
        )

        f, axarr = plt.subplots(1, 2, figsize=(20, 10))
        axarr[0].imshow(grayscale, cmap="gray")
        axarr[0].set_title("Grayscale Image (Model Input)", fontsize=20)
        axarr[1].imshow(predicted_image)
        axarr[1].set_title("RGB Image (Model Output)", fontsize=20)
        plt.imshow(f)
        f.savefig(location,facecolor='white', transparent=False)

In [None]:
def run_test_with_groundtruth(location):
    test_data = ImageDataset(test_root)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=5)
    # print('test_data',len(test_dataloader))
    batch_size=5
    img_gray, img_ab, img_se_resnet = iter(test_dataloader).next()
    # plt.imshow(img_gray)
    # plt.imshow(img_ab)
    # plt.imshow(img_se_resnet)
    img_gray, img_ab, img_se_resnet = img_gray.to(device), img_ab.to(device), img_se_resnet.to(device)

    model.eval()
    with torch.no_grad():
        output = model(img_gray, img_se_resnet)
    # randomList = random.sample(range(len(test_dataloader)),batch_size)
    for idx in range(5): # range(5):
        grayscale, predicted_image, ground_truth = convert_to_rgb(
          img_gray[idx].cpu(), 
          output[idx].cpu(), 
          img_ab[idx].cpu()
        )
        f= display_images(grayscale, predicted_image, ground_truth)
        f.savefig(location,facecolor='white', transparent=False)

### Run the below cell only for training 

In [None]:
min_validation = np.Inf # initially set to infinity
val_losses = []
train_losses = []
final_model ='./checkpoints/colorization_final_seResnet.pt'
filesize = os.path.getsize(final_model)
if filesize !=0:
    model, optimizer, epochs = load_checkpoint(final_model, model, optimizer)
    EPOCHS -= epochs

for i in tqdm(range(EPOCHS), desc="Epoch"):
    location = './results/senet/'+'with_groundtruth'+'e'+str(i)
    final_model ='./checkpoints/colorization_final_seResnet.pt'
    filesize = os.path.getsize(final_model)#checkpoints[-1])
#     if filesize !=0:
#         model, optimizer, epochs = load_checkpoint(final_model, model, optimizer)


    train_loss = 0
    model.train()

    for img_gray, img_ab, img_se_resnet in tqdm(train_dataloader, desc="Training"):
        img_gray, img_ab, img_se_resnet= img_gray.to(device), img_ab.to(device), img_se_resnet.to(device)
        #initialize gradients
        optimizer.zero_grad()
        #forward pass
        output = model(img_gray, img_se_resnet)
        loss = criterion(output, img_ab)
        #backward pass
        loss.backward()
        #computes loss and updates the weights.
        optimizer.step()
        train_loss += loss.item()
    else:
        valid_loss = 0
        accuracy = 0
        model.eval()
#     train_losses.append(train_loss) #.item()

    with torch.no_grad():
        for img_gray, img_ab, img_se_resnet in tqdm(valid_dataloader, desc="Validating"):
            img_gray, img_ab, img_se_resnet = img_gray.to(device), img_ab.to(device), img_se_resnet.to(device)

            output = model(img_gray, img_se_resnet)
            valid_loss += criterion(output, img_ab)
#         val_losses.append(valid_loss)
#     train_loss_epoch = train_loss/len(train_dataloader)
#     valid_loss_epoch = valid_loss.item()/len(valid_dataloader)
    train_loss = train_loss/len(train_dataloader)
    valid_loss = valid_loss/len(valid_dataloader)
    train_losses.append(train_loss)
    val_losses.append(valid_loss)   
    print("Epoch: {}/{}.. ".format(i+1, EPOCHS))
    print("-------------------------------------------------")
    print("Training Loss: {:.6f}.. ".format(train_loss))
    print("Validation Loss: {:.6f}.. ".format(valid_loss))

    if valid_loss <= min_validation:
        checkpoints.append(checkpoint_path[:-3]+datetime.now().strftime('%Y-%m-%d %H:%M:%S')+'.pt')
        fle = Path(checkpoints[-1])
        fle.touch(exist_ok=True)
        f = open(fle)
        
        print("Validation loss decreased ({:.6f} --> {:.6f}). Saving model." \
        .format(
          min_validation,
          valid_loss
        ))
        
#             torch.save(model.state_dict(), output)
#             # torch.save(model.state_dict(), save_path)
#             output.close()
#             best_eval.early_current_patience = 0
        torch.save(
          {
            "epoch": i,
            "model": model,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": criterion
          }, 
#           checkpoint_path
            checkpoints[-1]
        )
        min_validation = valid_loss
        

    torch.save(
      {
        "epoch": i,
        "model": model,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": criterion
      }, 
#           checkpoint_path
        final_model
    )
    run_test_with_groundtruth(location)        
    plt.figure(figsize=(10,5))
    plt.title("Training and Validation Loss")
    plt.plot(val_losses,label="val")
    plt.plot(train_losses,label="train")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

### Load the trained model, and run it on images in the new_images_root location.
Run all the above cells except the training cell and then run the below cells.

In [None]:
new_images_root = "../data/new_images"
# new_images_root = "../data/test_data"
def run_test_new_images(location):
    test_data = ImageDataset(new_images_root)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size= 10) # len(test_data))
    # print('test_data',len(test_dataloader))
    batch_size=5
    img_gray, img_ab, img_inception = iter(test_dataloader).next()
    # plt.imshow(img_gray)
    # plt.imshow(img_ab)
    # plt.imshow(img_inception)
    img_gray, img_ab, img_inception = img_gray.to(device), img_ab.to(device), img_inception.to(device)

    model.eval()
    with torch.no_grad():
        output = model(img_gray, img_inception)
    # randomList = random.sample(range(len(test_dataloader)),batch_size)
    for idx in range(10):# len(test_data)): # range(5):
        grayscale, predicted_image, ground_truth = convert_to_rgb(
          img_gray[idx].cpu(), 
          output[idx].cpu(), 
          img_ab[idx].cpu()
        )
        f= display_images(grayscale, predicted_image, ground_truth)
        f.savefig(location+str(idx)+'.png',facecolor='white', transparent=False)

In [None]:
# final_model = './colorization_final_efficientnet_old.pt'
# final_model ='./colorization_final.pt'
# final_model ='./colorization_authorimpl_final.pt'
final_model ='./colorization_final_seResnet.pt'
# model, optimizer, epochs = load_checkpoint(, model, optimizer)
# image_storage_location = '/home/aadityasp/Aditya/MS/sem2_fall2021/CV_PatternRecognition/Project/data/test_outputs/test_data_images_output_seResnet/'
image_storage_location = '/home/aadityasp/Aditya/MS/sem2_fall2021/CV_PatternRecognition/Project/data/new_images_outputs/colorization_final_seresnet/'

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model = Network()
# model = model.to(device)
checkpoint = torch.load(final_model)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model, optimizer, epochs =load_checkpoint(final_model,model,optimizer)
print(model)

run_test_new_images(image_storage_location)

In [None]:
#References
#https://arxiv.org/pdf/1712.03400.pdf
#https://github.com/baldassarreFe/deep-koalarization
#https://pytorch.org/docs/stable/data.html
#https://modelzoo.co/model/pretrained-modelspytorch
#https://arxiv.org/pdf/1905.11946.pdf
# https://lukemelas.github.io/image-colorization.html
# https://www.geeksforgeeks.org/training-neural-networks-with-validation-using-pytorch/
# https://github.com/lauradang/automatic-image-colorization