In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install fastai

# Imports

In [None]:
import warnings
warnings.filterwarnings('ignore')

## Import COCO Dataset from FastAI

In [None]:
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

## Imports Libraries

In [None]:
'''
file: dependencies.py
author: @vincit0re
brief: This file contains the dependencies for the application.
date: 2023-05-05
'''

# All import statements and libraries
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from matplotlib.ticker import MaxNLocator
import copy
use_gpu = True
device = torch.device("cuda" if (
    torch.cuda.is_available() and use_gpu) else "cpu")
print(f"Using Device: {device}")

# Colorization

## Hyperparameters

In [None]:
# Hyperparameters
class Hyperparameters:
    '''This contains the hyperparameters for the application.'''
    _SIZE = 256
    _DATA_DIR = coco_path
    _BATCH_SIZE = 32
    _N_EPOCHS = 50
    _SAVE_PATH = 'basemodel.pt'

## Dataset Images

In [None]:
def load_images(path):
    '''This function loads the images from the path.'''
    paths = glob.glob(path + "/*.jpg") # Grabbing all the image file names
    np.random.seed(123)
    paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 10000 images randomly
    rand_idxs = np.random.permutation(10_000)
    train_idxs = rand_idxs[:8000] # choosing the first 8000 as training set
    val_idxs = rand_idxs[8000:] # choosing last 2000 as validation set
    train_paths = paths_subset[train_idxs]
    val_paths = paths_subset[val_idxs]
    return train_paths, val_paths

train_imgs_path, val_imgs_path = load_images(Hyperparameters._DATA_DIR)

print(f"Number of Training Images: {len(train_imgs_path)}")
print(f"Number of Validation Images: {len(val_imgs_path)}")

In [None]:
def plot_sample_images(images_path):
    images_path = images_path[:16]
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    for ax, img_path in zip(axes.flatten(), images_path):
        ax.imshow(Image.open(img_path))
        ax.axis("off")

    plt.suptitle("Sample Images")
    plt.tight_layout()
    plt.show()

plot_sample_images(train_imgs_path)

## Dataset Class

In [None]:
# dataset class
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train', size=256):
        self.SIZE = size
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((self.SIZE, self.SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(),  # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize(
                (self.SIZE, self.SIZE),  Image.BICUBIC)

        self.split = split
        self.size = self.SIZE
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")  # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.  # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110.  # Between -1 and 1

        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)


# A handy function to make our dataloaders
def make_dataloaders(batch_size=16, n_workers=2, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(size= Hyperparameters._SIZE, **kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

# get train, val dataloaders


def get_train_val_dataloaders(train_paths, val_paths, batch_size=64, shuffle=True):
    '''This function returns the train, validation and test dataloaders.'''
    train_dataloader = make_dataloaders(
        batch_size=batch_size, paths=train_paths, split='train')
    val_dataloader = make_dataloaders(
        
        batch_size=batch_size, paths=val_paths, split='val')
    return train_dataloader, val_dataloader

train_loader, val_loader = get_train_val_dataloaders(train_paths=train_imgs_path, val_paths=val_imgs_path, batch_size=Hyperparameters._BATCH_SIZE, shuffle=True)
print(f"Train Data: {len(train_loader.dataset)} ({len(train_loader)} batches)")
print(f"Validation Data: {len(val_loader.dataset)} ({len(val_loader)} batches)")

In [None]:
for batch in train_loader:
    print(batch['L'].shape)
    print(batch['ab'].shape)
    break

## U-Net

In [None]:
# block for UNet


class UnetBlock(nn.Module):
    '''This is the block for the UNet.
        Args:
            nf (int): Number of filters.
            ni (int): Number of input channels.
            submodule (nn.Module): Submodule.
            input_c (int): Number of input channels.
            dropout (bool): Dropout.
            innermost (bool): Innermost.
            outermost (bool): Outermost.
    '''

    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None:
            input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout:
                up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)

    # forward
    def forward(self, x):
#         print(x.shape)
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)
        
# class for UNet


