#1 - Implementing the model





##1.1- Preparing to start
The implementation in this project utilizes a training set consisting of 8,000 images from the COCO dataset, which is a smaller subset compared to the extensive ImageNet dataset used in the referenced paper. The training set size in this project accounts for approximately 0.6% of the dataset size used in the paper. It is worth noting that for the colorization task, you can employ various datasets as long as they encompass diverse scenes and locations, which are expected to enable the model to learn colorization effectively. For instance, even a subset of 8,000 images from ImageNet would be sufficient for this particular project.

In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch.nn.functional as F

from fastai.data.external import untar_data, URLs
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = True

Loading images

In [None]:
!pip install fastai --upgrade

In [None]:
from fastai.data.external import untar_data
coco_url = "http://images.cocodataset.org/zips/train2017.zip"
coco_path = untar_data(coco_url)
coco_train_path = coco_path / "train2017"
use_colab = True 

In [None]:
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"
use_colab = True

In [None]:
if use_colab == True:
    path = coco_path

paths = glob.glob(str(path + "/*.jpg")) # Grabbing all the image file names

np.random.seed(123)
np.random.shuffle(paths)
n_images = len(paths)
n_train = int(n_images * 0.8) 
train_paths = paths[:n_train] 
val_paths = paths[n_train:]

print(len(train_paths), len(val_paths))

In [None]:
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

##1.2- Making Datasets and DataLoaders (using Data Augumentation)
The code implementation incorporates several important techniques for handling and augmenting image data. First, the images undergo resizing and, in the case of the training set, horizontal flipping. These transformations ensure consistency in input dimensions and introduce diversity in the training data by simulating different orientations.

Next, a significant step involves converting the RGB images to the Lab color space. This conversion separates the grayscale (L) channel from the color (a, b) channels. The grayscale channel serves as the input to the model, while the color channels act as the target for colorization. This representation enables the model to learn the mapping from grayscale to color.

To further enhance the training process, data augmentation is employed. Data augmentation is a powerful technique for increasing the effective size of the training set and improving model generalization. It involves applying various transformations to the images, either during training or as a preprocessing step. In the case of colorization, potential augmentation techniques include rotations, zooms, and horizontal/vertical flips. These augmentations introduce variations and increase the diversity of the training data.

In the context of colorizing grayscale images, a particularly useful augmentation is adding noise. By adding random noise to the grayscale images, the model learns to handle and correct small perturbations in the input. This augmentation technique enhances the model's robustness and can contribute to better performance on real-world data.

Overall, the code implementation encompasses essential steps such as resizing, flipping, color space conversion, and data augmentation to prepare and augment the image data for effective training of colorization models. These techniques contribute to improved model performance, generalization, and robustness when handling grayscale images.

In [None]:
import torch

SIZE = 256
class ColorizationDatasetAugumentation(Dataset):
    def __init__(self, paths, split='train', noise_factor=0.05):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x + noise_factor * torch.randn_like(x))  # Add Gaussian noise
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader



In [None]:
#without augumentation
SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [None]:

train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

##1.3-Improving Gradient Flow in U-Net: Incorporating Residual Connections
Here we implemented a U-Net architecture for image colorization tasks using PyTorch. The U-Net follows a U-shaped architecture, where down-sampling and up-sampling modules are added iteratively around the middle part of the U shape. This design allows for the effective representation of both global and local features in the input data. The U-Net consists of multiple custom module classes, each serving a specific purpose in the overall architecture.

The U-Net architecture starts with the ConvLayer and TransposeConvLayer classes, which perform convolutional and transposed convolutional operations, respectively. These layers are responsible for extracting features, capturing spatial information, and upsampling the feature maps to reconstruct the output.

The core building block of the U-Net architecture is the UnetBlock class. It represents a single module within the U-Net and plays a crucial role in both the encoding and decoding processes. Each UnetBlock consists of down-convolution, down-normalization, up-relu, up-normalization, and up-convolution operations.

