In [1]:
%cd "/home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code"

/home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code


# Modules

In [2]:
import argparse
import json
import math
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

try:
    import cv2  # type: ignore
except Exception:  # pragma: no cover - optional dependency
    cv2 = None

from utils import evaluate_single_image, rle_encode, score as competition_score
import itertools

# Dataset

In [3]:
class InferenceDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_size: int = 448, processor=None):
        self.df = df.reset_index(drop=True)
        self.img_size = img_size
        self.processor = processor
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        self.img_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            self.to_tensor,
            self.normalize,
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        case_id = row["case_id"]
        mask_path = row.get("mask_path", "")
        label = int(row.get("label", 0))
        image_path = row["image_path"]

        with open(image_path, "rb") as f:
            img = Image.open(f).convert("RGB")

        orig_size = img.size  # (W,H)

        if self.processor is not None:
            img_rs = transforms.functional.resize(img, (self.img_size, self.img_size))
            processed = self.processor(
                images=img_rs,
                return_tensors="pt",
                do_resize=False,
                do_center_crop=False,
                do_normalize=True,
                do_rescale=True,
            )
            img_t = processed["pixel_values"].squeeze(0)
        else:
            img_t = self.img_transform(img)

        return img_t, case_id, orig_size, mask_path, image_path, label



# Helper functions

In [4]:
def normalize_case_id(raw) -> int | str:
    if torch.is_tensor(raw):
        raw = raw.item()
    return int(raw) if isinstance(raw, (np.integer, int)) else str(raw)


def enhanced_adaptive_mask(prob: np.ndarray, alpha_grad=0.45):
    prob = np.asarray(prob, dtype=np.float32)
    if prob.ndim == 3:
        prob = prob[..., 0]

    gx = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)

    enhanced = (1 - alpha_grad) * prob + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)

    thr = float(np.mean(enhanced) + 0.3 * np.std(enhanced))
    mask = (enhanced > thr).astype(np.uint8)

    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    return mask, thr


def finalize_mask(prob: np.ndarray, orig_size_wh: tuple[int, int]):
    # orig_size_wh MUST be (W,H)
    mask, thr = enhanced_adaptive_mask(prob)
    w, h = orig_size_wh
    mask = cv2.resize(mask, (int(w), int(h)), interpolation=cv2.INTER_NEAREST)
    return mask, thr


def apply_overlay(img_np: np.ndarray, mask01: np.ndarray, color=(255, 0, 0), alpha=0.5):
    overlay = img_np.copy()
    mask_bool = mask01.astype(bool)
    overlay[mask_bool] = (
        alpha * np.array(color, dtype=np.float32)
        + (1 - alpha) * overlay[mask_bool].astype(np.float32)
    ).astype(np.uint8)
    return overlay


def collate_fn(batch):
    images, case_ids, orig_sizes, mask_paths, image_paths, labels = zip(*batch)
    images = default_collate(images)
    return images, list(case_ids), list(orig_sizes), list(mask_paths), list(image_paths), list(labels)


def infer_img_size_from_state(state_dict: dict, fallback: int) -> int:
    pos_embed = state_dict.get("backbone.pos_embed")
    patch_weight = state_dict.get("backbone.patch_embed.proj.weight")
    if pos_embed is None or patch_weight is None:
        return fallback

    tokens = pos_embed.shape[1]
    grid_tokens = tokens - 1
    grid = int(math.sqrt(grid_tokens))
    if grid * grid != grid_tokens:
        return fallback

    patch = int(patch_weight.shape[-1])
    return grid * patch


