In [None]:
from pathlib import Path
import re
from collections import defaultdict
import csv
import sys
import math
from typing import List, Tuple, Optional

# Change paths if necessary
ROOT = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
DIR_TRAIN_AUTH = Path(f"{ROOT}/train_images/authentic")
DIR_TRAIN_FORG = Path(f"{ROOT}/train_images/forged")
DIR_TEST       = Path(f"{ROOT}/test_images")  # not used, kept for compatibility
OUTPUT_CSV     = Path("/kaggle/working/pairs.csv")

# Method selection
USE_PHASH = True           # Use pHash for unmatched images by name?
PHASH_THRESHOLD = 10       # Hamming distance threshold (lower is better)

USE_SSIM = False           # (Optional) Use SSIM for remaining unmatched -after pHash?
SSIM_THRESHOLD = 0.70      # SSIM threshold (higher is better)
SSIM_MAX_SIDE = 512        # resize for speed

# If the dataset is too large, you can limit the number of test samples (None = no limit)
DEBUG_LIMIT_FORGED = None   # e.g. 200

# ------------------- Dependencies -------------------
# Install required packages (allowed in Kaggle). If already installed, this section will just pass.
try:
    import imagehash  # type: ignore
    from PIL import Image  # type: ignore
except Exception:
    !pip -q install imagehash
    import imagehash
    from PIL import Image

if USE_SSIM:
    try:
        import cv2  # type: ignore
        from skimage.metrics import structural_similarity as ssim  # type: ignore
    except Exception:
        !pip -q install scikit-image opencv-python-headless
        import cv2
        from skimage.metrics import structural_similarity as ssim

# ------------------- Utilities -------------------
IMG_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

def list_images(root: Path) -> List[Path]:
    return sorted([p for p in root.rglob("*") if p.suffix.lower() in IMG_EXTS])

def normalize_name(p: Path) -> str:
    s = p.stem.lower()
    s = s.replace("-", "_").replace(" ", "_")
    # Remove common ending tags
    s = re.sub(r"(authentic|auth|original|orig|clean|real|gt)$", "", s)
    s = re.sub(r"(forg(ed)?|fake|tampered|edit(ed)?|manipulated?)$", "", s)
    s = re.sub(r"(__+|_+$)", "", s)
    return s

def safe_open_image(path: Path) -> Optional[Image.Image]:
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return None

def phash_value(path: Path):
    im = safe_open_image(path)
    if im is None:
        return None
    try:
        return imagehash.phash(im)
    except Exception:
        return None

def hamming_distance(h1, h2) -> int:
    return abs(h1 - h2)

def load_gray_resized_cv2(path: Path, max_side=512):
    import numpy as np
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        return None
    h, w = img.shape[:2]
    scale = min(1.0, max_side / max(h, w))
    if scale < 1.0:
        img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
    return img

def best_match_by_phash(gf: Path, auth_hashes: List[Tuple[Path, object]]) -> Tuple[Optional[Path], Optional[int]]:
    gh = phash_value(gf)
    if gh is None:
        return (None, None)
    best_p, best_d = None, None
    for ap, ah in auth_hashes:
        if ah is None:
            continue
        d = hamming_distance(gh, ah)
        if best_d is None or d < best_d:
            best_p, best_d = ap, d
    return (best_p, best_d)

def best_match_by_ssim(gf: Path, auth_imgs: List[Tuple[Path, 'np.ndarray']], max_side=512) -> Tuple[Optional[Path], Optional[float]]:
    import numpy as np
    gi = load_gray_resized_cv2(gf, max_side=max_side)
    if gi is None:
        return (None, None)
    best_p, best_sc = None, None
    gh, gw = gi.shape[:2]
    for ap, ai in auth_imgs:
        if ai is None:
            continue
        ah, aw = ai.shape[:2]
        # Simple alignment if dimensions differ
        if (ah, aw) != (gh, gw):
            ai_r = cv2.resize(ai, (gw, gh), interpolation=cv2.INTER_AREA)
        else:
            ai_r = ai
        try:
            sc = ssim(gi, ai_r)
        except Exception:
            continue
        if best_sc is None or sc > best_sc:
            best_p, best_sc = ap, sc
    return (best_p, best_sc)