During the encoding phase, the down-convolution operation reduces the spatial dimensions and extracts lower-level features from the input data. Batch normalization is then applied for normalization and stability. On the other hand, the up-convolution operation upsamples the feature maps, recovering spatial details and expanding the resolution. The combination of down-convolution and up-convolution operations allows the U-Net to capture both global and local information in a multi-scale manner.

To enhance the learning capability of the U-Net, residual connections inspired by the paper "Deep Residual Learning for Image Recognition" (He et al., 2015) are incorporated. These connections mitigate the vanishing gradient problem, which can occur when training deeper neural networks. By adding skip connections between corresponding layers of the encoder and decoder, the model can directly transfer learned features, enabling effective gradient propagation and preventing the loss of information.

The Unet class represents the complete U-Net architecture. It sequentially stacks multiple UnetBlock instances, creating a comprehensive and deep architecture for image colorization. The U-Net progressively encodes the input data, capturing global and local features, and then decodes it to generate detailed colorization results.

Overall, the U-Net architecture, along with the specific operations performed by each layer, facilitates the effective encoding and decoding of features, capturing both global and local information to achieve accurate and detailed colorization of grayscale images. The incorporation of residual connections enhances the learning capability of the model, allowing for better gradient propagation and preservation of intricate details.

Reference:

- Deep Residual Learning for Image Recognition- https://arxiv.org/pdf/1512.03385.pdf

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, activation=None):
        super(ConvLayer, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(out_channels),
        ]
        if activation is not None:
            layers.append(activation)
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)


class TransposeConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, activation=None):
        super(TransposeConvLayer, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(out_channels),
        ]
        if activation is not None:
            layers.append(activation)
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)


class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, innermost=False, outermost=False):
        super(UnetBlock, self).__init__()
        self.outermost = outermost
        if input_c is None:
            input_c = nf
        downconv = ConvLayer(input_c, ni, activation=nn.LeakyReLU(0.2, inplace=True))
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(inplace=True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = TransposeConvLayer(ni * 2, nf, activation=nn.ReLU(inplace=True))
            down = [downconv]
            up = [upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = TransposeConvLayer(ni, nf, bias=False)
            down = [downconv]
            up = [upconv, upnorm]
            model = down + up
        else:
            upconv = TransposeConvLayer(ni * 2, nf, bias=False)
            down = [downconv, downnorm]
            up = [upconv, upnorm]
            if dropout:
                up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
        self.residual = nn.Sequential(
            nn.Conv2d(input_c, nf, kernel_size=1, stride=1),
            nn.BatchNorm2d(nf)
        )

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            x_clone = x.clone()
            return torch.cat([self.residual(x_clone), self.model(x)], 1)


class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super(Unet, self).__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        return torch.tanh(self.model(x)) * 0.5 + 0.5


##1.4- Discriminator (with Dilated Convolutions)
Here is the description of how we constructed the discriminator:

ConvLayer: The ConvLayer module is used to extract features from the input data. It performs a convolutional operation, followed by instance normalization and a leaky ReLU activation function. The convolution operation captures spatial patterns and important image features. Instance normalization improves the stability and performance of the model by normalizing the features within each instance. The leaky ReLU activation function introduces non-linearity, allowing the discriminator to learn complex and discriminative representations of real and fake images.

FinalConvLayer: The FinalConvLayer module is the last convolutional layer in the PatchGAN discriminator. Unlike the ConvLayer, it does not use normalization or an activation function. This layer is responsible for producing the final classification score for determining whether an input image is real or fake. The absence of normalization and activation in this layer ensures that the discriminator's output is not constrained by these operations and can directly influence the classification decision.

Dilated Convolutions: Dilated convolutions are used in the PatchGAN discriminator to allow the network to access a wider field of view or larger contextual information without significantly increasing the number of parameters or computational load. By using dilated convolutions, the model can incorporate larger-scale information, which can be particularly beneficial for tasks like image colorization. In image colorization, understanding colors in a larger region can help inform the appropriate colors in a specific location. Additionally, dilated convolutions help preserve the resolution of the images and avoid the loss of fine details that can occur with pooling or strided convolutions. This is crucial for generating high-quality colorized images.

Stacked Architecture: The PatchGAN discriminator employs a stacked architecture by sequentially stacking ConvLayer blocks. Each ConvLayer block increases the number of filters and reduces the spatial dimensions of the feature maps. This architecture enables the discriminator to learn hierarchical representations of the input images, capturing both local and global features. The stacking of ConvLayer blocks allows the discriminator to perform multi-scale analysis, capturing discriminative information at different levels of abstraction.

In summary, the ConvLayer and FinalConvLayer modules are used to extract features and produce classification scores, respectively, in the PatchGAN discriminator. Dilated convolutions are employed to incorporate larger-scale contextual information and preserve image resolution. The stacked architecture of ConvLayer blocks enables the discriminator to learn hierarchical representations and perform multi-scale analysis of the input images. Together, these features and layers contribute to the PatchGAN discriminator's ability to effectively classify real and fake images for tasks such as image colorization.

References:

- Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). Image-to-Image Translation with Conditional Adversarial Networks. https://arxiv.org/pdf/1611.07004.pdf

- Yu, F., & Koltun, V. (2016). Multi-Scale Context Aggregation by Dilated Convolutions. https://arxiv.org/pdf/1511.07122.pdf


In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2, dilation_rate=1):
        super(ConvLayer, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False, dilation=dilation_rate),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class FinalConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation_rate=1):
        super(FinalConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=True, dilation=dilation_rate)

    def forward(self, x):
        return self.conv(x)