def create_collage(
    image_path: str,
    pred_bin: np.ndarray,
    case_id: str,
    score: float,
    out_dir: Path,
    gt_label: str,
):
    with open(image_path, "rb") as f:
        orig_img = Image.open(f).convert("RGB")
    orig_np = np.array(orig_img)
    vis_w, vis_h = orig_img.size  # (W,H)

    if pred_bin.shape[:2] != (vis_h, vis_w):
        pred_bin_vis = cv2.resize(pred_bin, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST)
    else:
        pred_bin_vis = pred_bin

    gt_mask = load_gt_mask_from_image_path(image_path = image_path, w = vis_w, h = vis_h)

    if gt_mask.shape[:2] != (vis_h, vis_w):
        gt_mask_vis = cv2.resize(gt_mask, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST)
    else:
        gt_mask_vis = gt_mask

    pred_mask_rgb = (pred_bin_vis[..., None] * 255).repeat(3, axis=2).astype(np.uint8)
    gt_mask_rgb = (gt_mask_vis[..., None] * 255).repeat(3, axis=2).astype(np.uint8)

    pred_overlay = apply_overlay(orig_np, pred_bin_vis, color=(255, 0, 0))
    gt_overlay = apply_overlay(orig_np, gt_mask_vis, color=(0, 255, 0))

    fig, axes = plt.subplots(1, 5, figsize=(18, 4))
    titles = [f"Input_{case_id}", "Pred mask", "GT mask", "Pred overlay", "GT overlay"]
    imgs = [orig_np, pred_mask_rgb, gt_mask_rgb, pred_overlay, gt_overlay]
    for ax, img, title in zip(axes, imgs, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis("off")
    plt.tight_layout()

    score_str = f"{score:.3f}".replace(".", "pt")
    save_path = out_dir / f"{gt_label}_{case_id}_{score_str}.png"
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close(fig)


def load_gt_mask_from_image_path(image_path: str, w: int, h: int) -> np.ndarray:
    mask_path = (
        image_path
        .replace("train_images", "train_masks")
        .replace("supplemental_images", "supplemental_masks")
        .replace("/forged", "")
        .replace(".png", ".npy")
    )
    if not Path(mask_path).exists():
        return np.zeros(shape = (h, w), dtype = np.uint8)
    
    arr = np.load(mask_path, allow_pickle=True)
    if arr.ndim == 3:
        arr = arr.max(axis=0)
    gt_mask = (arr > 0).astype(np.uint8)
    if gt_mask.shape[:2] != (h, w):
        gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
    return gt_mask


def save_hist(values, title, out_path: Path, bins=50):
    plt.figure(figsize=(8, 4))
    plt.hist(values, bins=bins)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()

# Inference

#### Params

In [13]:
class args:
    csv_path = "../analysis/area_splits/val_fold0.csv"
    weights_path = "../exps/20260110_131204/vit_base_patch14_dinov2.lvd142m_fold0.pt"
    arch: str = "dinov2_uperhead"
    dinov2_id: str = "facebook/dinov2-base"
    model_name: str = "vit_base_patch14_dinov2.lvd142m"
    img_size: int = 532
    use_hf_processor: bool = True
    outdir: Path = "../analysis/preds_nposw"
    save_collages: bool = True

In [14]:
# adding suffix
outdir = args.outdir + f"_{args.arch}"
outdir = Path(outdir)
if outdir.exists():
    shutil.rmtree(outdir)
outdir.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(args.csv_path)

hf_processor = None
if args.use_hf_processor:
    from transformers import AutoImageProcessor
    hf_processor = AutoImageProcessor.from_pretrained(args.dinov2_id)

weights_path = Path(args.weights_path)
if not weights_path.exists():
    raise FileNotFoundError(f"No weights file at: {weights_path}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_collages = args.save_collages

#### Collecting model predictions

In [10]:
state = torch.load(weights_path, map_location=device)
model_state = state["model_state"]
img_size = args.img_size

if args.arch == "dino_seg":
    inferred_img_size = infer_img_size_from_state(model_state, img_size)
    if inferred_img_size != img_size:
        print(f"Adjusting img_size from {img_size} to {inferred_img_size} based on checkpoint pos_embed.")
        img_size = inferred_img_size
else:
    if img_size % 14 != 0:
        raise ValueError(f"{args.arch} requires img_size divisible by 14 (Dinov2 patch size).")

# -------------------------
# Dataset / loader
# -------------------------
dataset = InferenceDataset(df, img_size=img_size, processor=hf_processor)
loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_fn,
)

if args.arch == "dino_seg":
    from dinov2_seg import DinoSegModel
    model = DinoSegModel(model_name=args.model_name, pretrained=False, img_size=img_size).to(device)
elif args.arch == "dinov2_uperhead":
    from dinov2_uperhead import DinoV2_UPerNet
    model = DinoV2_UPerNet(dinov2_id=args.dinov2_id, num_classes=1).to(device)
elif args.arch == "dinov2_unet":
    from dinov2_unet import DinoV2UNet
    model = DinoV2UNet(dinov2_id=args.dinov2_id, out_classes=1).to(device)
else:
    raise ValueError(f"Unknown arch: {args.arch}")

model.load_state_dict(model_state)
model.eval()
print("[INFO] Loaded model")

# -------------------------
# Run inference once and cache per-image stats (for grid + final submission)
# -------------------------
predictions = []
solution_rows = []

with torch.no_grad():
    for images, case_ids, orig_sizes, mask_paths, image_paths, labels in tqdm(loader, desc="Predicting masks"):
        images = images.to(device, non_blocking=True)
        out = model(images)
        logits = out if torch.is_tensor(out) else out[0]
        probs = torch.sigmoid(logits).float().cpu().numpy()  # (B,1,H,W)

        for i, case_id in enumerate(case_ids):
            case_id_val = str(normalize_case_id(case_id))
            orig_w, orig_h = orig_sizes[i]
            prob = probs[i, 0]  # (img_size,img_size)

            mask_full, thr = finalize_mask(prob, (orig_w, orig_h))  # mask in orig size (H,W)

            gt_mask = load_gt_mask_from_image_path(image_paths[i], orig_w, orig_h)
            gt_label = "authentic" if int(labels[i]) == 0 else "forged"
            shape_str = json.dumps([int(orig_h), int(orig_w)])  # (H,W) for scorer

            predictions.append(dict(
                case_id=case_id_val,
                image_path=image_paths[i],
                orig_width=int(orig_w),
                orig_height=int(orig_h),
                prob=prob,
                pred_mask=mask_full.astype(np.uint8),
                gt_mask=gt_mask.astype(np.uint8),
                gt_label=gt_label,
                thr=float(thr),
                shape_str=shape_str,
            ))

            # scorer needs 'annotation' and 'shape' columns in solution
            solution_rows.append(dict(
                case_id=case_id_val,
                annotation="authentic" if gt_label == "authentic" else rle_encode([gt_mask.astype(np.uint8)]),
                shape=shape_str,
            ))

solution_df = pd.DataFrame(solution_rows)

[INFO] Loaded model


Predicting masks: 100%|██████████| 513/513 [00:36<00:00, 13.87it/s]


#### Grid Search

In [12]:
# -------------------------
# FAST Grid Search (precompute CC variants once)
# Replace your whole slow triple-loop block with this.
# Assumes you already have:
#   - predictions: list[dict] with keys: case_id, pred_mask (H,W), prob (S,S)
#   - solution_df
#   - competition_score(...)
#   - rle_encode(...)
#   - cv2 imported
# -------------------------

area_range = [i * 100 for i in range(1, 11)]
mean_range = [round(x, 2) for x in np.arange(0.20, 0.9, 0.05)]
cc_range = [0, 50, 100, 200, 400, 800]

best_score = -1.0
best_params = {"area_thres": None, "mean_thres": None, "min_cc_area": None}


def remove_small_components(mask01: np.ndarray, min_area: int) -> np.ndarray:
    """
    mask01: uint8/bool (H,W) with values {0,1}
    min_area: remove connected components smaller than this many pixels
    """
    if min_area is None or min_area <= 0:
        return (mask01 > 0).astype(np.uint8)

    m = (mask01 > 0).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)

    out = np.zeros_like(m, dtype=np.uint8)
    for lab in range(1, num_labels):
        if stats[lab, cv2.CC_STAT_AREA] >= min_area:
            out[labels == lab] = 1
    return out


