## 0- Setup and Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda


## 1- Load the pretrained GAN model (with ResUnet generator)

###### define the classes needed for the GAN model (imports between files are messy in google colab so I just copied and pasted the definitions)

In [None]:
def initialize_weights(model, init_type='normal', gain=0.02):
    """
    General weight initialization function for PyTorch models.

    Parameters:
        model (nn.Module): The PyTorch model to initialize.
        init_type (str): Initialization type ('normal', 'xavier', 'kaiming', 'orthogonal').
        gain (float): Gain value for certain initialization methods (e.g., Xavier, Kaiming).

    Returns:
        nn.Module: The model with initialized weights.
    """
    def init_func(m):
        classname = m.__class__.__name__
        # Initialize convolutional and linear layers
        if hasattr(m, 'weight') and any(layer in classname for layer in ['Conv', 'Linear']):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='leaky_relu')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise ValueError(f"Unsupported initialization type: {init_type}")

            # Initialize biases to zero if they exist
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

        # Initialize BatchNorm layers
        elif 'BatchNorm' in classname:
            nn.init.normal_(m.weight.data, mean=1.0, std=gain)
            nn.init.constant_(m.bias.data, 0.0)

    print(f"Initializing model weights with {init_type} initialization")
    model.apply(init_func)
    return model

def initialize_model(model, device):
    model = model.to(device)
    model = initialize_weights(model)
    return model

class GANModel(nn.Module):
    def __init__(self, net_G, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        # Device setup for GPU/CPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1  # Weight for L1 loss to balance with adversarial loss (hyperparameter)

        # Initialize generator and discriminator models
        self.net_G = net_G.to(self.device)
        self.net_D = initialize_model(PatchDiscriminator(input_channels=3, num_downsampling=3, num_filters=64), self.device)

        # GAN loss (BCE loss)
        self.GANcriterion = GANLoss().to(self.device)
        self.L1criterion = nn.L1Loss()  # L1 loss for pixel-level accuracy

        # Optimizers for generator and discriminator
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        """
        Enable or disable gradient computation for a given model.
        This is used to freeze/unfreeze a model during training.
        """
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        """
        Prepare the input data (L-channel and ab channels)
        """
        self.L = data['L'].to(self.device)  # Grayscale input (L-channel)
        self.ab = data['ab'].to(self.device)  # Ground truth color channels (ab)

    def forward(self):
        """
        Forward pass through the generator to create fake color (ab channels).
        """
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        """
        Backward pass for the discriminator:
        - Classifies real images as real.
        - Classifies fake images (from generator) as fake.
        """
        # Combine L-channel and generated color to form a fake image
        fake_image = torch.cat([self.L, self.fake_color], dim=1)

        # Use .detach() to ensure gradients are not calculated for the generator
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)  # Loss for fake images

        # Combine L-channel and real color to form a real image
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)  # Loss for real images

        # Average the losses and compute gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()  # Backpropagate discriminator loss

    def backward_G(self):
        """
        Backward pass for the generator:
        - Fool the discriminator (adversarial loss).
        - Minimize L1 loss between fake color and ground truth color (pixel-wise accuracy).
        """
        # Combine L-channel and generated color to form a fake image
        fake_image = torch.cat([self.L, self.fake_color], dim=1)

        # Evaluate the fake image with the discriminator
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)  # Generator's adversarial loss

        # L1 loss for pixel-level accuracy, scaled by lambda_L1
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1

        # Combine adversarial loss and L1 loss
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()  # Backpropagate generator loss

    def optimize(self):
        """
        Optimization step for both generator and discriminator
        """
        # Forward pass and generate fake images
        self.forward()

        # === Train Discriminator ===
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)  # Enable gradient computation for discriminator
        self.opt_D.zero_grad()
        self.backward_D()  # Compute discriminator loss
        self.opt_D.step()  # Update discriminator parameters

        # === Train Generator ===
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)  # Freeze discriminator during generator update
        self.opt_G.zero_grad()
        self.backward_G()  # Compute generator loss
        self.opt_G.step()  # Update generator parameters


class GANLoss(nn.Module):
    def __init__(self, real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))

        self.loss = nn.BCEWithLogitsLoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds) # tensor ful of 1 or 0s

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