class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3, dilation_rate=1):
        super(PatchDiscriminator, self).__init__()
        model = [ConvLayer(input_c, num_filters, stride=2, dilation_rate=dilation_rate)]
        model += [ConvLayer(num_filters * 2 ** i, num_filters * 2 ** (i + 1),
                            stride=1 if i == (n_down-1) else 2,
                            dilation_rate=dilation_rate)
                  for i in range(n_down)]
        model += [FinalConvLayer(num_filters * 2 ** n_down, 1, stride=1, dilation_rate=dilation_rate)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [None]:
PatchDiscriminator(3)

In [None]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

##1.5- GAN Loss (Addressing Color Imbalance in Image Colorization with Weighted Loss Function)

The provided class, GANLossWeighted, serves as a convenient tool for calculating the GAN loss in the final model. During initialization, the desired type of loss is determined (e.g., "vanilla" in this case), and constant tensors are registered as the "real" and "fake" labels.

When the module is called, it generates an appropriate tensor filled with zeros or ones based on the current stage (real or fake) and computes the loss using the selected loss function.

In the context of image colorization tasks, the option to replace the standard cross-entropy loss with the proposed weighted loss in the GAN is significant. Richard Zhang et al., in their paper "Colorful Image Colorization," highlight the problem of color imbalance in traditional approaches. The standard loss calculation often favors desaturated or greyish colors due to the data distribution, while colorful colors are less favored as they are rarer.

To address this issue, the proposed weighted loss incorporates class weights based on the prior probabilities of the classes. These weights, derived from the inverse log-probability of the class distribution, are retrieved from the 'prior_probs.npy' file. By assigning more importance to underrepresented (more colorful) classes, the weighted loss helps rebalance the color bias and encourages the model to predict diverse and vibrant colors. As a result, the aesthetic quality of the generated images is improved.

 Reference:
- Zhang, R., Isola, P., Efros, A. A. (2016). Colorful Image Colorization- https://arxiv.org/pdf/1603.08511.pdf

In [None]:
#enhanced loss
class GANLossWeighted(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0, num_classes=313, rebalance_factor=0.5):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        self.num_classes = num_classes
        self.rebalance_factor = rebalance_factor

        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss(reduction='none')  
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss(reduction='none') 

        class_distribution = np.load('prior_probs.npy')
        class_weights = 1 / (np.log(self.rebalance_factor + class_distribution))
        self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float))

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real, use_class_weights=False):
        labels = self.get_labels(preds, target_is_real)
        losses = self.loss(preds, labels)

        if use_class_weights:
            weights = self.class_weights[labels.long()]  
            weighted_losses = losses * weights  
            return weighted_losses.mean()  
        else:
            return losses.mean()  



