In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models import vit_b_16, ViT_B_16_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

## 1- Dataset
We are using a subset of COCO from fastai (~21K images)

In [None]:
IMG_SIZE = 224 # img resolution
SEED = 42

coco_path = untar_data(URLs.COCO_SAMPLE)
image_files = list((coco_path / 'train_sample').glob("*.jpg"))
print(f"Total images found in COCO_SAMPLE: {len(image_files)}")

np.random.seed(SEED)
np.random.shuffle(image_files)

# pick 10,000 images
N = min(len(image_files), 10000) # just in case it downloads less than 10K images, use min
image_files = image_files[:N]  # random (seeded) subset
print(f"Using {len(image_files)} images for demonstration.")

# 80/20 train/val split
split_idx = int(0.8 * len(image_files))
train_files = image_files[:split_idx]
val_files   = image_files[split_idx:]

print(f"Train set size: {len(train_files)}")
print(f"Val set size:   {len(val_files)}")

Total images found in COCO_SAMPLE: 21837
Using 10000 images for demonstration.
Train set size: 8000
Val set size:   2000


## 2- DATASET: On-the-fly LAB

In [None]:
class LABDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # added small data augmentation, maybe add more later
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((IMG_SIZE, IMG_SIZE),  Image.BICUBIC) # only resize for validation set, dont do augmentation

        self.size = IMG_SIZE
        self.paths = paths
        self.split = split

    def __getitem__(self, idx):
        path = self.paths[idx]
        with Image.open(path).convert("RGB") as img:
            # Apply transformations
            img = self.transforms(img)
            # Convert PIL -> NumPy
            rgb_np = np.array(img)  # (H,W,3) in [0..255]

        # Convert RGB to LAB using skimage
        img_lab = rgb2lab(img).astype(np.float32)  # Convert to LAB
        L_channel, A_channel, B_channel = img_lab[..., 0], img_lab[..., 1], img_lab[..., 2]

        # Normalize L and AB channels
        L = torch.from_numpy(L_channel / 50.0 - 1.0).unsqueeze(0)  # Normalize L to [-1, 1] and add channel dimension
        ab = torch.from_numpy(np.stack((A_channel / 110.0, B_channel / 110.0), axis=0))  # Normalize AB to [-1, 1] and stack (We used 110 because we don't generally see extreme values like 128 and also some sources say this works better)

        # Output as dictionary
        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

In [None]:
# dataloaders
train_ds = LABDataset(train_files, split='train')
val_ds   = LABDataset(val_files, split='val')

batch_size = 16
num_of_workers = 4
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_of_workers, pin_memory=True) # appearently pin memory can increase data speed from CPU to GPU
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_of_workers, pin_memory=True)

### ViT based generator

In [3]:
class UpsampleBlock(nn.Module):
    """
    A basic CNN block that doubles spatial resolution
    and reduces the feature dimension if needed.
    """
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        self.scale_factor = scale_factor
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
        x = self.conv(x)
        return x


class ViTColorizerWithTransformerDecoder(nn.Module):
    def __init__(self,
                 pretrained=True,
                 num_decoder_layers=2,
                 nhead=8,
                 dim_feedforward=2048,
                 final_upsample=False):
        super().__init__()
        if pretrained:
            weights = ViT_B_16_Weights.IMAGENET1K_V1
        else:
            weights = None

        # 1) Load a pretrained ViT (b_16)
        self.vit = vit_b_16(weights=weights)

        # Remove classification head and get encoder
        self.vit.heads = nn.Identity()
        self.encoder = self.vit.encoder

        # 2) Get hidden dim from conv_proj
        hidden_dim = self.vit.conv_proj.weight.shape[0]  # typically 768
        num_patches = (IMG_SIZE // 16) * (IMG_SIZE // 16) + 1  # 196 + 1 for 224x224 input

        self.hidden_dim = hidden_dim
        self.num_tokens = num_patches - 1  # ignoring [CLS] => 196

        # 3) Transformer Decoder Query Embeddings
        self.decoder_queries = nn.Parameter(torch.randn(self.num_tokens, hidden_dim))
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer,
                                                       num_layers=num_decoder_layers)

        # 4) CNN upsampling from (B,768,14,14) -> (B,2,224,224)
        self.up_net = nn.Sequential(
            UpsampleBlock(768, 512),  # 14×14 -> 28×28
            UpsampleBlock(512, 256),  # 28×28 -> 56×56
            UpsampleBlock(256, 128),  # 56×56 -> 112×112
            UpsampleBlock(128, 64),   # 112×112 -> 224×224
            nn.Conv2d(64, 2, kernel_size=3, padding=1)  # => (B,2,224,224)
        )

        self.final_upsample = final_upsample
        if self.final_upsample:
            self.final_layer = nn.Conv2d(2, 2, kernel_size=1)

    def forward(self, L):
        B = L.size(0)

        # 1) Replicate L -> 3ch and process through ViT's patch embedding
        x_3ch = L.repeat(1, 3, 1, 1)  # (B,3,H,W)

        # 2) Get patch embeddings directly
        x = self.vit.conv_proj(x_3ch)  # (B,768,14,14)
        x = x.flatten(2).transpose(1, 2)  # (B,196,768)

        # Add class token
        cls_token = self.vit.class_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # (B,197,768)

        # Add position embeddings
        x = x + self.vit.encoder.pos_embedding  # Add position embeddings directly

        # 3) Process through transformer encoder
        encoder_output = self.encoder(x)  # (B,197,768)

        # 4) Remove CLS token and use as memory for decoder
        memory = encoder_output[:, 1:, :]  # (B,196,768)

        # 5) Expand queries for batch
        queries = self.decoder_queries.unsqueeze(0).expand(B, -1, -1)  # (B,196,768)

        # 6) Transformer decoder
        decoded = self.transformer_decoder(queries, memory)  # (B,196,768)

        # 7) Reshape for CNN upsampling
        decoded = decoded.transpose(1, 2).view(B, self.hidden_dim, 14, 14)  # (B,768,14,14)

        # 8) Upsample to target size
        out_224 = self.up_net(decoded)  # (B,2,224,224)

        if self.final_upsample:
            out_256 = F.interpolate(out_224, size=(L.size(2), L.size(3)),
                                  mode='bilinear', align_corners=False)
            out_256 = self.final_layer(out_256)
            return out_256

        return out_224