# ------------------- Main Flow -------------------
def main():
    # 1) List image files
    auth_files = list_images(DIR_TRAIN_AUTH)
    forg_files = list_images(DIR_TRAIN_FORG)
    if DEBUG_LIMIT_FORGED is not None:
        forg_files = forg_files[:DEBUG_LIMIT_FORGED]

    print(f"#auth = {len(auth_files)}, #forg = {len(forg_files)}")

    # 2) Pairing based on file names
    auth_map = defaultdict(list)
    for p in auth_files:
        auth_map[normalize_name(p)].append(p)

    pairs: List[Tuple[Path, Path, str, float]] = []
    unmatched_forg: List[Path] = []

    for gf in forg_files:
        key = normalize_name(gf)
        if key in auth_map and len(auth_map[key]) > 0:
            # If multiple candidates exist, take the first for now
            ap = auth_map[key][0]
            pairs.append((ap, gf, "name", 0.0))
        else:
            unmatched_forg.append(gf)

    print(f"Name-matched pairs: {len(pairs)}")
    print(f"Unmatched forged by name: {len(unmatched_forg)}")

    # 3) pHash matching for unmatched images
    still_unmatched: List[Path] = unmatched_forg
    if USE_PHASH and len(still_unmatched) > 0:
        print("Computing pHash for authentic images...")
        auth_hashes = []
        for ap in auth_files:
            try:
                ah = phash_value(ap)
            except Exception:
                ah = None
            auth_hashes.append((ap, ah))

        print("Matching unmatched forged by pHash...")
        newly_paired = 0
        next_unmatched = []
        for gf in still_unmatched:
            ap, dist = best_match_by_phash(gf, auth_hashes)
            if ap is not None and dist is not None and dist <= PHASH_THRESHOLD:
                pairs.append((ap, gf, "phash", float(dist)))
                newly_paired += 1
            else:
                next_unmatched.append(gf)
        still_unmatched = next_unmatched
        print(f"pHash new pairs (<= {PHASH_THRESHOLD}): {newly_paired}")
        print(f"Remaining unmatched after pHash: {len(still_unmatched)}")

    # 4) Optional SSIM matching
    if USE_SSIM and len(still_unmatched) > 0:
        print("Preloading grayscale resized authentic images for SSIM...")
        auth_imgs = []
        for ap in auth_files:
            try:
                ai = load_gray_resized_cv2(ap, max_side=SSIM_MAX_SIDE)
            except Exception:
                ai = None
            auth_imgs.append((ap, ai))

        print("Matching unmatched forged by SSIM...")
        newly_paired = 0
        next_unmatched = []
        for gf in still_unmatched:
            try:
                ap, sc = best_match_by_ssim(gf, auth_imgs, max_side=SSIM_MAX_SIDE)
            except Exception:
                ap, sc = (None, None)
            if ap is not None and sc is not None and sc >= SSIM_THRESHOLD:
                pairs.append((ap, gf, "ssim", float(sc)))
                newly_paired += 1
            else:
                next_unmatched.append(gf)
        still_unmatched = next_unmatched
        print(f"SSIM new pairs (>= {SSIM_THRESHOLD}): {newly_paired}")
        print(f"Remaining unmatched after SSIM: {len(still_unmatched)}")

    # 5) Save output CSV
    OUTPUT_CSV.parent.mkdir(parents=True, exist_ok=True)
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["auth_path", "forg_path", "method", "score"])
        for ap, gf, m, s in pairs:
            writer.writerow([str(ap), str(gf), m, s])

    print(f"\nSaved pairs: {len(pairs)} -> {OUTPUT_CSV}")
    if len(still_unmatched) > 0:
        print("Sample unmatched forged (up to 10):")
        for g in still_unmatched[:10]:
            print(" -", g)

if __name__ == "__main__":
    main()

# ReCoDAI | Mask R-CNN + ELA (RGB+ELA 4ch)

In [None]:
###############################################
# ReCoDAI | Mask R-CNN + ELA (RGB+ELA 4ch)
###############################################

import os, gc, json, random, warnings
warnings.filterwarnings("ignore")
from pathlib import Path
from typing import List, Tuple

os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import cv2
import scipy.optimize

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset

# =============== Seeding & CuDNN ===============
def seed_all(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
GLOBAL_SEED = 42
seed_all(GLOBAL_SEED)

import torch.backends.cudnn as cudnn
cudnn.benchmark = False
cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# =============== Paths ===============
ROOT         = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
AUTH_DIR     = f"{ROOT}/train_images/authentic"
FORG_DIR     = f"{ROOT}/train_images/forged"
MASK_DIR     = f"{ROOT}/train_masks"
TEST_DIR     = f"{ROOT}/test_images"
SAMPLE_SUB   = f"{ROOT}/sample_submission.csv"
PAIRS_CSV    = "/kaggle/working/pairs.csv"

# =============== Hyperparams ===============
IMG_SIZE            = 128
MIN_INSTANCE_AREA   = 16   

EPOCHS              = 4  
BATCH_SIZE          = 2
NUM_WORKERS         = max(2, (os.cpu_count() or 4)//2)


USE_AMP             = False


BASE_LR             = 4e-4
WARMUP_ITERS        = 500
WARMUP_START_LR     = 1e-5
WEIGHT_DECAY        = 1e-4
MOMENTUM            = 0.9
MAX_NORM            = 1.0     # grad clip

# Inference thresholds 
SCORE_THR           = 0.15
MASK_THR            = 0.30
NMS_THR             = 0.30

ELA_JPEG_Q          = 95

# =============== Official metric ===============
import numba
import numpy.typing as npt

class ParticipantVisibleError(Exception): pass

@numba.jit(nopython=True)
def _rle_encode_jit(x: npt.NDArray, fg_val: int = 1):
    dots = np.where(x.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def rle_encode(masks: List[npt.NDArray], fg_val: int = 1) -> str:
    return ';'.join([json.dumps(_rle_encode_jit(x, fg_val)) for x in masks])

@numba.njit
def _rle_decode_jit(mask_rle: npt.NDArray, height: int, width: int) -> npt.NDArray:
    if len(mask_rle) % 2 != 0:
        raise ValueError('One or more rows has an odd number of values.')
    starts, lengths = mask_rle[0::2], mask_rle[1::2]
    starts -= 1
    ends = starts + lengths
    for i in range(len(starts) - 1):
        if ends[i] > starts[i + 1]:
            raise ValueError('Pixels must not be overlapping.')
    img = np.zeros(height * width, dtype=np.bool_)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img

def rle_decode(mask_rle: str, shape: Tuple[int,int]) -> npt.NDArray:
    mask_rle = json.loads(mask_rle)
    mask_rle = np.asarray(mask_rle, dtype=np.int32)
    starts = mask_rle[0::2]
    if sorted(starts) != list(starts):
        raise ParticipantVisibleError('Submitted values must be in ascending order.')
    try:
        return _rle_decode_jit(mask_rle, shape[0], shape[1]).reshape(shape, order='F')
    except ValueError as e:
        raise ParticipantVisibleError(str(e)) from e

def calculate_f1_score(pred_mask: npt.NDArray, gt_mask: npt.NDArray):
    pred_flat = pred_mask.flatten()
    gt_flat = gt_mask.flatten()
    tp = np.sum((pred_flat == 1) & (gt_flat == 1))
    fp = np.sum((pred_flat == 1) & (gt_flat == 0))
    fn = np.sum((pred_flat == 0) & (gt_flat == 1))
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall    = tp / (tp + fn) if (tp + fn) > 0 else 0
    return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

def calculate_f1_matrix(pred_masks: List[npt.NDArray], gt_masks: List[npt.NDArray]):
    num_instances_pred = len(pred_masks)
    num_instances_gt = len(gt_masks)
    f1_matrix = np.zeros((num_instances_pred, num_instances_gt))
    for i in range(num_instances_pred):
        for j in range(num_instances_gt):
            pred_flat = pred_masks[i].flatten()
            gt_flat = gt_masks[j].flatten()
            f1_matrix[i, j] = calculate_f1_score(pred_flat, gt_flat)
    if f1_matrix.shape[0] < len(gt_masks):
        f1_matrix = np.vstack((f1_matrix, np.zeros((len(gt_masks) - len(f1_matrix), num_instances_gt))))
    return f1_matrix

def oF1_score(pred_masks: List[npt.NDArray], gt_masks: List[npt.NDArray]):
    f1_matrix = calculate_f1_matrix(pred_masks, gt_masks)
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(-f1_matrix)
    if len(row_ind)==0: return 0.0
    excess_predictions_penalty = len(gt_masks) / max(len(pred_masks), len(gt_masks)) if max(len(pred_masks), len(gt_masks))>0 else 0.0
    return np.mean(f1_matrix[row_ind, col_ind]) * excess_predictions_penalty

def evaluate_single_image(label_rles: str, prediction_rles: str, shape_str: str) -> float:
    shape = json.loads(shape_str)
    label_rles = [rle_decode(x, shape=tuple(shape)) for x in label_rles.split(';')]
    prediction_rles = [rle_decode(x, shape=tuple(shape)) for x in prediction_rles.split(';')]
    return oF1_score(prediction_rles, label_rles)

def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    df = solution.copy().rename(columns={'annotation': 'label'})
    df['prediction'] = submission['annotation']
    authentic = (df['label'] == 'authentic') | (df['prediction'] == 'authentic')
    df['image_score'] = ((df['label'] == df['prediction']) & authentic).astype(float)
    df.loc[~authentic, 'image_score'] = df.loc[~authentic].apply(
        lambda row: evaluate_single_image(row['label'], row['prediction'], row['shape']), axis=1
    )
    return float(np.mean(df['image_score']))

# =============== Data & Splits ===============
def list_images(root: str) -> List[str]:
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    return sorted([str(p) for p in Path(root).rglob("*") if p.suffix.lower() in exts])

pairs_df = pd.read_csv(PAIRS_CSV)
val_df = pairs_df.sample(frac=0.2, random_state=GLOBAL_SEED)
train_df = pairs_df.drop(val_df.index)

auth_all = list_images(AUTH_DIR)
auth_val_count = max(1, int(0.2*len(auth_all)))
rng = np.random.default_rng(GLOBAL_SEED)
perm = rng.permutation(len(auth_all))
auth_val_idx = set(perm[:auth_val_count].tolist())
auth_train = [auth_all[i] for i in range(len(auth_all)) if i not in auth_val_idx]
auth_val   = [auth_all[i] for i in range(len(auth_all)) if i in auth_val_idx]

print(f"Train forged: {len(train_df)} | Val forged: {len(val_df)}")
print(f"Train authentic (info): {len(auth_train)} | Val authentic: {len(auth_val)}")

# =============== ELA helper ===============
def ela_map_gray_uint8(img_np_bgr, q=ELA_JPEG_Q):
    _, enc = cv2.imencode('.jpg', img_np_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), q])
    dec = cv2.imdecode(enc, cv2.IMREAD_COLOR)
    if dec is None: dec = img_np_bgr.copy()
    diff = cv2.absdiff(img_np_bgr, dec).astype(np.float32)
    m = float(np.percentile(diff, 99)); m = max(m, 1.0)
    ela = np.clip(diff / m, 0, 1.0)
    ela_gray = cv2.cvtColor((ela*255).astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
    return ela_gray

# =============== Box helpers (fix) ===============
from torchvision.ops import masks_to_boxes, nms

def _fix_boxes_inplace(boxes: torch.Tensor, H: int, W: int) -> torch.Tensor:
    if boxes.numel() == 0: return boxes
    boxes[:, 0] = boxes[:, 0].clamp(0, W - 1)  # x1
    boxes[:, 2] = boxes[:, 2].clamp(0, W - 1)  # x2
    boxes[:, 1] = boxes[:, 1].clamp(0, H - 1)  # y1
    boxes[:, 3] = boxes[:, 3].clamp(0, H - 1)  # y2
    bad_w = boxes[:, 2] <= boxes[:, 0]
    if bad_w.any():
        boxes[bad_w, 0] = (boxes[bad_w, 0] - 1).clamp(0, W - 2)
        boxes[bad_w, 2] = (boxes[bad_w, 0] + 1).clamp(1, W - 1)
    bad_h = boxes[:, 3] <= boxes[:, 1]
    if bad_h.any():
        boxes[bad_h, 1] = (boxes[bad_h, 1] - 1).clamp(0, H - 2)
        boxes[bad_h, 3] = (boxes[bad_h, 1] + 1).clamp(1, H - 1)
    return boxes

def _instances_to_valid_boxes(masks_np: np.ndarray, H: int, W: int) -> torch.Tensor:
    if masks_np.size == 0:
        return torch.zeros((0,4), dtype=torch.float32)
    m = torch.from_numpy(masks_np.astype(np.uint8))
    boxes = masks_to_boxes(m)  # [N,4] float32
    boxes = _fix_boxes_inplace(boxes, H, W)
    return boxes

# =============== Datasets (4ch: RGB+ELA) ===============
class ForgeryInstancesDataset(Dataset):
    def __init__(self, df_pairs: pd.DataFrame, mask_dir: str, train: bool):
        self.df = df_pairs.reset_index(drop=True)
        self.mask_dir = mask_dir
        self.train = train
        self.tf_resize = T.Resize((IMG_SIZE, IMG_SIZE))
        self.to_tensor = T.ToTensor()

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pf = row['forg_path']  # تصویر دستکاری‌شده
        img_pil = Image.open(pf).convert("RGB")

        # GT mask
        stem = Path(pf).stem
        mp = os.path.join(self.mask_dir, stem + ".npy")
        if os.path.exists(mp):
            m = np.load(mp)
            if m.ndim==3: m = np.max(m, axis=0)
        else:
            m = np.zeros(img_pil.size[::-1], dtype=np.uint8)  # (H,W)
        m_bin = (m > 0).astype(np.uint8)

        # ELA
        img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
        ela = ela_map_gray_uint8(img_bgr, q=ELA_JPEG_Q)

        # Resize 
        img_pil = self.tf_resize(img_pil)
        ela_rs  = cv2.resize(ela, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
        mask_rs = cv2.resize(m_bin, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        # 4ch 
        rgb   = self.to_tensor(img_pil)                         # [3,H,W] 0..1
        ela_t = torch.from_numpy(ela_rs).float().unsqueeze(0)   # [1,H,W]
        img4  = torch.cat([rgb, ela_t], dim=0)                  # [4,H,W]

        # 
        num, labels = cv2.connectedComponents(mask_rs, connectivity=8)
        insts = []
        for k in range(1, num):
            inst = (labels == k).astype(np.uint8)
            if int(inst.sum()) >= MIN_INSTANCE_AREA:
                insts.append(inst)

        if len(insts) == 0:
            target = {
                "boxes": torch.zeros((0,4), dtype=torch.float32),
                "labels": torch.zeros((0,), dtype=torch.int64),
                "masks": torch.zeros((0, IMG_SIZE, IMG_SIZE), dtype=torch.uint8),
                "image_id": torch.tensor([idx]),
                "area": torch.zeros((0,), dtype=torch.float32),
                "iscrowd": torch.zeros((0,), dtype=torch.int64),
            }
        else:
            masks_t = torch.from_numpy(np.stack(insts, 0)).to(torch.uint8)  # [N,H,W]
            boxes_t = _instances_to_valid_boxes(masks_t.numpy(), IMG_SIZE, IMG_SIZE)  # [N,4]
            keep = (boxes_t[:,2] > boxes_t[:,0]) & (boxes_t[:,3] > boxes_t[:,1])
            if keep.sum().item() < boxes_t.size(0):
                masks_t = masks_t[keep]
                boxes_t = boxes_t[keep]
            labels_t= torch.ones((masks_t.size(0),), dtype=torch.int64)   # forgery=1
            area_t  = masks_t.sum(dim=(1,2)).float()
            iscrowd = torch.zeros((masks_t.size(0),), dtype=torch.int64)
            target = {
                "boxes": boxes_t, "labels": labels_t, "masks": masks_t,
                "image_id": torch.tensor([idx]), "area": area_t, "iscrowd": iscrowd
            }

        return img4, target

class AuthenticEmptyDataset(Dataset):
    
    def __init__(self, paths: List[str]):
        self.paths = paths
        self.tf_resize = T.Resize((IMG_SIZE, IMG_SIZE))
        self.to_tensor = T.ToTensor()
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        ela = ela_map_gray_uint8(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), q=ELA_JPEG_Q)
        img  = self.tf_resize(img)
        ela  = cv2.resize(ela, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
        x4   = torch.cat([self.to_tensor(img), torch.from_numpy(ela).float().unsqueeze(0)], dim=0)
        target = {
            "boxes": torch.zeros((0,4), dtype=torch.float32),
            "labels": torch.zeros((0,), dtype=torch.int64),
            "masks": torch.zeros((0, IMG_SIZE, IMG_SIZE), dtype=torch.uint8),
            "image_id": torch.tensor([idx]),
            "area": torch.zeros((0,), dtype=torch.float32),
            "iscrowd": torch.zeros((0,), dtype=torch.int64),
        }
        return x4, target

def collate_fn(batch):
    imgs, tgts = list(zip(*batch))
    return list(imgs), list(tgts)

train_ds_pos = ForgeryInstancesDataset(train_df, MASK_DIR, train=True)
val_ds_pos   = ForgeryInstancesDataset(val_df,   MASK_DIR, train=False)
neg_train    = AuthenticEmptyDataset(auth_train)
neg_val      = AuthenticEmptyDataset(auth_val)


k = min(len(train_ds_pos), len(neg_train))
rng = np.random.default_rng(GLOBAL_SEED)
idxs = rng.choice(len(neg_train), k, replace=False)
neg_train_small = Subset(neg_train, idxs)

train_mix = ConcatDataset([train_ds_pos, neg_train_small])
val_mix   = ConcatDataset([val_ds_pos, neg_val])

train_loader = DataLoader(train_mix, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=collate_fn, persistent_workers=False)
val_loader   = DataLoader(val_mix, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=collate_fn, persistent_workers=False)

print(f"[RCNN] Train samples (pos+neg): {len(train_mix)} | Val samples (pos+neg): {len(val_mix)}")

# =============== Model (4ch, from scratch) ===============
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def maskrcnn_resnet50_fpn_4ch_from_scratch(num_classes=2):
    model = maskrcnn_resnet50_fpn(weights=None, weights_backbone=None)
    # conv1 چهارکاناله
    old_conv1 = model.backbone.body.conv1
    new_conv1 = nn.Conv2d(4, old_conv1.out_channels, kernel_size=old_conv1.kernel_size,
                          stride=old_conv1.stride, padding=old_conv1.padding, bias=False)
    with torch.no_grad():
        nn.init.kaiming_normal_(new_conv1.weight, mode='fan_out', nonlinearity='relu')
    model.backbone.body.conv1 = new_conv1

    # heads
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    # init heads
    for m in [model.roi_heads.box_predictor.cls_score,
              model.roi_heads.box_predictor.bbox_pred,
              model.roi_heads.mask_predictor.conv5_mask,
              model.roi_heads.mask_predictor.mask_fcn_logits]:
        if hasattr(m, 'weight') and m.weight is not None:
            nn.init.normal_(m.weight, std=0.01)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias, 0)

    # Normalize 
    model.transform.image_mean = [0.0, 0.0, 0.0, 0.0]
    model.transform.image_std  = [1.0, 1.0, 1.0, 1.0]
    model.transform.min_size = (IMG_SIZE,)
    model.transform.max_size = IMG_SIZE
    return model

# build
try:
    del model
    torch.cuda.empty_cache()
except:
    pass

model = maskrcnn_resnet50_fpn_4ch_from_scratch(num_classes=2).to(device)

MULTI_GPU = (torch.cuda.is_available() and torch.cuda.device_count() > 1)

# =============== Optimizer + Warmup ===============
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
                      lr=WARMUP_START_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

global_iter = 0
def warmup_lr(it):
    if it >= WARMUP_ITERS:
        return BASE_LR
    
    return WARMUP_START_LR + (BASE_LR - WARMUP_START_LR) * (it / max(1, WARMUP_ITERS))

# =============== Utilities ===============
def is_finite_tensor(x: torch.Tensor) -> bool:
    return torch.isfinite(x).all().item()

@torch.no_grad()
def quick_val_stats(model, loader):
    model.eval()
    n_imgs, n_preds, n_auth_preds = 0, 0, 0
    for imgs, _ in loader:
        outs = model([im.to(device) for im in imgs])
        for out in outs:
            n_imgs += 1
            scores = out.get("scores", torch.tensor([]))
            k = int((scores.detach().cpu() >= SCORE_THR).sum())
            n_preds += k
            if k == 0: n_auth_preds += 1
    print(f"[VAL] images={n_imgs} | avg_preds_per_image={n_preds/max(1,n_imgs):.2f} | authentic_preds%={100*n_auth_preds/max(1,n_imgs):.1f}%")

# RLE sanity test
m = (np.random.rand(64,64) > 0.7).astype(np.uint8)
enc = rle_encode([m])
_dec = rle_decode(enc.split(';')[0], (64,64))
assert np.all(m == _dec), "RLE encode/decode mismatch!"

# =============== Train (NaN-safe) ===============
for epoch in range(1, EPOCHS+1):
    model.train()
    loss_hist = []
    pbar = tqdm(train_loader, desc=f"RCNN Train {epoch}/{EPOCHS}")
    for imgs, tgts in pbar:
        # warmup lr
        lr_now = warmup_lr(global_iter)
        for g in optimizer.param_groups: g['lr'] = lr_now
        global_iter += 1

        imgs = [im.to(device) for im in imgs]
        tgts = [{k: (v.to(device) if torch.is_tensor(v) else v) for k,v in t.items()} for t in tgts]

        
        loss_dict = model(imgs, tgts)   # dict of losses
        
        bad = False
        total = 0.0
        for k, v in loss_dict.items():
            if not is_finite_tensor(v):
                bad = True; break
            total += float(v.detach().cpu().item())
        if bad or not np.isfinite(total):
            optimizer.zero_grad(set_to_none=True)
            continue

        loss = sum(loss_dict.values())
        if not is_finite_tensor(loss):
            optimizer.zero_grad(set_to_none=True)
            continue

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_NORM)
        optimizer.step()

        lv = float(loss.detach().cpu().item())
        loss_hist.append(lv)
        pbar.set_postfix(loss=np.mean(loss_hist) if len(loss_hist) else lv)

    torch.cuda.empty_cache(); gc.collect()
    
    quick_val_stats(model, val_loader)

# =============== Validation (official metric) ===============
@torch.no_grad()
def build_val_gt_and_pred_rcnn() -> Tuple[pd.DataFrame, pd.DataFrame]:
    model.eval()
    rows_gt, rows_sub = [], []
    for imgs, tgts in tqdm(val_loader, desc="RCNN Val"):
        imgs = [im.to(device) for im in imgs]
        outs = model(imgs)

        for im, tgt, out in zip(imgs, tgts, outs):
            H, W = im.shape[-2], im.shape[-1]
            masks_gt = tgt["masks"].numpy().astype(np.uint8) if tgt["masks"].numel()>0 else np.zeros((0,H,W), np.uint8)
            if masks_gt.shape[0]==0:
                rows_gt.append({"row_id": f"val_{len(rows_gt)}", "annotation": "authentic", "shape": "authentic"})
            else:
                rows_gt.append({"row_id": f"val_{len(rows_gt)}", "annotation": rle_encode([m for m in masks_gt]), "shape": json.dumps([H,W])})

            if ("scores" not in out) or (len(out["scores"])==0):
                pred = "authentic"
            else:
                boxes = out["boxes"].detach().cpu()
                scores= out["scores"].detach().cpu()
                keep = nms(boxes, scores, NMS_THR)
                boxes, scores = boxes[keep], scores[keep]
                keep2 = scores >= SCORE_THR
                boxes, scores = boxes[keep2], scores[keep2]

                insts = []
                if "masks" in out and len(out["masks"])>0 and len(keep)>0:
                    masks = out["masks"].detach().cpu()[keep][:,0]    # [K,H,W]
                    masks = masks[keep2]
                    for m in masks:
                        mb = (m.numpy() >= MASK_THR).astype(np.uint8)
                        if mb.sum() >= MIN_INSTANCE_AREA:
                            insts.append(mb)
                pred = rle_encode(insts) if len(insts)>0 else "authentic"

            rows_sub.append({"row_id": rows_gt[-1]["row_id"], "annotation": pred})

    gt_df = pd.DataFrame(rows_gt)
    sub_df= pd.DataFrame(rows_sub).reindex(range(len(rows_gt)))
    return gt_df, sub_df

gt_df_val, sub_df_val = build_val_gt_and_pred_rcnn()
overall_val = score(gt_df_val.copy(), sub_df_val.copy(), row_id_column_name="row_id")
print("\n====== RCNN Validation (OFFICIAL METRIC) ======")
print(f"Overall: {overall_val:.6f}")
print("===============================================\n")

# =============== Test + submission ===============
@torch.no_grad()
def predict_one_test_rcnn(image_path: str) -> str:
    img = Image.open(image_path).convert("RGB")
    rgb = T.Resize((IMG_SIZE, IMG_SIZE))(img)
    rgb_t = T.ToTensor()(rgb)
    ela = ela_map_gray_uint8(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), q=ELA_JPEG_Q)
    ela_rs = cv2.resize(ela, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
    ela_t = torch.from_numpy(ela_rs).float().unsqueeze(0)
    x4 = torch.cat([rgb_t, ela_t], dim=0).unsqueeze(0).to(device)

    out = model(x4)[0]
    if ("scores" not in out) or (len(out["scores"])==0):
        return "authentic"
    boxes = out["boxes"].detach().cpu()
    scores= out["scores"].detach().cpu()
    keep = nms(boxes, scores, NMS_THR)
    boxes, scores = boxes[keep], scores[keep]
    keep2 = scores >= SCORE_THR
    if keep2.sum()==0: return "authentic"
    masks = out["masks"].detach().cpu()[keep][:,0][keep2]  # [K,H,W]
    insts = []
    for m in masks:
        mb = (m.numpy() >= MASK_THR).astype(np.uint8)
        if mb.sum() >= MIN_INSTANCE_AREA:
            insts.append(mb)
    return rle_encode(insts) if len(insts)>0 else "authentic"

def build_submission_rcnn(test_dir: str, sample_csv_path: str, out_csv_path: str = "submission_rcnn.csv"):
    if os.path.exists(sample_csv_path):
        df_sub = pd.read_csv(sample_csv_path)
        id_col, target_col = df_sub.columns[0], df_sub.columns[1]
        test_files = [str(p) for p in Path(test_dir).glob("*")]
        by_stem = {Path(p).stem: p for p in test_files}
        by_name = {Path(p).name: p for p in test_files}
        preds = []
        for rid in tqdm(df_sub[id_col].astype(str), desc="Predicting test (RCNN)"):
            cand = by_name.get(rid) or by_stem.get(rid)
            if cand is None:
                cand = next((by_name.get(rid+ext) for ext in (".png",".jpg",".jpeg",".tif",".tiff",".bmp") if by_name.get(rid+ext)), None)
            preds.append("authentic" if cand is None else predict_one_test_rcnn(cand))
        out = df_sub.copy()
        out[target_col] = preds
        out.to_csv(out_csv_path, index=False)
    else:
        test_files = sorted([str(p) for p in Path(test_dir).glob("*") if p.lower().endswith((".png",".jpg",".jpeg",".tif",".tiff",".bmp"))])
        preds = [predict_one_test_rcnn(p) for p in tqdm(test_files, desc="Predicting test (RCNN)")]
        out = pd.DataFrame({"case_id":[Path(p).stem for p in test_files], "annotation":preds})
        out.to_csv(out_csv_path, index=False)
    print(f"Saved submission to {out_csv_path}")
    return out_csv_path

if os.path.exists(TEST_DIR):
    build_submission_rcnn(TEST_DIR, SAMPLE_SUB, out_csv_path="submission.csv")

torch.cuda.empty_cache(); gc.collect()


In [None]:
# ============================================================
# ReCoDAI | Visualization for Mask R-CNN (RGB+ELA, 4ch)
# Shows Authentic | Forged | Predicted (and GT if exists)
# ============================================================
import os, cv2, torch, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from torchvision.ops import nms

VIS_SCORE_THR = 0.05     
VIS_MASK_THR  = 0.20     
VIS_NMS_THR   = 0.50
TOPK          = 5        


def make_4ch_tensor(pil_img, img_size=IMG_SIZE, ela_q=ELA_JPEG_Q, device=device):
    rgb = T.Resize((img_size, img_size))(pil_img)
    rgb_t = T.ToTensor()(rgb)  # [3,H,W] in [0,1]
    # ELA
    ela = ela_map_gray_uint8(cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR), q=ela_q)
    ela_rs = cv2.resize(ela, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
    ela_t = torch.from_numpy(ela_rs).float().unsqueeze(0)  # [1,H,W]
    x4 = torch.cat([rgb_t, ela_t], dim=0)  # [4,H,W]
    return x4.to(device)

def to_rgb_disp(x4_tensor):
    
    x = x4_tensor.detach().float().cpu().numpy()
    x = np.transpose(x[:3], (1,2,0))
    return np.clip(x*255, 0, 255).astype(np.uint8)

@torch.no_grad()
def rcnn_predict_masks(pil_img,
                       score_thr=VIS_SCORE_THR,
                       mask_thr=VIS_MASK_THR,
                       nms_thr=VIS_NMS_THR,
                       topk=TOPK):
    
    assert 'model' in globals(), "Model not found. Make sure 'model' is defined."
    model.eval()
    x = make_4ch_tensor(pil_img).unsqueeze(0)
    out = model(x)[0]

    if ("scores" not in out) or (len(out["scores"]) == 0):
        return []

    boxes  = out["boxes"].detach().cpu()
    scores = out["scores"].detach().cpu()

    keep = nms(boxes, scores, nms_thr)
    boxes, scores = boxes[keep], scores[keep]

    order = torch.argsort(scores, descending=True)
    boxes, scores = boxes[order], scores[order]
    if len(scores) > topk:
        boxes, scores = boxes[:topk], scores[:topk]

    masks = []
    if "masks" in out and len(out["masks"]) > 0:
        m_all = out["masks"].detach().cpu()[keep][:,0][order]
        if len(m_all) > topk: m_all = m_all[:topk]
        for m, sc in zip(m_all, scores):
            sc = float(sc.item())
            if sc < score_thr: 
                continue
            mb = (m.numpy() >= mask_thr).astype(np.uint8)
            if int(mb.sum()) >= MIN_INSTANCE_AREA:
                masks.append((mb, sc))
    return masks

def resize_mask_to(img_size_wh, mask_small):
    
    W, H = img_size_wh
    return cv2.resize(mask_small.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)

def overlay_colored(base_img_bgr, masks_scores, alpha=0.45):
    
    out = base_img_bgr.copy()
    palette = [
        (0, 0, 255),   # قرمز (BGR)
        (0, 255, 0),   # سبز
        (255, 0, 0),   # آبی
        (0, 255, 255), # زرد
        (255, 0, 255), # ارغوانی
        (255, 255, 0), # فیروزه‌ای
        (0, 128, 255), # نارنجی کم‌رنگ
        (255, 0, 128), # صورتی
    ]
    for idx, (m, sc) in enumerate(masks_scores):
        color = palette[idx % len(palette)]
        color_img = np.zeros_like(out, dtype=np.uint8)
        color_img[:] = color
        m3 = m[..., None].astype(bool)
        blended = (alpha * color_img + (1 - alpha) * out).astype(np.uint8)
        out = np.where(m3, blended, out)

        # کانتور + نمره
        cnts, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(out, cnts, -1, (255, 255, 255), 1)
        if len(cnts):
            x,y,w,h = cv2.boundingRect(np.vstack(cnts))
            cv2.putText(out, f"{sc:.2f}", (x, max(12,y-4)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255,255,255), 1, cv2.LINE_AA)
    return out

def load_gt_mask_if_exists(forged_path, target_size=None):
    
    stem = Path(forged_path).stem
    mp = os.path.join(MASK_DIR, stem + ".npy")
    if os.path.exists(mp):
        m = np.load(mp)
        if m.ndim == 3: m = np.max(m, axis=0)
        m = (m > 0).astype(np.uint8)
        if target_size is not None:
            W, H = target_size
            m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
        return m
    return None

def pil_to_bgr(img_pil):
    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

def bgr_to_rgb(img_bgr):
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)


@torch.no_grad()
def visualize_pairs_with_masks(df=None, n_samples=8,
                               upscale_to_original=True,
                               show_gt=True,
                               random_state=123,
                               save_path=None):

    assert 'model' in globals(), "Model not found. Make sure 'model' is defined."
    use_df = df if df is not None else val_df
    samp = use_df.sample(n=min(n_samples, len(use_df)), random_state=random_state).reset_index(drop=True)

    rows = len(samp)
    cols = 4 if show_gt else 3
    plt.figure(figsize=(14, 3.2*rows))

    plot_idx = 1
    for _, row in samp.iterrows():
        ap, fp = row['auth_path'], row['forg_path']
        A = Image.open(ap).convert("RGB")
        F = Image.open(fp).convert("RGB")

        # نمایش نسخه‌ی 128×128 (برای ستون Authentic/Forged)
        a_disp_small = to_rgb_disp(make_4ch_tensor(A).unsqueeze(0)[0])
        f_disp_small = to_rgb_disp(make_4ch_tensor(F).unsqueeze(0)[0])

        
        masks_scores_small = rcnn_predict_masks(F)  

        
        if upscale_to_original:
            
            W0, H0 = F.size
            masks_scores = [(resize_mask_to((W0, H0), m), sc) for (m, sc) in masks_scores_small]
            base_bgr = pil_to_bgr(F)
        else:
            masks_scores = masks_scores_small
            base_bgr = cv2.resize(pil_to_bgr(F), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)

        overlay_pred_bgr = overlay_colored(base_bgr, masks_scores) if len(masks_scores) else base_bgr.copy()

        
        overlay_gt_rgb = None
        if show_gt:
            if upscale_to_original:
                gt = load_gt_mask_if_exists(fp, target_size=F.size)
                base_for_gt = np.array(F)  # RGB
            else:
                gt = load_gt_mask_if_exists(fp, target_size=(IMG_SIZE, IMG_SIZE))
                base_for_gt = cv2.resize(np.array(F), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)

            if gt is not None:
                green = np.zeros_like(base_for_gt)
                green[...,1] = 255
                overlay_gt = np.where(gt[...,None].astype(bool),
                                      (0.45*green + 0.55*base_for_gt).astype(np.uint8),
                                      base_for_gt)
                cnts, _ = cv2.findContours(gt.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(overlay_gt, cnts, -1, (255,255,255), 1)
                overlay_gt_rgb = overlay_gt

        
        ax = plt.subplot(rows, cols, plot_idx); plot_idx += 1
        ax.imshow(a_disp_small); ax.set_title("Authentic"); ax.axis('off')

        ax = plt.subplot(rows, cols, plot_idx); plot_idx += 1
        ax.imshow(f_disp_small); ax.set_title("Forged"); ax.axis('off')

        ax = plt.subplot(rows, cols, plot_idx); plot_idx += 1
        ax.imshow(bgr_to_rgb(overlay_pred_bgr))
        ttl = "Predicted"
        if len(masks_scores):
            ttl += f" (n={len(masks_scores)})"
        ax.set_title(ttl); ax.axis('off')

        if show_gt:
            ax = plt.subplot(rows, cols, plot_idx); plot_idx += 1
            if overlay_gt_rgb is not None:
                ax.imshow(overlay_gt_rgb)
                ax.set_title("GT mask")
            else:
                ax.imshow(np.zeros_like(a_disp_small))
                ax.set_title("GT mask (N/A)")
            ax.axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to: {save_path}")
    plt.show()


visualize_pairs_with_masks(df=val_df,
                           n_samples=8,
                           upscale_to_original=True,
                           show_gt=True,
                           random_state=123,
                           save_path=None)            
