In [None]:
!pip install -q -U fastai

import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict
from pathlib import Path

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from skimage.color import rgb2lab, lab2rgb

from fastai.data.external import untar_data, URLs

In [None]:
coco_root = untar_data(URLs.COCO_SAMPLE)
coco_path = Path(coco_root) / "train_sample"
assert coco_path.exists(), f"Ne postoji putanja: {coco_path}"

paths = sorted(
    glob.glob(str(coco_path / "*.jpg")) +
    glob.glob(str(coco_path / "*.jpeg")) +
    glob.glob(str(coco_path / "*.png"))
)

def is_ok(p):
    try:
        Image.open(p).verify()
        return True
    except:
        return False

paths = [p for p in paths if is_ok(p)]
print(f"Ukupno validnih slika: {len(paths)}")

np.random.seed(987)
N_TARGET = 3_000
N_TOTAL  = min(N_TARGET, len(paths))
subset   = np.random.choice(paths, N_TOTAL, replace=False)

perm = np.random.permutation(N_TOTAL)
n_train = min(2_000, N_TOTAL)
n_val   = max(0, N_TOTAL - n_train)

train_idx = perm[:n_train]
val_idx   = perm[n_train:n_train + n_val]

train_paths = [subset[i] for i in train_idx]
val_paths   = [subset[i] for i in val_idx]

print(f"Train: {len(train_paths)} | Val: {len(val_paths)}")

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths[:16]):
    ax.imshow(Image.open(img_path))
    ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, img_paths, split="train", size=256):
        self.img_paths = [str(p) for p in img_paths]
        self.split = split
        self.size = int(size)

        self._tx_train = transforms.Compose([
            transforms.Resize((self.size, self.size), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
        ])
        self._tx_eval = transforms.Resize((self.size, self.size), interpolation=Image.BICUBIC)

    def __len__(self):
        return len(self.img_paths)

    def _load_rgb(self, path):
        img = Image.open(path).convert("RGB")
        return img

    def _to_lab_tensor(self, pil_rgb):
        np_rgb = np.array(pil_rgb)
        lab = rgb2lab(np_rgb).astype("float32")
        tens = transforms.ToTensor()(lab)
        L  = tens[[0], ...] / 50.0 - 1.0
        ab = tens[[1, 2], ...] / 110.0
        return L, ab

    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = self._load_rgb(path)

        if self.split == "train":
            img = self._tx_train(img)
        else:
            img = self._tx_eval(img)

        L, ab = self._to_lab_tensor(img)
        return {"L": L, "ab": ab}


def make_dataloaders(*, paths, split="train", batch_size=16, num_workers=4, pin_memory=True, size=256):
    ds = ColorizationDataset(img_paths=paths, split=split, size=size)
    return DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)


In [None]:
train_dl = make_dataloaders(paths=train_paths, split="train")
val_dl   = make_dataloaders(paths=val_paths,   split="val")

batch = next(iter(train_dl))
L_batch, ab_batch = batch["L"], batch["ab"]

print(f"L shape:  {L_batch.shape} | ab shape: {ab_batch.shape}")
print(f"Num train batches: {len(train_dl)} | Num val batches: {len(val_dl)}")

n_show = 15
plt.figure(figsize=(12, 3*n_show))

for i in range(n_show):
    L = L_batch[i].numpy()[0]
    ab = ab_batch[i].numpy().transpose(1,2,0)

    L_img = (L + 1.) * 50.0
    ab_img = ab * 110.

    lab_img = np.concatenate((L_img[...,None], ab_img), axis=2)
    rgb_img = lab2rgb(lab_img.astype("float32"))

    gray_img = lab2rgb(np.concatenate((L_img[...,None], np.zeros_like(ab_img)), axis=2))

    plt.subplot(n_show, 2, 2*i+1)
    plt.imshow(gray_img)
    plt.title("Input (L channel)")
    plt.axis("off")

    plt.subplot(n_show, 2, 2*i+2)
    plt.imshow(rgb_img)
    plt.title("Ground Truth (L+ab)")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
PROTOTXT = "/content/colorization_deploy_v2.prototxt"
MODEL    = "/content/colorization_release_v2.caffemodel"
POINTS   = "/content/pts_in_hull.npy"

for fp in [PROTOTXT, MODEL, POINTS]:
    if not Path(fp).exists():
        raise FileNotFoundError(f"Nedostaje fajl: {fp}")

