# LocJND - Patch-level JND Regression (Train)

**Goal.**
Train a small CNN to predict human noticeability of compression artifacts at patch level.
Target is a scalar `y ∈ [0,1]` where `y = 1 − α`. Higher = more noticeable.

In [13]:
# CELL 0 — Imports and Reproducibility
# This cell sets up all required libraries and fixes random seeds
# to ensure consistent results across runs.

import os
import random
from pydoc import pipepager

import numpy as np
import pandas as pd
import cv2
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage.data import data_dir
from tensorboard.notebook import display

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

from scipy.stats import pearsonr

# --- reproducibility helper ---
def set_seed(seed: int = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # allow speedup on GPU

set_seed(2025)


In [None]:
# CELL 1 — Configuration
# Central place to define hyperparameters and paths.

from types import SimpleNamespace
from datetime import datetime

cfg = SimpleNamespace(
    # --- training ---
    epochs = 40,                 # 60–80 recommended for 20x20
    batch_size = 128,            # safe default for 20x20
    lr_init = 1e-4,
    weight_decay = 1e-4,
    img_size = 20,
    num_workers = 0,             # Jupyter-friendly
    device = "cuda" if torch.cuda.is_available() else "cpu",
    seed = 2025,
    aug_dup = 2,                 # duplicate train set via augmentation

    # --- data paths (relative to project root) ---
    data_dir = Path.cwd().parent / "dataset",
)

# derived paths
cfg.train_csv = cfg.data_dir / "blends_train.csv"
cfg.val_csv   = cfg.data_dir / "blends_val.csv"
cfg.test_csv  = cfg.data_dir / "blends_test.csv"
cfg.orig_dir  = cfg.data_dir / "orig_patches"
cfg.mixed_dir = cfg.data_dir / "mixed_patches"

# --- logging / checkpoints ---
cfg.run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
cfg.checkpoints_dir = Path("checkpoints")
cfg.runs_dir = Path("runs")

# ensure dirs
cfg.checkpoints_dir.mkdir(parents=True, exist_ok=True)
cfg.runs_dir.mkdir(parents=True, exist_ok=True)


In [None]:
# CELL 2 — Paths & Sanity
# Verifies dataset layout and sets advanced training toggles.

from types import SimpleNamespace

def _must_exist(p: Path, name: str):
    if not Path(p).exists():
        raise FileNotFoundError(f"Missing {name}: {p}")

# verify required folders/files
for p, n in [
    (cfg.data_dir, "data_dir"),
    (cfg.orig_dir, "orig_patches"),
    (cfg.mixed_dir, "mixed_patches"),
    (cfg.train_csv, "blends_train.csv"),
    (cfg.val_csv,   "blends_val.csv"),
    (cfg.test_csv,  "blends_test.csv"),
]:
    _must_exist(p, n)

# optional: LocJND.json (used later for references/metadata)
cfg.locjnd_json = Path.cwd().parent / "data" / "LocJND.json"
_must_exist(cfg.locjnd_json, "LocJND.json")

cfg.push99 = SimpleNamespace(
    use_heatmap_weighting=True,   # weighted MSE by heatmap if available
    use_gradient_channels=True,   # add Sobel |∇orig|, |∇mixed|, |∇delta|
    use_tta=True,                 # test-time augmentation (id,hflip,vflip,hv)
    ensemble_seeds=[2025, 2026, 2027],
    lambda_rank=0.2,              # weight for ranking loss
    lambda_pearson=0.1,           # weight for Pearson loss (1 - corr)
)

# enforce Jupyter-friendly settings
cfg.num_workers = 0
cfg.aug_dup = getattr(cfg, "aug_dup", 1)

print("✔ Paths OK")
print(f"data_dir      : {cfg.data_dir}")
print(f"orig/mixed    : {cfg.orig_dir} | {cfg.mixed_dir}")
print(f"train/val/test: {cfg.train_csv.name}, {cfg.val_csv.name}, {cfg.test_csv.name}")
print(f"LocJND.json   : {cfg.locjnd_json}")
print("push99        :", cfg.push99.__dict__)


✔ Paths OK
data_dir      : C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\dataset
orig/mixed    : C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\dataset\orig_patches | C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\dataset\mixed_patches
train/val/test: blends_train.csv, blends_val.csv, blends_test.csv
LocJND.json   : C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\LocJND\LocJND.json
push99        : {'use_heatmap_weighting': True, 'use_gradient_channels': True, 'use_tta': True, 'ensemble_seeds': [2025, 2026, 2027], 'lambda_rank': 0.2, 'lambda_pearson': 0.1}


In [16]:
# CELL 3 — Read Splits
# Load train/val/test CSVs, normalize paths, and print quick stats.

import pandas as pd

REQUIRED_COLS = {"mixed_path", "target"}

def _load_split(csv_path: Path) -> pd.DataFrame:
    df = pd.read_csv(csv_path, low_memory=False)
    missing = REQUIRED_COLS - set(df.columns)
    if missing:
        raise KeyError(f"CSV {csv_path.name} missing columns: {sorted(missing)}")

    # normalize types
    df["mixed_path"] = df["mixed_path"].astype(str)
    df["target"] = df["target"].astype("float32")

    # resolve absolute paths for mixed patches
    def _resolve_mixed(p: str) -> str:
        pth = Path(p)
        if pth.is_absolute() and pth.exists():
            return str(pth)
        # try relative to dataset root
        cand = (cfg.data_dir / pth).resolve()
        if cand.exists():
            return str(cand)
        # try relative to mixed_patches directory using full relative and then just filename
        cand2 = (cfg.mixed_dir / pth).resolve()
        if cand2.exists():
            return str(cand2)
        cand3 = (cfg.mixed_dir / pth.name).resolve()
        if cand3.exists():
            return str(cand3)
        # last resort: keep as-is (dataset check will catch later)
        return str(cand)

    df["mixed_path"] = df["mixed_path"].map(_resolve_mixed)

    # optional: ensure unique key if provided
    if "image_id" in df.columns:
        df["image_id"] = df["image_id"].astype(str)

    return df

def _quick_stats(name: str, df: pd.DataFrame):
    n = len(df)
    tgt = df["target"].to_numpy()
    print(f"{name:>5} | n={n} | target mean={tgt.mean():.4f} std={tgt.std():.4f} "
          f"min={tgt.min():.4f} max={tgt.max():.4f}")

def _check_files_exist(df: pd.DataFrame, col: str = "mixed_path", limit: int = 300):
    # Lightweight existence check
    sample = df[col].head(limit).tolist()
    missing = [p for p in sample if not Path(p).exists()]
    if missing:
        raise FileNotFoundError(f"{len(missing)} missing files in first {len(sample)} rows (showing 3): {missing[:3]}")
    print(f"✓ {col} existence check passed on first {len(sample)} rows.")

# --- load splits ---
df_train = _load_split(cfg.train_csv)
df_val   = _load_split(cfg.val_csv)
df_test  = _load_split(cfg.test_csv)

# --- basic stats ---
_quick_stats("train", df_train)
_quick_stats("  val", df_val)
_quick_stats(" test", df_test)

# --- light file checks (limit to keep fast in Jupyter) ---
_check_files_exist(df_train)
_check_files_exist(df_val)
_check_files_exist(df_test)

# --- leakage check if image_id is available ---
if all(("image_id" in d.columns) for d in (df_train, df_val, df_test)):
    s_tr, s_va, s_te = set(df_train.image_id), set(df_val.image_id), set(df_test.image_id)
    inter_tr_va = s_tr & s_va
    inter_tr_te = s_tr & s_te
    inter_va_te = s_va & s_te
    print("Leakage check by image_id:",
          f"train∩val={len(inter_tr_va)} | train∩test={len(inter_tr_te)} | val∩test={len(inter_va_te)}")
    if inter_tr_va or inter_tr_te or inter_va_te:
        raise AssertionError("Image leakage detected between splits.")
else:
    print("Leakage check skipped (image_id column not found in all splits).")

# keep DataFrames in cfg for downstream cells
cfg.df_train = df_train
cfg.df_val   = df_val
cfg.df_test  = df_test


train | n=672 | target mean=0.5000 std=0.3416 min=0.0000 max=1.0000
  val | n=336 | target mean=0.5000 std=0.3416 min=0.0000 max=1.0000
 test | n=336 | target mean=0.5000 std=0.3416 min=0.0000 max=1.0000
✓ mixed_path existence check passed on first 300 rows.
✓ mixed_path existence check passed on first 300 rows.
✓ mixed_path existence check passed on first 300 rows.
Leakage check by image_id: train∩val=0 | train∩test=0 | val∩test=0


In [17]:
# CELL 4 — Transforms (pair-synchronized for orig/mixed; no resize; 20x20 native)

NORM_MEAN = (0.5, 0.5, 0.5)
NORM_STD  = (0.5, 0.5, 0.5)

def make_pair_transform(train: bool) -> A.Compose:
    ops = []
    if train:
        ops += [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
        ]
    ops += [A.Normalize(mean=NORM_MEAN, std=NORM_STD, max_pixel_value=255.0)]
    # Synchronized transforms for two correlated images
    return A.Compose(ops, additional_targets={"mixed": "image"})

train_tf = make_pair_transform(train=True)
val_tf   = make_pair_transform(train=False)
test_tf  = val_tf

def apply_pair_tf(tf: A.Compose, orig_img, mixed_img, heatmap=None):
    """
    Apply identical geometry/photometric ops to orig/mixed (and optionally heatmap).
    Works with uint8 H×W×C images.
    Returns numpy arrays after normalization.
    """
    if heatmap is None:
        out = tf(image=orig_img, mixed=mixed_img)
        return out["image"], out["mixed"], None
    # if heatmap sync is needed:
    tf3 = A.Compose(tf.transforms, additional_targets={"mixed": "image", "heatmap": "image"})
    out = tf3(image=orig_img, mixed=mixed_img, heatmap=heatmap)
    return out["image"], out["mixed"], out["heatmap"]

cfg.train_tf = train_tf
cfg.val_tf   = val_tf
cfg.test_tf  = test_tf

print("Transforms ready.",
      f"train_ops={len(train_tf.transforms)} | val_ops={len(val_tf.transforms)}")


Transforms ready. train_ops=4 | val_ops=1


In [18]:
# CELL 5 — Dataset (orig, mixed, delta [+ optional gradients], y, w)
# Loads 20×20 RGB patches, applies pair-synced transforms, builds delta and optional Sobel gradients.

def _imread_rgb(path: str) -> np.ndarray:
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(f"Failed to read image: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img  # uint8 H×W×3

def _infer_orig_path(mixed_path: str) -> str:
    p = Path(mixed_path)
    # 1) swap 'mixed_patches' -> 'orig_patches'
    try1 = Path(str(p).replace(str(cfg.mixed_dir), str(cfg.orig_dir)))
    if try1.exists():
        return str(try1)
    # 2) same filename under orig_dir
    try2 = cfg.orig_dir / p.name
    if try2.exists():
        return str(try2)
    # 3) if CSV already has orig_path column we will override in __getitem__
    return str(try2)  # last guess; existence checked later

def _to_tensor(img_f32_hwc: np.ndarray) -> torch.Tensor:
    # input float32 H×W×C in [-1,1] after Normalize; output C×H×W
    t = torch.from_numpy(img_f32_hwc).permute(2, 0, 1).contiguous()
    return t

def _sobel_mag(gray: np.ndarray) -> np.ndarray:
    # gray float32 H×W in [-1,1] → magnitude in [0,1]
    gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
    mag = np.sqrt(gx * gx + gy * gy)
    # normalize robustly
    m = mag.max()
    if m > 0:
        mag = mag / m
    return mag

def _to_gray(img_f32_hwc: np.ndarray) -> np.ndarray:
    # img in [-1,1]; convert to gray with linear RGB weights
    r, g, b = img_f32_hwc[..., 0], img_f32_hwc[..., 1], img_f32_hwc[..., 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray.astype(np.float32)

class BlendsPatchDataset(Dataset):
    def __init__(
        self,
        df,
        orig_dir: Path,
        mixed_dir: Path,
        tf,                       # Albumentations compose (pair-synced)
        expect_hw=(20, 20),
        use_gradient_channels: bool = True,
        weight_col_candidates=("w", "weight", "h_c"),
    ):
        self.df = df.reset_index(drop=True)
        self.orig_dir = Path(orig_dir)
        self.mixed_dir = Path(mixed_dir)
        self.tf = tf
        self.expect_hw = expect_hw
        self.use_grads = bool(use_gradient_channels and getattr(cfg, "push99", SimpleNamespace(use_gradient_channels=False)).use_gradient_channels)
        self.weight_cols = [c for c in weight_col_candidates if c in self.df.columns]
        # precompute simple min/max if we will normalize a weight column not in [0,1]
        self._norm_stats = {}
        for c in self.weight_cols:
            v = self.df[c].astype("float32").to_numpy()
            vmin, vmax = float(np.nanmin(v)), float(np.nanmax(v))
            self._norm_stats[c] = (vmin, vmax)

    def __len__(self):
        return len(self.df)

    def _resolve_paths(self, row) -> tuple[str, str]:
        mixed_path = str(row["mixed_path"])
        if "orig_path" in row and isinstance(row["orig_path"], str) and len(row["orig_path"]) > 0:
            orig_path = row["orig_path"]
        else:
            orig_path = _infer_orig_path(mixed_path)
        if not Path(mixed_path).exists():
            raise FileNotFoundError(f"mixed_path not found: {mixed_path}")
        if not Path(orig_path).exists():
            raise FileNotFoundError(f"orig_path not found: {orig_path}")
        return orig_path, mixed_path

    def _weight_from_row(self, row) -> float:
        if not self.weight_cols:
            return 1.0
        c = self.weight_cols[0]
        val = float(row[c])
        vmin, vmax = self._norm_stats[c]
        if vmax > vmin:
            val = (val - vmin) / (vmax - vmin)
        return float(np.clip(val, 0.0, 1.0))

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        y = float(row["target"])
        w = self._weight_from_row(row) if getattr(cfg, "push99", SimpleNamespace(use_heatmap_weighting=False)).use_heatmap_weighting else 1.0

        orig_path, mixed_path = self._resolve_paths(row)
        orig = _imread_rgb(orig_path)
        mixed = _imread_rgb(mixed_path)

        # sanity: expect exact 20×20
        H, W = orig.shape[:2]
        if (H, W) != self.expect_hw:
            raise AssertionError(f"Unexpected orig patch size {orig.shape} at {orig_path}")
        if mixed.shape[:2] != (H, W):
            raise AssertionError(f"Size mismatch orig {orig.shape} vs mixed {mixed.shape}")

        # pair-synchronized transforms → float32 H×W×C in [-1,1] after Normalize
        orig_f, mixed_f, _ = apply_pair_tf(self.tf, orig, mixed, heatmap=None)

        # delta after identical transforms
        delta_f = (mixed_f - orig_f).astype(np.float32)

        # optional gradient channels
        grads_t = None
        if self.use_grads:
            g_orig = _sobel_mag(_to_gray(orig_f))
            g_mixed = _sobel_mag(_to_gray(mixed_f))
            g_delta = _sobel_mag(_to_gray(delta_f))
            # stack to C×H×W
            g_stack = np.stack([g_orig, g_mixed, g_delta], axis=-1).astype(np.float32)
            grads_t = _to_tensor(g_stack)

        # to tensors C×H×W
        orig_t  = _to_tensor(orig_f.astype(np.float32))
        mixed_t = _to_tensor(mixed_f.astype(np.float32))
        delta_t = _to_tensor(delta_f)

        sample = {
            "orig": orig_t,
            "mixed": mixed_t,
            "delta": delta_t,
            "y": torch.tensor([y], dtype=torch.float32),
            "w": torch.tensor([w], dtype=torch.float32),
        }
        if grads_t is not None:
            sample["grads"] = grads_t  # 3×H×W

        # optional metadata for debugging
        if "image_id" in self.df.columns:
            sample["image_id"] = row["image_id"]

        return sample


In [19]:
# CELL 6 — DataSets & DataLoaders
# Build train/val/test datasets. Duplicate train set via stochastic augmentation.
# Jupyter-friendly: num_workers=0. Prints quick loader stats.

from torch.utils.data import ConcatDataset, DataLoader

# --- datasets ---
train_parts = [BlendsPatchDataset(
    cfg.df_train, cfg.orig_dir, cfg.mixed_dir, tf=cfg.train_tf,
    use_gradient_channels=True
) for _ in range(1 + int(cfg.aug_dup))]  # base + aug duplicates

train_ds = ConcatDataset(train_parts)
val_ds   = BlendsPatchDataset(cfg.df_val,  cfg.orig_dir, cfg.mixed_dir, tf=cfg.val_tf,  use_gradient_channels=True)
test_ds  = BlendsPatchDataset(cfg.df_test, cfg.orig_dir, cfg.mixed_dir, tf=cfg.test_tf, use_gradient_channels=True)

# --- loaders ---
pin = bool(cfg.device == "cuda")
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=pin, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=pin)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=pin)

# --- debug peek ---
def _peek(loader, name):
    try:
        batch = next(iter(loader))
        o, m, d = batch["orig"], batch["mixed"], batch["delta"]
        g = batch.get("grads", None)
        y, w = batch["y"], batch["w"]
        print(f"{name}: n_batches≈{len(loader)} | batch={o.shape[0]} "
              f"| orig={tuple(o.shape)} mixed={tuple(m.shape)} delta={tuple(d.shape)} "
              f"| grads={'yes '+str(tuple(g.shape)) if g is not None else 'no'} "
              f"| y={tuple(y.shape)} w={tuple(w.shape)}")
    except StopIteration:
        print(f"{name}: empty.")

_peek(train_loader, "train_loader")
_peek(val_loader,   "val_loader")
_peek(test_loader,  "test_loader")

# expose in cfg
cfg.train_loader = train_loader
cfg.val_loader   = val_loader
cfg.test_loader  = test_loader


train_loader: n_batches≈15 | batch=128 | orig=(128, 3, 20, 20) mixed=(128, 3, 20, 20) delta=(128, 3, 20, 20) | grads=yes (128, 3, 20, 20) | y=(128, 1) w=(128, 1)
val_loader: n_batches≈3 | batch=128 | orig=(128, 3, 20, 20) mixed=(128, 3, 20, 20) delta=(128, 3, 20, 20) | grads=yes (128, 3, 20, 20) | y=(128, 1) w=(128, 1)
test_loader: n_batches≈3 | batch=128 | orig=(128, 3, 20, 20) mixed=(128, 3, 20, 20) delta=(128, 3, 20, 20) | grads=yes (128, 3, 20, 20) | y=(128, 1) w=(128, 1)


In [20]:
# CELL 7 — Model (20×20 dual-branch with optional gradient channels)
# Two CNN branches for ORIG and MIXED, a light path for DELTA, optional path for GRADS.
# Fusion → Conv → GAP → MLP regression head.

from types import SimpleNamespace

def conv_block(in_ch, out_ch, k=3, s=1, p=1):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=True),
        nn.ReLU(inplace=True)
    )

class SmallBackbone(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.b1 = conv_block(in_ch, 32)
        self.b2 = conv_block(32, 32)
        self.pool = nn.MaxPool2d(2)  # 20→10
        self.b3 = conv_block(32, 64)
    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.pool(x)
        x = self.b3(x)
        return x  # [B,64,10,10]

class LightPath(nn.Module):
    def __init__(self, in_ch, out_ch=16):
        super().__init__()
        self.net = conv_block(in_ch, out_ch)
    def forward(self, x):
        return self.net(x)  # keep spatial 20×20 or 10×10 per input

class DualBranchIQAModel(nn.Module):
    def __init__(self, use_grads: bool = False, in_ch_rgb: int = 3, in_ch_delta: int = 3, in_ch_grads: int = 3):
        super().__init__()
        # main branches (operate on 20→10 spatial)
        self.orig_backbone  = SmallBackbone(in_ch=in_ch_rgb)
        self.mixed_backbone = SmallBackbone(in_ch=in_ch_rgb)
        # delta path (keep 20, then down to match 10×10)
        self.delta_path = LightPath(in_ch=in_ch_delta, out_ch=16)
        self.delta_pool = nn.MaxPool2d(2)  # 20→10

        self.use_grads = bool(use_grads)
        if self.use_grads:
            self.grads_path = LightPath(in_ch=in_ch_grads, out_ch=16)
            self.grads_pool = nn.MaxPool2d(2)

        # fusion
        fuse_in = 64 + 64 + 16 + (16 if self.use_grads else 0)  # [orig,mixed,delta,(grads)]
        self.fuse = conv_block(fuse_in, 128)
        self.gap  = nn.AdaptiveAvgPool2d(1)  # → [B,128,1,1]

        # head
        self.head = nn.Sequential(
            nn.Flatten(),                 # [B,128]
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, orig, mixed, delta, grads=None):
        fo = self.orig_backbone(orig)     # [B,64,10,10]
        fm = self.mixed_backbone(mixed)   # [B,64,10,10]

        fd = self.delta_path(delta)       # [B,16,20,20]
        fd = self.delta_pool(fd)          # [B,16,10,10]

        parts = [fo, fm, fd]

        if self.use_grads:
            if grads is None:
                raise ValueError("Model configured with use_grads=True but no grads tensor was provided.")
            fg = self.grads_path(grads)   # [B,16,20,20]
            fg = self.grads_pool(fg)      # [B,16,10,10]
            parts.append(fg)

        x = torch.cat(parts, dim=1)       # [B, fuse_in, 10,10]
        x = self.fuse(x)                  # [B,128,10,10]
        x = self.gap(x)                   # [B,128,1,1]
        out = torch.sigmoid(self.head(x))               # [B,1]
        return out

# instantiate
push99 = getattr(cfg, "push99", SimpleNamespace())
use_grads = bool(getattr(push99, "use_gradient_channels", False))

model = DualBranchIQAModel(use_grads=use_grads).to(cfg.device)


def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print(model.__class__.__name__, "| params:", count_params(model))


DualBranchIQAModel | params: 250945


In [21]:
# Cell 8 — Visualize DualBranchIQAModel (graph + architecture + layer table)
# Assumes `model = DualBranchIQAModel(...).to(cfg.device)` was created in Cell 7.

import torch
from pathlib import Path

model.eval()
dev = next(model.parameters()).device  # cpu or cuda:0

# ---------- dummy inputs on the same device ----------
B, C, H, W = 1, 3, 20, 20
orig  = torch.randn(B, C, H, W, device=dev)
mixed = torch.randn(B, C, H, W, device=dev)
delta = torch.abs(mixed - orig)

use_grads = getattr(model, "use_grads", False)
if use_grads:
    grads = torch.randn(B, C, H, W, device=dev)
    input_args = (orig, mixed, delta, grads)
else:
    input_args = (orig, mixed, delta)

# output dirs
Path("reports").mkdir(exist_ok=True)

# ---------- C) Layer table (torchinfo) ----------
try:
    from torchinfo import summary
    table = summary(model,
                    input_data=input_args,
                    depth=5,
                    col_names=("input_size","output_size","kernel_size","num_params","mult_adds"),
                    row_settings=("var_names","depth","param_numbers"))
    with open("reports/model_summary.txt", "w") as f:
        f.write(str(table))
    with open("reports/model_summary.md", "w") as f:
        f.write("```text\n"); f.write(str(table)); f.write("\n```")
    print("Saved: reports/model_summary.txt, reports/model_summary.md")
except Exception as e:
    print("[torchinfo] skipped:", e)

# ---------- A) Computation graph (torchviz → SVG/PNG) ----------
try:
    from torchviz import make_dot
    with torch.no_grad():
        y = model(*input_args)
    dot = make_dot(y, params=dict(model.named_parameters()),
                   show_attrs=False, show_saved=False)
    dot.format = "svg"
    dot.render("reports/locjnd_regression_graph", cleanup=True)  # -> .svg
    dot_png = make_dot(y, params=dict(model.named_parameters()),
                       show_attrs=False, show_saved=False)
    dot_png.format = "png"
    dot_png.render("reports/locjnd_regression_graph", cleanup=True)  # -> .png
    print("Saved: reports/locjnd_regression_graph.svg, reports/locjnd_regression_graph.png")
except Exception as e:
    print("[torchviz] skipped:", e)

# ---------- B) Architecture diagram (torchview → PNG) ----------
try:
    from torchview import draw_graph
    g = draw_graph(model,
                   input_data=input_args,
                   expand_nested=True,
                   graph_name="LocJND_Regression",
                   save_graph=True,
                   filename="locjnd_architecture",
                   directory="reports",
                   format="png")
    print("Saved: reports/locjnd_architecture.png")
except Exception as e:
    print("[torchview] skipped:", e)


[torchinfo] skipped: 'param_numbers' is not a valid RowSettings
[torchviz] skipped: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH
[torchview] skipped: Failed to run torchgraph see error message


In [22]:
# CELL 8 — Losses & Metrics

from types import SimpleNamespace

EPS = 1e-8
_rank = nn.MarginRankingLoss(margin=0.05)

def _flat(x):
    return x.view(-1)

def rmse(yhat, y):
    e = (_flat(yhat) - _flat(y)) ** 2
    return torch.sqrt(e.mean())

def plcc(yhat, y):
    x = _flat(yhat); t = _flat(y)
    x = x - x.mean()
    t = t - t.mean()
    vx = x.pow(2).mean()
    vt = t.pow(2).mean()
    denom = torch.sqrt(vx * vt) + EPS
    return (x * t).mean() / denom

def pearson_loss(yhat, y):
    return 1.0 - plcc(yhat, y)

def weighted_mse(yhat, y, w=None, use_weights=False):
    e = (_flat(yhat) - _flat(y)) ** 2
    if use_weights and w is not None:
        w = _flat(w)
        return (w * e).sum() / (w.sum() + EPS)
    return e.mean()

def ranking_loss(yhat, y):
    x = _flat(yhat); t = _flat(y)
    n = x.numel()
    if n < 2:
        return x.new_tensor(0.0)

    perm = torch.randperm(n, device=x.device)
    half = n // 2
    i = perm[:half]
    j = perm[half:half + half]
    if j.numel() < i.numel():
        i = i[:j.numel()]

    s = torch.sign(t[i] - t[j])  # +1 if y_i > y_j, -1 if y_i < y_j
    m = s != 0
    if m.sum() == 0:
        return x.new_tensor(0.0)
    return _rank(x[i][m], x[j][m], s[m])

push99 = getattr(cfg, "push99", SimpleNamespace(lambda_rank=0.1, lambda_pearson=0.1, use_heatmap_weighting=False))

class LossBundle:
    def __init__(self, push):
        self.lambda_rank = float(getattr(push, "lambda_rank", 0.2))
        self.lambda_pearson = float(getattr(push, "lambda_pearson", 0.0))
        self.use_w = bool(getattr(push, "use_heatmap_weighting", False))

    def __call__(self, yhat, y, w=None):
        mse = weighted_mse(yhat, y, w, self.use_w)
        rnk = ranking_loss(yhat, y)
        prc = pearson_loss(yhat, y) if self.lambda_pearson > 0 else yhat.new_tensor(0.0)
        total = mse + self.lambda_rank * rnk + self.lambda_pearson * prc
        parts = {"mse": mse.detach(), "rank": rnk.detach(), "pearson": prc.detach()}
        return total, parts

loss_bundle = LossBundle(push99)

def compute_metrics(yhat, y):
    return {"rmse": rmse(yhat, y).item(), "plcc": plcc(yhat, y).item()}


In [23]:
# CELL 9 — Optimizer, Scheduler (warmup+cosine), AMP

import math
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

# --- optimizer ---
optimizer = Adam(model.parameters(), lr=cfg.lr_init, weight_decay=cfg.weight_decay)

# --- scheduler: linear warmup → cosine decay ---
warmup_epochs = getattr(cfg, "warmup_epochs", 3)
total_epochs  = int(cfg.epochs)

def _lr_lambda(epoch: int):
    if warmup_epochs > 0 and epoch < warmup_epochs:
        return float(epoch + 1) / float(max(1, warmup_epochs))
    # cosine on the remaining epochs (inclusive of last epoch index)
    t = (epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
    return 0.5 * (1.0 + math.cos(math.pi * min(1.0, max(0.0, t))))

scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda)

# --- AMP scaler ---
use_amp = (cfg.device == "cuda")
scaler = torch.amp.GradScaler(enabled=use_amp)

def current_lr(optim=optimizer) -> float:
    return float(optim.param_groups[0]["lr"])

print(f"Optimizer: Adam | init_lr={cfg.lr_init} | weight_decay={cfg.weight_decay}")
print(f"Scheduler: warmup={warmup_epochs} epochs → cosine over {total_epochs - warmup_epochs}")
print(f"AMP enabled: {use_amp}")


Optimizer: Adam | init_lr=0.0001 | weight_decay=0.0001
Scheduler: warmup=3 epochs → cosine over 37
AMP enabled: True


In [24]:
# CELL 10 — Train/Val Loop
# Prints per-epoch progress.

import time

def _move_to_device(batch, device):
    orig  = batch["orig"].to(device, non_blocking=True)
    mixed = batch["mixed"].to(device, non_blocking=True)
    delta = batch["delta"].to(device, non_blocking=True)
    grads = batch.get("grads", None)
    if grads is not None:
        grads = grads.to(device, non_blocking=True)
    y = batch["y"].to(device)
    w = batch["w"].to(device)
    return orig, mixed, delta, grads, y, w

def train_one_epoch(epoch_idx: int):
    model.train()
    total_loss = []
    yh_all, y_all = [], []

    for batch in cfg.train_loader:
        orig, mixed, delta, grads, y, w = _move_to_device(batch, cfg.device)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=use_amp):
            yhat = model(orig, mixed, delta, grads)
            loss, _ = loss_bundle(yhat, y, w=w)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss.append(loss.detach().item())
        yh_all.append(yhat.detach())
        y_all.append(y.detach())

    scheduler.step()

    yhat_cat = torch.cat(yh_all, 0)
    y_cat    = torch.cat(y_all, 0)
    metrics  = compute_metrics(yhat_cat, y_cat)

    return float(np.mean(total_loss)), metrics

@torch.no_grad()
def evaluate(loader):
    model.eval()
    yh_all, y_all = [], []
    losses = []

    for batch in loader:
        orig, mixed, delta, grads, y, w = _move_to_device(batch, cfg.device)
        yhat = model(orig, mixed, delta, grads)
        l, _ = loss_bundle(yhat, y, w=w)
        losses.append(l.detach().item())
        yh_all.append(yhat.detach())
        y_all.append(y.detach())

    yhat_cat = torch.cat(yh_all, 0)
    y_cat    = torch.cat(y_all, 0)
    metrics  = compute_metrics(yhat_cat, y_cat)
    return float(np.mean(losses)), metrics

best_val_rmse = float("inf")
best_epoch = -1

for epoch in range(int(cfg.epochs)):
    t0 = time.time()
    train_loss, train_m = train_one_epoch(epoch)
    val_loss,   val_m   = evaluate(cfg.val_loader)

    lr_now = current_lr(optimizer)
    print(f"[Epoch {epoch+1:03d}] "
          f"train_loss={train_loss:.4f} | train_PLCC={train_m['plcc']:.3f} "
          f"| val_RMSE={val_m['rmse']:.3f} val_PLCC={val_m['plcc']:.3f} "
          f"| lr={lr_now:.3e}")

    if val_m["rmse"] < best_val_rmse:
        best_val_rmse = val_m["rmse"]
        best_epoch = epoch + 1

print(f"Best val_RMSE={best_val_rmse:.4f} at epoch {best_epoch}")


[Epoch 001] train_loss=0.2391 | train_PLCC=-0.026 | val_RMSE=0.343 val_PLCC=-0.046 | lr=6.667e-05
[Epoch 002] train_loss=0.2124 | train_PLCC=0.130 | val_RMSE=0.338 val_PLCC=0.250 | lr=1.000e-04
[Epoch 003] train_loss=0.1884 | train_PLCC=0.286 | val_RMSE=0.330 val_PLCC=0.306 | lr=1.000e-04
[Epoch 004] train_loss=0.1646 | train_PLCC=0.449 | val_RMSE=0.328 val_PLCC=0.544 | lr=9.982e-05
[Epoch 005] train_loss=0.1458 | train_PLCC=0.585 | val_RMSE=0.316 val_PLCC=0.635 | lr=9.928e-05
[Epoch 006] train_loss=0.1332 | train_PLCC=0.654 | val_RMSE=0.307 val_PLCC=0.700 | lr=9.839e-05
[Epoch 007] train_loss=0.1248 | train_PLCC=0.710 | val_RMSE=0.303 val_PLCC=0.754 | lr=9.714e-05
[Epoch 008] train_loss=0.1177 | train_PLCC=0.744 | val_RMSE=0.297 val_PLCC=0.775 | lr=9.556e-05
[Epoch 009] train_loss=0.1103 | train_PLCC=0.771 | val_RMSE=0.289 val_PLCC=0.802 | lr=9.365e-05
[Epoch 010] train_loss=0.1074 | train_PLCC=0.759 | val_RMSE=0.283 val_PLCC=0.791 | lr=9.143e-05
[Epoch 011] train_loss=0.1029 | train_

