In [None]:
# ==================================================================================
# TITAN-FAGT V4.3: STABLE MAC TRAINING (CLEAN UI + NO FREEZES)
# ==================================================================================
import os
import cv2
import sys
import random
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from tqdm import tqdm  # Standard TQDM
import albumentations as A
from albumentations.pytorch import ToTensorV2

try:
    import timm
except ImportError:
    print("CRITICAL: 'timm' library not found. Please run: pip install timm")
    sys.exit(1)

# --- 1. CONFIGURATION ---
class CFG:
    BASE_DIR = '/Users/chanduchitikam/recodai/recodai-luc-scientific-image-forgery-detection'
    TRAIN_IMG_PATH = os.path.join(BASE_DIR, 'train_images') 
    TRAIN_MASK_PATH = os.path.join(BASE_DIR, 'train_masks')
    WEIGHTS_NAME = "TITAN_FAGT_V4_FINAL.pth"
    
    seed = 42
    img_size = 384
    batch_size = 2        # Safe for Mac
    accum_iter = 8        # Effective Batch = 16
    epochs = 7
    lr = 2e-5             
    
    device = torch.device("mps") 
    model_name = 'swin_large_patch4_window12_384.ms_in22k'

# --- 2. ARCHITECTURE (MPS-SAFE) ---
class ConstrainedBayarConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, padding=2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)
        self.mask = torch.ones(kernel_size, kernel_size)
        self.mask[kernel_size//2, kernel_size//2] = 0
        self.register_buffer('filter_mask', self.mask)

    def forward(self, x):
        masked_weight = self.weight * self.filter_mask.to(x.device)
        sum_filter = masked_weight.sum(dim=(2, 3), keepdim=True)
        masked_weight = masked_weight / (sum_filter + 1e-7)
        center_idx = self.kernel_size // 2
        final_weight = masked_weight.clone()
        final_weight[:, :, center_idx, center_idx] = -1.0
        return F.conv2d(x, final_weight, self.bias, stride=1, padding=self.padding)

class ResSobel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_x = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
        self.conv_y = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        with torch.no_grad():
            self.conv_x.weight.data = sobel_x.view(1, 1, 3, 3).repeat(out_channels, in_channels, 1, 1)
            self.conv_y.weight.data = sobel_y.view(1, 1, 3, 3).repeat(out_channels, in_channels, 1, 1)

    def forward(self, x):
        gx = self.conv_x(x)
        gy = self.conv_y(x)
        return torch.sqrt(gx**2 + gy**2 + 1e-6)

class ForensicPyramid(nn.Module):
    def __init__(self, dim_list=[192, 384, 768]):
        super().__init__()
        self.bayar = ConstrainedBayarConv2d(3, 32)
        self.sobel = ResSobel(3, 32)
        self.layer1 = nn.Sequential(nn.Conv2d(64, dim_list[0], 3, stride=4, padding=1), nn.GroupNorm(32, dim_list[0]), nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(dim_list[0], dim_list[1], 3, stride=2, padding=1), nn.GroupNorm(32, dim_list[1]), nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(dim_list[1], dim_list[2], 3, stride=2, padding=1), nn.GroupNorm(32, dim_list[2]), nn.ReLU())

    def forward(self, x):
        b = self.bayar(x)
        s = self.sobel(x)
        raw = torch.cat([b, s], dim=1) 
        f1 = self.layer1(raw)
        f2 = self.layer2(f1)
        f3 = self.layer3(f2)
        return f1, f2, f3

class Titan_FAGT_Model(nn.Module):
    def __init__(self, model_name=CFG.model_name):
        super().__init__()
        print(f">>> Loading Backbone: {model_name}")
        self.swin = timm.create_model(model_name, pretrained=True, features_only=True)
        self.swin_ch = self.swin.feature_info.channels()
        self.forensic = ForensicPyramid(dim_list=self.swin_ch[:3])
        self.gate1 = nn.Conv2d(self.swin_ch[0]*2, self.swin_ch[0], 1)
        self.gate2 = nn.Conv2d(self.swin_ch[1]*2, self.swin_ch[1], 1)
        self.gate3 = nn.Conv2d(self.swin_ch[2]*2, self.swin_ch[2], 1)
        self.center = nn.Conv2d(self.swin_ch[-1], 512, 1)
        self.up3 = self._up_block(512 + self.swin_ch[-2], 256)
        self.up2 = self._up_block(256 + self.swin_ch[-3], 128)
        self.up1 = self._up_block(128 + self.swin_ch[-4], 64)
        self.final = nn.Conv2d(64, 1, 1)

    def _up_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.GroupNorm(32, out_c), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

    def forward(self, x):
        f1, f2, f3 = self.forensic(x)
        s_feats = self.swin(x)
        c1, c2, c3, c4 = s_feats
        if c4.shape[-1] == self.swin_ch[-1]:
            c1, c2, c3, c4 = [c.permute(0, 3, 1, 2) for c in [c1, c2, c3, c4]]
            
        if f3.shape[2:] != c3.shape[2:]: f3 = F.interpolate(f3, size=c3.shape[2:])
        i3 = self.gate3(torch.cat([c3, f3], dim=1))
        
        if f2.shape[2:] != c2.shape[2:]: f2 = F.interpolate(f2, size=c2.shape[2:])
        i2 = self.gate2(torch.cat([c2, f2], dim=1))
        
        if f1.shape[2:] != c1.shape[2:]: f1 = F.interpolate(f1, size=c1.shape[2:])
        i1 = self.gate1(torch.cat([c1, f1], dim=1))
        
        x = self.center(c4)
        x = F.interpolate(x, size=i3.shape[2:], mode='bilinear')
        x = torch.cat([x, i3], dim=1)
        x = self.up3(x)
        if x.shape[2:] != i2.shape[2:]: x = F.interpolate(x, size=i2.shape[2:])
        x = torch.cat([x, i2], dim=1)
        x = self.up2(x)
        if x.shape[2:] != i1.shape[2:]: x = F.interpolate(x, size=i1.shape[2:])
        x = torch.cat([x, i1], dim=1)
        x = self.up1(x)
        x = F.interpolate(x, size=(CFG.img_size, CFG.img_size), mode='bilinear')
        return self.final(x)

# --- 4. LOSS & UTILS ---
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
    def forward(self, inputs, targets):
        bce = self.bce(inputs, targets)
        inputs = torch.sigmoid(inputs).view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + 1) / (inputs.sum() + targets.sum() + 1)
        return 0.5 * bce + 0.5 * (1 - dice)

