#                  Coloring black and white images using ML algorithms

### Team Members and Contributions

Our project was a collaborative effort, with each team member playing a crucial role in its successful completion. Here’s a breakdown of our contributions:

- **Nukushev Daniyar**  
  Focused on **data preparation**, including collecting and preprocessing images, converting them to grayscale, and normalizing pixel values for the models. Daniyar also ensured the dataset was clean and ready for training.

- **Abishev Rauan**  
  Specialized in **model development**, designing the architecture of the U-Net and GAN models. Rauan worked on fine-tuning hyperparameters and implementing the training pipeline to achieve optimal performance.

- **Momynkul Nurzhan**  
  Handled **evaluation and visualization**. Nurzhan analyzed the model outputs using quantitative metrics and visual quality checks, creating insightful visualizations for comparison between the models.

- **Melissov Nurdaulet**  
  Oversaw **project management and integration**, ensuring seamless collaboration between team members. Nurdaulet also contributed to report writing, presentation design, and deploying the prototype for demonstration.


This notebook explains the steps taken to develop a machine learning model that colorizes black and white images. Our approach uses U-Net and GAN models for this task. The project was completed by a team of four members, each contributing to different aspects like data preparation, model development, and evaluation. I'll assume you have basic knowledge about deep learning, GANs, and PyTorch library for the rest of the article. Let's begin!

## Problem Statement 

The primary goal is to restore colors in grayscale images using advanced machine learning techniques. This process is essential for tasks like reviving old memories through photos or adding color to creative works in a more automated and efficient manner.

### Why Lab Color Space?

Images are often represented in the RGB color space, where each pixel has values for red, green, and blue channels. However, for this project, We chose the Lab color space. In this space:

The L channel represents lightness (grayscale intensity).
The a and b channels represent color components (green-red and yellow-blue, respectively).

![rgb image](./files/rgb.jpg) **RGB**

Using Lab simplifies the task because the model can focus on predicting the two color channels (*a and *b) while taking the L channel (grayscale) as input. This approach reduces complexity compared to working directly with RGB values.

![lab image](./files/lab.jpg) **L*A*B**

## Solution Approach



Two types of losses were utilized to guide the model's learning:

1. **L1 Loss**  
   - A regression loss function used to minimize the difference between the predicted and actual colors in the image. This helps the model generate colorized images that are closer to the ground truth.

2. **Adversarial Loss (GAN Loss)**  
   - Used in the Generative Adversarial Network (GAN) to ensure the generated images appear realistic. The discriminator network classifies the images as "real" or "fake," and the adversarial loss helps improve the quality of generated images by encouraging the generator to produce more realistic outputs.
   
   
The generator in the GAN architecture predicts the color channels, while the discriminator evaluates their authenticity by comparing them against real images.


### A Deeper Dive into GANs

For this project, I designed a **conditional GAN**, combining it with an additional loss function, the **L1 loss**, to achieve high-quality results. Let’s break it down.

#### GAN Components

In a **Generative Adversarial Network (GAN)**, there are two primary components:

1. **The Generator**  
   The generator model produces data. In this case, it takes a grayscale (1-channel) image and generates a 2-channel output, corresponding to the *a* and *b* channels of the Lab color space.

2. **The Discriminator**  
   The discriminator model evaluates the authenticity of the generator's output. It takes the 2-channel output from the generator, concatenates it with the input grayscale image to form a 3-channel image, and determines whether the image is "real" or "fake." The discriminator is trained using real 3-channel images from the dataset for comparison.

#### Conditioning the GAN

The conditioning in this GAN refers to the grayscale input image, which is provided to both the generator and discriminator. This ensures that both models are conditioned on the same input, allowing the generator to learn context-specific colorization.

#### Mathematical Representation

Let:

- **𝑥** be the grayscale image (condition).
- **𝑧** represent input noise (if applicable) for the generator.
- **𝑦** denote the desired 2-channel output from the generator (*a* and *b* channels of a real image).
- **𝐺** and **𝐷** represent the generator and discriminator models, respectively.

The GAN loss function ensures the generator produces outputs that the discriminator cannot distinguish from real images. The generator’s task is to minimize this loss, while the discriminator tries to maximize it. Additionally, the **L1 loss** helps align the predicted colors with the actual colors, ensuring both realism and accuracy.

#### Incorporating Noise

Instead of feeding random noise vectors (**𝑧**) directly to the generator, I introduced noise through dropout layers within the generator’s architecture. This method injects variability during training, improving robustness and ensuring better generalization in the generated outputs.


## Loss Function Optimization

To achieve visually appealing and realistic colorizations, I designed a combined loss function for the model. The initial GAN loss ensures the generator produces outputs that the discriminator considers realistic. However, to introduce more supervision and guide the model toward accurate color predictions, I incorporated **L1 loss** (also known as mean absolute error).