class PatchDiscriminator(nn.Module):
    """
    PatchGAN-like discriminator with a configurable number of layers, filters, and strides.
    Takes an input with `input_channels` channels and outputs a probability map
    indicating real/fake for each patch.
    """
    def __init__(self, input_channels, num_filters=64, num_downsampling=3):
        """
        Args:
            input_channels (int): Number of input channels (e.g., 3 for RGB or 1 for grayscale).
            num_filters (int): Number of filters for the first convolutional layer.
            num_downsampling (int): Number of downsampling steps in the discriminator.
        """
        super().__init__()

        # Initial layer: no normalization for the input layer
        layers = [
            self._conv_block(input_channels, num_filters, normalize=False) # input image
        ]

        # Downsampling layers
        for i in range(num_downsampling):
            in_channels = num_filters * 2**i
            out_channels = num_filters * 2**(i + 1)

            # Use stride=1 for the last downsampling layer
            stride = 1 if i == (num_downsampling - 1) else 2
            layers.append(self._conv_block(in_channels, out_channels, stride=stride))

        # Final layer: outputs a single-channel probability map (logits)
        layers.append(
            self._conv_block(out_channels, 1, normalize=False, activation=False) # no activation because activation is in the loss
        )

        # Combine all layers into a sequential model
        self.model = nn.Sequential(*layers)

    def _conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, normalize=True, activation=True):
        """
        Creates a single convolutional block with optional normalization and activation.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Convolution kernel size (default: 4).
            stride (int): Stride for the convolution (default: 2).
            padding (int): Padding for the convolution (default: 1).
            normalize (bool): Whether to use BatchNorm (default: True).
            activation (bool): Whether to use LeakyReLU activation (default: True).

        Returns:
            nn.Sequential: A sequential block containing the specified layers.
        """
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not normalize)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        if activation:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass for the discriminator.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_channels, height, width).

        Returns:
            torch.Tensor: Output tensor (probability map).
        """
        return self.model(x)

###### actually load the pretrained GAN

In [None]:
def load_entire_pretrained_gan(path, device='cuda'):
    """
    Loads the entire model (generator + discriminator + state) from a .pt file.
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"Pretrained model not found at {path}")
    model = torch.load(path, map_location=device)
    print(f"Loaded pretrained GAN from {path}")
    return model

# Path to your previously-trained entire GAN
PRETRAINED_MODEL_PATH = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer.pt"
pretrained_gan = load_entire_pretrained_gan(PRETRAINED_MODEL_PATH, device=device)
pretrained_gan.eval()

"""
  The loaded object 'pretrained_gan' has:
  pretrained_gan.net_G  (the generator)
  pretrained_gan.net_D  (the discriminator)
  .opt_G, .opt_D, etc.

But the old net_G was for a single channel in (L),
and the old net_D was for 3 channels in (L + ab).
We want to adapt them to user hints:
  - G: 4 channels in (L, ab_hints, hint_mask)
  - D: 6 channels in (L, ab, ab_hints, hint_mask)
We'll copy the pretrained weights for the original channels,
and initialize the new ones with e.g. Xavier normal.
"""

  model = torch.load(path, map_location=device)


Loaded pretrained GAN from /content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer.pt


"\n  The loaded object 'pretrained_gan' has:\n  pretrained_gan.net_G  (the generator)\n  pretrained_gan.net_D  (the discriminator)\n  .opt_G, .opt_D, etc.\n\nBut the old net_G was for a single channel in (L),\nand the old net_D was for 3 channels in (L + ab).\nWe want to adapt them to user hints:\n  - G: 4 channels in (L, ab_hints, hint_mask)\n  - D: 6 channels in (L, ab, ab_hints, hint_mask)\nWe'll copy the pretrained weights for the original channels,\nand initialize the new ones with e.g. Xavier normal.\n"

## Modify pretrained net_G and net_D