class Unet(nn.Module):
    '''This is the UNet class.
        Args:
            input_c (int): Number of input channels.
            output_c (int): Number of output channels.
            n_down (int): Number of down samples.
            num_filters (int): Number of filters.
    '''

    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(
            num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(
                num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(
                out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(
            output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)

In [None]:
print(Unet())

## Discriminator

In [None]:
# discriminator class


class PatchDiscriminator(nn.Module):
    '''This is the discriminator class.
        Args:
            input_c (int): Number of input channels.
            num_filters (int): Number of filters.
            n_down (int): Number of down samples.
    '''

    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
                  for i in range(n_down)]  # the 'if' statement is taking care of not using
        # stride of 2 for the last block in this loop
        # Make sure to not use normalization or
        model += [self.get_layers(num_filters * 2 **
                                  n_down, 1, s=1, norm=False, act=False)]
        # activation for the last layer of the model
        self.model = nn.Sequential(*model)

    # when needing to make some repetitive blocks of layers,
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        # it's always helpful to make a separate method for that purpose
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
        if norm:
            layers += [nn.BatchNorm2d(nf)]
        if act:
            layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


In [None]:
print(PatchDiscriminator(input_c=1))

## GAN Loss

In [None]:
# gan loss


class GANLoss(nn.Module):
    '''This is the GAN loss.
        Args:
            gan_mode (str): GAN mode.
            real_label (float): Real label.
            fake_label (float): Fake label.
    '''

    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

## Model Initialization

In [None]:
# weights initialization


def init_weights(net, init='norm', gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

# model initialization


def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

# Main Model

In [None]:
# main model


class MainModel(nn.Module):
    '''This is the main model class.
        Args:
            net_G (nn.Module): Generator network.
            lr_G (float): Learning rate for the generator.
            lr_D (float): Learning rate for the discriminator.
            beta1 (float): Beta1 for Adam optimizer.
            beta2 (float): Beta2 for Adam optimizer.
            lambda_L1 (float): Weight for L1 loss.
    '''

    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(
                Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(
            input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(),
                                lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(),
                                lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(
            self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

## Metrics

In [None]:
# average meter class for loss


class AverageMeter:
    '''This class is used to keep track of loss and other metrics.'''

    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

# create loss meters


def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

# update loss meters in training


def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

# plot losses


def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

# visualize results


def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.tight_layout()
    plt.show()
    if not os.path.exists("results"):
        os.makedirs("results")
    if save:
        fig.savefig(f"results/colorization_{time.time()}.png")

# print results


def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")
        if not os.path.exists('losses.txt'):
            with open('losses.txt', 'w') as f:
                f.write(f"{loss_name}: {loss_meter.avg:.5f}\n")
        else:
            with open('losses.txt', 'a') as f:
                f.write(f"{loss_name}: {loss_meter.avg:.5f}\n")

## Train Model

In [None]:
def train_model(model, train_dl, val_dl, epochs, display_every=200, save_path='model.pt'):
    best_loss = 1e10
    best_model = None
    
    for e in range(epochs):
        # getting a batch for visualizing the model output after fixed intervals
        data_val = next(iter(val_dl))
        # function returning a dictionary of objects to
        loss_meter_dict = create_loss_meters()
        i = 0                                  # log the losses of the complete network
        for data in tqdm(train_dl):
            model.setup_input(data)
            model.optimize()
            # function updating the log objects
            update_losses(model, loss_meter_dict, count=data['L'].size(0))
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                # function to print out the losses
                log_results(loss_meter_dict)
                # function displaying the model's outputs
                visualize(model, data_val, save=True)        
        
        # save model after every epoch
        if loss_meter_dict['loss_G'].avg < best_loss:
            best_loss = loss_meter_dict['loss_G'].avg
            best_model = model

    torch.save(best_model, save_path)
    return best_model

### Training 

In [None]:
model = MainModel()
trained_model = train_model(model, train_loader, val_loader, epochs= Hyperparameters._N_EPOCHS, save_path= Hyperparameters._SAVE_PATH)

In [None]:
# get losses from results.txt


def get_losses(results_path):
    with open(results_path, 'r') as f:
        lines = f.readlines()

    # loss_D_fake: 0.44797
    # loss_D_real: 0.48049
    # loss_D: 0.46423
    # loss_G_GAN: 1.60854
    # loss_G_L1: 9.48314
    # loss_G: 11.09168

    losses_D_fake = []
    losses_D_real = []
    losses_D = []
    losses_G_GAN = []
    losses_G_L1 = []
    losses_G = []

    for line in lines:
        if "loss_D_fake" in line:
            losses_D_fake.append(float(line.split(":")[1]))
        elif "loss_D_real" in line:
            losses_D_real.append(float(line.split(":")[1]))
        elif "loss_D" in line:
            losses_D.append(float(line.split(":")[1]))
        elif "loss_G_GAN" in line:
            losses_G_GAN.append(float(line.split(":")[1]))
        elif "loss_G_L1" in line:
            losses_G_L1.append(float(line.split(":")[1]))
        elif "loss_G" in line:
            losses_G.append(float(line.split(":")[1]))

    return losses_D_fake, losses_D_real, losses_D, losses_G_GAN, losses_G_L1, losses_G

# plot loss curves


def plot_losses(losses, save_path=None):
    labels = ["loss_D_fake", "loss_D_real", "loss_D",
              "loss_G_GAN", "loss_G_L1", "loss_G"]
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].plot(losses[0], "*-", label=labels[0])
    ax[0].plot(losses[1], "*-", label=labels[1])
    ax[0].plot(losses[2], "*-", label=labels[2])
    ax[0].set_xlabel("Iterations")
    ax[0].set_ylabel("Loss")
    ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
    ax[0].legend()
    ax[0].set_title("Discriminator Losses")

    ax[1].plot(losses[3], "*-", label=labels[3])
    ax[1].plot(losses[4], "*-", label=labels[4])
    ax[1].plot(losses[5], "*-", label=labels[5])
    ax[1].set_xlabel("Iterations")
    ax[1].set_ylabel("Loss")
    ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
    ax[1].legend()
    ax[1].set_title("Generator Losses")
    plt.suptitle("Loss Curves")
    plt.tight_layout()
    plt.show()
    if save_path:
        fig.savefig(save_path)

In [None]:
losses = get_losses("losses.txt")
plot_losses(losses, save_path="losses.png")

## Testing

In [None]:
def predict_on_test_data(test_loader, model_path):
    '''This function returns the colorized version of the images in the test set.
        Args:
            test_loader (torch.utils.data.DataLoader): Test data loader.
            model_path (str): Path to the model.
        Returns:
            colorized_imgs (list): List of colorized images.
    '''
    count = 0
    model = torch.load(model_path)
    for data in test_loader:
        visualize(model=model, data=data, save=True)
        count += 1
        if count >= 2:
            break

In [None]:
predict_on_test_data(val_loader, '/kaggle/working/basemodel.pt')

# GrayScale Images

In [None]:
# get colored image from a given image


def get_colorized_image(img_path, model_path, SIZE=256, save_path=None):
    '''This function returns the colorized version of a given image.
        Args:
            img_path (str): Path to the image.
            model_path (str): Path to the model.
        Returns:
            colorized_img (numpy.ndarray): Colorized version of the image.
    '''
    transform = transforms.Resize(
        (SIZE, SIZE),  Image.BICUBIC)

    img = Image.open(img_path).convert("RGB")
    img = transform(img)
    img = np.array(img)
    img_lab = rgb2lab(img).astype("float32")  # Converting RGB to L*a*b
    img_lab = transforms.ToTensor()(img_lab)
    L = img_lab[[0], ...] / 50. - 1.  # Between -1 and 1
    ab = img_lab[[1, 2], ...] / 110.  # Between -1 and 1

    data = {"L": L.unsqueeze(0), "ab": ab.unsqueeze(0)}
    model = torch.load(model_path)
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)

    return fake_imgs[0], real_imgs[0]

# show the image


def plot_comparison(fake_img, real_img, save_path=None):

    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    ax[0].imshow(fake_img)
    ax[0].set_title("Colored")
    ax[1].imshow(real_img)
    ax[1].set_title("Original")
    ax[0].axis("off")
    ax[1].axis("off")
    plt.tight_layout()
    plt.show()
    if save_path is not None:
        plt.savefig(save_path)


## Results on Grayscale Images

In [None]:
img_dir = '/kaggle/input/grayscale-imgs/grayscale'
model_path = '/kaggle/working/basemodel.pt'
if not os.path.exists("outputs"):
        os.makedirs("outputs")
for img in os.listdir(img_dir):
    img_path = os.path.join(img_dir, img)
    col_img, real_img = get_colorized_image(
        img_path, model_path)
    plot_comparison(col_img, real_img, save_path="outputs/"+img)