In [None]:
# ==================================================================================
# TITAN-FAGT: LOCAL END-TO-END VALIDATOR (Mac Edition)
# ==================================================================================
# 1. GOAL: Calculate F1, Precision, and Recall on unseen LOCAL training data.
# 2. DATA: Matches 'train_images' with corresponding 'train_masks' (.npy).
# 3. FIX: Restored missing BASE_PATH and path recursion logic.
# ==================================================================================

import os
import cv2
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Offline TIMM Check
try:
    import timm
except ImportError:
    import sys
    sys.path.append('/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master')
    import timm

class CFG:
    # --- PATHS (Adjusted for your local folder tree) ---
    BASE_PATH = './recodai-luc-scientific-image-forgery-detection'
    IMG_DIR = os.path.join(BASE_PATH, 'train_images')
    MASK_DIR = os.path.join(BASE_PATH, 'train_masks')
    WEIGHTS_PATH = './TITAN_V2_UNLEASHED.pth'
    
    img_size = 384
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    THRESHOLD = 0.50 

# --- 2. ARCHITECTURE ---
class FrequencyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, kernel_size=3, padding=1)
    def forward(self, x):
        fft = torch.fft.fft2(x); fft_shift = torch.fft.fftshift(fft)
        B, C, H, W = x.shape; cy, cx = H // 2, W // 2
        y, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        y, x_grid = y.to(x.device), x_grid.to(x.device)
        mask = (torch.sqrt((y - cy)**2 + (x_grid - cx)**2) > 15).float().unsqueeze(0).unsqueeze(0)
        img_back = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(fft_shift * mask)))
        return self.conv(img_back)

class GraphModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Linear(dim, dim)
        self.gcn = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))
    def forward(self, x):
        B, C, H, W = x.shape
        nodes = x.flatten(2).transpose(1, 2)
        q = self.proj(nodes); attn = F.softmax(torch.matmul(q, q.transpose(-2, -1)) / (C**0.5), dim=-1)
        out = self.gcn(torch.matmul(attn, nodes))
        return x + out.transpose(1, 2).reshape(B, C, H, W)

class FAGT_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model(CFG.model_name if hasattr(CFG, 'model_name') else 'swin_base_patch4_window12_384', pretrained=False, features_only=True)
        self.dims = self.encoder.feature_info.channels()
        last_dim = self.dims[-1]
        self.physics = FrequencyBlock(); self.graph = GraphModule(last_dim)
        self.fusion = nn.Conv2d(last_dim + 32, 256, 1)
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(32, 1, 1))
    def forward(self, x):
        enc_feats = self.encoder(x); deep_feats = enc_feats[-1]
        if deep_feats.ndim == 4 and deep_feats.shape[-1] == self.dims[-1]: deep_feats = deep_feats.permute(0, 3, 1, 2)
        phys_feats = self.physics(x); graph_feats = self.graph(deep_feats)
        phys_resized = F.interpolate(phys_feats, size=graph_feats.shape[-2:], mode='bilinear', align_corners=False)
        fused = self.fusion(torch.cat([graph_feats, phys_resized], dim=1))
        return F.interpolate(self.decoder(fused), size=x.shape[-2:], mode='bilinear', align_corners=False)

# --- 3. METRIC HELPERS ---
def calculate_pixel_stats(pred_mask, gt_mask):
    pred = (pred_mask > CFG.THRESHOLD).astype(np.uint8)
    gt = (gt_mask > 0).astype(np.uint8)
    tp = np.sum((pred == 1) & (gt == 1))
    fp = np.sum((pred == 1) & (gt == 0))
    fn = np.sum((pred == 0) & (gt == 1))
    return tp, fp, fn

# --- 4. EXECUTION ---
def run_local_test():
    print(f">>> RUNNING VALIDATOR ON: {CFG.device}")
    
    # Correct recursion to find images in subfolders
    img_paths = glob.glob(os.path.join(CFG.IMG_DIR, '**', '*.*'), recursive=True)
    img_paths = [p for p in img_paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.tif'))]
    
    if not img_paths:
        print(f"!!! ERROR: No images found in {CFG.IMG_DIR}")
        return

    model = FAGT_Model()
    if os.path.exists(CFG.WEIGHTS_PATH):
        model.load_state_dict(torch.load(CFG.WEIGHTS_PATH, map_location=CFG.device))
        print(f">>> Weights Loaded: {CFG.WEIGHTS_PATH}")
    else:
        print(f"!!! CRITICAL: Weights missing at {CFG.WEIGHTS_PATH}")
        return
    model.to(CFG.device).eval()

    transform = A.Compose([A.Resize(CFG.img_size, CFG.img_size), A.Normalize(), ToTensorV2()])

    total_tp, total_fp, total_fn = 0, 0, 0
    results = []

    for path in tqdm(img_paths):
        base = os.path.basename(path).split('.')[0]
        mask_path = os.path.join(CFG.MASK_DIR, f"{base}.npy")
        if not os.path.exists(mask_path): mask_path = os.path.join(CFG.MASK_DIR, f"{base}_mask.npy")
        if not os.path.exists(mask_path): continue

        # 1. Load Data
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]
        gt_mask = np.load(mask_path)
        if gt_mask.ndim == 3: gt_mask = gt_mask.max(axis=2)
        gt_mask = cv2.resize((gt_mask > 0).astype(np.float32), (w, h), interpolation=cv2.INTER_NEAREST)

        # 2. Inference
        img_t = transform(image=image)['image'].unsqueeze(0).to(CFG.device)
        with torch.no_grad():
            pred = torch.sigmoid(model(img_t)).cpu().numpy().squeeze()
        pred_full = cv2.resize(pred, (w, h))

        # 3. Calculate Stats
        tp, fp, fn = calculate_pixel_stats(pred_full, gt_mask)
        total_tp += tp; total_fp += fp; total_fn += fn
        
        # Live feedback like Kaggle script
        img_f1 = (2*tp)/(2*tp + fp + fn + 1e-7)
        print(f"  [+] {base}: Pixel F1 {img_f1:.4f} | Max Prob: {np.max(pred):.4f}")
        results.append({"case_id": base, "f1": img_f1, "max_prob": np.max(pred)})

    if not results:
        print(">>> No matching images and masks found.")
        return

    # Final Summary
    precision = total_tp / (total_tp + total_fp + 1e-7)
    recall = total_tp / (total_tp + total_fn + 1e-7)
    final_f1 = (2 * precision * recall) / (precision + recall + 1e-7)

    print("\n" + "="*40)
    print(f"FINAL LOCAL PERFORMANCE (Unseen Data):")
    print(f"PRECISION: {precision:.4f}")
    print(f"RECALL:    {recall:.4f}")
    print(f"F1-SCORE:  {final_f1:.4f}")
    print("="*40)
    pd.DataFrame(results).to_csv("local_test_results.csv", index=False)

if __name__ == "__main__":
    run_local_test()