In [3]:
# ==================================================================================
# TITAN V5: MORPHOLOGICAL CLEANING (The "Cyan Fog" Killer)
# ==================================================================================
# 1. MODEL: Uses TITAN V2.5 (Swin) because it had the best F1 (0.56).
# 2. FIX: Applies "Morphological Opening" to physically erase noise.
#    - Step A: Erosion (Shrinks bright spots, killing small noise).
#    - Step B: Dilation (Grows the remaining big spots back to original size).
# 3. RESULT: Keeps the big forgeries, deletes the "fog" that ruins the score.
# ==================================================================================

import os
import cv2
import sys
import glob
import gc
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Offline IMM Check
try:
    import timm
except ImportError:
    sys.path.append('/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master')
    import timm

# --- 1. CONFIGURATION ---
class CFG:
    TEST_DIR = '/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images'
    SAMPLE_SUB = '/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv'
    
    # !!! IMPORTANT: USE THE V2.5 WEIGHTS (The ones that scored 0.56) !!!
    WEIGHTS_PATH = '/kaggle/input/recodai-model/TITAN_V2_UNLEASHED.pth'
    
    model_name = 'swin_base_patch4_window12_384'
    img_size = 384
    batch_size = 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # --- CLEANING SETTINGS ---
    # We use a lower threshold because we trust the Morphological Cleaner to remove noise.
    THRESHOLD = 0.50       # Standard threshold (trust the model's 0.56 logic)
    MIN_PIXELS = 200       # Minimum size AFTER cleaning
    KERNEL_SIZE = 3        # Size of the "Eraser" (3x3 pixels)

# --- 2. ARCHITECTURE (Must match V2.5 Swin) ---
class FrequencyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, kernel_size=3, padding=1)
    def forward(self, x):
        fft = torch.fft.fft2(x); fft_shift = torch.fft.fftshift(fft)
        B, C, H, W = x.shape; cy, cx = H // 2, W // 2
        y, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        y, x_grid = y.to(x.device), x_grid.to(x.device)
        mask = (torch.sqrt((y - cy)**2 + (x_grid - cx)**2) > 15).float().unsqueeze(0).unsqueeze(0)
        img_back = torch.abs(torch.fft.ifft2(torch.fft.ifftshift(fft_shift * mask)))
        return self.conv(img_back)

class GraphModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Linear(dim, dim)
        self.gcn = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))
    def forward(self, x):
        B, C, H, W = x.shape
        nodes = x.flatten(2).transpose(1, 2)
        q = self.proj(nodes); attn = F.softmax(torch.matmul(q, q.transpose(-2, -1)) / (C**0.5), dim=-1)
        out = self.gcn(torch.matmul(attn, nodes))
        return x + out.transpose(1, 2).reshape(B, C, H, W)

class FAGT_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model(CFG.model_name, pretrained=False, features_only=True)
        self.dims = self.encoder.feature_info.channels()
        last_dim = self.dims[-1]
        self.physics = FrequencyBlock(); self.graph = GraphModule(last_dim)
        self.fusion = nn.Conv2d(last_dim + 32, 256, 1)
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(32, 1, 1))
    def forward(self, x):
        enc_feats = self.encoder(x); deep_feats = enc_feats[-1]
        if deep_feats.ndim == 4 and deep_feats.shape[-1] == self.dims[-1]: deep_feats = deep_feats.permute(0, 3, 1, 2)
        phys_feats = self.physics(x); graph_feats = self.graph(deep_feats)
        phys_resized = F.interpolate(phys_feats, size=graph_feats.shape[-2:], mode='bilinear', align_corners=False)
        fused = self.fusion(torch.cat([graph_feats, phys_resized], dim=1))
        return F.interpolate(self.decoder(fused), size=x.shape[-2:], mode='bilinear', align_corners=False)

# --- 3. HELPER FUNCTIONS ---
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return json.dumps([int(x) for x in runs])