### Role of L1 Loss

L1 loss measures the difference between the predicted and actual color channels. While effective, using L1 loss alone leads to overly conservative results—commonly gray or brown tones. This occurs because the model minimizes L1 loss by averaging colors when unsure, resulting in less vibrant outputs. Compared to L2 loss (mean squared error), L1 loss is better at reducing this "grayish" effect and produces more defined colors.
![Alt Text](l1_loss.jpg)

### Combined Loss Function

The final loss function combines:
1. **Adversarial Loss (GAN Loss)**: Encourages realistic outputs by fooling the discriminator.
2. **L1 Loss**: Ensures the predicted colors closely match the ground truth.

The combined loss function is defined as:

![Alt Text](loss.jpg)


Here, \(\lambda\) is a balancing factor that determines the relative importance of the two loss terms. This approach allows the model to generate visually realistic outputs while maintaining high color accuracy.

---

With the theory in place, the next step is implementation. I'll start with the baseline method and then introduce refinements to achieve significantly improved results in just a few hours of training on a smaller dataset.


## 1 - Implementing the project


## 1.1 Loading Image Paths

For this project,We are using only 8,000 images from the COCO dataset that I had available on my device. This makes the training set size a fraction of what could be used in larger-scale projects.

You can use almost any dataset for this task, as long as it contains a variety of scenes and locations, which will help the model learn how to colorize images effectively. For example, you could use ImageNet, but for this project, only 8,000 images would be needed.


In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None

## 1.1.x Preparing Colab for Running the Code

If you are running this on Google Colab, you can uncomment and execute the following code to install the `fastai` library. Most of the code in this project uses pure PyTorch, but we need `fastai` for downloading part of the COCO dataset and for one other step in the second section.

Also, make sure to set your Colab runtime to GPU to speed up the training process.


In [None]:
#!pip install fastai==2.4

The following will download about 20,000 images from COCO dataset. Notice that **we are going to use only 8000 of them** for training. Also you can use any other dataset like ImageNet as long as it contains various scenes and locations.

In [None]:
# from fastai.data.external import untar_data, URLs
# coco_path = untar_data(URLs.COCO_SAMPLE)
# coco_path = str(coco_path) + "/train_sample"
# use_colab = True

In [None]:
if use_colab == True:
    path = coco_path
else:
    path = "Your path to the dataset"
    
paths = glob.glob(path + "/*.jpg") # Grabbing all the image file names
np.random.seed(123)
paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(10_000)
train_idxs = rand_idxs[:8000] # choosing the first 8000 as training set
val_idxs = rand_idxs[8000:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))

In [None]:
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

Although we are using the same dataset and number of training samples, the exact 8000 images that you train your model on may vary (although we are seeding!) because the dataset here has only 20000 images with different ordering while we sampled 10000 images from the complete dataset.

### 1.2- Making Datasets and DataLoaders

We hope the code is self-explanatory. The steps include resizing the images and applying horizontal flipping (only for the training set). Afterward, we read each image in RGB format, convert it to the Lab color space, and separate the first channel (grayscale) and the color channels. The grayscale channel becomes the input, and the color channels are used as the targets for the model. Finally, we create the data loaders for training and validation.


In [None]:
SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [None]:
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

## 1.3 Generator

The generator model is based on U-Net, which is a bit complicated and requires explanation. The code implements a U-Net architecture to be used as the generator of our GAN. While the detailed workings of the code are beyond the scope here, the key idea is that the U-Net starts with the middle part of the architecture (the "bottleneck" of the U-shape) and progressively adds down-sampling and up-sampling modules on the left and right sides, respectively. This process continues until it reaches the input and output modules.

To give a clearer understanding of what is happening in the code, consider the following illustration:

![U-Net Diagram](unet)

The blue rectangles in the diagram show the order in which the related modules are built with the code. While the U-Net we are implementing has more layers than shown here, this diagram gives a good sense of the structure. Notice that in the code, we go down 8 layers. So, if we start with a 256x256 image, by the time it reaches the middle of the U-Net, the image is reduced to a 1x1 (256 / 2⁸) image, and then it is up-sampled back to a 256x256 image with two color channels. 

This code snippet is quite fascinating, and we highly recommend experimenting with it to fully understand how each part functions.


In [None]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)

## 1.4 Discriminator

The architecture of our discriminator is quite straightforward. The code implements a model by stacking blocks of Conv-BatchNorm-LeakyReLU to determine whether the input image is real or fake. It is important to note that the first and last blocks do not use normalization, and the last block does not have an activation function. This is because the activation is embedded in the loss function that we will use.


In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

Let's take a look at its blocks:

In [None]:
PatchDiscriminator(3)

And its output shape:

In [None]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