In [None]:
#stanbdard loss

class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

We are going to initialize the weights of our model with a mean of 0.0 and standard deviation of 0.02 which are the proposed hyperparameters in the article

In [None]:
def init_weights(net, init='norm', gain=0.02, print_message=True):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)

        if 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    if print_message:
        print(f"model initialized with {init} initialization")
    return net

def init_model(model, device, init_type='norm', gain=0.02):
    model = model.to(device)
    model = init_weights(model, init=init_type, gain=gain)
    return model


##1.6- Attention Layer

The provided code implements a Self-Attention module, which is a mechanism that allows neural networks to focus on different spatial locations within an input. The module consists of query, key, and value convolutional layers, which transform the input feature maps into query, key, and value representations using 1x1 convolutions. These representations are then used to calculate an attention map, which determines the importance of each spatial location. The attended value map is obtained by multiplying the value representation with the transposed attention map. This attended value map is then scaled by a learnable parameter called gamma and added to the original input.

Self-Attention is utilized to capture long-range dependencies and enhance the performance of neural networks. By allowing the network to focus on different spatial locations, it can effectively model relationships between distant elements in the input. This is particularly useful in tasks such as image generation and natural language processing, where capturing global dependencies is essential. The Self-Attention module enables the network to attend to relevant regions and selectively incorporate information from those regions, enhancing the model's ability to extract meaningful features and generate high-quality outputs.

References:

- Show, Attend and Tell: Neural Image Caption
Generation with Visual Attention- https://arxiv.org/pdf/1502.03044.pdf 
- Attention Is All You Need- https://arxiv.org/pdf/1706.03762.pdf.

In [None]:

class SelfAttention(nn.Module):
    def __init__(self, in_dim, activation):
        super(SelfAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        # Ensure the output channel of the convolutional layers are at least 1
        out_c = max(1, in_dim // 8)

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=out_c, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=out_c, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        out = self.gamma*out + x
        return out



##1.7- Ensambling the model
The provided `MainModel` class represents the implementation of a complete model for the task of image colorization. The model architecture and training procedures are designed to achieve effective and high-quality colorization results. 

The initialization of the `MainModel` class involves setting up the generator (`net_G`) and discriminator (`net_D`) networks. If a generator is not provided, a default Unet model is initialized. The Unet architecture is well-suited for image-to-image translation tasks and consists of an encoder and decoder with skip connections. The Unet generator takes a grayscale image as input and aims to generate the corresponding colorized image. The discriminator is implemented as a PatchGAN discriminator, which focuses on local image patches to distinguish between real and fake images. 

To facilitate the training process, appropriate learning rates (`lr_G` and `lr_D`), beta coefficients (`beta1` and `beta2`) for the Adam optimizer, and a coefficient for the L1 loss (`lambda_L1`) are defined. The Adam optimizer is commonly used in deep learning models for efficient parameter updates. The L1 loss is employed to ensure that the generated colorization is visually similar to the ground truth colorization.

The `forward` method performs the forward pass of the model. It takes the grayscale input image (`L`) and generates the corresponding colorization (`fake_color`) using the generator network (`net_G`). The generated colorization is further enhanced using a self-attention mechanism, which allows the model to focus on important spatial locations during colorization.

During the training process, the `backward_D` method trains the discriminator by calculating the GAN loss for both real and fake images. The fake images are obtained by concatenating the grayscale input image (`L`) with the generated colorization (`fake_color`). The real images consist of the grayscale input image (`L`) and the ground truth colorization (`ab`). The discriminator is optimized to distinguish between real and fake images, and the gradients are backpropagated to update the discriminator's parameters.

The `backward_G` method trains the generator by calculating the adversarial loss and the L1 loss between the generated colorization and the ground truth colorization. The adversarial loss encourages the generator to produce colorizations that can deceive the discriminator into classifying them as real. The L1 loss measures the pixel-wise difference between the generated colorization and the ground truth, ensuring the generated colorization is visually close to the ground truth. These two losses are combined, with the L1 loss scaled by the `lambda_L1` coefficient, and the gradients are backpropagated to update the generator's parameters.

The `optimize` method orchestrates the overall optimization process by calling the forward pass, training the discriminator (`net_D`), training the generator (`net_G`), and updating the optimizer parameters. This iterative training procedure allows the model to learn colorization patterns and adversarial strategies, leading to improved colorization performance.

The construction of the model and the training procedures aim to leverage the power of deep learning and adversarial training to generate high-quality colorized images from grayscale inputs. The Unet generator, PatchGAN discriminator, self-attention mechanism, and the utilization of GAN and L1 losses collectively contribute to the model's ability to produce visually pleasing and realistic colorizations.

References:
- Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). Image-to-Image Translation with Conditional Adversarial Networks. https://arxiv.org/pdf/1611.07004.pdf

- Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2014). Generative Adversarial Nets. -https://papers.nips.cc/paper_files/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf




In [None]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
        self.self_attention = SelfAttention(in_dim=2, activation=nn.ReLU())

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)
        self.fake_color = self.self_attention(self.fake_color)  
        self.fake_color = torch.tanh(self.fake_color) * 0.5 + 0.5  

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()


The MeasureClass class is a utility class for computing and tracking the average of a value over multiple iterations. It keeps track of the count, sum, and average of the values. The reset method resets the meter, while the update method updates the meter with a new value and count.

The create_loss_meters function creates a dictionary of AMeasureClass objects to track different loss values during training. These meters include loss_D_fake, loss_D_real, loss_D, loss_G_GAN, loss_G_L1, and loss_G.

The update_losses function updates the loss meters with the corresponding loss values from the model. It takes the model, loss meter dictionary, count of iterations, a SummaryWriter for TensorBoard logging, and the current step as inputs. It iterates through the loss meter dictionary, retrieves the corresponding loss value from the model, updates the corresponding loss meter, and logs the average loss value to TensorBoard using the writer.

The lab_to_rgb function converts a batch of Lab* images to RGB images. It takes the L channel (L) and ab channels (ab) as inputs. The function performs the necessary scaling and conversion operations to transform the Lab* images to RGB images.

The visualize function visualizes the colorization results of the model. It sets the generator network (net_G) to evaluation mode and generates colorized images from the provided data. It then sets the generator network back to training mode. The function converts the colorized images and the ground truth color images to RGB format using the lab_to_rgb function and visualizes them using matplotlib. The resulting plot shows the grayscale input images, generated colorized images, and ground truth color images side by side.

The log_results function logs the average loss values to the console and writes them to TensorBoard using the provided writer. It iterates through the loss meter dictionary, retrieves the average loss value from each meter, prints it to the console, and logs it to TensorBoard.

These utility functions play a crucial role in tracking and visualizing the progress of the colorization model during training, allowing for efficient monitoring and analysis of the model's performance.

In [None]:
from torch.utils.tensorboard import SummaryWriter

class MeasureClass:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = MeasureClass()
    loss_D_real = MeasureClass()
    loss_D = MeasureClass()
    loss_G_GAN = MeasureClass()
    loss_G_L1 = MeasureClass()
    loss_G = MeasureClass()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count, writer, step):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)
        writer.add_scalar(loss_name, loss_meter.avg, step)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict, step, writer):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")
        writer.add_scalar(loss_name, loss_meter.avg, step)

##1.8- Training function


The provided code defines the train_model function, which is responsible for training the MainModel using the specified dataloader (train_dl) for a specified number of epochs. It also includes additional functionalities for logging and visualization.