In [None]:
class HintAttentionBlock(nn.Module):
    """
    A small module that:
      - Takes as input some feature map F of shape (N, C, H, W)
      - Takes the user hint mask M of shape (N, 1, H, W)  [or possibly (N, 2, H, W) if you also want ab info]
      - Produces an attention map A of shape (N, 1, H, W) in [0..1]
      - Merges F and A to produce a new feature map F_att
    This is just a toy example. You can expand on it with more layers.
    """
    def __init__(self, in_channels=3, hidden_channels=64):
        super().__init__()
        # We'll do a tiny 2-layer CNN that merges the feature map's global average + the mask
        self.conv1 = nn.Conv2d(
            in_channels= in_channels + 1,   # we cat( F, mask ), so total channels = in_channels + 1
            out_channels= hidden_channels,
            kernel_size= 3,
            padding=1
        )
        self.conv2 = nn.Conv2d(
            in_channels= hidden_channels,
            out_channels= 1,  # single-channel attention
            kernel_size= 1
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, F, mask):
        """
        F:   (N,C,H,W)     - feature maps from the UNet encoder or intermediate layer
        mask:(N,1,H,W)     - user hint mask in [0,1]
        Output:
          F_att: (N,C,H,W) - feature maps modulated by attention
        """
        # 1) cat the mask to F => (N, C+1, H, W)
        x = torch.cat([F, mask], dim=1)

        # 2) pass through small conv
        x = nn.functional.relu(self.conv1(x))
        x = self.conv2(x)
        # 3) produce attention map in [0..1]
        A = self.sigmoid(x)  # shape (N,1,H,W)

        # 4) multiply feature maps by A
        F_att = F * A  # (N,C,H,W)
        return F_att


class UserHintAttnGenerator(nn.Module):
    """
    A wrapper around your existing net_G (DynamicUnet with 4 channels in).
    We'll inject one (or more) HintAttentionBlock(s).
    For demonstration, let's place it near the input or after the first conv.
    """
    def __init__(self, base_unet, attn_channels=64):
        super().__init__()
        self.base_unet = base_unet
        # Let's assume we want to intercept the first or second layer of the unet.
        # For simplicity, we'll do it "pre" or "post" the unet.
        # E.g., just do it after we form x = (L, ab_hints, mask).
        self.attn_block = HintAttentionBlock(in_channels=3, hidden_channels=64)

    def forward(self, L, ab_hints, hint_mask):
        """
        L:   (N,1,H,W)
        ab_hints:(N,2,H,W)
        hint_mask:(N,1,H,W)
        We'll cat => x in shape (N,4,H,W), pass it through an attention block,
        then feed to base_unet.
        """
        x = torch.cat([L, ab_hints, hint_mask], dim=1)  # (N,4,H,W)
        # apply attention block
        # We'll treat x as (features=4ch, mask=1ch?), but we only have 4ch total.
        # We do want the mask as a separate param, so let's separate them again.
        # Actually, let's assume the mask is channel index #3:
        # or we can pass them separately:
        features = x[:, :3, :, :]  # (N,3,H,W) - the first 3 channels (L + partial ab?)
        M        = x[:, 3:4, :, :] # (N,1,H,W)
        # If you also want ab_hints to be separate from the mask, you'd do a different slicing.

        # We might want the entire 4ch to be "features" though. Let's do it that way:
        # features = x
        # M = hint_mask
        # Then the attention block sees x + mask. But we already have mask in channel #3.
        # We'll do it like this for clarity:
        F_att = self.attn_block(features, M)  # shape (N,3,H,W)
        # Now cat the attended features with the rest if needed:
        # If we want 4 channels, we might re-cat the mask or ab hints.
        # For demonstration, let's just cat the mask back so unet still sees 4 channels:
        x_att = torch.cat([F_att, M], dim=1)  # (N, 4, H, W)

        # pass to the base unet
        fake_ab = self.base_unet(x_att)
        return fake_ab

In [None]:
def build_res_unet_4_in(n_input=4, n_output=2, size=256):
    backbone = resnet18(pretrained=True)
    # This tells fastai to adapt the first conv from 3 → 4 channels
    body = create_body(backbone, n_in=n_input, cut=-2)
    # Now create a unet with n_output=2
    model = DynamicUnet(body, n_output, (size, size))
    return model