### Helper functions

In [4]:
def lab_to_rgb(L, ab):
    # Denormalize
    L = (L + 1.) * 50.0
    a = ab[:, [0], :, :] * 110.0
    b = ab[:, [1], :, :] * 110.0

    Lab = torch.cat([L, a, b], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = [lab2rgb(img.astype(np.float64)) for img in Lab]
    return np.stack(rgb_imgs, axis=0)

def visualize_colorization(model, data, max_images=5, save=False):
    model.eval()
    with torch.no_grad():
        L_in = data['L'].to(device)
        ab_gt = data['ab'].to(device)
        ab_pred = model(L_in)

    rgb_fake = lab_to_rgb(L_in, ab_pred)
    rgb_real = lab_to_rgb(L_in, ab_gt)
    L_np = L_in.cpu().numpy()[:, 0, :, :]

    plt.figure(figsize=(15, 8))
    for i in range(min(max_images, L_in.size(0))):
        ax = plt.subplot(3, max_images, i + 1)
        ax.imshow(L_np[i], cmap='gray', vmin=-1, vmax=1)
        ax.set_title("L Input")
        ax.axis("off")

        ax = plt.subplot(3, max_images, i + 1 + max_images)
        ax.imshow(rgb_fake[i])
        ax.set_title("Predicted")
        ax.axis("off")

        ax = plt.subplot(3, max_images, i + 1 + 2 * max_images)
        ax.imshow(rgb_real[i])
        ax.set_title("Ground Truth")
        ax.axis("off")

    plt.tight_layout()
    plt.show()
    if save:
        plt.savefig(f"colorization_{time.time()}.png")

### Pretrain the generator

In [None]:
def pretrain_vit_generator(
    generator,
    train_dl,
    val_dl=None,
    epochs=10,
    lr=1e-4,
    lambda_L1=1.0,  # How strongly to scale L1 (like your lambda_L1 in GAN)
    save_path=None,
    device='cuda',
    vis_interval=3  # Visualize every `vis_interval` epochs
):
    """
    Pretrain the ViT generator on the colorization task alone (no discriminator).
      - generator: an instance of ViTColorizerWithTransformerDecoder
      - train_dl: DataLoader for the training set
      - val_dl: DataLoader for validation (optional)
      - epochs: number of pretraining epochs
      - lr: learning rate for Adam
      - lambda_L1: weighting factor for L1 loss
      - save_path: if provided, saves the generator's weights at the end
      - device: "cuda" or "cpu"
    """
    generator.to(device)
    generator.train()

    # Optimizer (just for the generator)
    optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.999))

    # L1 Loss
    l1_loss_fn = nn.L1Loss()

    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in tqdm(train_dl, desc=f"Pretrain Epoch {epoch+1}/{epochs}"):
            # Move data to device
            L_in  = batch['L'].to(device)    # (B,1,H,W)
            ab_gt = batch['ab'].to(device)   # (B,2,H,W)

            # Forward pass
            ab_pred = generator(L_in)        # (B,2,H,W)

            # Compute L1 loss
            loss = l1_loss_fn(ab_pred, ab_gt) * lambda_L1

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= len(train_dl)
        print(f"[Pretrain Epoch {epoch+1}/{epochs}] L1 Loss: {epoch_loss:.4f}")

        # --- Optional Validation ---
        if val_dl is not None:
            generator.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_batch in val_dl:
                    L_in  = val_batch['L'].to(device)
                    ab_gt = val_batch['ab'].to(device)
                    ab_pred = generator(L_in)
                    val_loss += l1_loss_fn(ab_pred, ab_gt).item()
            val_loss /= len(val_dl)
            print(f"Validation L1 Loss: {val_loss:.4f}")
            # Visualize every `vis_interval` epochs
            if (epoch + 1) % vis_interval == 0:
                print(f"Visualizing predictions at epoch {epoch+1}...")
                visualize_colorization(generator, next(iter(val_dl)), max_images=4)
            generator.train()

    # --- Save generator weights ---
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(generator, save_path)
        print(f"Pretrained generator saved to {save_path}")


