In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage import io, transform, color
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.models as models
from torchmetrics.image import StructuralSimilarityIndexMeasure

np.random.seed(1)

In [None]:
ImagePath_flower = "flower/"
ImagePath_face = "aniF/"
ImagePath_landscape = "landscape/"

In [None]:
HEIGHT = 256
WIDTH = 256

class ColorizationDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.image_list = [i for i in os.listdir(path) if int(i.split('.')[0]) < 1500]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]

        # original image
        img_rgb_orig = io.imread(os.path.join(self.path, image_name))

        # Resize
        img_rgb_res = transform.resize(img_rgb_orig, (HEIGHT, WIDTH), anti_aliasing=True)

        # Convert to Lab color space
        img_lab_res = color.rgb2lab(img_rgb_res)

        # Separate the L and ab channels
        l_res = img_lab_res[:, :, 0]
        ab_res = img_lab_res[:, :, 1:]

        # Normalize L and ab values
        # l_orig = l_orig / 100.0  # Scale L channel between 0 and 1
        l_res = l_res / 100.0  # Scale L channel between 0 and 1
        ab_res = ab_res / 128.0  # Scale ab channel between -1 and 1

        # Convert to torch tensors
        l_res_tensor = torch.tensor(l_res, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        ab_res_tensor = torch.tensor(ab_res, dtype=torch.float32).permute(2, 0, 1)  # Change shape to (C, H, W)

        # sample = {'L_orig': l_orig_tensor, 'L_res': l_res_tensor, 'ab': ab_res_tensor}
        sample = {'L_res': l_res_tensor, 'ab': ab_res_tensor}

        return sample

In [None]:
class ColorizationResNet(nn.Module):
    def __init__(self):
        super(ColorizationResNet
              , self).__init__()

        # Load pretrained ResNet-18
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # Modify first layer to accept 1-channel input (grayscale L channel)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),  # Change input to 1 channel
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

        # Decoder (More upsampling layers to restore 256x256)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 32x32 → 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 64x64 → 128x128
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),    # 128x128 → 256x256
            nn.ReLU(),
            nn.Conv2d(16, 2, kernel_size=1),  # Output 2-channel (ab)
            nn.Tanh()  # Normalize to [-1, 1]
        )

    def forward(self, input_l):
        x = self.encoder(input_l) 
        x = self.decoder(x)  # Upsample to (2, 256, 256)
        return x


In [None]:
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0)

def train_colorization_model(model, dataset, epochs, batch_size, learning_rate, device='cuda'):

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    l1_loss = nn.L1Loss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    losses = []
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for i, batch in enumerate(data_loader):
            # Get the inputs and labels from the batch
            l_res = batch['L_res'] # L channel (input)
            ab_res = batch['ab']   # ab channels (target)

            optimizer.zero_grad()

            outputs = model(l_res)
            # #run this if use mse
            # loss = mse_loss(outputs, ab_res)
            # #run this if use l1
            # loss = l1_loss(outputs, ab_res) 
            # #run this if use ssim
            loss = 1 - ssim_metric(outputs, ab_res)

            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
        
        avg_loss = total_loss / len(data_loader)
        losses.append(avg_loss)
        print(f"Epoch [{epoch+1}/{epochs}] - Average Loss: {avg_loss:.4f}")
        scheduler.step(avg_loss)
        

    print("Training complete!")
    return model, losses


In [None]:
batch_size = 128
epochs = 60
learning_rate = 0.001

dataset = ColorizationDataset(ImagePath_flower)
ResNet = ColorizationResNet()
model, losses = train_colorization_model(ResNet, dataset, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate)