In [None]:
!pip install fastai

import os
import cv2
import glob
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
from typing import List, Dict, Optional, Tuple, Any, Iterable
from pathlib import Path

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.transforms import InterpolationMode
import torchvision.models as models
from fastai.vision.learner import create_body
from fastai.data.external import untar_data, URLs
from fastai.vision.models.unet import DynamicUnet
from skimage.color import rgb2lab, lab2rgb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
coco_root = untar_data(URLs.COCO_SAMPLE)
coco_path = Path(coco_root) / "train_sample"
assert coco_path.exists(), f"Ne postoji putanja: {coco_path}"

paths = sorted(
    glob.glob(str(coco_path / "*.jpg")) +
    glob.glob(str(coco_path / "*.jpeg")) +
    glob.glob(str(coco_path / "*.png"))
)

def is_ok(p):
    try:
        Image.open(p).verify()
        return True
    except:
        return False

paths = [p for p in paths if is_ok(p)]
print(f"Ukupno validnih slika: {len(paths)}")

np.random.seed(268)
N_TARGET = 6_000
N_TOTAL  = min(N_TARGET, len(paths))
subset   = np.random.choice(paths, N_TOTAL, replace=False)

perm = np.random.permutation(N_TOTAL)
n_train = min(5_000, N_TOTAL)
n_val   = max(0, N_TOTAL - n_train)

train_idx = perm[:n_train]
val_idx   = perm[n_train:n_train + n_val]

train_paths = [subset[i] for i in train_idx]
val_paths   = [subset[i] for i in val_idx]

print(f"Train: {len(train_paths)} | Val: {len(val_paths)}")

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths[:16]):
    ax.imshow(Image.open(img_path))
    ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, img_paths, split="train", size=256):
        self.img_paths = [str(p) for p in img_paths]
        self.split = split
        self.size = int(size)

        self._tx_train = transforms.Compose([
            transforms.Resize((self.size, self.size), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
        ])
        self._tx_eval = transforms.Resize((self.size, self.size), interpolation=Image.BICUBIC)

    def __len__(self):
        return len(self.img_paths)

    def _load_rgb(self, path):
        img = Image.open(path).convert("RGB")
        return img

    def _to_lab_tensor(self, pil_rgb):
        np_rgb = np.array(pil_rgb)
        lab = rgb2lab(np_rgb).astype("float32")
        tens = transforms.ToTensor()(lab)
        L  = tens[[0], ...] / 50.0 - 1.0
        ab = tens[[1, 2], ...] / 110.0
        return L, ab

    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = self._load_rgb(path)

        if self.split == "train":
            img = self._tx_train(img)
        else:
            img = self._tx_eval(img)

        L, ab = self._to_lab_tensor(img)
        return {"L": L, "ab": ab}


def make_dataloaders(*, paths, split="train", batch_size=16, num_workers=4, pin_memory=True, size=256):
    ds = ColorizationDataset(img_paths=paths, split=split, size=size)
    return DataLoader(ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [None]:
train_dl = make_dataloaders(paths=train_paths, split="train")
val_dl   = make_dataloaders(paths=val_paths,   split="val")

batch = next(iter(train_dl))
L_batch, ab_batch = batch["L"], batch["ab"]

print(f"L shape:  {L_batch.shape} | ab shape: {ab_batch.shape}")
print(f"Num train batches: {len(train_dl)} | Num val batches: {len(val_dl)}")

n_show = 4
plt.figure(figsize=(12, 3*n_show))

for i in range(n_show):
    L = L_batch[i].numpy()[0]
    ab = ab_batch[i].numpy().transpose(1,2,0)

    L_img = (L + 1.) * 50.0
    ab_img = ab * 110.

    lab_img = np.concatenate((L_img[...,None], ab_img), axis=2)
    rgb_img = lab2rgb(lab_img.astype("float32"))

    gray_img = lab2rgb(np.concatenate((L_img[...,None], np.zeros_like(ab_img)), axis=2))

    plt.subplot(n_show, 2, 2*i+1)
    plt.imshow(gray_img)
    plt.title("Input (L channel)")
    plt.axis("off")

    plt.subplot(n_show, 2, 2*i+2)
    plt.imshow(rgb_img)
    plt.title("Ground Truth (L+ab)")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost

        if input_c is None:
            input_c = nf

        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.InstanceNorm2d(ni)
        uprelu   = nn.ReLU(True)
        upnorm = nn.InstanceNorm2d(nf)

        downconv = nn.Conv2d(input_c, ni, kernel_size=4, stride=2, padding=1, bias=False)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)

            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up

        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, stride=2, padding=1, bias=False)

            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up

        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False)

            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if dropout:
                up.append(nn.Dropout(0.5))
            model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], dim=1)


