In [6]:
# ==================================================================================
# TITAN-FAGT V3.1: LOCAL HYSTERESIS AUDIT
# ==================================================================================
# 1. GOAL: Test the "Grandmaster" Hysteresis logic on local unseen data.
# 2. LOGIC: High Thresh 0.85 (Seed) + Low Thresh 0.45 (Growth).
# 3. METRIC: Calculates real F1 Score to verify if this strategy works.
# ==================================================================================

import matplotlib
matplotlib.use('Agg') # Safe Mode

import os
import cv2
import sys
import glob
import time
import gc
import io
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score, precision_score, recall_score
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from IPython.display import display, clear_output, Image
from tqdm.notebook import tqdm

# Install timm
try:
    import timm
except ImportError:
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "timm"])
    import timm

# --- 1. CONFIGURATION ---
class CFG:
    BASE_DIR = '/Users/chanduchitikam/recodai/recodai-luc-scientific-image-forgery-detection'
    TRAIN_IMG_PATH = os.path.join(BASE_DIR, 'train_images') 
    TRAIN_MASK_PATH = os.path.join(BASE_DIR, 'train_masks')
    
    # YOUR CHAMPION WEIGHTS
    WEIGHTS_PATH = "TITAN_V2_UNLEASHED.pth"
    
    model_name = 'swin_base_patch4_window12_384'
    img_size = 384
    batch_size = 1
    device = torch.device("mps") 
    
    SEED = 42
    NUM_VISUALS = 5
    
    # --- HYSTERESIS SETTINGS TO TEST ---
    HIGH_THRESH = 0.85
    LOW_THRESH = 0.45
    MIN_PIXELS = 50

# --- 2. ARCHITECTURE ---
class FrequencyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, kernel_size=3, padding=1)
    def forward(self, x):
        fft = torch.fft.fft2(x); fft_shift = torch.fft.fftshift(fft)
        B, C, H, W = x.shape; cy, cx = H // 2, W // 2
        y, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        y, x_grid = y.to(x.device), x_grid.to(x.device)
        mask = (torch.sqrt((y - cy)**2 + (x_grid - cx)**2) > 15).float().unsqueeze(0).unsqueeze(0)
        img_back = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(fft_shift * mask)))
        return self.conv(img_back)

class GraphModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Linear(dim, dim)
        self.gcn = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))
    def forward(self, x):
        B, C, H, W = x.shape
        nodes = x.flatten(2).transpose(1, 2)
        q = self.proj(nodes); attn = F.softmax(torch.matmul(q, q.transpose(-2, -1)) / (C**0.5), dim=-1)
        out = self.gcn(torch.matmul(attn, nodes))
        return x + out.transpose(1, 2).reshape(B, C, H, W)

class FAGT_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model(CFG.model_name, pretrained=False, features_only=True)
        self.dims = self.encoder.feature_info.channels()
        last_dim = self.dims[-1]
        self.physics = FrequencyBlock(); self.graph = GraphModule(last_dim)
        self.fusion = nn.Conv2d(last_dim + 32, 256, 1)
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(32, 1, 1))
    def forward(self, x):
        enc_feats = self.encoder(x); deep_feats = enc_feats[-1]
        if deep_feats.ndim == 4 and deep_feats.shape[-1] == self.dims[-1]: deep_feats = deep_feats.permute(0, 3, 1, 2)
        phys_feats = self.physics(x); graph_feats = self.graph(deep_feats)
        phys_resized = F.interpolate(phys_feats, size=graph_feats.shape[-2:], mode='bilinear', align_corners=False)
        fused = self.fusion(torch.cat([graph_feats, phys_resized], dim=1))
        return F.interpolate(self.decoder(fused), size=x.shape[-2:], mode='bilinear', align_corners=False)

# --- 3. HYSTERESIS LOGIC ---
def apply_hysteresis(prob_map, high, low, min_size):
    h, w = prob_map.shape
    strong_mask = (prob_map >= high).astype(np.uint8)
    weak_mask = (prob_map >= low).astype(np.uint8)
    
    # Fast Fail: If no strong seed, return all black
    if strong_mask.sum() < min_size:
        return np.zeros((h, w), dtype=np.uint8)
    
    # Grow Weak regions that touch Strong regions
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(weak_mask, connectivity=8)
    final_mask = np.zeros((h, w), dtype=np.uint8)
    
    for i in range(1, num_labels):
        blob_mask = (labels == i).astype(np.uint8)
        # If overlap with strong mask, keep the blob
        if cv2.bitwise_and(blob_mask, strong_mask).sum() > 0:
            final_mask = cv2.bitwise_or(final_mask, blob_mask)
            
    return final_mask

# --- 4. DATA SETUP ---
def get_validation_data():
    print(">>> Replicating Fold 0 Split...")
    auth = sorted(glob.glob(f'{CFG.TRAIN_IMG_PATH}/authentic/*.*'))
    forg = sorted(glob.glob(f'{CFG.TRAIN_IMG_PATH}/forged/*.*'))
    data = [{'image_path': f, 'mask_path': None, 'label': 0} for f in auth]
    for f in forg:
        base = os.path.basename(f).split('.')[0]
        mp = f'{CFG.TRAIN_MASK_PATH}/{base}.npy'
        if not os.path.exists(mp): mp = f'{CFG.TRAIN_MASK_PATH}/{base}_mask.npy'
        data.append({'image_path': f, 'mask_path': mp, 'label': 1})
    df = pd.DataFrame(data)
    kf = KFold(n_splits=5, shuffle=True, random_state=CFG.SEED)
    _, val_idx = next(kf.split(df))
    return df.iloc[val_idx].copy()

