# FORSMITH Roof Image Classifier — ViT (DINOv2) Fine‑Tuning

**Objective:** Train a model that takes a roof image and predicts the correct `observation_id` (class) from the `forsmith_roof_labels.json` taxonomy.  
**Artifacts:** `model_best.pt`, `label_map.json`, `calibration.json`, `metrics.json`, ONNX export (optional), and an inference wrapper with optional **sheet-aware** masking.

> Dataset: 1,616 images. CSV columns required: `image_file`, `label`, `observation_id`. The filename contains `report_id` as `<report_id>_pageXX_imgY.png`, enabling **GroupKFold** by report to avoid leakage.


### Section 0 - Dependency Installs
Ensures all required Python packages are available inside the Vertex AI Workbench kernel. Missing dependencies (Albumentations, timm, OpenCV, scikit-learn, torchmetrics, and the Google Cloud Storage client) are installed into the local `_deps` directory so future cells can import them without elevated permissions.


In [1]:
# %% [markdown]
# ### Environment setup (Pinned versions for CUDA 12.1)
# This cell ensures `torch`, `torchvision`, `torchaudio`, and `numpy` are mutually compatible.
# - torch/vision/audio = 2.5.1 / 0.20.1 / 2.5.1 (cu121 wheels)
# - numpy < 2.2 (e.g., 2.1.4) to satisfy numba 0.61 and ydata-profiling.
# If your runtime has a different CUDA minor (e.g., 12.2), change `cu121` to `cu122`.

# %%
# 0) Inspect GPU


# 1) Remove mismatched wheels
!pip -q uninstall -y numpy || true

# 2) Install compatible triplet and safe numpy
!pip -q install --index-url https://download.pytorch.org/whl/cu121   torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
!pip -q install "numpy<2.2"

# 3) Print versions
import torch, torchvision, torchaudio, numpy as _np
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("torchaudio:", torchaudio.__version__)
print("numpy:", _np.__version__)
print("cuda available:", torch.cuda.is_available())

# (Optional) If you rely on torchmetrics, install a version known to pair with torch 2.5.x
# !pip -q install torchmetrics==1.4.0


torch: 2.5.1+cu121
torchvision: 0.20.1+cu121
torchaudio: 2.5.1+cu121
numpy: 2.1.2
cuda available: False


In [2]:
# %%
# Quick sanity check for CUDA tensor ops
import torch
x = torch.randn(1024, 1024, device='cuda' if torch.cuda.is_available() else 'cpu')
y = x @ x.T
print("OK - matmul result shape:", y.shape, "on device:", y.device)


OK - matmul result shape: torch.Size([1024, 1024]) on device: cpu


### Section 1 - Configuration & Paths
Defines the master `CONFIG` dictionary: Cloud Storage locations for images and labels, the local cache directory, runtime hyperparameters, and inference flags. Update these values to point at new buckets, prefixes, or experiment settings before running the pipeline.


In [3]:
# =========================
# 1) CONFIG & ENVIRONMENT
# =========================
from pathlib import Path

DATA_ROOT = Path('/home/jupyter/forsmith_roof_data')

CONFIG = {
    # Cloud Storage locations
    'GCS_BUCKET': 'forsmith-report-bucket',
    'GCS_IMAGES_PREFIX': 'images/',
    'GCS_LABELS_CSV': 'labels/labels.csv',
    'GCS_LABELS_JSON': 'labels/forsmith_roof_labels.json',  # update if your JSON lives elsewhere

    # Paths on the notebook instance (populated from GCS)
    'DATA_ROOT': str(DATA_ROOT),
    'IMAGES_DIR': str(DATA_ROOT / 'images'),
    'CSV_PATH': str(DATA_ROOT / 'labels.csv'),
    'LABELS_JSON': str(DATA_ROOT / 'forsmith_roof_labels.json'),

    # Run control
    'SEED': 1337,
    'IMAGE_SIZE': 518,
    'BATCH_SIZE': 16,
    'ACCUM_STEPS': 1,
    'EPOCHS_LINPROBE': 8,
    'EPOCHS_UNFREEZE': 20,
    'EPOCHS_FULLFT': 10,
    'BASE_LR': 2e-4,
    'WEIGHT_DECAY': 0.05,
    'WARMUP_PCT': 0.05,
    'LABEL_SMOOTH': 0.1,
    'USE_AMP': True,

    # Cross-validation
    'N_SPLITS': 5,
    'FOLD_INDEX': 0,
    'TEST_SIZE': 0.15,

    # Checkpointing
    'OUT_DIR': str(DATA_ROOT / 'outputs'),
    'SAVE_ON_BEST': 'macro_f1',

    # Inference aids
    'ENABLE_SHEET_MASK': True,
}

print(CONFIG)


{'GCS_BUCKET': 'forsmith-report-bucket', 'GCS_IMAGES_PREFIX': 'images/', 'GCS_LABELS_CSV': 'labels/labels.csv', 'GCS_LABELS_JSON': 'labels/forsmith_roof_labels.json', 'DATA_ROOT': '/home/jupyter/forsmith_roof_data', 'IMAGES_DIR': '/home/jupyter/forsmith_roof_data/images', 'CSV_PATH': '/home/jupyter/forsmith_roof_data/labels.csv', 'LABELS_JSON': '/home/jupyter/forsmith_roof_data/forsmith_roof_labels.json', 'SEED': 1337, 'IMAGE_SIZE': 518, 'BATCH_SIZE': 16, 'ACCUM_STEPS': 1, 'EPOCHS_LINPROBE': 8, 'EPOCHS_UNFREEZE': 20, 'EPOCHS_FULLFT': 10, 'BASE_LR': 0.0002, 'WEIGHT_DECAY': 0.05, 'WARMUP_PCT': 0.05, 'LABEL_SMOOTH': 0.1, 'USE_AMP': True, 'N_SPLITS': 5, 'FOLD_INDEX': 0, 'TEST_SIZE': 0.15, 'OUT_DIR': '/home/jupyter/forsmith_roof_data/outputs', 'SAVE_ON_BEST': 'macro_f1', 'ENABLE_SHEET_MASK': True}


### Section 1a - Sync Data From Cloud Storage
Connects to Google Cloud Storage, mirrors the `images/` hierarchy onto the notebook instance, and pulls down the label CSV/JSON artifacts. Existing local files are left untouched so repeated syncs are fast.


In [4]:
# =========================
# 1a) SYNC DATA FROM GCS
# =========================
from pathlib import Path
from google.cloud import storage
import json
import io

data_root = Path(CONFIG['DATA_ROOT'])
data_root.mkdir(parents=True, exist_ok=True)

images_dir = Path(CONFIG['IMAGES_DIR'])
images_dir.mkdir(parents=True, exist_ok=True)

bucket_name = CONFIG.get('GCS_BUCKET')
if not bucket_name:
    print('No GCS bucket configured; skipping sync.')
else:
    try:
        client = storage.Client()
        bucket = client.bucket(bucket_name)
    except Exception as exc:
        raise RuntimeError(
            'Failed to create a Cloud Storage client. Confirm your Vertex AI '
            'Workbench environment has access to the project (service account/IAM).'
        ) from exc

    def _normalize(prefix: str) -> str:
        if not prefix:
            return ''
        return prefix.strip('/') + '/'

    def _strip_prefix(text: str, prefix: str) -> str:
        if prefix and text.startswith(prefix):
            return text[len(prefix):]
        return text

    images_prefix = _normalize(CONFIG.get('GCS_IMAGES_PREFIX', ''))
    downloaded_images = 0
    listed_any = False
    for blob in client.list_blobs(bucket_name, prefix=images_prefix):
        listed_any = True
        if blob.name.endswith('/'):
            continue
        rel_path = _strip_prefix(blob.name, images_prefix).lstrip('/')
        target = images_dir / rel_path
        if target.exists():
            continue
        target.parent.mkdir(parents=True, exist_ok=True)
        print(f'Downloading {blob.name} -> {target}')
        blob.download_to_filename(target)
        downloaded_images += 1
    if not listed_any:
        print(f'No objects found under gs://{bucket_name}/{images_prefix}')
    elif downloaded_images == 0:
        print('Images already present locally.')
    else:
        print(f'Downloaded {downloaded_images} new image files.')

    def _download_file(gcs_key: str, local_key: str, force: bool = False):
        remote_path = CONFIG.get(gcs_key)
        if not remote_path:
            return None
        blob_name = remote_path.strip('/')
        target_path = Path(CONFIG[local_key])
        target_path.parent.mkdir(parents=True, exist_ok=True)

        if force and target_path.exists():
            print(f'Force re-sync on {target_path.name} … deleting local copy.')
            try:
                target_path.unlink()
            except FileNotFoundError:
                pass

        if target_path.exists():
            print(f'{target_path.name} already present locally.')
            return target_path

        print(f'Downloading gs://{bucket_name}/{blob_name} -> {target_path}')
        blob = bucket.blob(blob_name)
        if not blob.exists():
            raise FileNotFoundError(
                f'Blob gs://{bucket_name}/{blob_name} does not exist. '
                'Check CONFIG paths.'
            )
        blob.download_to_filename(target_path)
        return target_path

    # CSV and JSON (taxonomy)
    _download_file('GCS_LABELS_CSV', 'CSV_PATH')

    # Allow forcing a clean pull of the JSON if it was previously corrupted/empty
    FORCE_RESYNC_JSON = False  # set True if you keep hitting JSON issues

    json_path = _download_file('GCS_LABELS_JSON', 'LABELS_JSON', force=FORCE_RESYNC_JSON)

    # --- Sanity-check taxonomy JSON locally; if bad, auto re-download once ---
    def _is_valid_json_file(p: Path) -> tuple[bool, str]:
        if not p or not p.exists():
            return False, 'file missing'
        size = p.stat().st_size
        if size == 0:
            return False, 'file is empty (0 bytes)'
        try:
            with p.open('r', encoding='utf-8') as f:
                json.load(f)
            return True, f'valid JSON ({size} bytes)'
        except Exception as e:
            return False, f'json load failed: {type(e).__name__}: {e}'

    if json_path and json_path.exists():
        ok, msg = _is_valid_json_file(json_path)
        # Print basic diagnostics to help you debug bucket contents vs local cache
        head = Path(json_path).read_bytes()[:128]
        print(f'[Taxonomy JSON check] {json_path.name}: {msg}. First bytes: {head!r}')
        if not ok:
            print('Attempting one forced re-sync of taxonomy JSON from GCS…')
            json_path = _download_file('GCS_LABELS_JSON', 'LABELS_JSON', force=True)
            ok2, msg2 = _is_valid_json_file(json_path)
            print(f'[After re-sync] {json_path.name}: {msg2}')
            if not ok2:
                print('Warning: taxonomy JSON still invalid. Section 3 will try CSV/XLSX fallback.')


Images already present locally.
labels.csv already present locally.
forsmith_roof_labels.json already present locally.
[Taxonomy JSON check] forsmith_roof_labels.json: valid JSON (451763 bytes). First bytes: b'{\r\n  "version": 1,\r\n  "source_sheets": [\r\n    "1.0 - B.U.R",\r\n    "2.0 - Mod. Bit.",\r\n    "3.0 - Thermoplastic",\r\n    "4.0 - IRM'


### Section 2 - Import Libraries & Seed RNGs
Adds the `_deps` directory to `sys.path`, imports PyTorch, Albumentations, scikit-learn, and helper utilities, and seeds Python/NumPy/Torch for reproducible runs while checking whether a GPU is available.


In [5]:

# ================================
# 2) ENVIRONMENT, INSTALLS, SEEDS
# ================================
# Dependencies installed above if needed.

import os, json, math, random, time, shutil, pathlib
from pathlib import Path
import sys

_DEPS_PATH = Path.cwd() / '_deps'
if _DEPS_PATH.exists() and str(_DEPS_PATH) not in sys.path:
    sys.path.insert(0, str(_DEPS_PATH))


import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix, f1_score

# Determinism
def set_seed(seed: int = 1337):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG["SEED"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cpu')

### Section 3 - Load Taxonomy JSON
Reads the taxonomy JSON, builds mappings between observation IDs and class indices, records sheet metadata, and prepares lookup tables that power masking and reporting later in the workflow.


In [6]:
# =========================================
# 3) LOAD TAXONOMY & BUILD LABEL MAPS
# =========================================
import json
import pandas as pd
from pathlib import Path

def _try_load_json(path: Path):
    """Strict JSON loader, then a BOM/lenient pass, else None."""
    if not path or not path.exists():
        return None, 'missing'
    try:
        with path.open('r', encoding='utf-8') as f:
            return json.load(f), 'json_ok'
    except Exception as e1:
        # try UTF-8-SIG (handles BOM) and a small repair pass
        try:
            raw = path.read_text(encoding='utf-8-sig')
            # minimal repair: sometimes people save with single quotes or trailing commas
            repaired = raw.replace("'", '"')  # safest minimal tweak
            return json.loads(repaired), f'json_repaired ({type(e1).__name__})'
        except Exception as e2:
            return None, f'json_failed ({type(e1).__name__} -> {type(e2).__name__})'

def _try_load_csv(path: Path):
    """Load CSV schema: expects at least observation_id, label, sheet."""
    if not path or not path.exists():
        return None, 'missing'
    try:
        df = pd.read_csv(path)
        return df, 'csv_ok'
    except Exception as e:
        return None, f'csv_failed ({type(e).__name__})'

def _try_load_xlsx(path: Path):
    """If your taxonomy lives in an .xlsx instead, try to parse it here."""
    if not path or not path.exists():
        return None, 'missing'
    try:
        df = pd.read_excel(path)
        return df, 'xlsx_ok'
    except Exception as e:
        return None, f'xlsx_failed ({type(e).__name__})'

def _normalize_records_from_df(df: pd.DataFrame):
    """
    Normalize a dataframe with columns like:
      observation_id / id, label, sheet, cause_effect, recommendation, default_text, etc.
    Only id/label/sheet are required for the classifier mapping; others are optional.
    """
    # Flexible column mapping
    colmap = {
        'id': None,
        'observation_id': None,
        'label': None,
        'sheet': None,
        'cause_effect': None,
        'recommendation': None,
        'default_text': None,
    }
    lower_cols = {c.lower(): c for c in df.columns}

    def pick(*names):
        for n in names:
            if n in lower_cols:
                return lower_cols[n]
        return None

    colmap['id'] = pick('id', 'observation_id', 'obs_id')
    colmap['observation_id'] = colmap['id']  # treat as same
    colmap['label'] = pick('label', 'name', 'class')
    colmap['sheet'] = pick('sheet', 'category', 'group')
    colmap['cause_effect'] = pick('cause_effect', 'cause/effect', 'cause')
    colmap['recommendation'] = pick('recommendation', 'recommendations')
    colmap['default_text'] = pick('default_text', 'text', 'boilerplate')

    required = [colmap['id'], colmap['label'], colmap['sheet']]
    if any(c is None for c in required):
        missing = [k for k, v in {'id': colmap['id'], 'label': colmap['label'], 'sheet': colmap['sheet']}.items() if v is None]
        raise ValueError(f'Missing required columns in taxonomy table: {missing}. '
                         f'Found columns: {list(df.columns)}')

    recs = []
    for _, row in df.iterrows():
        recs.append({
            'id': str(row[colmap['id']]).strip(),
            'label': str(row[colmap['label']]).strip(),
            'sheet': str(row[colmap['sheet']]).strip(),
            'cause_effect': (str(row[colmap['cause_effect']]).strip()
                             if colmap['cause_effect'] else ''),
            'recommendation': (str(row[colmap['recommendation']]).strip()
                               if colmap['recommendation'] else ''),
            'default_text': (str(row[colmap['default_text']]).strip()
                             if colmap['default_text'] else ''),
        })
    return recs

# --- Load order: JSON → CSV → XLSX ---
labels_json_path = Path(CONFIG["LABELS_JSON"])
labels_csv_path  = Path(CONFIG["CSV_PATH"])

labels_raw, src = _try_load_json(labels_json_path)
if labels_raw is None:
    print(f'[Taxonomy] JSON load failed: {src}. Trying CSV fallback…')
    df_csv, src_csv = _try_load_csv(labels_csv_path)
    if df_csv is not None:
        labels_raw = _normalize_records_from_df(df_csv)
        src = f'csv_fallback ({src_csv})'
    else:
        # Optional XLSX fallback if you keep a spreadsheet version
        # Put its local path in CONFIG like CONFIG["LABELS_XLSX"] if you want this path to be dynamic
        xlsx_path = Path(CONFIG.get('LABELS_XLSX', ''))
        df_xlsx, src_xlsx = _try_load_xlsx(xlsx_path) if xlsx_path else (None, 'missing')
        if df_xlsx is not None:
            labels_raw = _normalize_records_from_df(df_xlsx)
            src = f'xlsx_fallback ({src_xlsx})'
        else:
            raise RuntimeError(
                f'Could not load taxonomy from JSON ({src}), CSV ({src_csv}), '
                f'or XLSX ({src_xlsx}). Fix your artifacts and rerun.'
            )

# Validate shape and normalize minimal fields
if isinstance(labels_raw, dict):
    # Some exports store items under a key like "items"
    labels_raw = labels_raw.get('items', [])
if not isinstance(labels_raw, list):
    raise TypeError(f'Expected taxonomy to be a list; got {type(labels_raw)}')

def _coerce_record(it):
    return {
        'id': str(it.get('id') or it.get('observation_id') or '').strip(),
        'label': str(it.get('label') or '').strip(),
        'sheet': str(it.get('sheet') or '').strip(),
        'cause_effect': str(it.get('cause_effect') or ''),
        'recommendation': str(it.get('recommendation') or ''),
        'default_text': str(it.get('default_text') or ''),
    }

labels_norm = [_coerce_record(it) for it in labels_raw if (it.get('id') or it.get('observation_id')) and it.get('label')]
if not labels_norm:
    raise ValueError('Taxonomy normalization produced 0 items—check your JSON/CSV/XLSX content.')

# Build maps
obs_ids = [it["id"] for it in labels_norm]
id_to_idx = {oid: i for i, oid in enumerate(obs_ids)}
idx_to_info = {
    i: {
        "observation_id": it["id"],
        "label": it["label"],
        "sheet": it["sheet"],
        "cause_effect": it.get("cause_effect", ""),
        "recommendation": it.get("recommendation", ""),
        "default_text": it.get("default_text", ""),
    }
    for i, it in enumerate(labels_norm)
}

# Sheet → set of obs_ids (for inference masking)
sheet_to_ids = {}
for it in labels_norm:
    sheet_to_ids.setdefault(it["sheet"], set()).add(it["id"])

print(f"[Taxonomy] Source: {src}")
print(f"[Taxonomy] Total classes: {len(obs_ids)}")
print("[Taxonomy] Example entries:", labels_norm[:3])
print("[Taxonomy] Sheets:", list(sheet_to_ids.keys())[:5], "...")

[Taxonomy] Source: json_ok
[Taxonomy] Total classes: 424
[Taxonomy] Example entries: [{'id': '1.01.01', 'label': 'Metal Edge Flashing', 'sheet': '1.0 - B.U.R', 'cause_effect': 'Improper installation, deterioration, external damage, etc. has caused the metal edge to not function as intended, which could lead to moisture ingress into the roof system and/or create a safety hazard.', 'recommendation': '', 'default_text': ''}, {'id': '1.01.02', 'label': 'Perimeter Membrane Flashings', 'sheet': '1.0 - B.U.R', 'cause_effect': 'Improper installation, deterioration, external damage, etc. could allow moisture migration to occur into the roof and/or wall assembly.', 'recommendation': '', 'default_text': ''}, {'id': '1.01.03', 'label': 'Perimeter Metal Flashings', 'sheet': '1.0 - B.U.R', 'cause_effect': 'Improper installation, deterioration, external damage, etc. has caused the metal flashings to not function as intended, which could lead to damage of the roof system and/or create a safety hazard.

### Section 4 - Load Dataset Manifest
Ingests the labels CSV, verifies schema, derives a `report_id` grouping from each filename, filters to taxonomy-supported observations, and attaches numeric class indices required for model training.


In [7]:

# =====================================================
# 4) LOAD DATASET CSV & EXTRACT report_id FROM FILENAME
# =====================================================
df = pd.read_csv(CONFIG["CSV_PATH"])
expected_cols = {"image_file","label","observation_id"}
missing = expected_cols - set(df.columns)
assert not missing, f"CSV missing columns: {missing}"

# Derive report_id: everything before '_page'
def get_report_id(fname: str) -> str:
    base = Path(fname).stem
    # e.g., '20-063_page41_img2' -> '20-063'
    return base.split("_page")[0]

df["report_id"] = df["image_file"].map(get_report_id)

# Filter to only obs_ids present in JSON (defensive)
df = df[df["observation_id"].isin(id_to_idx.keys())].copy()
df["class_index"] = df["observation_id"].map(id_to_idx)
print(df.head())
print("Unique reports:", df["report_id"].nunique(), " | images:", len(df), " | classes in use:", df["class_index"].nunique())


                  image_file                                   label  \
0  18-053-12_page12_img2.png                    Unprotected Openings   
1  18-053-12_page17_img3.png             Redundant roof penetrations   
2   23-023R1_page20_img2.png                     Subsurface Moisture   
3    23-023R1_page5_img3.png  Conduit Penetration Through Mech. Unit   
4     21-009_page18_img2.png  Conduit Penetration Through Mech. Unit   

  observation_id  confidence  report_id  class_index  
0        2.11.02       0.502  18-053-12          130  
1        2.06.01       0.504  18-053-12           96  
2        2.12.01       0.505   23-023R1          132  
3        2.04.04       0.506   23-023R1           88  
4        2.04.04       0.506     21-009           88  
Unique reports: 106  | images: 1616  | classes in use: 121


### Section 5 - Create Group-Aware Splits
Uses group-based splitters to hold out a test set and generate cross-validation folds without leaking report-level context between train, validation, and test partitions.


In [8]:

# ==================================================
# 5) GROUP-AWARE SPLITS (TEST; then CV folds for TR/VAL)
# ==================================================
gss = GroupShuffleSplit(n_splits=1, test_size=CONFIG["TEST_SIZE"], random_state=CONFIG["SEED"])
trainval_idx, test_idx = next(gss.split(df, groups=df["report_id"]))
df_trainval = df.iloc[trainval_idx].reset_index(drop=True)
df_test = df.iloc[test_idx].reset_index(drop=True)

print("Train+Val:", len(df_trainval), " | Test:", len(df_test))

gkf = GroupKFold(n_splits=CONFIG["N_SPLITS"])
folds = list(gkf.split(df_trainval, df_trainval["class_index"], groups=df_trainval["report_id"]))
train_idx, val_idx = folds[CONFIG["FOLD_INDEX"]]

train_df = df_trainval.iloc[train_idx].reset_index(drop=True)
val_df = df_trainval.iloc[val_idx].reset_index(drop=True)

print(f"Fold {CONFIG['FOLD_INDEX']}: train={len(train_df)}, val={len(val_df)}")


Train+Val: 1231  | Test: 385
Fold 0: train=984, val=247


### Section 6 - Compute Class Weights
Calculates inverse-frequency weights from the training data so the loss emphasises under-represented classes during optimisation.


In [11]:
# ============================
# Section 6a – Active Class Discovery & Remap (FINAL)
# ============================
import torch, numpy as np
from torch.utils.data import Subset, Dataset

# ---------- Helper: safely extract labels ----------
def _get_labels(dset):
    """Return list[int] of labels without heavy transform calls."""
    # Common attributes first
    if hasattr(dset, "targets") and len(getattr(dset, "targets")) == len(dset):
        return [int(t) for t in dset.targets]
    if hasattr(dset, "labels") and len(getattr(dset, "labels")) == len(dset):
        return [int(t) for t in dset.labels]
    for name in ("samples", "imgs"):
        if hasattr(dset, name):
            items = getattr(dset, name)
            try:
                return [int(lbl) for _, lbl in items]
            except Exception:
                pass

    # Fallback: light index read
    labels = []
    for i in range(len(dset)):
        it = dset[i]
        if isinstance(it, tuple) and len(it) >= 2:
            labels.append(int(it[1]))
        elif isinstance(it, dict):
            for k in ("label", "target", "y"):
                if k in it:
                    val = it[k]
                    if isinstance(val, (list, tuple)): val = val[0]
                    labels.append(int(val))
                    break
        else:
            raise TypeError(f"Unsupported item type {type(it)} at index {i}")
    return labels


# ---------- Build mapping for active classes ----------
train_labels_raw = _get_labels(train_dataset_raw)
val_labels_raw   = _get_labels(val_dataset_raw)

active_old_ids = sorted(set(train_labels_raw) | set(val_labels_raw))
old2new = {old: i for i, old in enumerate(active_old_ids)}
new2old = {i: old for old, i in old2new.items()}
K = len(active_old_ids)
print(f"[6a] Active classes in this run: {K} / 424")


# ---------- Dataset wrapper to remap labels ----------
class _Remap(Dataset):
    def __init__(self, base, id_map):
        self.base, self.id_map = base, id_map
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        item = self.base[idx]
        # Handle tuple (x,y)
        if isinstance(item, tuple) and len(item) >= 2:
            x, y_old = item[0], int(item[1])
            return x, self.id_map[y_old]
        # Handle dict-like item
        if isinstance(item, dict):
            out = dict(item)
            y_old = None
            for k in ("label", "target", "y"):
                if k in out:
                    val = out[k]
                    if isinstance(val, (list, tuple)): val = val[0]
                    y_old = int(val)
                    out[k] = self.id_map[y_old]
                    break
            if y_old is None:
                raise ValueError("Dict item missing label/target/y key.")
            return out
        raise TypeError(f"Unsupported dataset item type: {type(item)}")


# ---------- Filter to active classes (no transform calls) ----------
train_kept = [i for i, y in enumerate(train_labels_raw) if int(y) in old2new]
val_kept   = [i for i, y in enumerate(val_labels_raw)   if int(y) in old2new]

train_dataset = _Remap(Subset(train_dataset_raw, train_kept), old2new)
val_dataset   = _Remap(Subset(val_dataset_raw,   val_kept),   old2new)


# ---------- Compute class weights ----------
train_labels_new = np.array([
    int(train_dataset[i][1]) if isinstance(train_dataset[i], tuple)
    else int(next(v for k,v in train_dataset[i].items() if k in ("label","target","y")))
    for i in range(len(train_dataset))
])
counts = np.bincount(train_labels_new, minlength=K)
class_weights = 1.0 / np.clip(counts, 1, None)

print(f"[6a] After remap → K={K} | train={len(train_dataset)} | val={len(val_dataset)}")
print(f"[6a] Non-empty train classes: {(counts>0).sum()} / {K}")

# Optional quick probe
sample = train_dataset[0]
if isinstance(sample, tuple):
    x, y = sample
    print(f"[6a] Example tuple item → label={y}, type(x)={type(x)}, shape={getattr(x, 'shape', None)}")
else:
    print(f"[6a] Example dict item keys → {list(sample.keys())}")


NameError: name 'train_dataset_raw' is not defined

In [None]:

# ===========================================
# 6) CLASS WEIGHTS (inverse frequency simple)
# ===========================================
class_counts = train_df["class_index"].value_counts().sort_index()
num_classes = len(obs_ids)
freq = np.ones(num_classes)
freq[class_counts.index.values] = class_counts.values
inv_freq = 1.0 / np.clip(freq, 1.0, None)
class_weights = inv_freq / inv_freq.sum() * num_classes
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Example weights (first 10):", class_weights[:10])


### Section 7 - Define Image Transforms
Configures Albumentations pipelines for training, validation, and inference: resizing, normalization, and light augmentations matched to the configured input size.


In [None]:
# =====================================
# 7) IMAGE TRANSFORMS (Albumentations)
# =====================================
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

IMG = CONFIG["IMAGE_SIZE"]
PAD_COLOR = (128, 128, 128)

# ---- Helpers with conservative, cross-version-safe kwargs -------------------
def _pad_kwargs():
    # Use 'value' universally to satisfy 1.x; harmless in 2.x when accepted via SSR.
    return dict(border_mode=cv2.BORDER_CONSTANT, value=PAD_COLOR)

def _image_compression_kwargs():
    # Prefer 2.x API, fallback to 1.x gracefully.
    try:
        _ = A.ImageCompression(quality_range=(80, 100), p=0.0)
        return dict(quality_range=(80, 100))
    except TypeError:
        return dict(quality_lower=80, quality_upper=100)

def _random_resized_crop(height, width, **common):
    # Some builds want size=(H,W); others height=..., width=...
    try:
        return A.RandomResizedCrop(size=(height, width), **common)
    except TypeError:
        return A.RandomResizedCrop(height=height, width=width, **common)

def _mild_geo():
    """
    Prefer ShiftScaleRotate across versions (translation via shift_limit,
    small scale/rotate). If SSR isn't present, fall back to Affine WITHOUT
    mode/cval (keeps signature-compatible).
    """
    if hasattr(A, "ShiftScaleRotate"):
        return A.ShiftScaleRotate(
            shift_limit=0.02,     # ~±2% translation
            scale_limit=0.05,     # ~±5% scale
            rotate_limit=5,       # ±5 degrees
            border_mode=cv2.BORDER_CONSTANT,
            value=PAD_COLOR,      # universal kw for fill
            p=0.5,
        )
    else:
        # Minimal-arg Affine that avoids 'mode'/'cval' complaints
        return A.Affine(
            translate_percent={"x": (-0.02, 0.02), "y": (-0.02, 0.02)},
            scale=(0.95, 1.05),
            rotate=(-5, 5),
            p=0.5,
        )

# ---- Pipelines --------------------------------------------------------------
train_tfms = A.Compose([
    A.LongestMaxSize(max_size=IMG, interpolation=cv2.INTER_AREA),

    A.PadIfNeeded(min_height=IMG, min_width=IMG, **_pad_kwargs()),

    _random_resized_crop(
        IMG, IMG, scale=(0.90, 1.0), ratio=(0.9, 1.1), p=0.7
    ),

    A.HorizontalFlip(p=0.5),

    _mild_geo(),

    A.RandomBrightnessContrast(p=0.5),

    A.ImageCompression(p=0.3, **_image_compression_kwargs()),

    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)),
    ToTensorV2(),
])

eval_tfms = A.Compose([
    A.LongestMaxSize(max_size=IMG, interpolation=cv2.INTER_AREA),
    A.PadIfNeeded(min_height=IMG, min_width=IMG, **_pad_kwargs()),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)),
    ToTensorV2(),
])