We are using a "Patch" Discriminator here. But what does that mean? In a vanilla discriminator, the model outputs a single number (a scalar) that indicates how real or fake the entire input image is. However, in a patch discriminator, the model outputs one number for every patch of the image (e.g., 70x70 pixels) and decides whether each patch is real or fake individually.

Using a patch discriminator for the task of colorization seems reasonable because the model needs to make local changes, and considering the entire image at once might overlook important subtle details. Instead, by evaluating smaller patches, the model can focus on more localized information, which is crucial for accurate colorization.

In this setup, the output shape of the model is 30x30, but this doesn’t mean the patches are 30x30. The actual patch size is determined by calculating the receptive field of each of the 900 (30x30) output values, which, in this case, corresponds to a patch size of 70x70 pixels.


## 1.5 GAN Loss

The `GANLoss` class is used to calculate the GAN loss for the model. During initialization, the type of loss to be used (e.g., "vanilla" for Binary Cross-Entropy loss) is specified, and constant tensors


In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

### 1.x Model Initialization

Model initialization is a critical step in ensuring stable and effective training. The logic for initializing the model involves setting the weights with a mean of 0.0 and a standard deviation of 0.02, which are commonly used hyperparameters in deep learning projects:

In [None]:
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

## Putting Everything Together

This class integrates all the components and implements methods to handle the training of the complete model.

### Key Steps:

1. **Initialization**:  
   The generator and discriminator are defined using previously implemented functions and initialized using the `init_model` function. Additionally, two loss functions (adversarial loss and L1 loss) and optimizers for both the generator and discriminator are defined.

2. **Optimize Method**:  
   The training process is managed within the `optimize` method:
   - **Forward Pass**: The generator produces the fake colorization output, which is stored in the `fake_color` variable.
   - **Training the Discriminator**:  
     - Fake images from the generator are passed to the discriminator, detached from the generator's computation graph, and labeled as "fake."
     - Real images from the training set are labeled as "real."
     - Both losses (for real and fake inputs) are computed, averaged, and used to update the discriminator weights.
   - **Training the Generator**:  
     - Fake images are passed to the discriminator, and the generator tries to fool the discriminator by assigning them "real" labels.
     - The adversarial loss is combined with the L1 loss, which calculates the difference between the predicted and true color channels. The L1 loss is scaled by a coefficient (100 in this case) to balance the two losses.
     - The combined loss is used to update the generator weights.


In [None]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

### 1.xx Utility functions

These functions were nor included in the explanations of the TDS article. These are just some utility functions to log the losses of our network and also visualize the results during training. So here you can check them out:

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

### 1.7- Training function

We believe this code is self-explanatory. Each epoch takes approximately 4 minutes on a moderately powerful GPU, such as the Nvidia P5000. If a more powerful GPU, like the 1080Ti or higher, is used, the training


In [None]:
def train_model(model, train_dl, epochs, display_every=200):
    data = next(iter(val_dl)) # getting a batch for visualizing the model output after fixed intrvals
    for e in range(epochs):
        loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to 
        i = 0                                  # log the losses of the complete network
        for data in tqdm(train_dl):
            model.setup_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0)) # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict) # function to print out the losses
                visualize(model, data, save=False) # function displaying the model's outputs

model = MainModel()
train_model(model, train_dl, 100)

Each epoch takes approximately 3 to 4 minutes on Colab. After around 20 epochs, some reasonable results should begin to appear.

To further evaluate the model, it was trained for a longer duration (about 100 epochs). Below are the results of the baseline model:

**Baseline Results:**

The baseline model demonstrates a basic understanding of common objects in images, such as the sky and trees. However, its output remains far from satisfactory:
- It struggles to decide on the color of less common or complex objects.
- The outputs exhibit noticeable color spillovers and circular color artifacts (e.g., the center of the first image in the second row), which detract from the overall quality.

These limitations suggest that with this small dataset, the current strategy is insufficient for achieving high-quality results. As a result, a different approach is required to improve performance.


## Final Result

Below are the results of the trained model, showcasing the **before** (grayscale input) and **after** (colorized output) images. The model has successfully added color to the grayscale images, demonstrating its ability to learn and apply colorization effectively.

### Before and After

#### Before:
<img src="sample.jpg" alt="Grayscale Image 1" width="256" height="176"/>  
<img src="sample_1.jpg" alt="Grayscale Image 2" width="256" height="176"/>  
<img src="sample_2.jpg" alt="Grayscale Image 3" width="256" height="176"/>

#### After:
<img src="colorized_sample.png" alt="Colorized Image 1" width="256" height="176"/>  
<img src="colorized_sample_1.png" alt="Colorized Image 2" width="256" height="176"/>  
<img src="colorized_sample_2.png" alt="Colorized Image 3" width="256" height="176"/>




---

These results highlight the model's potential for enhancing grayscale images by adding realistic colors. While some improvements could be made for rare or complex objects, the overall performance is promising.
