In [None]:
from pathlib import Path
import re
from collections import defaultdict
import csv
import sys
import math
from typing import List, Tuple, Optional

# Change paths if necessary
ROOT = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
DIR_TRAIN_AUTH = Path(f"{ROOT}/train_images/authentic")
DIR_TRAIN_FORG = Path(f"{ROOT}/train_images/forged")
DIR_TEST       = Path(f"{ROOT}/test_images")  # not used, kept for compatibility
OUTPUT_CSV     = Path("/kaggle/working/pairs.csv")

# Method selection
USE_PHASH = True           # Use pHash for unmatched images by name?
PHASH_THRESHOLD = 10       # Hamming distance threshold (lower is better)

USE_SSIM = False           # (Optional) Use SSIM for remaining unmatched -after pHash?
SSIM_THRESHOLD = 0.70      # SSIM threshold (higher is better)
SSIM_MAX_SIDE = 512        # resize for speed

# If the dataset is too large, you can limit the number of test samples (None = no limit)
DEBUG_LIMIT_FORGED = None   # e.g. 200

# ------------------- Dependencies -------------------
# Install required packages (allowed in Kaggle). If already installed, this section will just pass.
try:
    import imagehash  # type: ignore
    from PIL import Image  # type: ignore
except Exception:
    !pip -q install imagehash
    import imagehash
    from PIL import Image

if USE_SSIM:
    try:
        import cv2  # type: ignore
        from skimage.metrics import structural_similarity as ssim  # type: ignore
    except Exception:
        !pip -q install scikit-image opencv-python-headless
        import cv2
        from skimage.metrics import structural_similarity as ssim

# ------------------- Utilities -------------------
IMG_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

def list_images(root: Path) -> List[Path]:
    return sorted([p for p in root.rglob("*") if p.suffix.lower() in IMG_EXTS])

def normalize_name(p: Path) -> str:
    s = p.stem.lower()
    s = s.replace("-", "_").replace(" ", "_")
    # Remove common ending tags
    s = re.sub(r"(authentic|auth|original|orig|clean|real|gt)$", "", s)
    s = re.sub(r"(forg(ed)?|fake|tampered|edit(ed)?|manipulated?)$", "", s)
    s = re.sub(r"(__+|_+$)", "", s)
    return s

def safe_open_image(path: Path) -> Optional[Image.Image]:
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return None

def phash_value(path: Path):
    im = safe_open_image(path)
    if im is None:
        return None
    try:
        return imagehash.phash(im)
    except Exception:
        return None

def hamming_distance(h1, h2) -> int:
    return abs(h1 - h2)

def load_gray_resized_cv2(path: Path, max_side=512):
    import numpy as np
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        return None
    h, w = img.shape[:2]
    scale = min(1.0, max_side / max(h, w))
    if scale < 1.0:
        img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
    return img

def best_match_by_phash(gf: Path, auth_hashes: List[Tuple[Path, object]]) -> Tuple[Optional[Path], Optional[int]]:
    gh = phash_value(gf)
    if gh is None:
        return (None, None)
    best_p, best_d = None, None
    for ap, ah in auth_hashes:
        if ah is None:
            continue
        d = hamming_distance(gh, ah)
        if best_d is None or d < best_d:
            best_p, best_d = ap, d
    return (best_p, best_d)

def best_match_by_ssim(gf: Path, auth_imgs: List[Tuple[Path, 'np.ndarray']], max_side=512) -> Tuple[Optional[Path], Optional[float]]:
    import numpy as np
    gi = load_gray_resized_cv2(gf, max_side=max_side)
    if gi is None:
        return (None, None)
    best_p, best_sc = None, None
    gh, gw = gi.shape[:2]
    for ap, ai in auth_imgs:
        if ai is None:
            continue
        ah, aw = ai.shape[:2]
        # Simple alignment if dimensions differ
        if (ah, aw) != (gh, gw):
            ai_r = cv2.resize(ai, (gw, gh), interpolation=cv2.INTER_AREA)
        else:
            ai_r = ai
        try:
            sc = ssim(gi, ai_r)
        except Exception:
            continue
        if best_sc is None or sc > best_sc:
            best_p, best_sc = ap, sc
    return (best_p, best_sc)