# ---- quick probe ------------------------------------------------------------
if True:
    import numpy as _np
    _fake = (_np.random.rand(320, 500, 3) * 255).astype("uint8")
    _out = train_tfms(image=_fake)["image"]; assert _out.shape[-2:] == (IMG, IMG)
    _out2 = eval_tfms(image=_fake)["image"];  assert _out2.shape[-2:] == (IMG, IMG)


Sanity Probe

In [None]:
import albumentations as _A, numpy as _np, cv2 as _cv2
print("Albumentations:", _A.__version__)

# Make a fake RGB uint8 image with non-square shape to exercise pad/crop
_fake = (_np.random.rand(320, 500, 3) * 255).astype("uint8")
try:
    _out = train_tfms(image=_fake)["image"]  # torch.Tensor CxHxW
    print("train_tfms ok:", tuple(_out.shape))
    assert _out.shape[-2:] == (IMG, IMG), "train_tfms must output IMGxIMG"
except Exception as e:
    print("train_tfms FAILED:", repr(e))
    raise

try:
    _out2 = eval_tfms(image=_fake)["image"]
    print("eval_tfms ok:", tuple(_out2.shape))
    assert _out2.shape[-2:] == (IMG, IMG), "eval_tfms must output IMGxIMG"
except Exception as e:
    print("eval_tfms FAILED:", repr(e))
    raise