In [None]:
net = cv2.dnn.readNetFromCaffe(PROTOTXT, MODEL)
pts = np.load(POINTS)

class8_id = net.getLayerId("class8_ab")
conv8_id  = net.getLayerId("conv8_313_rh")

pts = pts.transpose().reshape(2, 313, 1, 1)
net.getLayer(class8_id).blobs = [pts.astype("float32")]
net.getLayer(conv8_id).blobs  = [np.full([1, 313], 2.606, dtype="float32")]

In [None]:
def bgr_to_lab_float01(img_bgr: np.ndarray) -> np.ndarray:
    scaled = img_bgr.astype("float32") / 255.0
    return cv2.cvtColor(scaled, cv2.COLOR_BGR2LAB)

def lab_to_bgr_uint8(lab: np.ndarray) -> np.ndarray:
    bgr = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    bgr = np.clip(bgr, 0, 1)
    return (bgr * 255).astype("uint8")

def bgr_to_rgb(img_bgr: np.ndarray) -> np.ndarray:
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

def colorize(img_bgr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    H, W = img_bgr.shape[:2]

    lab = bgr_to_lab_float01(img_bgr)
    L   = lab[:, :, 0]

    L_rs = cv2.resize(L, (224, 224))
    L_rs -= 50.0

    net.setInput(cv2.dnn.blobFromImage(L_rs))
    ab_dec = net.forward()[0].transpose(1, 2, 0)
    ab_up  = cv2.resize(ab_dec, (W, H))

    lab_out  = np.concatenate([L[:, :, None], ab_up], axis=2)
    color_bgr = lab_to_bgr_uint8(lab_out)

    gray_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    gray_bgr = cv2.cvtColor(gray_bgr, cv2.COLOR_GRAY2BGR)
    return color_bgr, gray_bgr

def plot_triptych_grid(triptychs: list[tuple[np.ndarray, np.ndarray, np.ndarray, str]]) -> None:
    rows = len(triptychs)
    plt.figure(figsize=(12, 4 * rows))
    for i, (orig_rgb, gray_rgb, color_rgb, _) in enumerate(triptychs):
        plt.subplot(rows, 3, 3 * i + 1)
        plt.imshow(orig_rgb);  plt.axis("off"); plt.title("Original")

        plt.subplot(rows, 3, 3 * i + 2)
        plt.imshow(gray_rgb);  plt.axis("off"); plt.title("Grayscale")

        plt.subplot(rows, 3, 3 * i + 3)
        plt.imshow(color_rgb); plt.axis("off"); plt.title("Colorized (Zhang)")
    plt.tight_layout()
    plt.show()

def save_triptychs(triptychs: list[tuple[np.ndarray, np.ndarray, np.ndarray, str]], out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    for orig_rgb, gray_rgb, color_rgb, stem in triptychs:
        fig, ax = plt.subplots(1, 3, figsize=(12, 4))
        ax[0].imshow(orig_rgb);  ax[0].axis("off"); ax[0].set_title("Original")
        ax[1].imshow(gray_rgb);  ax[1].axis("off"); ax[1].set_title("Grayscale")
        ax[2].imshow(color_rgb); ax[2].axis("off"); ax[2].set_title("Colorized (Zhang)")
        fig.tight_layout()
        out_path = out_dir / f"{stem}_triptych.png"
        fig.savefig(out_path, dpi=150)
        plt.close(fig)

In [None]:
OUT_DIR = Path("./zhang_examples")
picked_paths = val_paths[:15] if len(val_paths) >= 15 else val_paths
print(f"Pripremam {len(picked_paths)} triptiha...")

triptychs = []
for p in picked_paths:
    img_bgr = cv2.imread(p)
    if img_bgr is None:
        print(f"Preskačem (ne može da se učita): {p}")
        continue

    color_bgr, gray_bgr = colorize(img_bgr)
    orig_rgb  = bgr_to_rgb(img_bgr)
    gray_rgb  = bgr_to_rgb(gray_bgr)
    color_rgb = bgr_to_rgb(color_bgr)
    triptychs.append((orig_rgb, gray_rgb, color_rgb, Path(p).stem))

plot_triptych_grid(triptychs)

save_triptychs(triptychs, OUT_DIR)
print(f"Snimljeni triptisi u: {OUT_DIR.resolve()}")