# -------------------------
# 1) Precompute per-image stats/RLE for each CC threshold
# -------------------------
N = len(predictions)
K = len(cc_range)

area_cc = np.zeros((N, K), dtype=np.int32)
mean_cc = np.zeros((N, K), dtype=np.float32)
rle_cc = [[None] * K for _ in range(N)]  # precomputed string per (image, cc)

for i, pred in enumerate(tqdm(predictions, desc="Precomputing CC variants", leave=True)):
    prob = pred["prob"]  # (S,S) model-res prob map
    base = (pred["pred_mask"] > 0).astype(np.uint8)  # (H,W) bin mask in orig size

    # We'll reuse the same resize target each time
    prob_w, prob_h = prob.shape[1], prob.shape[0]

    for k, min_cc in enumerate(cc_range):
        m = remove_small_components(base, int(min_cc))
        a = int(m.sum())
        area_cc[i, k] = a

        if a > 0:
            m_small = cv2.resize(m, (prob_w, prob_h), interpolation=cv2.INTER_NEAREST)
            mean_cc[i, k] = float(prob[m_small == 1].mean()) if (m_small == 1).any() else 0.0
            rle_cc[i][k] = rle_encode([m])  # NOTE: RLE of CC-filtered mask
        else:
            mean_cc[i, k] = 0.0
            rle_cc[i][k] = "authentic"


