In [5]:
%cd "/home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code"

/home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code


# Modules

In [11]:
import argparse
import json
import math
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

try:
    import cv2  # type: ignore
except Exception:  # pragma: no cover - optional dependency
    cv2 = None

from utils import evaluate_single_image, rle_encode, score as competition_score
import itertools
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    classification_report,
    roc_curve,
    precision_recall_curve,
)
from dinov2_uperhead import DinoV2_UPerNet
from transformers import AutoImageProcessor, AutoModel
import joblib

In [12]:
def normalize_case_id(raw) -> int | str:
    if torch.is_tensor(raw):
        raw = raw.item()
    return int(raw) if isinstance(raw, (np.integer, int)) else str(raw)


def enhanced_adaptive_mask(prob: np.ndarray, alpha_grad=0.45):
    prob = np.asarray(prob, dtype=np.float32)
    if prob.ndim == 3:
        prob = prob[..., 0]

    gx = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)

    enhanced = (1 - alpha_grad) * prob + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)

    thr = float(np.mean(enhanced) + 0.3 * np.std(enhanced))
    mask = (enhanced > thr).astype(np.uint8)

    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    return mask, thr


def finalize_mask(prob: np.ndarray, orig_size_wh: tuple[int, int]):
    # orig_size_wh MUST be (W,H)
    mask, thr = enhanced_adaptive_mask(prob)
    w, h = orig_size_wh
    mask = cv2.resize(mask, (int(w), int(h)), interpolation=cv2.INTER_NEAREST)
    return mask, thr


def apply_overlay(img_np: np.ndarray, mask01: np.ndarray, color=(255, 0, 0), alpha=0.5):
    overlay = img_np.copy()
    mask_bool = mask01.astype(bool)
    overlay[mask_bool] = (
        alpha * np.array(color, dtype=np.float32)
        + (1 - alpha) * overlay[mask_bool].astype(np.float32)
    ).astype(np.uint8)
    return overlay


def collate_fn(batch):
    images, case_ids, orig_sizes, mask_paths, image_paths, labels = zip(*batch)
    images = default_collate(images)
    return images, list(case_ids), list(orig_sizes), list(mask_paths), list(image_paths), list(labels)


def infer_img_size_from_state(state_dict: dict, fallback: int) -> int:
    pos_embed = state_dict.get("backbone.pos_embed")
    patch_weight = state_dict.get("backbone.patch_embed.proj.weight")
    if pos_embed is None or patch_weight is None:
        return fallback

    tokens = pos_embed.shape[1]
    grid_tokens = tokens - 1
    grid = int(math.sqrt(grid_tokens))
    if grid * grid != grid_tokens:
        return fallback

    patch = int(patch_weight.shape[-1])
    return grid * patch


def create_collage(
    image_path: str,
    pred_bin: np.ndarray,
    case_id: str,
    score: float,
    out_dir: Path,
    gt_label: str,
):
    with open(image_path, "rb") as f:
        orig_img = Image.open(f).convert("RGB")
    orig_np = np.array(orig_img)
    vis_w, vis_h = orig_img.size  # (W,H)

    if pred_bin.shape[:2] != (vis_h, vis_w):
        pred_bin_vis = cv2.resize(pred_bin, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST)
    else:
        pred_bin_vis = pred_bin

    gt_mask = load_gt_mask_from_image_path(image_path = image_path, w = vis_w, h = vis_h)

    if gt_mask.shape[:2] != (vis_h, vis_w):
        gt_mask_vis = cv2.resize(gt_mask, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST)
    else:
        gt_mask_vis = gt_mask

    pred_mask_rgb = (pred_bin_vis[..., None] * 255).repeat(3, axis=2).astype(np.uint8)
    gt_mask_rgb = (gt_mask_vis[..., None] * 255).repeat(3, axis=2).astype(np.uint8)

    pred_overlay = apply_overlay(orig_np, pred_bin_vis, color=(255, 0, 0))
    gt_overlay = apply_overlay(orig_np, gt_mask_vis, color=(0, 255, 0))

    fig, axes = plt.subplots(1, 5, figsize=(18, 4))
    titles = [f"Input_{case_id}", "Pred mask", "GT mask", "Pred overlay", "GT overlay"]
    imgs = [orig_np, pred_mask_rgb, gt_mask_rgb, pred_overlay, gt_overlay]
    for ax, img, title in zip(axes, imgs, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis("off")
    plt.tight_layout()

    score_str = f"{score:.3f}".replace(".", "pt")
    save_path = out_dir / f"{gt_label}_{case_id}_{score_str}.png"
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close(fig)


def load_gt_mask_from_image_path(image_path: str, w: int, h: int) -> np.ndarray:
    mask_path = (
        image_path
        .replace("train_images", "train_masks")
        .replace("supplemental_images", "supplemental_masks")
        .replace("/forged", "")
        .replace(".png", ".npy")
    )
    if not Path(mask_path).exists():
        return np.zeros(shape = (h, w), dtype = np.uint8)
    
    arr = np.load(mask_path, allow_pickle=True)
    if arr.ndim == 3:
        arr = arr.max(axis=0)
    gt_mask = (arr > 0).astype(np.uint8)
    if gt_mask.shape[:2] != (h, w):
        gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
    return gt_mask


def save_hist(values, title, out_path: Path, bins=50):
    plt.figure(figsize=(8, 4))
    plt.hist(values, bins=bins)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()

In [18]:
class EmbeddingModel:
    def __init__(self, model_id: str, device: torch.device):
        self.encoder = AutoModel.from_pretrained(model_id).eval().to(device)
        self.device = device

    @torch.no_grad()
    def encode_from_pixel_values(self, pixel_values: torch.Tensor):
        out = self.encoder(pixel_values=pixel_values)
        cls = out.last_hidden_state[:, 0, :]  # [B, D]
        return cls.float().cpu().numpy()

class InferenceDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_size: int, processor):
        self.df = df.reset_index(drop=True)
        self.img_size = int(img_size)
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        case_id = str(row["case_id"])
        image_path = str(row["image_path"])
        label = int(row.get("label", 0))

        with open(image_path, "rb") as f:
            img = Image.open(f).convert("RGB")
        orig_size = img.size  # (W,H)

        img_rs = transforms.functional.resize(img, (self.img_size, self.img_size))
        processed = self.processor(
            images=img_rs,
            return_tensors="pt",
            do_resize=False,
            do_center_crop=False,
            do_normalize=True,
            do_rescale=True,
        )
        img_t = processed["pixel_values"].squeeze(0)  # [C,H,W] torch tensor

        return img_t, case_id, orig_size, image_path, label

def collate_fn(batch):
    images, case_ids, orig_sizes, image_paths, labels = zip(*batch)
    images = default_collate(images)  # [B,C,H,W]
    labels = torch.tensor(labels, dtype=torch.long)
    return images, list(case_ids), list(orig_sizes), list(image_paths), labels

def remove_small_components(mask01: np.ndarray, min_area: int) -> np.ndarray:
    """
    mask01: uint8/bool (H,W) with values {0,1}
    min_area: remove connected components smaller than this many pixels
    """
    if min_area is None or min_area <= 0:
        return (mask01 > 0).astype(np.uint8)

    m = (mask01 > 0).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)

    out = np.zeros_like(m, dtype=np.uint8)
    for lab in range(1, num_labels):
        if stats[lab, cv2.CC_STAT_AREA] >= min_area:
            out[labels == lab] = 1
    return out

# Inference

In [17]:
class args:
    csv_path = "../analysis/area_splits/val_fold0.csv"
    weights_path = "../exps/20260110_131204/vit_base_patch14_dinov2.lvd142m_fold0.pt"
    arch: str = "dinov2_uperhead"
    dinov2_id: str = "facebook/dinov2-base"
    model_name: str = "vit_base_patch14_dinov2.lvd142m"
    img_size: int = 532
    use_hf_processor: bool = True
    outdir: Path = "../analysis/preds_nposw"
    save_collages: bool = True
    logistic_regression_model: str = "../exps/classifiers/20260113_074223/models/logreg/model.pkl"
    svc_model: str = "../exps/classifiers/20260113_074223/models/linear_svm_calibrated/model.pkl"
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    best_seg_params: str = "../analysis/preds_nposw_dinov2_uperhead/analysis/run_metadata.json"
    best_svc_params: str = "../exps/classifiers/20260113_074223/models/linear_svm_calibrated/metrics.json"
    best_lgr_params: str = "../exps/classifiers/20260113_074223/models/logreg/metrics.json"

if args.img_size % 14 != 0:
    raise ValueError(f"{args.arch} requires img_size divisible by 14 (Dinov2 patch size).")


In [None]:
# -------------------------
# Dataset / loader
# -------------------------

df = pd.read_csv(args.csv_path)

hf_processor = AutoImageProcessor.from_pretrained(args.dinov2_id)
dataset = InferenceDataset(df, img_size=args.img_size, processor=hf_processor)
loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_fn,
)

# -------------------------
# Models
# -------------------------
model = DinoV2_UPerNet(dinov2_id=args.dinov2_id, num_classes=1).to(args.device)
state = torch.load(args.weights_path, map_location=args.device)
model_state = state["model_state"]
model.load_state_dict(model_state)
model.eval()
print("[INFO] Loaded Segmentation model")

clf_logreg = joblib.load(args.logistic_regression_model)
clf_svc = joblib.load(args.svc_model)
print("[INFO] Loaded classifiers")

embedder = EmbeddingModel(args.dinov2_id, args.device)
print("[INFO] Loaded classifiers")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


[INFO] Loaded Segmentation model
[INFO] Loaded classifiers
[INFO] Loaded classifiers


In [20]:
with open(args.best_seg_params, "r") as f:
    best_seg_params = json.load(f)

best_seg_params = best_seg_params["best_params"]

with open(args.best_lgr_params, "r") as f:
    best_clf_lg_params = json.load(f)

with open(args.best_svc_params, "r") as f:
    best_clf_svc_params = json.load(f)

predictions = []
solution_rows = []

with torch.no_grad():
    for images, case_ids, orig_sizes, image_paths, labels in tqdm(loader, desc="Predicting"):
        images = images.to(args.device, non_blocking=True)

        emb = embedder.encode_from_pixel_values(images)           # (B,D) numpy
        proba_lr = clf_logreg.predict_proba(emb)[:, 1]            # (B,)
        proba_svc = clf_svc.predict_proba(emb)[:, 1]              # (B,)

        out = model(images)
        logits = out if torch.is_tensor(out) else out[0]
        probs = torch.sigmoid(logits).float().cpu().numpy()       # (B,1,S_h,S_w)

        for i, case_id in enumerate(case_ids):
            cid = str(normalize_case_id(case_id))
            orig_w, orig_h = orig_sizes[i]
            prob = probs[i, 0]                                    # (S_h,S_w)

            mask_full, thr = finalize_mask(prob, (int(orig_w), int(orig_h)))  # (H,W)
            gt_mask = load_gt_mask_from_image_path(image_paths[i], int(orig_w), int(orig_h))
            gt_label = "authentic" if int(labels[i]) == 0 else "forged"
            shape_str = json.dumps([int(orig_h), int(orig_w)])

            # binarize + cc filter
            pred_bin = (mask_full > 0).astype(np.uint8)
            pred_bin = remove_small_components(pred_bin, int(best_seg_params.get("min_cc_area", 0))).astype(np.uint8)

            # area + mean_inside (align to prob shape)
            area = int(pred_bin.sum())
            if area > 0:
                S_h, S_w = prob.shape
                mask_small = cv2.resize(pred_bin, (S_w, S_h), interpolation=cv2.INTER_NEAREST)
                mean_inside = float(prob[mask_small == 1].mean()) if (mask_small == 1).any() else 0.0
            else:
                mean_inside = 0.0

            # seg gating label
            a_thr = int(best_seg_params["area_thres"])
            m_thr = float(best_seg_params["mean_thres"])
            pred_mask_thres = "forged" if (area >= a_thr and mean_inside >= m_thr) else "authentic"

            # rle (only if forged + non-empty)
            if pred_mask_thres == "forged" and area > 0:
                rle_str = rle_encode([pred_bin])
            else:
                rle_str = "authentic"

            predictions.append(dict(
                case_id=cid,
                image_path=image_paths[i],
                orig_width=int(orig_w),
                orig_height=int(orig_h),

                prob=prob.astype(np.float32),
                pred_mask=pred_bin,           # store final filtered mask (better than raw mask_full)
                thr=float(thr),

                gt_mask=gt_mask.astype(np.uint8),
                gt_label=gt_label,
                shape_str=shape_str,

                proba_lr=float(proba_lr[i]),
                proba_svc=float(proba_svc[i]),
                embedding=emb[i].astype(np.float32),

                pred_lr="forged" if float(proba_lr[i]) >= float(best_clf_lg_params["threshold_best_f1"]) else "authentic",
                pred_svc="forged" if float(proba_svc[i]) >= float(best_clf_svc_params["threshold_best_f1"]) else "authentic",
                pred_mask_thres=pred_mask_thres,

                mask_area=area,
                mask_mean_inside=float(mean_inside),

                rle_encode_str=rle_str,
            ))

            solution_rows.append(dict(
                case_id=cid,
                annotation="authentic" if gt_label == "authentic" else rle_encode([gt_mask.astype(np.uint8)]),
                shape=shape_str,
            ))