class F1ScoreMeter:
    def __init__(self): self.tp = self.fp = self.fn = 0
    def update(self, preds, targets):
        preds = (torch.sigmoid(preds) > 0.5).long()
        targets = (targets > 0.5).long()
        self.tp += (preds * targets).sum().item()
        self.fp += (preds * (1 - targets)).sum().item()
        self.fn += ((1 - preds) * targets).sum().item()
    def get_score(self):
        prec = self.tp / (self.tp + self.fp + 1e-7)
        rec = self.tp / (self.tp + self.fn + 1e-7)
        f1 = 2 * prec * rec / (prec + rec + 1e-7)
        return f1, prec, rec

class TitanDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row['image_path'])
        if img is None: img = np.zeros((384, 384, 3), dtype=np.uint8)
        else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if row['label'] == 1 and row['mask_path'] and os.path.exists(row['mask_path']):
            try:
                m = np.load(row['mask_path'])
                if m.ndim == 3: m = np.max(m, axis=2)
                mask = cv2.resize((m > 0).astype(np.float32), (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
            except: mask = np.zeros(img.shape[:2], dtype=np.float32)
        else: mask = np.zeros(img.shape[:2], dtype=np.float32)
        
        if self.transforms:
            aug = self.transforms(image=img, mask=mask)
            img, mask = aug['image'], aug['mask']
        return img, mask.unsqueeze(0)

# --- 5. TRAINING LOOP ---
def train_loop():
    print(f">>> TITAN V4.3: STABLE MAC TRAINING STARTED")
    
    # 1. Data Prep
    auth_files = sorted(glob.glob(os.path.join(CFG.TRAIN_IMG_PATH, 'authentic', '*.*')))
    forg_files = sorted(glob.glob(os.path.join(CFG.TRAIN_IMG_PATH, 'forged', '*.*')))
    print(f">>> Found {len(auth_files)} Authentic, {len(forg_files)} Forged.")
    
    data_list = []
    for f in auth_files: data_list.append({'image_path': f, 'mask_path': None, 'label': 0})
    for f in forg_files:
        base = os.path.basename(f).split('.')[0]
        m_path = os.path.join(CFG.TRAIN_MASK_PATH, base + '.npy')
        if not os.path.exists(m_path): m_path = os.path.join(CFG.TRAIN_MASK_PATH, base + '_mask.npy')
        data_list.append({'image_path': f, 'mask_path': m_path, 'label': 1})
    
    df = pd.DataFrame(data_list)
    kf = KFold(n_splits=5, shuffle=True, random_state=CFG.seed)
    train_idx, val_idx = next(kf.split(df))
    
    # 2. Transforms
    train_aug = A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    val_aug = A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # 3. Loaders (NUM_WORKERS=0 FIX)
    train_loader = DataLoader(TitanDataset(df.iloc[train_idx], train_aug), batch_size=CFG.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(TitanDataset(df.iloc[val_idx], val_aug), batch_size=CFG.batch_size, shuffle=False, num_workers=0)
    
    # 4. Model
    model = Titan_FAGT_Model().to(CFG.device)
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr)
    criterion = CombinedLoss()
    best_f1 = 0.0
    
    print(f"\n>>> TRAINING START ({CFG.epochs} Epochs)")
    
    for epoch in range(CFG.epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        # Clean TQDM Bar
        pbar = tqdm(train_loader, desc=f"Ep {epoch+1}/{CFG.epochs}", leave=False)
        
        for i, (images, masks) in enumerate(pbar):
            images, masks = images.to(CFG.device), masks.to(CFG.device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss = loss / CFG.accum_iter
            loss.backward()
            
            if (i + 1) % CFG.accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad()
            
            current_loss = loss.item() * CFG.accum_iter
            train_loss += current_loss
            
            # Update TQDM with loss
            pbar.set_postfix(loss=f"{current_loss:.4f}")
            
        # Validation
        model.eval()
        meter = F1ScoreMeter()
        val_loss_sum = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(CFG.device), masks.to(CFG.device)
                outputs = model(images)
                val_loss_sum += criterion(outputs, masks).item()
                meter.update(outputs, masks)
        
        f1, prec, rec = meter.get_score()
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | F1: {f1:.4f}")
        
        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), CFG.WEIGHTS_NAME)
            print(f">>> SAVED BEST MODEL: {CFG.WEIGHTS_NAME}")

if __name__ == "__main__":
    train_loop()