In [25]:
# CELL 11 — Checkpointing + Early Stopping (self-contained training wrapper)
# Saves best (by min val_RMSE) and last checkpoints. Early-stops on patience.

import time
from types import SimpleNamespace

ckpt_dir = cfg.checkpoints_dir / cfg.run_id
ckpt_dir.mkdir(parents=True, exist_ok=True)
best_ckpt_path = ckpt_dir / "best.ckpt"
last_ckpt_path = ckpt_dir / "last.ckpt"

def save_checkpoint(path: Path, epoch: int, best_val_rmse: float):
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict() if 'scaler' in globals() else None,
        "best_val_rmse": best_val_rmse,
        "run_id": cfg.run_id,
    }, path)

class EarlyStopper:
    def __init__(self, patience: int = 10, min_delta: float = 1e-6):
        self.patience = patience
        self.min_delta = min_delta
        self.best = float("inf")
        self.count = 0
    def step(self, val_rmse: float) -> bool:
        if val_rmse < self.best - self.min_delta:
            self.best = val_rmse
            self.count = 0
            return False
        self.count += 1
        return self.count >= self.patience  # True => stop

def train_with_ckpt(patience: int = 10):
    stopper = EarlyStopper(patience=patience)
    best_val = float("inf")
    best_epoch = -1

    for epoch in range(int(cfg.epochs)):
        t0 = time.time()
        train_loss, train_m = train_one_epoch(epoch)
        val_loss,   val_m   = evaluate(cfg.val_loader)

        # scheduler already stepped inside train_one_epoch; if moved out, keep scheduler.step() here.
        lr_now = current_lr(optimizer)
        print(f"[Epoch {epoch+1:03d}] "
              f"train_loss={train_loss:.4f} | train_PLCC={train_m['plcc']:.3f} "
              f"| val_RMSE={val_m['rmse']:.3f} val_PLCC={val_m['plcc']:.3f} "
              f"| lr={lr_now:.3e} | {time.time()-t0:.1f}s")

        # always save 'last'
        save_checkpoint(last_ckpt_path, epoch+1, best_val)

        # update 'best'
        if val_m["rmse"] < best_val:
            best_val = val_m["rmse"]
            best_epoch = epoch + 1
            save_checkpoint(best_ckpt_path, best_epoch, best_val)
            print(f"  ↳ saved best checkpoint @ epoch {best_epoch} (val_RMSE={best_val:.4f})")

        # early stop check
        if stopper.step(val_m["rmse"]):
            print(f"Early stopping after {epoch+1} epochs (no improvement for {stopper.patience}).")
            break

    print(f"Best val_RMSE={best_val:.4f} at epoch {best_epoch}")
    print(f"best: {best_ckpt_path}\nlast: {last_ckpt_path}")
    return SimpleNamespace(best_val_rmse=best_val, best_epoch=best_epoch,
                           best_path=best_ckpt_path, last_path=last_ckpt_path)

