In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

## 1- Dataset
We are using a subset of COCO from fastai (~21K images)

In [None]:
IMG_SIZE = 256 # img resolution
SEED = 42

coco_path = untar_data(URLs.COCO_SAMPLE)
image_files = list((coco_path / 'train_sample').glob("*.jpg"))
print(f"Total images found in COCO_SAMPLE: {len(image_files)}")

np.random.seed(SEED)
np.random.shuffle(image_files)

# pick 10,000 images
N = min(len(image_files), 10000) # just in case it downloads less than 10K images, use min
image_files = image_files[:N]  # random (seeded) subset
print(f"Using {len(image_files)} images for demonstration.")

# 80/20 train/val split
split_idx = int(0.8 * len(image_files))
train_files = image_files[:split_idx]
val_files   = image_files[split_idx:]

print(f"Train set size: {len(train_files)}")
print(f"Val set size:   {len(val_files)}")

Total images found in COCO_SAMPLE: 21837
Using 10000 images for demonstration.
Train set size: 8000
Val set size:   2000


## 2- DATASET: On-the-fly LAB

In [6]:
class LABDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # added small data augmentation, maybe add more later
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC) # only resize for validation set, dont do augmentation

        self.size = IMG_SIZE
        self.paths = paths
        self.split = split

    def __getitem__(self, idx):
        path = self.paths[idx]
        with Image.open(path).convert("RGB") as img:
            # Apply transformations
            img = self.transforms(img)
            # Convert PIL -> NumPy
            rgb_np = np.array(img)  # (H,W,3) in [0..255]

        # Convert RGB to LAB using skimage
        img_lab = rgb2lab(img).astype(np.float32)  # Convert to LAB
        L_channel, A_channel, B_channel = img_lab[..., 0], img_lab[..., 1], img_lab[..., 2]

        # Normalize L and AB channels
        L = torch.from_numpy(L_channel / 50.0 - 1.0).unsqueeze(0)  # Normalize L to [-1, 1] and add channel dimension
        ab = torch.from_numpy(np.stack((A_channel / 110.0, B_channel / 110.0), axis=0))  # Normalize AB to [-1, 1] and stack (We used 110 because we don't generally see extreme values like 128 and also some sources say this works better)

        # Output as dictionary
        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

In [7]:
# dataloaders
train_ds = LABDataset(train_files, split='train')
val_ds   = LABDataset(val_files, split='val')

batch_size = 16
num_of_workers = 4
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_of_workers, pin_memory=True) # appearently pin memory can increase data speed from CPU to GPU
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_of_workers, pin_memory=True)

NameError: name 'train_files' is not defined

## Load the Res-U-Net that we trained before

In [None]:
def load_model(load_path, device='cuda'):
    """
    Load a PyTorch model (architecture + weights) from the specified path.

    Parameters:
      - load_path: Path to the saved model file.
      - device: Device to load the model onto ('cuda' or 'cpu').

    Returns:
      - model: The loaded PyTorch model.
    """
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"Model file not found at {load_path}")

    model = torch.load(load_path, map_location=device)
    print(f"Model loaded from {load_path}")
    return model


In [None]:
generator = load_model("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/resUnet/resUnet_colorizer_coco.pt", device=device)

  model = torch.load(load_path, map_location=device)


Model loaded from /content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/resUnet/resUnet_colorizer_coco.pt


## Define patch discriminator