Within the function, a validation data sample is obtained from the validation dataloader (val_dl). A SummaryWriter is created to log the training progress to TensorBoard. The variable step is initialized to keep track of the current step in the training process.

The function then enters a loop over the number of epochs. Within each epoch, a new dictionary of loss meters is created using the create_loss_meters function. The variable i is used to keep track of the iteration count.

Next, the function iterates through the training dataloader (train_dl) using the tqdm progress bar to provide a visual indication of the training progress. For each batch of data, the model's input is set up, and the optimize method is called to perform optimization and update the model's parameters. The update_losses function is used to update the loss meters with the corresponding loss values from the model, and the losses are logged using the SummaryWriter. The i and step variables are incremented accordingly.

At regular intervals determined by the display_every parameter, the function prints the current epoch and iteration information, calls the log_results function to log the average loss values to the console and TensorBoard, and invokes the visualize function to visualize the colorization results of the model using the validation data.

After each epoch, the model's state dictionary is saved to a file named with the format "model_epoch{epoch}_fulldataset.pt" using torch.save.

Once all the epochs are completed, the SummaryWriter is closed, concluding the training process.

The last section of the code creates an instance of the MainModel, named model, and calls the train_model function to train the model using the provided dataloader (train_dl) for 100 epochs.

These combined functionalities allow for the training of the MainModel, logging of loss values, visualization of results, and saving of model checkpoints at the end of each epoch.

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
def train_model(model, train_dl, epochs, display_every=200):
    val_data = next(iter(val_dl))
    writer = SummaryWriter()  
    step = 0  

    for epoch in range(1, epochs + 1):
        loss_meter_dict = create_loss_meters()
        i = 0

        for data in tqdm(train_dl):
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0), writer=writer, step=step)
            i += 1
            step += 1

            if i % display_every == 0:
                print(f"\nEpoch {epoch}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict, step, writer)
                visualize(model, val_data, save=False)

        epoch_filename = f"model_epoch{epoch}_fulldataset.pt"
        torch.save(model.state_dict(), epoch_filename)
        print(f"Saved weights of epoch {epoch} to {epoch_filename}")

    writer.close() 

model = MainModel()
train_model(model, train_dl, 100)


#2-Transfer Learning

##2.1- Using a new generator
We'll use fastai library's Dynamic U-Net module to easily build a U-Net with a ResNet backbone

In [None]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [None]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18(pretrained=True), pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

##2.2 Pretraining the generator for colorization task

In [None]:
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for e in range(epochs):
        loss_meter = MeasureClass()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()        
pretrain_generator(net_G, train_dl, opt, criterion, 5)

In [None]:
torch.save(net_G.state_dict(), "res18-unet.pt")


##2.3- Final training

In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(f"res18_unet.pt", map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load("model_epoch20.pt", map_location=device))
train_model(model, train_dl, 10)

# 3- Test the model

In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load("final_weights.pt", map_location=device))

In [None]:
import numpy as np
import cv2
from skimage import color

def colorize(gray_image_path, model):
    gray_image= cv2.imread(gray_image_path)
    lab_image = color.rgb2lab(gray_image)
    L = lab_image[:,:,0]
  
    L = (L - 50.) / 100.
    L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
    
    model.eval()
    with torch.no_grad():
        L = L.to(model.device)
        ab = model.net_G(L).cpu().numpy()
    
    ab = ab.transpose((0, 2, 3, 1))
    ab = ab[0] * 128.
    L = (L.cpu().numpy()[0] * 100.) + 50.
    L = L[0]
    
    lab_image = np.concatenate((L[:,:,np.newaxis], ab), axis=2)
    colorized_image = color.lab2rgb(lab_image)
    
    return colorized_image

colorized_image= colorize("img.jpg", model)
plt.imshow(colorized_image)
plt.show()

colorized_image_bgr = cv2.cvtColor((colorized_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
cv2.imwrite('img.jpg', colorized_image_bgr)