def expand_discriminator_input(net_D):
    """
    net_D was originally 3 channels in -> 1 channel out (patch).
    We want 6 channels in -> 1 out:
      L(1) + ab(2) + ab_h(2) + hint_mask(1) = 6.

    We'll:
      1) find the first Conv2d layer with in_channels=3
      2) replace it with in_channels=6
      3) copy the old weights to the first 3 channels
      4) random init the extra 3 channels
    """
    first_conv = None
    for name, module in net_D.named_modules():
        if isinstance(module, nn.Conv2d) and module.in_channels == 3:
            first_conv = (name, module)
            break
    if not first_conv:
        raise RuntimeError("Could not find a Conv2d with in_channels=3 in net_D. Adapt code if needed.")

    name, old_conv = first_conv

    out_c = old_conv.out_channels
    ksz = old_conv.kernel_size
    stride = old_conv.stride
    padding = old_conv.padding
    dilation = old_conv.dilation
    bias_bool = (old_conv.bias is not None)

    new_conv = nn.Conv2d(6, out_c, ksz, stride, padding, dilation, bias=bias_bool)

    with torch.no_grad():
        # copy old weights => new_conv[:, 0:3]
        new_conv.weight[:, 0:3] = old_conv.weight
        # random init the extra channels
        nn.init.xavier_normal_(new_conv.weight[:, 3:], gain=1.0)
        if bias_bool:
            new_conv.bias.copy_(old_conv.bias)

    _replace_module_by_name(net_D, name, new_conv)


def _replace_module_by_name(root_module, module_name, new_module):
    """
    Recursively replaces a named module within root_module with new_module.
    module_name is something like 'model.0.0' or 'layers.0.0.conv' etc.
    """
    components = module_name.split(".")
    obj = root_module
    for comp in components[:-1]:
        obj = getattr(obj, comp)
    setattr(obj, components[-1], new_module)

In [None]:
# Build a fresh 4-in unet, partially load old net_G weights
old_net_G = pretrained_gan.net_G
old_sd    = old_net_G.state_dict()

model_new = build_res_unet_4_in(n_input=4, n_output=2, size=256).to(device)
new_sd    = model_new.state_dict()

matched_weights = {}
for k, v in old_sd.items():
    if k in new_sd and v.shape == new_sd[k].shape:
        matched_weights[k] = v
new_sd.update(matched_weights)
model_new.load_state_dict(new_sd)

# 2) wrap it in our attention generator
attn_gen = UserHintAttnGenerator(model_new, attn_channels=64).to(device)

# 3) set pretrained_gan.net_G to this new attention generator
pretrained_gan.net_G = attn_gen



In [None]:
# expand net_D from 3->6 in-ch
old_net_D = pretrained_gan.net_D
expand_discriminator_input(old_net_D)
pretrained_gan.net_D = old_net_D

print("Modified the pretrained generator to 4 input channels.")
print("Modified the pretrained discriminator to 6 input channels.")

Modified the pretrained generator to 4 input channels.
Modified the pretrained discriminator to 6 input channels.


## Helper Functions

In [None]:
def lab_to_rgb(L, ab):
    """
    Convert normalized LAB -> RGB
    L in [-1,1] => [0..100]
    ab in [-1,1] => [-110..110]
    Returns (N,H,W,3) in [0,1] or single image if no batch dimension
    """
    single_image = (L.ndim == 3)  # (1,H,W), (2,H,W)
    if single_image:
        L = L.unsqueeze(0)
        ab = ab.unsqueeze(0)

    L_den = (L + 1.) * 50.0
    a_den = ab[:, [0], :, :] * 110.0
    b_den = ab[:, [1], :, :] * 110.0

    Lab = torch.cat([L_den, a_den, b_den], dim=1)  # (N,3,H,W)
    Lab = Lab.permute(0,2,3,1).cpu().numpy()

    out_list = []
    for i in range(Lab.shape[0]):
        out_list.append(lab2rgb(Lab[i].astype(np.float64)))
    out_array = np.stack(out_list, axis=0)  # (N,H,W,3)
    if single_image:
        out_array = out_array[0]
    return out_array

def single_lab_to_rgb(L_val, A_val, B_val):
    """
    Convert a single pixel (L_val, A_val, B_val) from normalized lab -> RGB
    Returns np array [r,g,b] in [0..1].
    """
    L_den = (L_val + 1.0) * 50.
    A_den = A_val * 110.
    B_den = B_val * 110.
    patch = np.array([[[L_den, A_den, B_den]]], dtype=np.float64)
    rgb_01 = lab2rgb(patch)  # shape (1,1,3) in [0..1]
    return rgb_01[0,0]  # shape (3,)