# run training with checkpoints + early stop
train_summary = train_with_ckpt(patience=10)


[Epoch 001] train_loss=0.0417 | train_PLCC=0.895 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
  ↳ saved best checkpoint @ epoch 1 (val_RMSE=0.2138)
[Epoch 002] train_loss=0.0425 | train_PLCC=0.892 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 003] train_loss=0.0433 | train_PLCC=0.893 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.6s
[Epoch 004] train_loss=0.0420 | train_PLCC=0.895 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 005] train_loss=0.0415 | train_PLCC=0.895 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 006] train_loss=0.0431 | train_PLCC=0.892 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 007] train_loss=0.0421 | train_PLCC=0.896 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 008] train_loss=0.0418 | train_PLCC=0.892 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 009] train_loss=0.0426 | train_PLCC=0.893 | val_RMSE=0.214 val_PLCC=0.816 | lr=0.000e+00 | 2.5s
[Epoch 010] train_l

In [26]:
# CELL 12 — Restore best, Evaluate on VAL/TEST with TTA, optional linear calibration

import numpy as np
import torch

@torch.no_grad()
def _predict_loader(loader, tta: bool):
    model.eval()
    yh, yt = [], []
    for batch in loader:
        orig, mixed, delta, grads, y, w = _move_to_device(batch, cfg.device)

        def _fwd(o, m, d, g):  # single pass
            return model(o, m, d, g)

        # base
        y0 = _fwd(orig, mixed, delta, grads)

        if not tta:
            yh.append(y0.cpu())
            yt.append(y.cpu())
            continue

        # TTA: hflip, vflip, hvflip (flip inputs only; scalar output)
        y_list = [y0]
        for dims in [(3,), (2,), (2, 3)]:
            o = torch.flip(orig,  dims)
            m = torch.flip(mixed, dims)
            d = torch.flip(delta, dims)
            g = torch.flip(grads, dims) if grads is not None else None
            y_list.append(_fwd(o, m, d, g))
        y_avg = torch.stack(y_list, dim=0).mean(dim=0)

        yh.append(y_avg.cpu())
        yt.append(y.cpu())

    yhat = torch.cat(yh, 0).view(-1).numpy()
    ytrue = torch.cat(yt, 0).view(-1).numpy()
    return yhat, ytrue