class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()

        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)

        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)

        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2

        self.model = UnetBlock(output_c, out_filters, input_c=input_c,
                               submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)


In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c: int, num_filters: int = 64, n_down: int = 3) -> None:
        super().__init__()
        layers: List[nn.Module] = []

        layers.append(self._block(in_ch=input_c, out_ch=num_filters, k=4, s=2, p=1,
                                  use_norm=False, use_act=True))

        for i in range(n_down):
            in_ch  = num_filters * (2 ** i)
            out_ch = num_filters * (2 ** (i + 1))
            stride = 1 if i == (n_down - 1) else 2
            layers.append(self._block(in_ch=in_ch, out_ch=out_ch, k=4, s=stride, p=1,
                                      use_norm=True, use_act=True))

        layers.append(self._block(in_ch=num_filters * (2 ** n_down), out_ch=1, k=4, s=1, p=1,
                                  use_norm=False, use_act=False))

        self.model = nn.Sequential(*layers)

    @staticmethod
    def _block(in_ch: int, out_ch: int, k: int = 4, s: int = 2, p: int = 1,
               use_norm: bool = True, use_act: bool = True) -> nn.Sequential:

        layers: List[nn.Module] = []
        layers.append(nn.Conv2d(in_channels=in_ch,
                                out_channels=out_ch,
                                kernel_size=k,
                                stride=s,
                                padding=p,
                                bias=not use_norm))
        if use_norm:
            layers.append(nn.BatchNorm2d(out_ch))
        if use_act:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))

        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        else:
            raise NotImplementedError(f"gan_mode '{gan_mode}' not implemented")

    def get_labels(self, preds, target_is_real):
        labels = self.real_label if target_is_real else self.fake_label
        return labels.expand_as(preds)

    def forward(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        return self.loss(preds, labels)


In [None]:
def init_weights(net, init_type="normal", gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, "weight") and "Conv" in classname:
            if init_type == "normal":
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init_type == "xavier":
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == "kaiming":
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            else:
                raise ValueError(f"Nepoznata inicijalizacija: {init_type}")

            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

        elif "BatchNorm2d" in classname:
            nn.init.normal_(m.weight.data, mean=1.0, std=gain)
            nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)
    print(f"Model inicijalizovan ({init_type} inicijalizacija)")
    return net


def init_model(model, device, init_type="normal", gain=0.02):
    model = model.to(device)
    return init_weights(model, init_type=init_type, gain=gain)


In [None]:
class ColorizationModel(nn.Module):
    def __init__(
        self,
        net_G: Optional[nn.Module] = None,
        lr_G: float = 2e-4,
        lr_D: float = 2e-4,
        beta1: float = 0.5,
        beta2: float = 0.999,
        lambda_L1: float = 100.0,
    ):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = float(lambda_L1)

        self.net_G = (net_G if net_G is not None
                      else Unet(input_c=1, output_c=2, n_down=8, num_filters=64))
        self.net_G = init_model(self.net_G, self.device)

        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)

        self.adv_criterion = GANLoss(gan_mode="vanilla").to(self.device)
        self.l1_criterion  = nn.L1Loss()

        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

        self.L: torch.Tensor  = torch.empty(0)
        self.ab: torch.Tensor = torch.empty(0)
        self.fake_color: torch.Tensor = torch.empty(0)

        self.loss_D_fake = torch.tensor(0.0)
        self.loss_D_real = torch.tensor(0.0)
        self.loss_D      = torch.tensor(0.0)
        self.loss_G_GAN  = torch.tensor(0.0)
        self.loss_G_L1   = torch.tensor(0.0)
        self.loss_G      = torch.tensor(0.0)

    @staticmethod
    def _cat_L_ab(L: torch.Tensor, ab: torch.Tensor) -> torch.Tensor:
        return torch.cat([L, ab], dim=1)

    @staticmethod
    def _toggle_grad(model: nn.Module, requires_grad: bool) -> None:
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, batch: Dict[str, torch.Tensor]) -> None:
        self.L  = batch["L"].to(self.device, non_blocking=True)
        self.ab = batch["ab"].to(self.device, non_blocking=True)

    def forward(self) -> None:
        self.fake_color = self.net_G(self.L)

    def backward_D(self) -> None:
        fake_pair = self._cat_L_ab(self.L, self.fake_color).detach()
        pred_fake = self.net_D(fake_pair)
        self.loss_D_fake = self.adv_criterion(pred_fake, target_is_real=False)

        real_pair = self._cat_L_ab(self.L, self.ab)
        pred_real = self.net_D(real_pair)
        self.loss_D_real = self.adv_criterion(pred_real, target_is_real=True)

        self.loss_D = 0.5 * (self.loss_D_fake + self.loss_D_real)
        self.loss_D.backward()

    def backward_G(self) -> None:
        fake_pair = self._cat_L_ab(self.L, self.fake_color)
        pred_fake = self.net_D(fake_pair)

        self.loss_G_GAN = self.adv_criterion(pred_fake, target_is_real=True)
        self.loss_G_L1  = self.l1_criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G     = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self) -> None:
        self.forward()

        self._toggle_grad(self.net_D, True)
        self.opt_D.zero_grad(set_to_none=True)
        self.backward_D()
        self.opt_D.step()

        self._toggle_grad(self.net_D, False)
        self.opt_G.zero_grad(set_to_none=True)
        self.backward_G()
        self.opt_G.step()