def visualize_colorization(model, batch, max_images=5, hint_radius=0):
    """
    Show grayscale + hints in their actual color, predicted color, ground truth.

    - We'll "brighten" grayscale by denormalizing from [-1,1] to [0,1].
    - We'll place a small square or circle with the actual color of the hint pixel.
    - If hint_radius>0, we show a slightly bigger area for the hint color.
    """
    model.eval()
    with torch.no_grad():
        model.setup_input(batch)
        model.forward()
        fake_ab = model.fake_color.detach().cpu()  # (N,2,H,W)

    L      = batch['L'].cpu()      # (N,1,H,W) in [-1,1]
    ab_gt  = batch['ab'].cpu()     # (N,2,H,W)
    ab_h   = batch['ab_hints'].cpu()   # (N,2,H,W)
    mask_h = batch['hint_mask'].cpu()  # (N,1,H,W)

    rgb_fake = lab_to_rgb(L, fake_ab)
    rgb_real = lab_to_rgb(L, ab_gt)

    # Denormalize grayscale to [0..1] for display
    L_01 = (L + 1.) * 0.5  # now in [0..1]
    L_np = L_01[:, 0, :, :].numpy()

    N = L.size(0)
    n_show = min(max_images, N)
    fig, axs = plt.subplots(n_show, 3, figsize=(12, 4*n_show))

    for i in range(n_show):
        # 1) Grayscale as 3-ch in [0..1]
        gray3 = np.stack([L_np[i]]*3, axis=-1)  # (H,W,3)

        # Where mask=1, we place the actual user color from ab_h.
        # If the user wanted a bigger "radius", we do that in the next lines.
        H,W = gray3.shape[:2]
        mh = mask_h[i,0].numpy()  # shape (H,W)
        ab_h_i = ab_h[i].numpy()  # shape (2,H,W)

        # For each pixel where mask=1, we'll convert that ab_h_i to an actual color
        # and place it in gray3. Possibly do a local region around that pixel.
        coords = np.argwhere(mh>0.5)
        for (r,c) in coords:
            # Convert ab_h to rgb for that pixel
            # We also need the L from L_01
            # But note: L_01 is in [0..1], we have to re-convert if we want the same normalization as the net
            # Instead, we can do: L_val in [-1..1], ab in [-1..1].
            # So let's get the 'original' L from batch, not L_01
            L_val = L[i,0,r,c].item()
            A_val = ab_h[i,0,r,c].item()
            B_val = ab_h[i,1,r,c].item()

            # convert single pixel lab -> rgb
            color_3 = single_lab_to_rgb(L_val, A_val, B_val)  # [r,g,b] in [0..1]

            gray3[r, c] = color_3
            # We'll fill a local region of size (2*hint_radius+1)
            # for rr in range(r-hint_radius, r+hint_radius+1):
            #     for cc in range(c-hint_radius, c+hint_radius+1):
            #         if 0 <= rr < H and 0 <= cc < W:
            #             gray3[rr, cc] = color_3


        ax0 = axs[i,0] if n_show>1 else axs[0]
        ax1 = axs[i,1] if n_show>1 else axs[1]
        ax2 = axs[i,2] if n_show>1 else axs[2]

        ax0.imshow(gray3, vmin=0, vmax=1)
        ax0.set_title("Grayscale + Hints")
        ax0.axis("off")

        ax1.imshow(rgb_fake[i])
        ax1.set_title("Predicted")
        ax1.axis("off")

        ax2.imshow(rgb_real[i])
        ax2.set_title("Ground Truth")
        ax2.axis("off")

    plt.tight_layout()
    plt.show()


## Define dataset

In [None]:
# We'll use the same COCO sample but we'll incorporate random n_hints in [1..5].
# This generally helps the model learn to handle variable number of user hints.

IMG_SIZE = 256
SEED = 42
np.random.seed(SEED)

coco_path = untar_data(URLs.COCO_SAMPLE)
all_files = list((coco_path / 'train_sample').glob("*.jpg"))
np.random.shuffle(all_files)

N = min(len(all_files), 10000)
all_files = all_files[:N]

split_idx = int(0.8 * len(all_files))
train_files = all_files[:split_idx]
val_files   = all_files[split_idx:]

print(f"Train set size: {len(train_files)}")
print(f"Val set size:   {len(val_files)}")

Train set size: 8000
Val set size:   2000