def _metrics_np(yhat, ytrue):
    e = yhat - ytrue
    rmse_ = float(np.sqrt(np.mean(e**2)))
    x = yhat - yhat.mean()
    t = ytrue - ytrue.mean()
    denom = np.sqrt((x**2).mean() * (t**2).mean()) + 1e-8
    plcc_ = float((x * t).mean() / denom)
    return rmse_, plcc_

# --- load best checkpoint ---
ckpt_dir = cfg.checkpoints_dir / cfg.run_id
state = torch.load(best_ckpt_path, map_location=cfg.device)
res = model.load_state_dict(state["model"], strict=False)
print(f"Loaded best checkpoint: {best_ckpt_path}")
print("missing:", len(res.missing_keys), "unexpected:", len(res.unexpected_keys))


# --- toggles ---
use_tta = bool(getattr(getattr(cfg, "push99", type("X",(object,),{})()), "use_tta", True))
use_linear_calib = True  # small boost without heavy code; fits in this cell

# --- VAL (no calib) ---
yhat_v, y_v = _predict_loader(cfg.val_loader, tta=use_tta)
rmse_v, plcc_v = _metrics_np(yhat_v, y_v)
print(f"VAL  | TTA={use_tta} | RMSE={rmse_v:.4f} | PLCC={plcc_v:.4f}")