# ------------------- Main Flow -------------------
def main():
    # 1) List image files
    auth_files = list_images(DIR_TRAIN_AUTH)
    forg_files = list_images(DIR_TRAIN_FORG)
    if DEBUG_LIMIT_FORGED is not None:
        forg_files = forg_files[:DEBUG_LIMIT_FORGED]

    print(f"#auth = {len(auth_files)}, #forg = {len(forg_files)}")

    # 2) Pairing based on file names
    auth_map = defaultdict(list)
    for p in auth_files:
        auth_map[normalize_name(p)].append(p)

    pairs: List[Tuple[Path, Path, str, float]] = []
    unmatched_forg: List[Path] = []

    for gf in forg_files:
        key = normalize_name(gf)
        if key in auth_map and len(auth_map[key]) > 0:
            # If multiple candidates exist, take the first for now
            ap = auth_map[key][0]
            pairs.append((ap, gf, "name", 0.0))
        else:
            unmatched_forg.append(gf)

    print(f"Name-matched pairs: {len(pairs)}")
    print(f"Unmatched forged by name: {len(unmatched_forg)}")

    # 3) pHash matching for unmatched images
    still_unmatched: List[Path] = unmatched_forg
    if USE_PHASH and len(still_unmatched) > 0:
        print("Computing pHash for authentic images...")
        auth_hashes = []
        for ap in auth_files:
            try:
                ah = phash_value(ap)
            except Exception:
                ah = None
            auth_hashes.append((ap, ah))

        print("Matching unmatched forged by pHash...")
        newly_paired = 0
        next_unmatched = []
        for gf in still_unmatched:
            ap, dist = best_match_by_phash(gf, auth_hashes)
            if ap is not None and dist is not None and dist <= PHASH_THRESHOLD:
                pairs.append((ap, gf, "phash", float(dist)))
                newly_paired += 1
            else:
                next_unmatched.append(gf)
        still_unmatched = next_unmatched
        print(f"pHash new pairs (<= {PHASH_THRESHOLD}): {newly_paired}")
        print(f"Remaining unmatched after pHash: {len(still_unmatched)}")

    # 4) Optional SSIM matching
    if USE_SSIM and len(still_unmatched) > 0:
        print("Preloading grayscale resized authentic images for SSIM...")
        auth_imgs = []
        for ap in auth_files:
            try:
                ai = load_gray_resized_cv2(ap, max_side=SSIM_MAX_SIDE)
            except Exception:
                ai = None
            auth_imgs.append((ap, ai))

        print("Matching unmatched forged by SSIM...")
        newly_paired = 0
        next_unmatched = []
        for gf in still_unmatched:
            try:
                ap, sc = best_match_by_ssim(gf, auth_imgs, max_side=SSIM_MAX_SIDE)
            except Exception:
                ap, sc = (None, None)
            if ap is not None and sc is not None and sc >= SSIM_THRESHOLD:
                pairs.append((ap, gf, "ssim", float(sc)))
                newly_paired += 1
            else:
                next_unmatched.append(gf)
        still_unmatched = next_unmatched
        print(f"SSIM new pairs (>= {SSIM_THRESHOLD}): {newly_paired}")
        print(f"Remaining unmatched after SSIM: {len(still_unmatched)}")

    # 5) Save output CSV
    OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["auth_path", "forg_path", "method", "score"])
        for ap, gf, m, s in pairs:
            writer.writerow([str(ap), str(gf), m, s])

    print(f"\nSaved pairs: {len(pairs)} -> {OUTPUT_CSV}")
    if len(still_unmatched) > 0:
        print("Sample unmatched forged (up to 10):")
        for g in still_unmatched[:10]:
            print(" -", g)

if __name__ == "__main__":
    main()

In [None]:
###############################################
# ReCoDAI | Paired EfficientNet-B0 — OFFICIAL METRIC + Improvements + Retrieval
# - Official metric only (score/evaluate_single_image/oF1_score)
# - IMG_SIZE=256, preserve masks (dilate before resize)
# - Seg + Cls; background loss for empty masks
# - Edge channel (Sobel) + amplified diff
# - TTA (horizontal flip) for seg & cls
# - Morphological postproc + grid search on seg thresholds (validation)
# - Retrieval (name → pHash → SSIM) to build realistic diff at TEST (and optional VALIDATION)
# - Submission: case_id,annotation | "authentic" or JSON-RLE via rle_encode
###############################################

import os, gc, json, random, re
from pathlib import Path
from typing import List, Tuple, Optional
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import cv2
import scipy.optimize

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score

# -------------------------
# Repro & Device
# -------------------------
def seed_everything(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -------------------------
# Paths
# -------------------------
ROOT         = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
AUTH_DIR     = f"{ROOT}/train_images/authentic"
FORG_DIR     = f"{ROOT}/train_images/forged"
MASK_DIR     = f"{ROOT}/train_masks"
TEST_DIR     = f"{ROOT}/test_images"
SAMPLE_SUB   = f"{ROOT}/sample_submission.csv"
PAIRS_CSV    = "/kaggle/working/pairs.csv"

# -------------------------
# Hyperparams
# -------------------------
IMG_SIZE          = 256
BATCH_FORGED      = 2
BATCH_AUTH_CLS    = 8
EPOCHS            = 8
LR                = 1e-4
WEIGHT_DECAY      = 1e-4
ALPHA_CLS         = 0.5
NEG_BG_WEIGHT     = 0.25  # background loss weight when mask is empty

# Postproc (initial; tuned later on val)
THRESH_SEG        = 0.20
MIN_INSTANCE_AREA = 4
TOPK_INSTANCES    = 5

# Detection threshold (tuned later)
BEST_CLS_THR      = 0.5

# Retrieval flags (for building realistic diff at TEST; optional at VAL)
USE_RETR_IN_TEST  = True
USE_RETR_IN_VAL   = False  # اگر True کنی، در ولیدیشن هم diff با retrieval ساخته می‌شود (Dataset باید مسیرها را بدهد)
USE_RETR_PHASH    = True
PHASH_THRESHOLD   = 10
USE_RETR_SSIM     = False     # در صورت نیاز True شود (کندتر)
SSIM_THRESHOLD    = 0.70
SSIM_MAX_SIDE     = 512

# -------------------------
# Utilities
# -------------------------
def list_images(root: str) -> List[str]:
    IMG_EXTS = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    return sorted([str(p) for p in Path(root).rglob("*") if p.suffix.lower() in IMG_EXTS])

# =========================================================
# OFFICIAL METRIC (EXACTLY as provided)
# =========================================================
import numba
import numpy.typing as npt

class ParticipantVisibleError(Exception):
    pass

@numba.jit(nopython=True)
def _rle_encode_jit(x: npt.NDArray, fg_val: int = 1) -> list[int]:
    dots = np.where(x.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def rle_encode(masks: list[npt.NDArray], fg_val: int = 1) -> str:
    return ';'.join([json.dumps(_rle_encode_jit(x, fg_val)) for x in masks])

@numba.njit
def _rle_decode_jit(mask_rle: npt.NDArray, height: int, width: int) -> npt.NDArray:
    if len(mask_rle) % 2 != 0:
        raise ValueError('One or more rows has an odd number of values.')
    starts, lengths = mask_rle[0::2], mask_rle[1::2]
    starts -= 1
    ends = starts + lengths
    for i in range(len(starts) - 1):
        if ends[i] > starts[i + 1]:
            raise ValueError('Pixels must not be overlapping.')
    img = np.zeros(height * width, dtype=np.bool_)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img

def rle_decode(mask_rle: str, shape: tuple[int, int]) -> npt.NDArray:
    mask_rle = json.loads(mask_rle)
    mask_rle = np.asarray(mask_rle, dtype=np.int32)
    starts = mask_rle[0::2]
    if sorted(starts) != list(starts):
        raise ParticipantVisibleError('Submitted values must be in ascending order.')
    try:
        return _rle_decode_jit(mask_rle, shape[0], shape[1]).reshape(shape, order='F')
    except ValueError as e:
        raise ParticipantVisibleError(str(e)) from e

def calculate_f1_score(pred_mask: npt.NDArray, gt_mask: npt.NDArray):
    pred_flat = pred_mask.flatten()
    gt_flat = gt_mask.flatten()
    tp = np.sum((pred_flat == 1) & (gt_flat == 1))
    fp = np.sum((pred_flat == 1) & (gt_flat == 0))
    fn = np.sum((pred_flat == 0) & (gt_flat == 1))
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall    = tp / (tp + fn) if (tp + fn) > 0 else 0
    return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

def calculate_f1_matrix(pred_masks: list[npt.NDArray], gt_masks: list[npt.NDArray]):
    num_instances_pred = len(pred_masks)
    num_instances_gt = len(gt_masks)
    f1_matrix = np.zeros((num_instances_pred, num_instances_gt))
    for i in range(num_instances_pred):
        for j in range(num_instances_gt):
            pred_flat = pred_masks[i].flatten()
            gt_flat = gt_masks[j].flatten()
            f1_matrix[i, j] = calculate_f1_score(pred_flat, gt_flat)
    if f1_matrix.shape[0] < len(gt_masks):
        f1_matrix = np.vstack((f1_matrix, np.zeros((len(gt_masks) - len(f1_matrix), num_instances_gt))))
    return f1_matrix

def oF1_score(pred_masks: list[npt.NDArray], gt_masks: list[npt.NDArray]):
    f1_matrix = calculate_f1_matrix(pred_masks, gt_masks)
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(-f1_matrix)
    excess_predictions_penalty = len(gt_masks) / max(len(pred_masks), len(gt_masks))
    return np.mean(f1_matrix[row_ind, col_ind]) * excess_predictions_penalty

def evaluate_single_image(label_rles: str, prediction_rles: str, shape_str: str) -> float:
    shape = json.loads(shape_str)
    label_rles = [rle_decode(x, shape=shape) for x in label_rles.split(';')]
    prediction_rles = [rle_decode(x, shape=shape) for x in prediction_rles.split(';')]
    return oF1_score(prediction_rles, label_rles)

def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    df = solution
    df = df.rename(columns={'annotation': 'label'})
    df['prediction'] = submission['annotation']  # assumes same order
    authentic_indices = (df['label'] == 'authentic') | (df['prediction'] == 'authentic')
    df['image_score'] = ((df['label'] == df['prediction']) & authentic_indices).astype(float)
    df.loc[~authentic_indices, 'image_score'] = df.loc[~authentic_indices].apply(
        lambda row: evaluate_single_image(row['label'], row['prediction'], row['shape']), axis=1
    )
    return float(np.mean(df['image_score']))
# =========================================================

# -------------------------
# Transforms
# -------------------------
train_tf = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(5),
    T.ColorJitter(0.05,0.05,0.05,0.02),
    T.ToTensor(),
])
val_tf = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
])