# -------------------------
# 2) Grid search without cv2 / RLE inside the hot loop
# -------------------------
case_ids = [p["case_id"] for p in predictions]

for k, min_cc in enumerate(cc_range):
    combos_2d = list(itertools.product(area_range, mean_range))
    pbar = tqdm(combos_2d, desc=f"Grid Search (CC={min_cc})", leave=True)

    for a_thr, m_thr in pbar:
        is_forged = (area_cc[:, k] >= a_thr) & (mean_cc[:, k] >= m_thr)

        submission_rows = [
            {
                "case_id": case_ids[i],
                "annotation": (rle_cc[i][k] if is_forged[i] else "authentic"),
            }
            for i in range(N)
        ]

        submission_df = pd.DataFrame(submission_rows)
        curr_score = competition_score(
            solution=solution_df,
            submission=submission_df,
            row_id_column_name="case_id",
        )

        if curr_score > best_score:
            best_score = float(curr_score)
            best_params = {
                "area_thres": int(a_thr),
                "mean_thres": float(m_thr),
                "min_cc_area": int(min_cc),
            }
            pbar.set_postfix(best=f"{best_score:.6f}", A=a_thr, M=m_thr, CC=min_cc)

print("BEST:", best_score, best_params)

Precomputing CC variants:   0%|          | 0/1026 [00:00<?, ?it/s]

Precomputing CC variants: 100%|██████████| 1026/1026 [00:21<00:00, 47.86it/s]
Grid Search (CC=0): 100%|██████████| 140/140 [02:40<00:00,  1.15s/it, A=900, CC=0, M=0.75, best=0.523972]
Grid Search (CC=50): 100%|██████████| 140/140 [02:41<00:00,  1.15s/it, A=100, CC=50, M=0.55, best=0.525308]
Grid Search (CC=100): 100%|██████████| 140/140 [02:42<00:00,  1.16s/it, A=100, CC=100, M=0.55, best=0.530177]
Grid Search (CC=200): 100%|██████████| 140/140 [02:41<00:00,  1.15s/it]
Grid Search (CC=400): 100%|██████████| 140/140 [02:39<00:00,  1.14s/it]
Grid Search (CC=800): 100%|██████████| 140/140 [02:40<00:00,  1.15s/it, A=100, CC=800, M=0.75, best=0.532601]

BEST: 0.5326009045463622 {'area_thres': 100, 'mean_thres': 0.75, 'min_cc_area': 800}