In [8]:
class PatchDiscriminator(nn.Module):
    """
    PatchGAN-like discriminator with a configurable number of layers, filters, and strides.
    Takes an input with `input_channels` channels and outputs a probability map
    indicating real/fake for each patch.
    """
    def __init__(self, input_channels, num_filters=64, num_downsampling=3):
        """
        Args:
            input_channels (int): Number of input channels (e.g., 3 for RGB or 1 for grayscale).
            num_filters (int): Number of filters for the first convolutional layer.
            num_downsampling (int): Number of downsampling steps in the discriminator.
        """
        super().__init__()

        # Initial layer: no normalization for the input layer
        layers = [
            self._conv_block(input_channels, num_filters, normalize=False) # input image
        ]

        # Downsampling layers
        for i in range(num_downsampling):
            in_channels = num_filters * 2**i
            out_channels = num_filters * 2**(i + 1)

            # Use stride=1 for the last downsampling layer
            stride = 1 if i == (num_downsampling - 1) else 2
            layers.append(self._conv_block(in_channels, out_channels, stride=stride))

        # Final layer: outputs a single-channel probability map (logits)
        layers.append(
            self._conv_block(out_channels, 1, normalize=False, activation=False) # no activation because activation is in the loss
        )

        # Combine all layers into a sequential model
        self.model = nn.Sequential(*layers)

    def _conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, normalize=True, activation=True):
        """
        Creates a single convolutional block with optional normalization and activation.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Convolution kernel size (default: 4).
            stride (int): Stride for the convolution (default: 2).
            padding (int): Padding for the convolution (default: 1).
            normalize (bool): Whether to use BatchNorm (default: True).
            activation (bool): Whether to use LeakyReLU activation (default: True).

        Returns:
            nn.Sequential: A sequential block containing the specified layers.
        """
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not normalize)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        if activation:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass for the discriminator.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_channels, height, width).

        Returns:
            torch.Tensor: Output tensor (probability map).
        """
        return self.model(x)

## GAN Loss

In [9]:
class GANLoss(nn.Module):
    def __init__(self, real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))

        self.loss = nn.BCEWithLogitsLoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds) # tensor ful of 1 or 0s

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

### Weight initialization

In [10]:
def initialize_weights(model, init_type='normal', gain=0.02):
    """
    General weight initialization function for PyTorch models.

    Parameters:
        model (nn.Module): The PyTorch model to initialize.
        init_type (str): Initialization type ('normal', 'xavier', 'kaiming', 'orthogonal').
        gain (float): Gain value for certain initialization methods (e.g., Xavier, Kaiming).

    Returns:
        nn.Module: The model with initialized weights.
    """
    def init_func(m):
        classname = m.__class__.__name__
        # Initialize convolutional and linear layers
        if hasattr(m, 'weight') and any(layer in classname for layer in ['Conv', 'Linear']):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='leaky_relu')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise ValueError(f"Unsupported initialization type: {init_type}")

            # Initialize biases to zero if they exist
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

        # Initialize BatchNorm layers
        elif 'BatchNorm' in classname:
            nn.init.normal_(m.weight.data, mean=1.0, std=gain)
            nn.init.constant_(m.bias.data, 0.0)

    print(f"Initializing model weights with {init_type} initialization")
    model.apply(init_func)
    return model

def initialize_model(model, device):
    model = model.to(device)
    model = initialize_weights(model)
    return model

## Main GAN model

In [11]:
class GANModel(nn.Module):
    def __init__(self, net_G, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        # Device setup for GPU/CPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1  # Weight for L1 loss to balance with adversarial loss (hyperparameter)

        # Initialize generator and discriminator models
        self.net_G = net_G.to(self.device)
        self.net_D = initialize_model(PatchDiscriminator(input_channels=3, num_downsampling=3, num_filters=64), self.device)

        # GAN loss (BCE loss)
        self.GANcriterion = GANLoss().to(self.device)
        self.L1criterion = nn.L1Loss()  # L1 loss for pixel-level accuracy

        # Optimizers for generator and discriminator
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        """
        Enable or disable gradient computation for a given model.
        This is used to freeze/unfreeze a model during training.
        """
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        """
        Prepare the input data (L-channel and ab channels)
        """
        self.L = data['L'].to(self.device)  # Grayscale input (L-channel)
        self.ab = data['ab'].to(self.device)  # Ground truth color channels (ab)

    def forward(self):
        """
        Forward pass through the generator to create fake color (ab channels).
        """
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        """
        Backward pass for the discriminator:
        - Classifies real images as real.
        - Classifies fake images (from generator) as fake.
        """
        # Combine L-channel and generated color to form a fake image
        fake_image = torch.cat([self.L, self.fake_color], dim=1)

        # Use .detach() to ensure gradients are not calculated for the generator
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)  # Loss for fake images

        # Combine L-channel and real color to form a real image
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)  # Loss for real images

        # Average the losses and compute gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()  # Backpropagate discriminator loss

    def backward_G(self):
        """
        Backward pass for the generator:
        - Fool the discriminator (adversarial loss).
        - Minimize L1 loss between fake color and ground truth color (pixel-wise accuracy).
        """
        # Combine L-channel and generated color to form a fake image
        fake_image = torch.cat([self.L, self.fake_color], dim=1)

        # Evaluate the fake image with the discriminator
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)  # Generator's adversarial loss

        # L1 loss for pixel-level accuracy, scaled by lambda_L1
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1

        # Combine adversarial loss and L1 loss
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()  # Backpropagate generator loss

    def optimize(self):
        """
        Optimization step for both generator and discriminator
        """
        # Forward pass and generate fake images
        self.forward()

        # === Train Discriminator ===
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)  # Enable gradient computation for discriminator
        self.opt_D.zero_grad()
        self.backward_D()  # Compute discriminator loss
        self.opt_D.step()  # Update discriminator parameters

        # === Train Generator ===
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)  # Freeze discriminator during generator update
        self.opt_G.zero_grad()
        self.backward_G()  # Compute generator loss
        self.opt_G.step()  # Update generator parameters


## Helper Functions

In [3]:
def lab_to_rgb(L, ab):
    """
    Convert normalized LAB channels back to RGB.
    - L in [-1, 1] -> [0, 100]
    - ab in [-1, 1] -> [-128, 127]
    Returns RGB images in [0, 1].
    """
    # Rescale L and AB channels
    L = (L + 1.) * 50.0  # Scale L to [0, 100]
    a = ab[:, [0], :, :] * 110.0  # Scale A back
    b = ab[:, [1], :, :] * 110.0  # Scale B back

    # Concatenate LAB and convert to RGB
    Lab = torch.cat([L, a, b], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = [lab2rgb(img.astype(np.float64)) for img in Lab]  # LAB -> RGB

    return np.stack(rgb_imgs, axis=0)  # Stack into a batch


def visualize_colorization(model, data, max_images=5, save=False):
    """
    Visualizes grayscale input (L), predicted colorization, and ground truth colorization.
    - max_images: Number of images to display (default: 5).
    - save: Whether to save the plot as a PNG file.
    """
    model.eval()
    with torch.no_grad():
        L_in = data['L'].to(device)  # Grayscale input
        ab_gt = data['ab'].to(device)  # Ground truth AB channels
        ab_pred = model(L_in)  # Predicted AB channels

    # Convert LAB to RGB
    rgb_fake = lab_to_rgb(L_in, ab_pred)
    rgb_real = lab_to_rgb(L_in, ab_gt)
    L_np = L_in.cpu().numpy()[:, 0, :, :]  # Extract L channel as NumPy array

    # Plot the results
    plt.figure(figsize=(15, 8))
    for i in range(min(max_images, L_in.size(0))):
        # Grayscale input (L channel)
        ax = plt.subplot(3, max_images, i + 1)
        ax.imshow(L_np[i], cmap='gray', vmin=-1, vmax=1)
        ax.set_title("L Input")
        ax.axis("off")

        # Predicted colorization
        ax = plt.subplot(3, max_images, i + 1 + max_images)
        ax.imshow(rgb_fake[i])
        ax.set_title("Predicted")
        ax.axis("off")

        # Ground truth colorization
        ax = plt.subplot(3, max_images, i + 1 + 2 * max_images)
        ax.imshow(rgb_real[i])
        ax.set_title("Ground Truth")
        ax.axis("off")

    plt.tight_layout()
    plt.show()

    # Save the plot if needed
    if save:
        plt.savefig(f"colorization_{time.time()}.png")


## Training Loop

In [None]:
def train_gan_model(
    gan_model,
    train_dl,
    val_dl=None,
    epochs=20,
    save_path=None,
    vis_interval=3
):
    """
    Train the GAN model for colorization.

    Parameters:
      - gan_model: Instance of GANModel.
      - train_dl: DataLoader for training data.
      - val_dl: DataLoader for validation data (optional).
      - epochs: Number of training epochs.
      - save_path: Path to save the model after training (optional).
      - vis_interval: Interval (in epochs) to visualize predictions.
    """
    for epoch in range(epochs):
        gan_model.train()
        epoch_loss_D, epoch_loss_G = 0.0, 0.0

        for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{epochs}"):
            gan_model.setup_input(batch)
            gan_model.optimize()

            # Accumulate losses
            epoch_loss_D += gan_model.loss_D.item()
            epoch_loss_G += gan_model.loss_G.item()

        # Average losses
        epoch_loss_D /= len(train_dl)
        epoch_loss_G /= len(train_dl)
        print(f"[Epoch {epoch+1}/{epochs}] Loss_D: {epoch_loss_D:.4f}, Loss_G: {epoch_loss_G:.4f}")

        # Optional: Validation
        if val_dl:
            gan_model.net_G.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_batch in val_dl:
                    gan_model.setup_input(val_batch)
                    gan_model.forward()
                    val_loss += nn.L1Loss()(gan_model.fake_color, gan_model.ab).item()
            val_loss /= len(val_dl)
            print(f"[Epoch {epoch+1}/{epochs}] Validation Loss: {val_loss:.4f}")

            # Visualize every `vis_interval` epochs
            if (epoch + 1) % vis_interval == 0:
                print(f"Visualizing predictions at epoch {epoch+1}...")
                visualize_colorization(gan_model.net_G, next(iter(val_dl)), max_images=4)

    # Save the model if `save_path` is provided
    if save_path:
        model_dir = os.path.dirname(save_path)
        os.makedirs(model_dir, exist_ok=True)
        torch.save(gan_model, save_path)
        print(f"Model saved at {save_path}")


#### Save and Load the GAN model

In [4]:
def save_gan_model(gan_model, save_path, save_entire_model=False):
    """
    Save the GAN model.

    Parameters:
      - gan_model: The GAN model to save.
      - save_path: Path to save the model.
      - save_entire_model: If True, saves the entire model (architecture + weights).
    """
    model_dir = os.path.dirname(save_path)
    os.makedirs(model_dir, exist_ok=True)

    if save_entire_model:
        # Save the entire model
        torch.save(gan_model, save_path)
        print(f"Entire GAN model saved to {save_path}")
    else:
        # Save only the state dictionary
        torch.save(gan_model.state_dict(), save_path)
        print(f"GAN model state dictionary saved to {save_path}")



def load_gan_model(load_path, generator=None, save_entire_model=False, device='cuda'):
    """
    Load the GAN model.

    Parameters:
      - load_path: Path to the saved model.
      - generator: Pre-trained generator to initialize the GAN model (ignored if save_entire_model=True).
      - save_entire_model: If True, loads the entire model.
      - device: Device to load the model onto.

    Returns:
      - gan_model: The loaded GAN model.
    """
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"Model file not found at {load_path}")

    if save_entire_model:
        # Load the entire model
        gan_model = torch.load(load_path, map_location=device)
        print(f"Entire GAN model loaded from {load_path}")
    else:
        # Load the state dictionary
        if generator is None:
            raise ValueError("Generator must be provided when loading only the state dictionary.")
        gan_model = GANModel(generator).to(device)
        gan_model.load_state_dict(torch.load(load_path, map_location=device))
        print(f"GAN model state dictionary loaded from {load_path}")

    return gan_model

### Start training

In [None]:
generator = load_model("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/resUnet/resUnet_colorizer_coco.pt", device=device)

gan_model = GANModel(generator)

  model = torch.load(load_path, map_location=device)


Model loaded from /content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/resUnet/resUnet_colorizer_coco.pt
Initializing model weights with normal initialization


In [None]:
train_gan_model(
    gan_model=gan_model,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    save_path="/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer.pt",
    vis_interval=3
)

Output hidden; open in https://colab.research.google.com to view.

## Test on user images

In [14]:
def test_gan_colorization(gan_model, input_path, output_dir=None, img_size=256):
    """
    Test the GAN colorization model on user-provided images or directories.

    Parameters:
      - gan_model: Trained GAN model (already loaded).
      - input_path: Path to an image or directory containing images.
      - output_dir: Directory to save the colorized images (optional).
      - img_size: Image size to resize input images to (default: 256).
    """
    gan_model.eval()

    # Determine if input_path is an image or a directory
    input_path = Path(input_path)
    if input_path.is_file():
        image_paths = [input_path]
    elif input_path.is_dir():
        image_paths = list(input_path.glob("*.[pjPJ][pnPN][gG]"))
        image_paths = sorted(image_paths)
    else:
        raise ValueError(f"Invalid input path: {input_path}")

    # Limit visualization to 5 images
    visualize_images = image_paths[:5]
    visualize_only = len(image_paths) > 5 and output_dir is None

    # Ensure output directory exists
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

    transform = transforms.Compose([
        transforms.Resize((img_size, img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    for idx, img_path in enumerate(tqdm(image_paths, desc="Processing images")):
        grayscale_img = Image.open(img_path).convert("L")
        L = transform(grayscale_img).unsqueeze(0).to(device)

        # Predict colorization
        with torch.no_grad():
            ab_pred = gan_model.net_G(L)

        # Convert LAB to RGB
        L = L.cpu()
        ab_pred = ab_pred.cpu()
        rgb_fake = lab_to_rgb(L, ab_pred)[0]

        if idx < 5:
            plt.figure(figsize=(15, 5))
            plt.subplot(1, 3, 1)
            plt.imshow(L[0, 0].numpy(), cmap="gray", vmin=-1, vmax=1)
            plt.title("Grayscale Input")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(rgb_fake)
            plt.title("Predicted Colorization")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(Image.open(img_path).convert("RGB"))
            plt.title("Original")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

        if output_dir:
            output_file = output_dir / img_path.name
            plt.imsave(output_file, rgb_fake)
            print(f"Saved colorized image to {output_file}")

        if visualize_only and idx >= 4:
            break


In [None]:
input_path = "/content/test_images"
test_gan_colorization(gan_model, input_path)

Output hidden; open in https://colab.research.google.com to view.

#### Train for 20 more epochs

In [None]:
train_gan_model(
    gan_model=gan_model,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    save_path="/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer40.pt",
    vis_interval=3
)

Output hidden; open in https://colab.research.google.com to view.

In [12]:
gan_model = load_gan_model("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer40.pt", device=device, save_entire_model=True)

  gan_model = torch.load(load_path, map_location=device)


Entire GAN model loaded from /content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/GAN_w_ResUnet/gan_colorizer40.pt


In [None]:
input_path = "/content/test_images"
test_gan_colorization(gan_model, input_path)