def apply_morphological_cleaning(prob_map, threshold=0.5, min_pixels=200):
    """
    Standard Morphological Opening to remove noise.
    1. Threshold
    2. Erode (Shrink) -> Kills small noise
    3. Dilate (Expand) -> Restores size of real objects
    """
    h, w = prob_map.shape
    
    # 1. Binary Threshold
    mask = (prob_map > threshold).astype(np.uint8)
    
    # 2. Define Kernel (The "Eraser")
    # A 3x3 kernel looks at neighbors. If neighbors are 0, it kills the pixel.
    kernel = np.ones((CFG.KERNEL_SIZE, CFG.KERNEL_SIZE), np.uint8)
    
    # 3. Morphological Opening (Erosion followed by Dilation)
    # This is the standard CV way to remove "salt" noise.
    clean_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    # 4. Final Size Check (Safety)
    if clean_mask.sum() < min_pixels:
        return np.zeros((h, w), dtype=np.uint8)
        
    return clean_mask

def load_model(path):
    print(f">>> Loading V2.5 Swin Model: {path}")
    model = FAGT_Model()
    if os.path.exists(path):
        state = torch.load(path, map_location=CFG.device)
        model.load_state_dict(state, strict=True)
        print(">>> Weights Loaded Successfully.")
    else:
        print("!!! CRITICAL: Weights not found.")
    model.to(CFG.device)
    model.eval()
    return model

# --- 4. EXECUTION ---
def run_inference():
    print(">>> TITAN V5: MORPHOLOGICAL CLEANING INFERENCE...")
    
    all_files = glob.glob(os.path.join(CFG.TEST_DIR, '**', '*'), recursive=True)
    id_map = {}
    for f in all_files:
        ext = os.path.splitext(f)[1].lower()
        if ext in ['.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp']:
            base = os.path.basename(f)
            digits = ''.join(filter(str.isdigit, os.path.splitext(base)[0]))
            if digits: id_map[str(int(digits))] = f 

    model = load_model(CFG.WEIGHTS_PATH)
    transform = A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()])

    preds_list = []
    print(f">>> Processing {len(id_map)} images...")
    
    for case_id, path in tqdm(id_map.items()):
        label = "authentic"
        try:
            image = cv2.imread(path)
            if image is None: image = np.array(Image.open(path).convert('RGB'))
            else: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            h, w = image.shape[:2]
            
            # TTA: Simple Flip
            img_t = transform(image=image)['image'].unsqueeze(0).to(CFG.device)
            with torch.no_grad():
                p1 = torch.sigmoid(model(img_t))[0,0]
                p2 = torch.flip(torch.sigmoid(model(torch.flip(img_t, [3]))), [3])[0,0]
                prob_map = ((p1 + p2) / 2.0).cpu().numpy()
            
            # Resize
            pred_full = cv2.resize(prob_map, (w, h))
            
            # --- APPLY MORPHOLOGICAL CLEANING ---
            # This is cleaner and faster than Hysteresis
            mask = apply_morphological_cleaning(
                pred_full, 
                threshold=CFG.THRESHOLD, 
                min_pixels=CFG.MIN_PIXELS
            )
            
            if mask.sum() > 0:
                label = rle_encode(mask)
                if label == "": label = "authentic"
                
        except Exception as e:
            pass 
            
        preds_list.append({"case_id": case_id, "annotation": label})

    try:
        sample_sub = pd.read_csv(CFG.SAMPLE_SUB)
    except:
        sample_sub = pd.DataFrame({'case_id': [45], 'annotation': ['authentic']})
        
    sample_sub['case_id'] = sample_sub['case_id'].astype(str)
    
    if len(preds_list) > 0:
        preds_df = pd.DataFrame(preds_list)
        preds_df['case_id'] = preds_df['case_id'].astype(str)
        submission = sample_sub[['case_id']].merge(preds_df, on='case_id', how='left')
        submission['annotation'] = submission['annotation'].fillna("authentic")
    else:
        submission = sample_sub
        submission['annotation'] = 'authentic'

    submission.to_csv('submission.csv', index=False)
    print(f">>> SUCCESS: Submission Saved. Rows: {len(submission)}")
    print(submission.head())

if __name__ == "__main__":
    run_inference()

>>> TITAN V5: MORPHOLOGICAL CLEANING INFERENCE...
>>> Loading V2.5 Swin Model: /kaggle/input/recodai-model/TITAN_V2_UNLEASHED.pth
!!! CRITICAL: Weights not found.
>>> Processing 0 images...


0it [00:00, ?it/s]

>>> SUCCESS: Submission Saved. Rows: 1
  case_id annotation
0      45  authentic