### Section 8 - Dataset & DataLoaders
Implements the PyTorch dataset wrapper that reads images, applies transforms, and returns tensors plus class indices. Also instantiates the train/validation/test `DataLoader` objects with appropriate batching options.


In [None]:
# =========================
# Section 8 - Dataset & DataLoaders
# =========================
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np

BATCH_SIZE = CONFIG.get("BATCH_SIZE", 32)

# Balanced sampling to help tiny/imbalanced data
train_labels_new = np.array([int(train_dataset[i][1]) for i in range(len(train_dataset))])
counts = np.bincount(train_labels_new, minlength=len(class_weights))
class_weights = 1.0 / np.clip(counts, 1, None)  # ensure synced with Section 6a
sample_weights = class_weights[train_labels_new]
sampler = WeightedRandomSampler(sample_weights.tolist(), len(sample_weights), replacement=True)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
    num_workers=CONFIG.get("NUM_WORKERS", 4), pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=CONFIG.get("NUM_WORKERS", 4), pin_memory=True
)

# Cross-entropy with class weights + light label smoothing
crit = torch.nn.CrossEntropyLoss(
    weight=torch.tensor(class_weights, dtype=torch.float32, device=device),
    label_smoothing=0.10
)


### Section 9 - Build DINOv2 Model
Instantiates the timm DINOv2 ViT-S/14 backbone, swaps in a classifier sized for the dataset, and prepares the network for staged fine-tuning.