#### Submission run

In [15]:
# -------- final submission + optional collages
outdir.mkdir(parents=True, exist_ok=True)
analysis_dir = outdir / "analysis"
analysis_dir.mkdir(parents=True, exist_ok=True)
collage_dir = outdir / "collage"
if save_collages:
    collage_dir.mkdir(parents=True, exist_ok=True)

# find cc index
best_cc = int(best_params["min_cc_area"])
best_k = cc_range.index(best_cc)

final_rows = []
# For hists (use best CC filtered values)
best_area_vals = area_cc[:, best_k].tolist()
best_mean_vals = mean_cc[:, best_k].tolist()
thr_vals = [p["thr"] for p in predictions]

for i, p in enumerate(tqdm(predictions, desc="Final submission", leave=True)):
    a_thr = int(best_params["area_thres"])
    m_thr = float(best_params["mean_thres"])

    is_forged = (area_cc[i, best_k] >= a_thr) and (mean_cc[i, best_k] >= m_thr)
    annotation = (rle_cc[i][best_k] if is_forged else "authentic")
    final_rows.append({"case_id": p["case_id"], "annotation": annotation})

    if save_collages:
        # compute per-image score (only for naming)
        label_rles = solution_df.loc[solution_df["case_id"] == p["case_id"], "annotation"].iloc[0]
        pred_rles = annotation

        if (label_rles == "authentic") or (pred_rles == "authentic"):
            image_score = 1.0 if (label_rles == pred_rles) else 0.0
        else:
            image_score = float(evaluate_single_image(
                label_rles=label_rles,
                prediction_rles=pred_rles,
                shape_str=p["shape_str"],
            ))

        # for visualization, decode the chosen pred mask
        if annotation == "authentic":
            final_mask = np.zeros_like(p["gt_mask"], dtype=np.uint8)
        else:
            # We already have CC-filtered binary mask inside RLE, but not stored as array.
            # Recreate it quickly: apply CC filter again (only for collage, not for scoring)
            base = (p["pred_mask"] > 0).astype(np.uint8)
            final_mask = remove_small_components(base, best_cc)

        create_collage(
            image_path=p["image_path"],
            pred_bin=final_mask,
            case_id=p["case_id"],
            score=image_score,
            out_dir=collage_dir,
            gt_label=p["gt_label"],
        )

final_submission_df = pd.DataFrame(final_rows)
submission_path = outdir / "submission.csv"
final_submission_df.to_csv(submission_path, index=False)
print(f"[INFO] Saved: {submission_path}")

# -------- distributions
save_hist(best_area_vals, f"Mask Area (CC>={best_cc})", analysis_dir / "mask_area_hist.png")
save_hist(best_mean_vals, f"Mask Mean-Inside (CC>={best_cc})", analysis_dir / "mask_mean_inside_hist.png")
save_hist(thr_vals, "Finalize Threshold (thr)", analysis_dir / "thr_hist.png")

# -------- metadata json
run_meta = {
    "best_score": best_score,
    "best_params": best_params,
    "per_case": [
        {
            "case_id": p["case_id"],
            "image_path": p["image_path"],
            "orig_width": p["orig_width"],
            "orig_height": p["orig_height"],
            "gt_label": p["gt_label"],
            "thr": p["thr"],
            "area_cc_best": int(area_cc[i, best_k]),
            "mean_cc_best": float(mean_cc[i, best_k]),
        }
        for i, p in enumerate(predictions)
    ],
}
meta_path = analysis_dir / "run_metadata.json"
meta_path.write_text(json.dumps(run_meta, indent=2))
print(f"[INFO] Saved: {meta_path}")

Final submission: 100%|██████████| 1026/1026 [09:00<00:00,  1.90it/s] 


[INFO] Saved: ../analysis/preds_nposw_dinov2_uperhead/submission.csv
[INFO] Saved: ../analysis/preds_nposw_dinov2_uperhead/analysis/run_metadata.json