# --- optional linear calibration on VAL, then apply to TEST ---
if use_linear_calib:
    # ŷ* = a·ŷ + b (least squares on VAL)
    A = np.vstack([yhat_v, np.ones_like(yhat_v)]).T
    a, b = np.linalg.lstsq(A, y_v, rcond=None)[0]
    def _calib(x): return a * x + b
    print(f"Calib(linear): a={a:.4f}, b={b:.4f}")
else:
    _calib = lambda x: x

# --- TEST ---
yhat_t, y_t = _predict_loader(cfg.test_loader, tta=use_tta)
rmse_t, plcc_t = _metrics_np(yhat_t, y_t)
print(f"TEST | raw   | RMSE={rmse_t:.4f} | PLCC={plcc_t:.4f}")

if use_linear_calib:
    yhat_t_c = _calib(yhat_t)
    rmse_t_c, plcc_t_c = _metrics_np(yhat_t_c, y_t)
    print(f"TEST | calib | RMSE={rmse_t_c:.4f} | PLCC={plcc_t_c:.4f}")


Loaded best checkpoint: checkpoints\20250916-202151\best.ckpt
missing: 0 unexpected: 0
VAL  | TTA=True | RMSE=0.2148 | PLCC=0.8129
Calib(linear): a=1.2423, b=-0.1961
TEST | raw   | RMSE=0.2287 | PLCC=0.7954
TEST | calib | RMSE=0.2579 | PLCC=0.7954


In [30]:
"""
Export trained model to ONNX and TorchScript for Netron.
Handles both variants: with or without a 'grads' input.
Assumes `model` is defined and best weights are loaded.
"""
from pathlib import Path
import inspect
import torch

model.eval()
dev = next(model.parameters()).device

ART = Path("artifacts"); ART.mkdir(parents=True, exist_ok=True)
onnx_path = ART / "model.onnx"
ts_path   = ART / "model.ts.pt"

# Detect whether the model expects a 'grads' tensor
sig = inspect.signature(model.forward)
expects_grads = ("grads" in sig.parameters) or bool(getattr(model, "use_grads", False))

# Build dummy inputs: (N,C,H,W) = (1,3,20,20)
B, C, H, W = 1, 3, 20, 20
orig  = torch.randn(B, C, H, W, device=dev)
mixed = torch.randn(B, C, H, W, device=dev)
delta = torch.randn(B, C, H, W, device=dev)

if expects_grads:
    grads = torch.randn(B, C, H, W, device=dev)
    dummy_args  = (orig, mixed, delta, grads)
    input_names = ["orig", "mixed", "delta", "grads"]
else:
    dummy_args  = (orig, mixed, delta)
    input_names = ["orig", "mixed", "delta"]

# Sanity forward
with torch.no_grad():
    _ = model(*dummy_args)

# ONNX export
torch.onnx.export(
    model,
    dummy_args,
    onnx_path.as_posix(),
    input_names=input_names,
    output_names=["logits"],
    opset_version=17,
    do_constant_folding=True,
    dynamic_axes={name: {0: "batch"} for name in input_names} | {"logits": {0: "batch"}},
)
print("Saved ONNX:", onnx_path.resolve())

# TorchScript (trace)
ts = torch.jit.trace(model, dummy_args)
ts.save(ts_path.as_posix())
print("Saved TorchScript:", ts_path.resolve())


Saved ONNX: C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\artifacts\model.onnx
Saved TorchScript: C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\artifacts\model.ts.pt


In [31]:
"""
Open Netron viewer
If running locally, this launches http://localhost:8080
"""
try:
    import netron
    print("Launching Netron on port 8080...")
    netron.start(onnx_path.as_posix(), address="127.0.0.1", browse=True)
except Exception as e:
    print("Netron not available. Install with: pip install netron")
    print("Then open the exported file manually:", onnx_path.resolve())

Launching Netron on port 8080...
Netron not available. Install with: pip install netron
Then open the exported file manually: C:\Users\YehonatanR\OneDrive\Documents\HIT Computer Science\3rd\PixelQuality\PythonProject\artifacts\model.onnx