In [None]:
# =========================================
# Section 9 – Build DINOv2 Model
# =========================================
import torch
import torch.nn as nn

assert K > 0, "Active-class K not set. Run Section 6a before this."

# --- Load / create your DINOv2 backbone here ---
# If you already created `model` earlier, keep it. Otherwise, an example:
try:
    _ = model  # keep existing model if present
except NameError:
    from torchvision.models.vision_transformer import vit_b_16  # placeholder if you need a stub
    model = vit_b_16(weights=None)  # replace with your DINOv2 backbone load

# Ensure we have a "backbone" attribute or an equivalent forward_features:
if not hasattr(model, "backbone"):
    # Try to create a lightweight wrapper so we can freeze a backbone and attach a head
    class Wrap(torch.nn.Module):
        def __init__(self, base):
            super().__init__()
            self.backbone = base
            # temporary dummy head; will be replaced
            self.head = nn.Identity()
        def forward_features(self, x):
            # Try dino-like API; else fall back to forward
            if hasattr(self.backbone, "forward_features"):
                return self.backbone.forward_features(x)
            return self.backbone(x)
        def forward(self, x):
            feats = self.forward_features(x)
            if isinstance(feats, (list, tuple)):
                feats = feats[0]
            # If backbone already outputs logits, stop earlier; here we force a linear head usage
            if feats.ndim > 2:
                feats = feats.mean(dim=(-2, -1))  # global pool if needed
            return self.head(feats)

    model = Wrap(model)

# Freeze backbone for linear probe
for p in model.backbone.parameters():
    p.requires_grad = False

# Infer feature dim for the head
model = model.to(device).eval()
with torch.no_grad():
    dummy = torch.randn(1, 3, IMG, IMG, device=device)
    if hasattr(model, "forward_features"):
        feats = model.forward_features(dummy)
    else:
        feats = model.backbone(dummy)
    if isinstance(feats, (list, tuple)):
        feats = feats[0]
    if feats.ndim > 2:  # e.g., feature map
        feats = feats.mean(dim=(-2, -1))
    feat_dim = feats.shape[-1]

# Replace head to match only active classes K
model.head = nn.Linear(feat_dim, K).to(device)
print(f"[Section 9] Linear head -> in_features={feat_dim}, out_features={K}")
model.train()


### Section 10 - Optimiser Helpers
Provides utility functions to construct the AdamW optimiser and cosine learning-rate scheduler with warm-up so each training phase can initialise its own optimiser stack consistently.


In [None]:
# ==============================================
# Section 10 – Optimiser & Scheduler 
# ==============================================
import torch.optim as optim

HEAD_LR = CONFIG.get("HEAD_LR", 5e-2)   # punchy LR for linear probe
WD      = CONFIG.get("WEIGHT_DECAY", 1e-4)
EPOCHS  = CONFIG.get("EPOCHS", 20)

optimizer = optim.SGD(model.head.parameters(), lr=HEAD_LR, momentum=0.9, weight_decay=WD)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

print(f"[Section 10] Optimizer=SGD(lr={HEAD_LR}, wd={WD}), Scheduler=Cosine(T_max={EPOCHS})")