In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count = 0
        self.sum   = 0.0
        self.avg   = 0.0

    def update(self, val: float, count: int = 1):
        self.count += count
        self.sum   += val * count
        self.avg    = self.sum / max(self.count, 1)


def create_loss_meters() -> Dict[str, AverageMeter]:
    names = ["loss_D_fake", "loss_D_real", "loss_D",
             "loss_G_GAN", "loss_G_L1", "loss_G"]
    return {n: AverageMeter() for n in names}


def update_losses(model: nn.Module, meters: Dict[str, AverageMeter], count: int) -> None:
    for name, meter in meters.items():
        meter.update(getattr(model, name).item(), count=count)


def log_results(meters: Dict[str, AverageMeter]) -> None:
    for name, meter in meters.items():
        print(f"{name}: {meter.avg:.5f}")

def _lab_tensors_to_rgb_uint8(
    L: torch.Tensor, ab: torch.Tensor
) -> np.ndarray:

    L_real  = (L + 1.0) * 50.0
    ab_real = ab * 110.0

    Lab = torch.cat([L_real, ab_real], dim=1).permute(0, 2, 3, 1).detach().cpu().numpy()

    rgb_list = []
    for i in range(Lab.shape[0]):
        rgb_f = lab2rgb(Lab[i])
        rgb_u8 = (np.clip(rgb_f, 0, 1) * 255).astype(np.uint8)
        rgb_list.append(rgb_u8)
    return np.stack(rgb_list, axis=0)


def _gray_from_L_uint8(L: torch.Tensor) -> np.ndarray:
    L_real = (L + 1.0) * 50.0   # [0..100]

    B, _, H, W = L.shape
    zeros = torch.zeros((B, 2, H, W), device=L.device, dtype=L.dtype)
    gray_rgb = _lab_tensors_to_rgb_uint8(L, zeros)
    return gray_rgb

def _build_triptychs_from_batch(
    L: torch.Tensor, fake_ab: torch.Tensor, real_ab: torch.Tensor
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray, str]]:

    gt_rgb   = _lab_tensors_to_rgb_uint8(L, real_ab)
    gray_rgb = _gray_from_L_uint8(L)
    pred_rgb = _lab_tensors_to_rgb_uint8(L, fake_ab)

    triptychs = []
    B = L.shape[0]
    for i in range(B):
        stem = f"sample_{i:03d}"
        triptychs.append((gt_rgb[i], gray_rgb[i], pred_rgb[i], stem))
    return triptychs

def plot_triptych_grid(triptychs: list[tuple[np.ndarray, np.ndarray, np.ndarray, str]]) -> None:
    n = len(triptychs)
    fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n))

    if n == 1:
        axes = [axes]

    for ax_row, (orig, gray, color, _) in zip(axes, triptychs):
        for ax, img, title in zip(ax_row, [orig, gray, color], ["Original", "Grayscale", "Colorized (pix2pix)"]):
            ax.imshow(img)
            ax.set_title(title)
            ax.axis("off")

    plt.tight_layout()
    plt.show()


def save_triptychs(triptychs: list[tuple[np.ndarray, np.ndarray, np.ndarray, str]], out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)

    for orig, gray, color, stem in triptychs:
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        for ax, img, title in zip(axes, [orig, gray, color], ["Original", "Grayscale", "Colorized (pix2pix)"]):
            ax.imshow(img)
            ax.set_title(title)
            ax.axis("off")

        fig.tight_layout()
        fig.savefig(out_dir / f"{stem}_triptych.png", dpi=150)
        plt.close(fig)


def visualize_batch(
    model: nn.Module,
    data: dict,
    *,
    save: bool = False,
    out_dir: Path = Path("./colorization_viz"),
    max_show: int = 5
) -> None:

    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
        L   = model.L
        ab  = model.ab
        fab = model.fake_color

    all_triptychs = _build_triptychs_from_batch(L, fab, ab)[:max_show]

    plot_triptych_grid(all_triptychs)
    if save:
        save_triptychs(all_triptychs, out_dir)