# -------------------------
# Datasets
# -------------------------
class PairedForgeryDataset(Dataset):
    """Forged paired with authentic. Seg+Cls (label=1). Keeps mask=0; adds bg loss for empty masks.
       Optionally returns file paths for retrieval in validation."""
    def __init__(self, df_pairs: pd.DataFrame, mask_dir: str, train: bool, return_paths: bool=False):
        self.df = df_pairs.reset_index(drop=True)
        self.mask_dir = mask_dir
        self.train = train
        self.return_paths = return_paths

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pa, pf = row['auth_path'], row['forg_path']
        A = Image.open(pa).convert("RGB")
        F = Image.open(pf).convert("RGB")
        xa = train_tf(A) if self.train else val_tf(A)
        xf = train_tf(F) if self.train else val_tf(F)

        # difference channel (paired authentic vs forged)
        diff = torch.abs(xf - xa)

        # load mask and preserve small positives when downsampling
        stem = Path(pf).stem
        mp = os.path.join(self.mask_dir, stem + ".npy")
        if os.path.exists(mp):
            m = np.load(mp)
            if m.ndim==3: m = np.max(m, axis=0)
        else:
            m = np.zeros((A.height, A.width), dtype=np.uint8)

        H, W = m.shape
        scale = max(H/IMG_SIZE, W/IMG_SIZE)
        m_bin = (m > 0).astype(np.uint8)
        if scale > 1:
            k = int(np.ceil(scale))
            kernel = np.ones((k, k), np.uint8)
            m_bin = cv2.dilate(m_bin, kernel, iterations=1)
        m_res = cv2.resize(m_bin, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        mask = torch.from_numpy(m_res.astype(np.float32)).unsqueeze(0)

        if self.return_paths:
            return xf, diff, mask, torch.tensor(1, dtype=torch.long), pf, pa
        return xf, diff, mask, torch.tensor(1, dtype=torch.long)

class AuthenticOnlyClsDataset(Dataset):
    """Authentic images as negatives for classification (label=0)."""
    def __init__(self, paths: List[str], train: bool):
        self.paths = paths; self.train = train
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        x = train_tf(img) if self.train else val_tf(img)
        return x, torch.tensor(0, dtype=torch.long)

# -------------------------
# Build loaders
# -------------------------
pairs_df = pd.read_csv(PAIRS_CSV)
val_df = pairs_df.sample(frac=0.2, random_state=42)
train_df = pairs_df.drop(val_df.index)

auth_all = list_images(AUTH_DIR)
auth_val_count = max(1, int(0.2*len(auth_all)))
rng = np.random.default_rng(42)
perm = rng.permutation(len(auth_all))
auth_val_idx = set(perm[:auth_val_count].tolist())
auth_train = [auth_all[i] for i in range(len(auth_all)) if i not in auth_val_idx]
auth_val   = [auth_all[i] for i in range(len(auth_all)) if i in auth_val_idx]

train_forged_ds = PairedForgeryDataset(train_df, MASK_DIR, train=True, return_paths=False)
val_forged_ds   = PairedForgeryDataset(val_df,   MASK_DIR, train=False, return_paths=USE_RETR_IN_VAL)
train_auth_ds   = AuthenticOnlyClsDataset(auth_train, train=True)
val_auth_ds     = AuthenticOnlyClsDataset(auth_val,   train=False)

train_forged_loader = DataLoader(train_forged_ds, batch_size=BATCH_FORGED, shuffle=True,  num_workers=2, pin_memory=True)
val_forged_loader   = DataLoader(val_forged_ds,   batch_size=BATCH_FORGED, shuffle=False, num_workers=2, pin_memory=True)
train_auth_loader   = DataLoader(train_auth_ds,   batch_size=BATCH_AUTH_CLS, shuffle=True,  num_workers=2, pin_memory=True)
val_auth_loader     = DataLoader(val_auth_ds,     batch_size=BATCH_AUTH_CLS, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train forged: {len(train_forged_ds)} | Val forged: {len(val_forged_ds)}")
print(f"Train authentic (cls neg): {len(train_auth_ds)} | Val authentic (cls neg): {len(val_auth_ds)}")

# -------------------------
# Model (EfficientNet-B0) + heads
# -------------------------
class EffNetB0Multi(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.efficientnet_b0(weights=None)
        self.backbone.classifier = nn.Identity()
        # 1280 backbone + 3(diff RGB) + 1(edge)
        self.seg_head = nn.Sequential(
            nn.Conv2d(1280+4, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1)
        )
        self.cls_head = nn.Linear(1280, 2)

    def forward(self, x_forg, diff):
        feats = self.backbone.features(x_forg)   # [B,1280,h,w]
        pooled = F.adaptive_avg_pool2d(feats,1).flatten(1)
        logits_cls = self.cls_head(pooled)
        f_up = F.interpolate(feats, size=x_forg.shape[2:], mode="bilinear", align_corners=False)

        # edges (Sobel)
        with torch.no_grad():
            x_gray = x_forg.mean(dim=1, keepdim=True)
            kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], device=x_forg.device, dtype=x_forg.dtype)
            ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], device=x_forg.device, dtype=x_forg.dtype)
            sobel_x = F.conv2d(x_gray, kx, padding=1)
            sobel_y = F.conv2d(x_gray, ky, padding=1)
            edges = torch.clamp(torch.abs(sobel_x)+torch.abs(sobel_y), 0, 1)

        diff_amp = diff * 2.0
        seg_in = torch.cat([f_up, diff_amp, edges], dim=1)  # 1280 + 3 + 1

        logits_seg = self.seg_head(seg_in)
        return logits_seg, logits_cls

model = EffNetB0Multi().to(device)

# -------------------------
# Loss & Optim
# -------------------------
crit_cls = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# -------------------------
# TTA helper (flip) for seg & cls
# -------------------------
def seg_with_tta(x, diff):
    model.eval()
    with torch.no_grad():
        log_seg1, log_cls1 = model(x, diff)
        x2 = torch.flip(x, dims=[3]); d2 = torch.flip(diff, dims=[3])
        log_seg2, log_cls2 = model(x2, d2)
    p1 = torch.sigmoid(log_seg1)
    p2 = torch.flip(torch.sigmoid(log_seg2), dims=[3])
    p = (p1 + p2) / 2.0
    logits = torch.log(p.clamp(1e-6, 1-1e-6) / (1 - p.clamp(1e-6, 1-1e-6)))
    cls_logits = (log_cls1 + log_cls2) / 2.0
    return logits, cls_logits

# -------------------------
# Post-processing helpers
# -------------------------
def masks_from_probs(prob: np.ndarray, thr: float, min_area: int, use_morph: bool=True) -> List[np.ndarray]:
    bm = (prob > thr).astype(np.uint8)
    if use_morph:
        bm = cv2.morphologyEx(bm, cv2.MORPH_CLOSE, np.ones((3,3),np.uint8))
        bm = cv2.morphologyEx(bm, cv2.MORPH_OPEN,  np.ones((3,3),np.uint8))
    num, labels = cv2.connectedComponents(bm, connectivity=8)
    insts, scores = [], []
    for k in range(1, num):
        inst = (labels==k).astype(np.uint8)
        area = int(inst.sum())
        if area >= min_area:
            insts.append(inst)
            scores.append(float(prob[inst==1].mean()))
    if TOPK_INSTANCES is not None and len(insts) > TOPK_INSTANCES:
        order = np.argsort(scores)[::-1][:TOPK_INSTANCES]
        insts = [insts[i] for i in order]
    return insts

def logits_to_instances_rle(logits: torch.Tensor,
                            thr: float,
                            min_area: int,
                            use_morph: bool=True) -> str:
    if logits.ndim == 4:
        prob = torch.sigmoid(logits)[0,0].detach().cpu().numpy()
    elif logits.ndim == 3:
        prob = torch.sigmoid(logits)[0].detach().cpu().numpy()
    else:
        raise ValueError("logits must be [B,1,H,W] or [1,H,W]")
    insts = masks_from_probs(prob, thr=thr, min_area=min_area, use_morph=use_morph)
    return rle_encode(insts) if insts else "[]"

# -------------------------
# TRAIN LOOP (with bg loss for empty masks)
# -------------------------
for epoch in range(1, EPOCHS+1):
    model.train()
    tr_loss = 0.0
    cls_true, cls_pred = [], []

    forged_iter = iter(train_forged_loader)
    auth_iter   = iter(train_auth_loader)
    steps = max(len(train_forged_loader), len(train_auth_loader))

    for _ in range(steps):
        total_loss = 0.0

        # forged step (seg+cls)
        try:
            xf, diff, mask, y_f = next(forged_iter)
            xf, diff, mask, y_f = xf.to(device), diff.to(device), mask.to(device), y_f.to(device)
            seg_logits, cls_logits = model(xf, diff)

            probs = torch.sigmoid(seg_logits)
            has_pos = (mask.view(mask.size(0), -1).sum(dim=1) > 0).float()
            pos_idx = (has_pos == 1)
            neg_idx = (has_pos == 0)

            seg_loss_pos = torch.tensor(0.0, device=device)
            if pos_idx.any():
                bce_pos = nn.BCEWithLogitsLoss(reduction="none")(seg_logits[pos_idx], mask[pos_idx]) \
                          .view(pos_idx.sum(), -1).mean(dim=1)
                inter = (probs[pos_idx]*mask[pos_idx]).view(pos_idx.sum(), -1).sum(dim=1)
                den   = (probs[pos_idx].view(pos_idx.sum(), -1).sum(dim=1) +
                         mask[pos_idx].view(pos_idx.sum(), -1).sum(dim=1) + 1e-6)
                dice_pos = 1 - (2*inter)/den
                seg_loss_pos = (bce_pos + dice_pos).mean()

            seg_loss_neg = torch.tensor(0.0, device=device)
            if neg_idx.any():
                zero_target = torch.zeros_like(seg_logits[neg_idx])
                bce_neg = nn.BCEWithLogitsLoss(reduction="mean")(seg_logits[neg_idx], zero_target)
                seg_loss_neg = NEG_BG_WEIGHT * bce_neg

            seg_loss = seg_loss_pos + seg_loss_neg
            cls_loss = crit_cls(cls_logits, y_f)
            loss_f = seg_loss + ALPHA_CLS*cls_loss
            total_loss += loss_f

            cls_true.extend(y_f.cpu().tolist())
            cls_pred.extend(cls_logits.argmax(1).detach().cpu().tolist())
        except StopIteration:
            pass

        # authentic-only step (cls)
        try:
            xa, y_a = next(auth_iter)
            xa, y_a = xa.to(device), y_a.to(device)
            feats = model.backbone.features(xa)
            pooled = F.adaptive_avg_pool2d(feats,1).flatten(1)
            cls_logits_a = model.cls_head(pooled)
            loss_a = ALPHA_CLS*crit_cls(cls_logits_a, y_a)
            total_loss += loss_a

            cls_true.extend(y_a.cpu().tolist())
            cls_pred.extend(cls_logits_a.argmax(1).detach().cpu().tolist())
        except StopIteration:
            pass

        if isinstance(total_loss, torch.Tensor):
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            tr_loss += total_loss.item()

    clsF1 = f1_score(cls_true, cls_pred) if len(set(cls_true))>1 else 0.0
    print(f"[Epoch {epoch}/{EPOCHS}] train_loss={tr_loss/steps:.4f} | ClsF1={clsF1:.3f}")
    torch.cuda.empty_cache(); gc.collect()