In [None]:
# ===========================
# 10.1) CHECKPOINT & RESUME
# ===========================
import json, time

CKPT_DIR = CONFIG["OUT_DIR"]
CKPT_LAST = os.path.join(CKPT_DIR, "last.pt")
CKPT_BEST = os.path.join(CKPT_DIR, "best.pt")
CKPT_META = os.path.join(CKPT_DIR, "meta.json")

def save_ckpt(state_dict, optimizer, epoch, phase, best_f1=None, is_best=False):
    os.makedirs(CKPT_DIR, exist_ok=True)
    payload = {
        "model": state_dict,
        "optimizer": optimizer.state_dict() if optimizer else None,
        "epoch": epoch,
        "phase": phase,
        "best_f1": best_f1,
        "ts": time.time(),
        "config": CONFIG,
    }
    torch.save(payload, CKPT_LAST)
    if is_best:
        torch.save(payload, CKPT_BEST)
    with open(CKPT_META, "w") as f:
        json.dump({k:v for k,v in payload.items() if k not in ("model","optimizer")}, f, indent=2)
    print(f"[CKPT] Saved {'BEST' if is_best else 'LAST'} at epoch {epoch} ({phase})")

def try_resume(net, optimizer=None, path=None):
    path = path or (CONFIG.get("RESUME_FROM") or CKPT_LAST)
    if not os.path.exists(path):
        print("[CKPT] No resume checkpoint found.")
        return 0, None, None
    ckpt = torch.load(path, map_location="cpu")
    net.load_state_dict(ckpt["model"], strict=True)
    if optimizer is not None and ckpt.get("optimizer"):
        optimizer.load_state_dict(ckpt["optimizer"])
    print(f"[CKPT] Resumed from {path} at epoch {ckpt['epoch']} (phase={ckpt['phase']}) best_f1={ckpt.get('best_f1')}")
    return ckpt["epoch"], ckpt.get("phase"), ckpt.get("best_f1")


### Section 11 - Training & Evaluation Loops
Implements the shared training loop, validation pass, and three-phase schedule (linear probe, partial unfreeze, full fine-tune) while tracking metrics with torchmetrics and supporting mixed precision.


In [None]:
# ====================================================
# Section 11 – Training & Evaluation Loops (REPLACED)
# ====================================================
import contextlib
from torch import amp
from tqdm import tqdm

USE_AMP   = CONFIG.get("USE_AMP", True)
AMP_DTYPE = torch.float16  # switch to torch.bfloat16 if your GPU supports it better
scaler = amp.GradScaler('cuda') if USE_AMP else None

def _forward(x, y):
    ctx = amp.autocast('cuda', dtype=AMP_DTYPE) if USE_AMP else contextlib.nullcontext()
    with ctx:
        logits = model(x)
        loss   = crit(logits, y)
    return logits, loss

def train_one_epoch(epoch, loader):
    model.train()
    running_loss, running_correct, seen = 0.0, 0, 0
    pbar = tqdm(loader, desc=f"[Train] Epoch {epoch}", leave=False)
    for x, y in pbar:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits, loss = _forward(x, y)
        if USE_AMP:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        preds = logits.argmax(1)
        running_correct += (preds == y).sum().item()
        seen += y.size(0)
        running_loss += loss.item() * y.size(0)
        pbar.set_postfix(loss=running_loss/seen, acc=running_correct/seen)

    scheduler.step()
    return running_loss/seen, running_correct/seen

@torch.no_grad()
def evaluate(epoch, loader):
    model.eval()
    total_loss, total_correct, seen = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits, loss = _forward(x, y)
        preds = logits.argmax(1)
        total_loss   += loss.item() * y.size(0)
        total_correct += (preds == y).sum().item()
        seen += y.size(0)
    return total_loss/seen, total_correct/seen

def fit(epochs=EPOCHS):
    best_val = -1.0
    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_one_epoch(ep, train_loader)
        va_loss, va_acc = evaluate(ep, val_loader)
        print(f"[Epoch {ep:02d}] train: loss {tr_loss:.4f} acc {tr_acc:.4f} | val: loss {va_loss:.4f} acc {va_acc:.4f}")
        if va_acc > best_val:
            best_val = va_acc
            torch.save({'model':model.state_dict(),
                        'old2new':old2new,'new2old':new2old,
                        'K':K, 'feat_dim':feat_dim}, "best_linear_probe.pt")
            print(f"  ↳ New best val acc {best_val:.4f} — saved to best_linear_probe.pt")

In [None]:
# ==========================================================
# Section 11a – One-Batch Overfit Sanity Check
# ==========================================================
xb, yb = next(iter(train_loader))
xb, yb = xb.to(device), yb.to(device)

_test_opt = torch.optim.SGD(model.head.parameters(), lr=0.05, momentum=0.9)
model.train()
for step in range(300):
    _test_opt.zero_grad(set_to_none=True)
    with amp.autocast('cuda', dtype=AMP_DTYPE) if USE_AMP else contextlib.nullcontext():
        out = model(xb); loss = crit(out, yb)
    loss.backward(); _test_opt.step()
    if step % 50 == 0:
        acc = (out.argmax(1) == yb).float().mean().item()
        print(f"[1-batch] step {step:3d} | loss {loss.item():.3f} | acc {acc:.3f}")


### Section 12 - Test Set Evaluation
Runs the final trained model against the held-out test split, generating macro F1 scores, a classification report, and a confusion matrix for post-training analysis.


In [None]:
# =========================================
# Section 12 – Test Set Evaluation
# =========================================
@torch.no_grad()
def evaluate_loader(loader):
    model.eval()
    total_loss, total_correct, seen = 0.0, 0, 0
    all_preds, all_tgts = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, loss = _forward(x, y)
        preds = logits.argmax(1)
        total_loss   += loss.item() * y.size(0)
        total_correct += (preds == y).sum().item()
        seen += y.size(0)
        all_preds.append(preds.detach().cpu())
        all_tgts.append(y.detach().cpu())
    import torch as _t
    return (total_loss/seen, total_correct/seen,
            _t.cat(all_preds).numpy(), _t.cat(all_tgts).numpy())