In [None]:
vit_generator = ViTColorizerWithTransformerDecoder(
    pretrained=True,
    num_decoder_layers=2,
    nhead=8,
    dim_feedforward=2048,
    final_upsample=False
)

In [None]:
pretrain_vit_generator(
    generator=vit_generator,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=10,
    lr=1e-4,
    lambda_L1=1.0,
    save_path="/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator.pt",
    device=device
)

Output hidden; open in https://colab.research.google.com to view.

###### Train for 20 more epochs

In [None]:
pretrain_vit_generator(
    generator=vit_generator,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    lr=1e-4,
    lambda_L1=1.0,
    save_path="/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator30.pt",
    device=device
)

##### Train 20 more epochs

In [None]:
vit_generator = torch.load("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator30.pt")
pretrain_vit_generator(
    generator=vit_generator,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=20,
    lr=1e-4,
    lambda_L1=1.0,
    save_path="/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator50.pt",
    device=device
)

Output hidden; open in https://colab.research.google.com to view.

##### in total we trained it for 50 epochs

## Test on user images

In [5]:
vit_generator = torch.load("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator50.pt")

  vit_generator = torch.load("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/alihan_vitgan/pretrained_vit_generator50.pt")


In [6]:
def test_colorization(model, input_path, output_dir=None, img_size=224):
    """
    Test the colorization model on user-provided images or directories of images.

    Parameters:
      - model: Trained colorization model (already loaded).
      - input_path: Path to an image or directory containing images.
      - output_dir: Directory to save the colorized images (optional).
      - img_size: Image size to resize input images to (default: 224).
    """
    # Ensure the model is in evaluation mode
    model.eval()

    # Determine if input_path is an image or a directory
    input_path = Path(input_path)
    if input_path.is_file():  # Single image
        image_paths = [input_path]
    elif input_path.is_dir():  # Directory of images
        image_paths = list(input_path.glob("*.jpg")) + \
                      list(input_path.glob("*.jpeg")) + \
                      list(input_path.glob("*.png"))
        image_paths = sorted(image_paths)  # Sort to ensure consistent order
    else:
        raise ValueError(f"Invalid input path: {input_path}")

    # Ensure the output directory exists, if provided
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

    # Transformation for input images
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])

    # Process images
    for idx, img_path in enumerate(tqdm(image_paths, desc="Processing images")):
        # Load and preprocess the grayscale version
        grayscale_img = Image.open(img_path).convert("L")  # Open as grayscale
        L = transform(grayscale_img).unsqueeze(0).to(device)  # Add batch dimension and send to device

        # Load the ground truth (original RGB image)
        ground_truth_img = Image.open(img_path).convert("RGB")
        ground_truth_img_resized = ground_truth_img.resize((img_size, img_size), Image.BICUBIC)

        # Predict colorization
        with torch.no_grad():
            ab_pred = model(L)

        # Convert LAB to RGB
        L = L.cpu()
        ab_pred = ab_pred.cpu()
        rgb_fake = lab_to_rgb(L, ab_pred)[0]  # Extract the first (and only) image from the batch

        # Visualize the first 5 images
        if idx < 5:
            plt.figure(figsize=(15, 5))

            # Grayscale input
            plt.subplot(1, 3, 1)
            plt.imshow(L[0, 0].numpy(), cmap="gray", vmin=-1, vmax=1)
            plt.title("Grayscale Input")
            plt.axis("off")

            # Predicted colorization
            plt.subplot(1, 3, 2)
            plt.imshow(rgb_fake)
            plt.title("Predicted Colorization")
            plt.axis("off")

            # Ground truth (original color image)
            plt.subplot(1, 3, 3)
            plt.imshow(ground_truth_img_resized)
            plt.title("Ground Truth (Color)")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

        # Save the processed image if output_dir is provided
        if output_dir:
            output_file = output_dir / img_path.name
            plt.imsave(output_file, rgb_fake)
            print(f"Saved colorized image to {output_file}")


In [7]:
input_path = "/content/test_images"

test_colorization(vit_generator, input_path, img_size=224)

Output hidden; open in https://colab.research.google.com to view.