# -------------------------
# Validation — tune detection threshold (cls) with TRUE labels
# -------------------------
def cls_probs_on_val_true_labels() -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        # forged → 1
        for batch in val_forged_loader:
            if USE_RETR_IN_VAL:
                xf, diff, _, _, _, _ = batch
            else:
                xf, diff, _, _ = batch
            xf, diff = xf.to(device), diff.to(device)
            _, logit = model(xf, diff)
            prob = torch.softmax(logit, 1)[:, 1].cpu().numpy()
            y_prob.extend(prob.tolist()); y_true.extend([1]*len(prob))
        # authentic → 0
        for xa, y in val_auth_loader:
            xa = xa.to(device)
            feats = model.backbone.features(xa)
            pooled = F.adaptive_avg_pool2d(feats,1).flatten(1)
            logit = model.cls_head(pooled)
            prob = torch.softmax(logit, 1)[:, 1].cpu().numpy()
            y_prob.extend(prob.tolist()); y_true.extend([0]*len(prob))
    return np.array(y_true), np.array(y_prob)

def best_det_threshold_for_f1(y_true, y_prob):
    thrs = np.linspace(0.05, 0.95, 37)
    best_thr, best_f1 = 0.5, 0.0
    for t in thrs:
        y_pred = (y_prob >= t).astype(int)
        f1 = f1_score(y_true, y_pred)
        if f1 > best_f1:
            best_f1, best_thr = float(f1), float(t)
    return best_thr, best_f1

y_true_det, y_prob_det = cls_probs_on_val_true_labels()
BEST_CLS_THR, best_detF1_val = best_det_threshold_for_f1(y_true_det, y_prob_det)
print(f"[Validation] Best detection threshold = {BEST_CLS_THR:.3f} | DetF1={best_detF1_val:.4f}")

# -------------------------
# Retrieval index (for TEST and optional VAL)
# -------------------------
IMG_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

def _normalize_name(p: Path) -> str:
    s = p.stem.lower().replace("-", "_").replace(" ", "_")
    s = re.sub(r"(authentic|auth|original|orig|clean|real|gt)$", "", s)
    s = re.sub(r"(forg(ed)?|fake|tampered|edit(ed)?|manipulated?)$", "", s)
    s = re.sub(r"(__+|_+$)", "", s)
    return s

# imports for retrieval
try:
    import imagehash
except Exception:
    # 
    pass

from PIL import Image as _PIL_Image

AUTH_LIST = list_images(AUTH_DIR)
PHASH_DB = []
if USE_RETR_PHASH and len(AUTH_LIST):
    print("[Retrieval] Building pHash index for authentic...")
    for ap in AUTH_LIST:
        try:
            h = imagehash.phash(_PIL_Image.open(ap).convert("RGB"))
        except Exception:
            h = None
        PHASH_DB.append((ap, h))

SSIM_DB = []
if USE_RETR_SSIM:
    try:
        from skimage.metrics import structural_similarity as ssim
    except Exception:
        pass
    def _load_gray_resized(path, max_side=SSIM_MAX_SIDE):
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if img is None: return None
        h, w = img.shape[:2]
        sc = min(1.0, max_side / max(h, w))
        if sc < 1: img = cv2.resize(img, (int(w*sc), int(h*sc)), interpolation=cv2.INTER_AREA)
        return img
    if len(AUTH_LIST):
        print("[Retrieval] Preloading grayscale authentic for SSIM...")
        for ap in AUTH_LIST:
            try:
                ai = _load_gray_resized(ap, SSIM_MAX_SIDE)
            except Exception:
                ai = None
            SSIM_DB.append((ap, ai))

def _best_by_name(test_path):
    stem = _normalize_name(Path(test_path))
    for ap in AUTH_LIST:
        if _normalize_name(Path(ap)) == stem:
            return ap
    return None

def _best_by_phash(test_path):
    if not USE_RETR_PHASH or len(PHASH_DB)==0: return (None, None)
    try:
        h = imagehash.phash(_PIL_Image.open(test_path).convert("RGB"))
    except Exception:
        return (None, None)
    best_p, best_d = None, None
    for ap, ah in PHASH_DB:
        if ah is None: continue
        d = abs(h - ah)
        if (best_d is None) or (d < best_d):
            best_p, best_d = ap, d
    return (best_p, best_d)

def _best_by_ssim(test_path):
    if not USE_RETR_SSIM or len(SSIM_DB)==0: return (None, None)
    from skimage.metrics import structural_similarity as ssim
    def _load_gray_resized(path, max_side=SSIM_MAX_SIDE):
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if img is None: return None
        h, w = img.shape[:2]
        sc = min(1.0, max_side / max(h, w))
        if sc < 1: img = cv2.resize(img, (int(w*sc), int(h*sc)), interpolation=cv2.INTER_AREA)
        return img
    gi = _load_gray_resized(test_path, SSIM_MAX_SIDE)
    if gi is None: return (None, None)
    gh, gw = gi.shape[:2]
    best_p, best_sc = None, None
    for ap, ai in SSIM_DB:
        if ai is None: continue
        ah, aw = ai.shape[:2]
        ai_r = cv2.resize(ai, (gw, gh), interpolation=cv2.INTER_AREA) if (ah, aw)!=(gh, gw) else ai
        try:
            sc = ssim(gi, ai_r)
        except:
            continue
        if (best_sc is None) or (sc > best_sc):
            best_p, best_sc = ap, sc
    return (best_p, best_sc)

def find_best_authentic_for_path(test_path: str) -> Optional[str]:
    ap = _best_by_name(test_path)
    if ap is not None: return ap
    ap, d = _best_by_phash(test_path)
    if ap is not None and d is not None and d <= PHASH_THRESHOLD:
        return ap
    ap, sc = _best_by_ssim(test_path)
    if ap is not None and sc is not None and sc >= SSIM_THRESHOLD:
        return ap
    return None