# If you have a test_loader, run it; otherwise, run on val.
loader_to_eval = globals().get("test_loader", val_loader)

te_loss, te_acc, te_preds, te_tgts = evaluate_loader(loader_to_eval)
print(f"[Test] loss {te_loss:.4f} | acc {te_acc:.4f}")

# Optional: basic classification report & confusion matrix
try:
    from sklearn.metrics import classification_report, confusion_matrix
    print("\nClassification report (new-index space):")
    print(classification_report(te_tgts, te_preds, digits=3, zero_division=0))
except Exception as e:
    print("sklearn metrics unavailable:", e)

### Section 13 - Calibrate Probabilities
Performs temperature scaling on validation logits to calibrate predicted probabilities and saves the resulting temperature for downstream inference.


In [None]:
# ===================================
# 13) TEMPERATURE SCALING (Calibrate)
# ===================================
def temperature_scale(logits, T):
    return logits / T

def find_temperature(val_loader):
    net.eval()
    logits_list, y_list = [], []
    with torch.no_grad():
        for x,y in val_loader:
            x = x.to(device)
            logits,_ = net(x, None)
            logits_list.append(logits.cpu().numpy())
            y_list.append(y.numpy())
    L = np.concatenate(logits_list); Y = np.concatenate(y_list)

    T = 1.0
    for _ in range(100):
        # simple 1D search via gradient-free update
        temps = np.linspace(0.5, 3.0, 26)
        nll = []
        for t in temps:
            z = torch.tensor(L/t).float()
            y = torch.tensor(Y).long()
            nll.append(F.cross_entropy(z, y).item())
        T = float(temps[int(np.argmin(nll))])
    return T

T = find_temperature(val_loader)
print("Best temperature:", T)

os.makedirs(CONFIG["OUT_DIR"], exist_ok=True)
with open(os.path.join(CONFIG["OUT_DIR"], "calibration.json"), "w") as f:
    json.dump({"temperature": T}, f)

### Section 14 - Inference Helpers
Supplies convenience functions to load single images, apply optional sheet-based masking, and return top-k predictions with metadata for interactive use or deployment.


In [None]:
# ==========================================
# Section 14 – Inference Helpers
# ==========================================
import torch

def argmax_to_original_id(pred_new_idx):
    """
    Convert predicted new-index classes (0..K-1) back to original 424-id labels.
    """
    if torch.is_tensor(pred_new_idx):
        pred_new_idx = pred_new_idx.detach().cpu().tolist()
    return [new2old[int(i)] for i in pred_new_idx]

@torch.no_grad()
def predict_images(img_batch_tensor):
    """
    img_batch_tensor: (B,3,H,W) normalized as in training.
    Returns:
      preds_new: tensor of shape (B,) in 0..K-1
      preds_orig: list of original class ids (subset of 424 space)
      logits: raw logits (B,K)
    """
    model.eval()
    x = img_batch_tensor.to(device)
    logits, _ = _forward(x, torch.zeros(x.size(0), dtype=torch.long, device=device))
    preds_new = logits.argmax(1)
    preds_orig = argmax_to_original_id(preds_new)
    return preds_new.cpu(), preds_orig, logits.cpu()

### Section 15 - Save Metadata Artefacts
Exports label mappings, sheet groupings, and other metadata so future sessions can decode predictions without recomputing the taxonomy processing.


In [None]:
# =============================
# 15) SAVE LABEL MAP & METRICS
# =============================
out = {
    "id_to_idx": id_to_idx,
    "idx_to_info": idx_to_info,
    "sheet_to_ids": {k:list(v) for k,v in sheet_to_ids.items()},
    "image_size": CONFIG["IMAGE_SIZE"],
}
with open(os.path.join(CONFIG["OUT_DIR"], "label_map.json"), "w", encoding="utf-8") as f:
    json.dump(out, f, indent=2, ensure_ascii=False)
print("Saved label_map.json")

### Section 16 - Export to ONNX (Optional)
Shows how to export the trained PyTorch model to ONNX for serving in inference runtimes that expect the format.


In [None]:
# ========================
# 16) OPTIONAL: ONNX EXPORT
# ========================
example = torch.randn(1,3,CONFIG["IMAGE_SIZE"],CONFIG["IMAGE_SIZE"]).to(device)
torch.onnx.export(net, (example, None), os.path.join(CONFIG["OUT_DIR"], "model_best.onnx"),
                  input_names=["image","targets"], output_names=["logits","loss"],
                  opset_version=17, do_constant_folding=True, verbose=False)
print("Exported ONNX.")

### Section 17 - Run Summary
Prints a concise recap of the paths, image size, and compute device used for the current run, making experiment tracking easier.


In [None]:
# ====================
# 17) RUN SUMMARY
# ====================
print("\n================ SUMMARY ================")
print("Images dir       :", CONFIG["IMAGES_DIR"])
print("CSV path         :", CONFIG["CSV_PATH"])
print("Labels JSON      :", CONFIG["LABELS_JSON"])
print("Output dir       :", CONFIG["OUT_DIR"])
print("Image size       :", CONFIG["IMAGE_SIZE"])
print("Device           :", device)
print("=========================================")

### Appendix - Optional Image Inspection
Commented example code for plotting a handful of sample images with PIL and Matplotlib. Uncomment when you need a quick visual check of the synced dataset.


In [None]:

# from pathlib import Path
# from PIL import Image
# import matplotlib.pyplot as plt
# samples = [
#     Path(CONFIG["IMAGES_DIR"]) / "17-034_page6_img3.png",
#     Path(CONFIG["IMAGES_DIR"]) / "17-057-4_page17_img3.png",
#     Path(CONFIG["IMAGES_DIR"]) / "17-057-16_page16_img2.png",
#     Path(CONFIG["IMAGES_DIR"]) / "17-067_page21_img1.png",
#     Path(CONFIG["IMAGES_DIR"]) / "19-025_page24_img3.png",
#     Path(CONFIG["IMAGES_DIR"]) / "21-009_page14_img2.png",
# ]
# ]
# for p in samples:
#     try:
#         plt.figure(); plt.imshow(Image.open(p)); plt.axis('off'); plt.title(Path(p).name)
#     except Exception as e:
#         print("Missing:", p)