solution_df = pd.DataFrame(solution_rows)


Predicting: 100%|██████████| 513/513 [00:57<00:00,  8.89it/s]


### Competition scores

In [21]:
import pandas as pd

def make_submission_from_logic(predictions, decide_fn):
    rows = []
    for p in predictions:
        is_forged = bool(decide_fn(p))
        if is_forged:
            ann = p["rle_encode_str"]  # should already be "authentic" if empty
            # safety: don't allow forged label with empty rle
            if ann == "authentic":
                ann = "authentic"
        else:
            ann = "authentic"
        rows.append({"case_id": str(p["case_id"]), "annotation": ann})
    return pd.DataFrame(rows)

def score_logic(name, predictions, solution_df, decide_fn):
    submission_df = make_submission_from_logic(predictions, decide_fn)
    s = competition_score(solution=solution_df, submission=submission_df, row_id_column_name="case_id")
    return name, float(s)

# helpers: convert label strings to bool
def is_lr(p):  return p["pred_lr"] == "forged"
def is_svc(p): return p["pred_svc"] == "forged"
def is_mask(p):return p["pred_mask_thres"] == "forged"

logics = [
    ("lgr_and_pred_mask_thres", lambda p: is_lr(p) and is_mask(p)),
    ("lgr_or_pred_mask_thres",  lambda p: is_lr(p) or  is_mask(p)),

    ("svc_and_pred_mask_thres", lambda p: is_svc(p) and is_mask(p)),
    ("svc_or_pred_mask_thres",  lambda p: is_svc(p) or  is_mask(p)),

    ("lgr_and_svc_and_pred_mask_thres", lambda p: is_lr(p) and is_svc(p) and is_mask(p)),
    ("lgr_or_svc_or_pred_mask_thres",   lambda p: is_lr(p) or  is_svc(p) or  is_mask(p)),

    # mixed ones (often useful)
    ("lgr_and_svc_or_pred_mask_thres",  lambda p: (is_lr(p) and is_svc(p)) or is_mask(p)),
    ("lgr_or_svc_and_pred_mask_thres",  lambda p: (is_lr(p) or  is_svc(p)) and is_mask(p)),
]

results = []
for name, fn in logics:
    results.append(score_logic(name, predictions, solution_df, fn))

results_df = pd.DataFrame(results, columns=["logic", "score"]).sort_values("score", ascending=False)
print(results_df.to_string(index=False))


                          logic    score
        lgr_and_pred_mask_thres 0.540774
 lgr_or_svc_and_pred_mask_thres 0.540774
lgr_and_svc_and_pred_mask_thres 0.538667
        svc_and_pred_mask_thres 0.538667
         svc_or_pred_mask_thres 0.532601
         lgr_or_pred_mask_thres 0.532601
  lgr_or_svc_or_pred_mask_thres 0.532601
 lgr_and_svc_or_pred_mask_thres 0.532601