In [None]:
class UserHintLABDataset(Dataset):
    """
    - For training set: random hints each time __getitem__ is called
    - For validation set: random hints but consistent each epoch (we store them once).
    """
    def __init__(self, paths, split='train', n_hints_min=1, n_hints_max=5, hint_radius=0):
        self.paths = paths
        self.split = split
        self.n_min = n_hints_min
        self.n_max = n_hints_max
        self.hint_radius = hint_radius

        if split=='train':
            self.transform = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ])
            # For training, we do *not* store hints; we generate them each time
            self.stored_hints = [None]*len(paths)
        else:
            self.transform = transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC)
            # For validation, we store hints the first time we see the item
            self.stored_hints = [None]*len(paths)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        with Image.open(path).convert("RGB") as img:
            img = self.transform(img)

        rgb_np = np.array(img)
        lab_np = rgb2lab(rgb_np).astype(np.float32)
        L_, a_, b_ = lab_np[...,0], lab_np[...,1], lab_np[...,2]
        # normalize
        L_norm = (L_/50.) - 1.
        a_norm = a_/110.
        b_norm = b_/110.

        L_t  = torch.from_numpy(L_norm).unsqueeze(0)   # (1,H,W)
        ab_t = torch.from_numpy(np.stack([a_norm,b_norm],axis=0))  # (2,H,W)

        H,W = L_t.shape[1], L_t.shape[2]
        ab_hints= np.zeros((2,H,W), dtype=np.float32)
        mask    = np.zeros((1,H,W), dtype=np.float32)

        # =============== CREATE OR REUSE HINTS ===============
        if self.split=='train':
            # Generate new random hints each time
            n_hints = np.random.randint(self.n_min, self.n_max+1)
            coords = []
            for _ in range(n_hints):
                rr = np.random.randint(0,H)
                cc = np.random.randint(0,W)
                coords.append((rr, cc))
            # optional: store them if you want to see them again, but we do random each time
        else:
            # Validation
            if self.stored_hints[idx] is None:
                # generate once
                n_hints = np.random.randint(self.n_min, self.n_max+1)
                coords = []
                for _ in range(n_hints):
                    rr = np.random.randint(0,H)
                    cc = np.random.randint(0,W)
                    coords.append((rr, cc))
                self.stored_hints[idx] = coords
            else:
                coords = self.stored_hints[idx]

        # apply them (optionally with a radius effect for "bigger" hints)
        for (rr,cc) in coords:
            for r_ in range(rr-self.hint_radius, rr+self.hint_radius+1):
                for c_ in range(cc-self.hint_radius, cc+self.hint_radius+1):
                    if 0 <= r_ < H and 0 <= c_ < W:
                        ab_hints[0, r_, c_] = a_norm[r_, c_]
                        ab_hints[1, r_, c_] = b_norm[r_, c_]
                        mask[0, r_, c_] = 1.0

        ab_h_t = torch.from_numpy(ab_hints)
        mask_t = torch.from_numpy(mask)

        return {
            'L': L_t,       # (1,H,W)
            'ab': ab_t,     # (2,H,W)
            'ab_hints': ab_h_t,   # (2,H,W)
            'hint_mask': mask_t   # (1,H,W)
        }

In [None]:
train_ds = UserHintLABDataset(train_files, split='train', n_hints_min=1, n_hints_max=5, hint_radius=5)
val_ds   = UserHintLABDataset(val_files,   split='val',   n_hints_min=1, n_hints_max=5, hint_radius=5)

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

## Modify the GAN model (setup_input and forward)

In [20]:
class UserHintGANModel(type(pretrained_gan)):
    """
    We'll dynamically create a subclass so that:
      - setup_input handles L, ab, ab_hints, hint_mask
      - forward uses net_G(x) with x = cat(L, ab_hints, hint_mask)
      - net_D sees cat(L, ab, ab_hints, hint_mask) => 6-ch
    """
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        self.ab_hints = data['ab_hints'].to(self.device)
        self.hint_mask = data['hint_mask'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L, self.ab_hints, self.hint_mask)

    def backward_D(self):
        # real
        real_in = torch.cat([self.L, self.ab, self.ab_hints, self.hint_mask], dim=1)  # (N,6,H,W)
        pred_real = self.net_D(real_in)
        self.loss_D_real = self.GANcriterion(pred_real, True)

        # fake
        fake_in = torch.cat([self.L, self.fake_color, self.ab_hints, self.hint_mask], dim=1)
        pred_fake= self.net_D(fake_in.detach())
        self.loss_D_fake = self.GANcriterion(pred_fake, False)

        self.loss_D = 0.5*(self.loss_D_real + self.loss_D_fake)
        self.loss_D.backward()

    def backward_G(self):
        # adv
        fake_in = torch.cat([self.L, self.fake_color, self.ab_hints, self.hint_mask], dim=1)
        pred_fake= self.net_D(fake_in)
        self.loss_G_GAN = self.GANcriterion(pred_fake, True)

        # L1
        # self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1

        # SmoothL1
        criterion = nn.SmoothL1Loss(beta=1.0)
        self.loss_G_L1 = criterion(self.fake_color, self.ab) * self.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