In [None]:
def train_model(
    model: nn.Module,
    train_dl,
    epochs: int,
    val_dl=None,
    display_every: int = 100,
    viz_save: bool = False
):
    viz_batch = next(iter(val_dl)) if val_dl is not None else None

    for e in range(1, epochs + 1):
        meters = create_loss_meters()
        for i, data in enumerate(tqdm(train_dl, desc=f"Epoch {e}/{epochs}"), start=1):
            model.setup_input(data)
            model.optimize()

            bsize = data["L"].size(0)
            update_losses(model, meters, count=bsize)

            if i % display_every == 0:
                print(f"\n[Epoch {e}/{epochs}] Iter {i}/{len(train_dl)}")
                log_results(meters)

                if viz_batch is not None:
                    visualize_batch(model, viz_batch, save=viz_save)
                else:
                    visualize_batch(model, data, save=viz_save)

        print(f"\n===> Epoch {e} done.")
        log_results(meters)
        if viz_batch is not None:
            visualize_batch(model, viz_batch, save=viz_save)

model = ColorizationModel()
train_model(model, train_dl, epochs=30, val_dl=val_dl, display_every=100, viz_save=True)

In [None]:
def build_resnet_18(n_input: int = 1, n_output: int = 2, size: int = 256) -> nn.Module:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    resnet18_model = models.resnet18(pretrained=True)
    body = create_body(resnet18_model, n_in=n_input, cut=-2)

    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G


@torch.inference_mode()
def _valid_step_L1(
    net_G: nn.Module,
    val_dl: Iterable[Dict[str, Any]],
    criterion: nn.Module,
    device: torch.device
) -> float:
    prev_mode = net_G.training
    net_G.eval()

    total_loss = 0.0
    total_count = 0

    for batch in val_dl:
        L = batch["L"].to(device, non_blocking=True)
        ab = batch["ab"].to(device, non_blocking=True)

        pred = net_G(L)
        loss = criterion(pred, ab).item()

        batch_size = L.size(0)
        total_loss += loss * batch_size
        total_count += batch_size

    net_G.train(prev_mode)

    if total_count == 0:
        return 0.0
    return total_loss / total_count


In [None]:
def pretrain_generator(
    net_G: nn.Module,
    train_dl: Iterable[Dict[str, Any]],
    val_dl:   Iterable[Dict[str, Any]],
    opt: torch.optim.Optimizer,
    criterion: nn.Module,
    epochs: int,
    device: torch.device,
    *,
    use_amp: bool = True,
    grad_clip: Optional[float] = None
) -> List[Dict[str, float]]:

    history: List[Dict[str, float]] = []
    scaler = GradScaler(enabled=use_amp)

    for epoch in range(1, epochs + 1):
        net_G.train()
        epoch_loss, seen = 0.0, 0
        t0 = time.time()

        pbar = tqdm(train_dl, desc=f"[Pretrain G] Epoch {epoch}/{epochs}", leave=False)
        for batch in pbar:
            L  = batch["L"].to(device, non_blocking=True)
            ab = batch["ab"].to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with autocast(enabled=use_amp):
                pred  = net_G(L)
                loss  = criterion(pred, ab)

            if use_amp:
                scaler.scale(loss).backward()
                if grad_clip is not None:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(net_G.parameters(), grad_clip)
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(net_G.parameters(), grad_clip)
                opt.step()

            bsz = L.size(0)
            epoch_loss += loss.item() * bsz
            seen       += bsz
            pbar.set_postfix(loss=f"{epoch_loss/seen:.4f}")

        train_loss = epoch_loss / max(1, seen)
        val_loss   = _valid_step_L1(net_G, val_dl, criterion, device)
        dt = time.time() - t0

        print(f"Epoch {epoch:03d}/{epochs} | L1(train) {train_loss:.5f} | "
              f"L1(val) {val_loss:.5f} | {dt:.1f}s")

        history.append({"epoch": epoch, "train_l1": train_loss, "val_l1": val_loss})

    return history


In [None]:
net_G = build_resnet_18(n_input=1, n_output=2, size=256)
opt_G = optim.Adam(net_G.parameters(), lr=1e-4)
crit  = nn.L1Loss()

pretrain_generator(net_G, train_dl, val_dl, opt_G, crit, epochs=1, device=device)

torch.save(net_G.state_dict(), "resnet-18.pt")

In [None]:
net_G = build_resnet_18(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("resnet-18.pt", map_location=device))

model = ColorizationModel(net_G=net_G)
train_model(model, train_dl, epochs=30, val_dl=val_dl, display_every=100, viz_save=True)