def visualize_result(img_tensor, mask_tensor, pred_mask, fname):
    mean, std = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
    img_np = np.clip(img_tensor.detach().cpu().permute(1,2,0).numpy()*std+mean, 0, 1)
    mask_np = mask_tensor.detach().cpu().squeeze().numpy()
    
    fig = plt.figure(figsize=(15, 5), facecolor='#0f0f0f')
    gs = gridspec.GridSpec(1, 3)
    
    p1 = fig.add_subplot(gs[0, 0]); p1.imshow(img_np); p1.axis('off'); p1.set_title(f"SOURCE: {fname}", color='white')
    p2 = fig.add_subplot(gs[0, 1]); p2.imshow(mask_np, cmap='gray'); p2.axis('off'); p2.set_title("TRUTH", color='white')
    p3 = fig.add_subplot(gs[0, 2]); p3.imshow(pred_mask, cmap='jet', vmin=0, vmax=1); p3.axis('off'); p3.set_title("HYSTERESIS PRED", color='white')
    
    plt.tight_layout(); display(fig); plt.close(fig); gc.collect()

# --- 5. EXECUTION ---
def run_local_audit():
    val_df = get_validation_data()
    print(f">>> Loading Weights: {CFG.WEIGHTS_PATH}")
    model = FAGT_Model().to(CFG.device)
    if os.path.exists(CFG.WEIGHTS_PATH): model.load_state_dict(torch.load(CFG.WEIGHTS_PATH, map_location=CFG.device))
    else: print("!!! CRITICAL: Weights missing."); return
    model.eval()
    
    class AuditDataset(torch.utils.data.Dataset):
        def __init__(self, df): self.df = df; self.tf = A.Compose([A.Resize(CFG.img_size, CFG.img_size), A.Normalize(), ToTensorV2()])
        def __len__(self): return len(self.df)
        def __getitem__(self, idx):
            row = self.df.iloc[idx]; img = cv2.imread(row['image_path'])
            if img is None: img = np.zeros((384,384,3), dtype=np.uint8)
            else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = np.zeros(img.shape[:2], dtype=np.float32)
            if row['label']==1 and row['mask_path']:
                try: 
                    m=np.load(row['mask_path'])
                    if m.ndim==3: m=m.max(axis=2)
                    mask = cv2.resize((m>0).astype(np.float32), (img.shape[1], img.shape[0]), interpolation=0)
                except: pass
            aug = self.tf(image=img, mask=mask)
            return aug['image'], aug['mask'].unsqueeze(0), row['image_path']

    dl = torch.utils.data.DataLoader(AuditDataset(val_df), batch_size=1, shuffle=False)
    print(f">>> Testing Hysteresis (High: {CFG.HIGH_THRESH}, Low: {CFG.LOW_THRESH}) on {len(dl)} images...")
    
    all_preds, all_gts = [], []
    visual_indices = sorted(random.sample(range(len(dl)), min(len(dl), CFG.NUM_VISUALS)))
    
    with torch.no_grad():
        for i, (img, msk, path) in enumerate(tqdm(dl)):
            img, msk = img.to(CFG.device), msk.to(CFG.device)
            
            # TTA: Base + Flip
            p1 = torch.sigmoid(model(img))[0,0]
            p2 = torch.flip(torch.sigmoid(model(torch.flip(img, [3]))), [3])[0,0]
            prob_map = ((p1 + p2) / 2.0).cpu().numpy()
            
            # --- APPLY HYSTERESIS ---
            final_mask = apply_hysteresis(prob_map, CFG.HIGH_THRESH, CFG.LOW_THRESH, CFG.MIN_PIXELS)
            
            # Save for Stats
            all_preds.append(final_mask.flatten())
            all_gts.append(msk.cpu().numpy().flatten().astype(np.uint8))
            
            if i in visual_indices: 
                visualize_result(img[0], msk[0], final_mask, os.path.basename(path[0]))

    # Stats
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_gts)
    
    f1 = f1_score(y_true, y_pred, zero_division=1)
    prec = precision_score(y_true, y_pred, zero_division=1)
    rec = recall_score(y_true, y_pred, zero_division=1)
    
    print("\n" + "="*40)
    print(f" HYSTERESIS AUDIT RESULTS")
    print("="*40)
    print(f" F1 SCORE:  {f1:.4f}")
    print(f" PRECISION: {prec:.4f}")
    print(f" RECALL:    {rec:.4f}")
    print("="*40)

if __name__ == "__main__":
    run_local_audit()

>>> Replicating Fold 0 Split...
>>> Loading Weights: TITAN_V2_UNLEASHED.pth
>>> Testing Hysteresis (High: 0.85, Low: 0.45) on 1026 images...


  0%|          | 0/1026 [00:00<?, ?it/s]

<Figure size 1500x500 with 3 Axes>

<Figure size 1500x500 with 3 Axes>

<Figure size 1500x500 with 3 Axes>

<Figure size 1500x500 with 3 Axes>

<Figure size 1500x500 with 3 Axes>


 HYSTERESIS AUDIT RESULTS
 F1 SCORE:  0.4202
 PRECISION: 0.7207
 RECALL:    0.2965