In [21]:
pretrained_gan.__class__ = UserHintGANModel
pretrained_gan.device = device
pretrained_gan.net_G.to(device)
pretrained_gan.net_D.to(device)

print("Assigned new class to pretrained_gan for user hints. Now we can train it.")

Assigned new class to pretrained_gan for user hints. Now we can train it.


## Training loop

In [None]:
def train_user_guided_gan(
    gan_model, train_dl, val_dl=None, epochs=5, lr_G=2e-4, lr_D=2e-4,
    beta1=0.5, beta2=0.999, save_path=None, vis_interval=3, hint_radius=2
):
    """
    Fine-tune the pretrained GAN with user hints in G and D.
    """
    # Re-setup the optimizers because model arcitecture changes
    gan_model.opt_G = optim.Adam(gan_model.net_G.parameters(), lr=lr_G, betas=(beta1,beta2))
    gan_model.opt_D = optim.Adam(gan_model.net_D.parameters(), lr=lr_D, betas=(beta1,beta2))

    for epoch in range(epochs):
        gan_model.train()
        epoch_loss_D, epoch_loss_G = 0.0, 0.0

        for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{epochs}"):
            gan_model.setup_input(batch)
            gan_model.optimize()

            epoch_loss_D += gan_model.loss_D.item()
            epoch_loss_G += gan_model.loss_G.item()

        epoch_loss_D /= len(train_dl)
        epoch_loss_G /= len(train_dl)
        print(f"[Epoch {epoch+1}/{epochs}] Loss_D: {epoch_loss_D:.4f}, Loss_G: {epoch_loss_G:.4f}")

        if val_dl:
            gan_model.net_G.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_batch in val_dl:
                    gan_model.setup_input(val_batch)
                    gan_model.forward()
                    val_loss += nn.L1Loss()(gan_model.fake_color, gan_model.ab).item()
            val_loss /= len(val_dl)
            print(f"         Val L1: {val_loss:.4f}")

            if (epoch+1) % vis_interval == 0:
                sample_batch = next(iter(val_dl))
                visualize_colorization(gan_model, sample_batch, max_images=3, hint_radius=hint_radius)

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(gan_model, save_path)
        print(f"Model saved to {save_path}")

## Run fine tuning

In [None]:
SAVE_PATH = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/user_guided/user_hint_gan_finetuned_attention.pt"

train_user_guided_gan(
    gan_model=pretrained_gan,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    vis_interval=3,
    save_path=SAVE_PATH,
    hint_radius=5
)

print("Fine-tuning complete!")

Output hidden; open in https://colab.research.google.com to view.

## Test with user images