def build_diff_from_retrieved(test_img_path: str) -> Optional[torch.Tensor]:
    """Return xa tensor [1,3,H,W] for retrieved authentic (val_tf applied), None if not found."""
    ap = find_best_authentic_for_path(test_img_path)
    if ap is None:
        return None
    try:
        A = Image.open(ap).convert("RGB")
    except:
        return None
    xa = val_tf(A).unsqueeze(0).to(device)
    return xa

# -------------------------
# Build GT & PRED dataframes for validation (order-aligned)
# -------------------------
@torch.no_grad()
def build_val_gt_and_pred_for_forged(thr_seg: float, min_area: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    rows_gt, rows_sub = [], []
    for batch in tqdm(val_forged_loader, desc="Val forged eval"):
        if USE_RETR_IN_VAL:
            xf, diff, mask, _, forg_path_batch, _ = batch
        else:
            xf, diff, mask, _ = batch; forg_path_batch = None
        xf, diff = xf.to(device), diff.to(device)

        # اگر retrieval را در val بخواهیم، diff را جایگزین کنیم (شبیه تست)
        if USE_RETR_IN_VAL and forg_path_batch is not None:
            diffs = []
            for i in range(xf.size(0)):
                xa = build_diff_from_retrieved(forg_path_batch[i])
                if xa is None:
                    diffs.append(torch.zeros_like(xf[i:i+1]))
                else:
                    diffs.append(torch.abs(xf[i:i+1] - xa))
            diff = torch.cat(diffs, dim=0)

        log_seg, log_cls = seg_with_tta(xf, diff)  # TTA
        probs = torch.softmax(log_cls, 1)[:, 1].cpu().numpy()
        B = xf.size(0)
        for i in range(B):
            shape = (IMG_SIZE, IMG_SIZE)
            gt_m = (mask[i,0].cpu().numpy() > 0.5).astype(np.uint8)

            # GT row
            if gt_m.sum() == 0:
                rows_gt.append({"row_id": f"val_{len(rows_gt)}",
                                "annotation": "authentic",
                                "shape": "authentic"})
            else:
                num, labels = cv2.connectedComponents(gt_m, connectivity=8)
                insts = [(labels==k).astype(np.uint8) for k in range(1, num) if (labels==k).sum()>0]
                rle_gt = rle_encode(insts) if len(insts)>0 else "authentic"
                rows_gt.append({"row_id": f"val_{len(rows_gt)}",
                                "annotation": rle_gt,
                                "shape": json.dumps(list(shape))})

            # PRED row
            if probs[i] < BEST_CLS_THR:
                pred = "authentic"
            else:
                pred = logits_to_instances_rle(log_seg[i:i+1], thr=thr_seg, min_area=min_area, use_morph=True)
                pred = "authentic" if pred=="[]" else pred
            rows_sub.append({"row_id": rows_gt[-1]["row_id"], "annotation": pred})
    gt_df = pd.DataFrame(rows_gt)
    sub_df= pd.DataFrame(rows_sub).reindex(range(len(rows_gt)))
    return gt_df, sub_df

# -------------------------
# Small grid-search for seg thresholds on validation (official score)
# -------------------------
best_official, best_thr_seg, best_min_area = -1.0, THRESH_SEG, MIN_INSTANCE_AREA
for thr in [0.10, 0.15, 0.20, 0.25, 0.30]:
    for area in [1, 2, 4, 8]:
        gt_tmp, sub_tmp = build_val_gt_and_pred_for_forged(thr_seg=thr, min_area=area)
        s = score(gt_tmp.copy(), sub_tmp.copy(), row_id_column_name="row_id")
        if s > best_official:
            best_official, best_thr_seg, best_min_area = s, thr, area
        del gt_tmp, sub_tmp
        gc.collect()
THRESH_SEG, MIN_INSTANCE_AREA = best_thr_seg, best_min_area
print(f"[VAL seg-tune] Best official={best_official:.6f} at THRESH_SEG={THRESH_SEG} | MIN_INSTANCE_AREA={MIN_INSTANCE_AREA}")

# Final validation score with the best params
gt_df_val, sub_df_val = build_val_gt_and_pred_for_forged(thr_seg=THRESH_SEG, min_area=MIN_INSTANCE_AREA)
overall_val = score(gt_df_val.copy(), sub_df_val.copy(), row_id_column_name="row_id")
print("\n========== Validation (OFFICIAL METRIC ONLY) ==========")
print(f"Overall (official score): {overall_val:.6f}")
print("=======================================================\n")

# -------------------------
# TEST Inference & submission.csv (exact required format)
#   - Must have a row for each image
#   - For negatives: 'authentic'
#   - For positives: JSON-RLE via rle_encode
#   - Header: case_id,annotation
# -------------------------
@torch.no_grad()
def predict_one_test(image_path: str) -> str:
    img = Image.open(image_path).convert("RGB")
    x = val_tf(img).unsqueeze(0).to(device)

    # Use retrieval to build diff (or fallback to zeros)
    if USE_RETR_IN_TEST:
        xa = build_diff_from_retrieved(image_path)
        diff = torch.abs(x - xa) if xa is not None else torch.zeros_like(x)
    else:
        diff = torch.zeros_like(x)

    log_seg, log_cls = seg_with_tta(x, diff)  # TTA
    prob_forg = torch.softmax(log_cls,1)[0,1].item()
    if prob_forg < BEST_CLS_THR:
        return "authentic"

    rle = logits_to_instances_rle(log_seg, thr=THRESH_SEG, min_area=MIN_INSTANCE_AREA, use_morph=True)
    return "authentic" if rle=="[]" else rle

def build_submission(test_dir: str, sample_csv_path: str, out_csv_path: str = "submission.csv"):
    # Output must be: case_id,annotation
    if os.path.exists(sample_csv_path):
        df_sub = pd.read_csv(sample_csv_path)
        id_col, target_col = df_sub.columns[0], df_sub.columns[1]
        test_files = [str(p) for p in Path(test_dir).glob("*")]
        by_stem = {Path(p).stem: p for p in test_files}
        by_name = {Path(p).name: p for p in test_files}
        preds = []
        for rid in tqdm(df_sub[id_col].astype(str), desc="Predicting test"):
            cand = by_name.get(rid) or by_stem.get(rid)
            if cand is None:
                cand = next((by_name.get(rid+ext) for ext in (".png",".jpg",".jpeg",".tif",".tiff",".bmp") if by_name.get(rid+ext)), None)
            preds.append("authentic" if cand is None else predict_one_test(cand))
        out = df_sub.copy()
        out[target_col] = preds
        out.to_csv(out_csv_path, index=False)
    else:
        test_files = sorted([str(p) for p in Path(test_dir).glob("*") if p.lower().endswith((".png",".jpg",".jpeg",".tif",".tiff",".bmp"))])
        preds = [predict_one_test(p) for p in tqdm(test_files, desc="Predicting test")]
        out = pd.DataFrame({"case_id":[Path(p).stem for p in test_files], "annotation":preds})
        out.to_csv(out_csv_path, index=False)
    print(f"Saved submission to {out_csv_path}")
    return out_csv_path

if os.path.exists(TEST_DIR):
    build_submission(TEST_DIR, SAMPLE_SUB, out_csv_path="submission.csv")

torch.cuda.empty_cache(); gc.collect()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image

def _to_disp(img_tensor):
    x = img_tensor.detach().cpu().numpy()
    x = np.transpose(x, (1,2,0))
    x = np.clip(x*255, 0, 255).astype(np.uint8)
    return x

def _union_mask_from_instances(prob, thr, min_area):
    bm = (prob > thr).astype(np.uint8)
    bm = cv2.morphologyEx(bm, cv2.MORPH_CLOSE, np.ones((3,3),np.uint8))
    bm = cv2.morphologyEx(bm, cv2.MORPH_OPEN,  np.ones((3,3),np.uint8))
    num, labels = cv2.connectedComponents(bm, connectivity=8)
    out = np.zeros_like(bm)
    for k in range(1, num):
        inst = (labels==k).astype(np.uint8)
        if inst.sum() >= min_area:
            out[inst==1] = 1
    return out

@torch.no_grad()
def show_auth_forged_seg(n_samples=10, df=None, figsize=(12, 3)):
    model.eval()
    use_df = df if df is not None else val_df
    samp = use_df.sample(n=min(n_samples, len(use_df)), random_state=42).reset_index(drop=True)

    rows = len(samp)
    plt.figure(figsize=(figsize[0], figsize[1]*rows))

    idx_plot = 1
    for i, row in samp.iterrows():
        ap, fp = row['auth_path'], row['forg_path']

        A = Image.open(ap).convert("RGB")
        F = Image.open(fp).convert("RGB")
        xa = val_tf(A).unsqueeze(0).to(device)
        xf = val_tf(F).unsqueeze(0).to(device)

        diff = torch.abs(xf - xa)
        log_seg, log_cls = seg_with_tta(xf, diff)
        prob_forg = torch.softmax(log_cls, 1)[0,1].item()

        a_disp = _to_disp(xa[0])
        f_disp = _to_disp(xf[0])

        if prob_forg >= BEST_CLS_THR:
            prob = torch.sigmoid(log_seg)[0,0].cpu().numpy()
            mask_bin = _union_mask_from_instances(prob, THRESH_SEG, MIN_INSTANCE_AREA)
        else:
            mask_bin = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)

        # --- Overlay امن (بدون اندیس‌گذاری ۱بعدی) ---
        overlay = f_disp.copy()
        if mask_bin.sum() > 0:
            alpha = 0.45
            red_img = np.zeros_like(overlay, dtype=np.uint8)
            red_img[..., 0] = 255  # R=255, G=B=0
            m3 = mask_bin[..., None].astype(bool)  # HxWx1 -> broadcast به HxWx3
            blended = (alpha * red_img + (1 - alpha) * overlay).astype(np.uint8)
            overlay = np.where(m3, blended, overlay)

            # کانتور سفید
            cnts, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 1)

        # Authentic
        ax1 = plt.subplot(rows, 3, idx_plot); idx_plot += 1
        ax1.imshow(a_disp); ax1.set_title("Authentic", fontsize=11); ax1.axis('off')

        # Forged
        ax2 = plt.subplot(rows, 3, idx_plot); idx_plot += 1
        ax2.imshow(f_disp); ax2.set_title("Forged", fontsize=11); ax2.axis('off')

        # Predicted Seg
        ax3 = plt.subplot(rows, 3, idx_plot); idx_plot += 1
        ax3.imshow(overlay)
        ax3.set_title(f"Predicted Seg (p_forg={prob_forg:.2f})", fontsize=11)
        ax3.axis('off')

    plt.tight_layout()
    plt.show()

# اجرا
show_auth_forged_seg(n_samples=10, df=val_df)