In [16]:
def test_user_hinted_colorization(gan_model, input_path, output_dir=None, img_size=256,
                                  n_hints=3, hint_radius=2):
    """
    Test the user-hinted colorization model on user-provided images (or a directory of images).
    For demonstration, we create random hints for each image.

    Parameters:
      - gan_model: The user-hinted GAN model (already loaded).
      - input_path: Path to an image or directory containing images.
      - output_dir: Where to save colorized images (optional).
      - img_size: resize for input.
      - n_hints: number of random hint points to place per image.
      - hint_radius: local radius around each hint pixel to fill with the same color.
    """
    gan_model.eval()

    input_path = Path(input_path)
    if input_path.is_file():
        image_paths = [input_path]
    elif input_path.is_dir():
        image_paths = sorted(list(input_path.glob("*.[pjPJ][pnPN][gG]")))
    else:
        raise ValueError(f"Invalid input path: {input_path}")

    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

    # We'll define a transform to produce the L-ab space ourselves
    # since we want to create random hints
    def to_lab_tensors(img_pil):
        # resize
        img_pil = img_pil.resize((img_size, img_size), Image.BICUBIC)
        rgb_np  = np.array(img_pil)
        lab_np  = rgb2lab(rgb_np).astype(np.float32)
        L_, a_, b_ = lab_np[...,0], lab_np[...,1], lab_np[...,2]

        L_norm = (L_/50.) - 1.
        a_norm = a_/110.
        b_norm = b_/110.

        L_t  = torch.from_numpy(L_norm).unsqueeze(0)   # (1,H,W)
        ab_t = torch.from_numpy(np.stack([a_norm,b_norm],axis=0))  # (2,H,W)
        return L_t, ab_t

    for idx, img_path in enumerate(tqdm(image_paths, desc="Processing images")):
        pil_img = Image.open(img_path).convert("RGB")
        W,H = pil_img.size

        L_t, ab_t = to_lab_tensors(pil_img)  # (1,H,W), (2,H,W)
        L_np = L_t.numpy()[0]  # (H,W)
        a_np = ab_t.numpy()[0] # (H,W)
        b_np = ab_t.numpy()[1] # (H,W)

        # create hints
        ab_hints = np.zeros_like(ab_t.numpy())  # shape (2,H,W)
        mask     = np.zeros((1,)+ab_hints.shape[1:], np.float32) # (1,H,W)
        HH,WW = ab_hints.shape[1:]

        # random coords
        for _ in range(n_hints):
            rr = np.random.randint(0,HH)
            cc = np.random.randint(0,WW)
            # fill the region
            for r_ in range(rr-hint_radius, rr+hint_radius+1):
                for c_ in range(cc-hint_radius, cc+hint_radius+1):
                    if 0 <= r_ < HH and 0 <= c_ < WW:
                        ab_hints[0, r_, c_] = a_np[r_, c_]
                        ab_hints[1, r_, c_] = b_np[r_, c_]
                        mask[0, r_, c_]     = 1.

        ab_h_t = torch.from_numpy(ab_hints)
        mask_t = torch.from_numpy(mask)

        data = {
            'L': L_t.unsqueeze(0).to(device),        # (1,1,H,W)
            'ab': ab_t.unsqueeze(0).to(device),      # (1,2,H,W)
            'ab_hints': ab_h_t.unsqueeze(0).to(device),  # (1,2,H,W)
            'hint_mask': mask_t.unsqueeze(0).to(device)  # (1,1,H,W)
        }

        with torch.no_grad():
            gan_model.setup_input(data)
            gan_model.forward()
            fake_ab = gan_model.fake_color.detach().cpu()  # (1,2,H,W)

        # convert to RGB
        rgb_fake = lab_to_rgb(data['L'].cpu(), fake_ab)[0]  # (H,W,3)

        # Visualization or saving
        if idx<5:
            # Show
            plt.figure(figsize=(12,5))

            # 1) Show grayscale + hints
            L_01 = (data['L'][0,0].cpu().numpy()+1.)*0.5  # [0..1]
            gray3 = np.stack([L_01]*3, axis=-1)
            coords = np.argwhere(mask[0]>0.5)
            for (r,c) in coords:
                # get ab from ab_hints
                A_val = ab_hints[0, r, c]
                B_val = ab_hints[1, r, c]
                L_val = data['L'][0,0,r,c].item()
                color_3 = single_lab_to_rgb(L_val, A_val, B_val)
                gray3[r, c] = color_3

            plt.subplot(1,3,1)
            plt.imshow(gray3, vmin=0, vmax=1)
            plt.title("Gray + Hints")
            plt.axis("off")

            # 2) Show predicted
            plt.subplot(1,3,2)
            plt.imshow(rgb_fake)
            plt.title("Predicted")
            plt.axis("off")

            # 3) Show original
            plt.subplot(1,3,3)
            plt.imshow(pil_img)
            plt.title("Original")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

        if output_dir:
            out_dir = Path(output_dir)
            out_dir.mkdir(parents=True, exist_ok=True)
            out_file = out_dir / img_path.name
            plt.imsave(out_file, rgb_fake)
            print(f"Saved colorized image to {out_file}")



In [19]:
input_path = "/content/test_images"
test_user_hinted_colorization(pretrained_gan, input_path, hint_radius=5)

Output hidden; open in https://colab.research.google.com to view.

##### 20 more epochs

In [22]:
SAVE_PATH = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/user_guided/user_hint_gan_finetuned_attention40.pt"

train_user_guided_gan(
    gan_model=pretrained_gan,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    vis_interval=3,
    save_path=SAVE_PATH,
    hint_radius=5
)

print("Fine-tuning complete!")

Output hidden; open in https://colab.research.google.com to view.