# PHASE 1: SETUP & DEFINITIONS

#### Cell 1: Install Dependencies
#### Installs everything needed for Mamba, visualization, and efficiency tracking.

In [None]:
# --- Cell 1: Installs ---
# 1. System dependencies for visualization
!apt-get install -y graphviz

# 2. Python libraries
# Uninstall potential conflicts first
!pip uninstall -y mamba-ssm causal-conv1d

print("--- Installing Deep Learning Extensions... ---")
!pip install causal-conv1d==1.5.0 --no-deps --no-build-isolation
!pip install mamba-ssm==2.2.4 --no-deps --no-build-isolation

print("--- Installing Analysis Tools... ---")
!pip install -q graphviz captum thop

print("\n‚úÖ Environment Ready.")

#### Cell 2: Imports & Configuration
#### All libraries and global variables in one place.

In [None]:
# --- Cell 2: Imports & Configuration ---
import os
import sys
import time
import glob
import random
import warnings
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from io import BytesIO

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import graphviz
from IPython.display import Image as IPImage, display

# PyTorch & ML
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import (
    roc_auc_score, accuracy_score, average_precision_score, 
    f1_score, confusion_matrix, classification_report, 
    roc_curve, auc, precision_recall_curve
)
from thop import profile # For FLOPs counting

# Mamba Import
try:
    from mamba_ssm.modules.mamba_simple import Mamba
except ImportError:
    print("‚ö†Ô∏è Warning: Mamba not imported. Ensure Cell 1 ran successfully.")

# Suppress warnings
warnings.filterwarnings('ignore')

# --- CONFIGURATION ---
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
DATA_ROOT = "/kaggle/input/200k-real-vs-ai-visuals-by-mbilal/my_real_vs_ai_dataset/my_real_vs_ai_dataset"
TRAIN_CSV = "/kaggle/input/200k-real-vs-ai-visuals-by-mbilal/train_labels.csv"
VAL_CSV   = "/kaggle/input/200k-real-vs-ai-visuals-by-mbilal/val_labels.csv"
TEST_CSV  = "/kaggle/input/200k-real-vs-ai-visuals-by-mbilal/test_labels.csv"
CKPT_DIR  = "/kaggle/working/checkpoints"

# External Data (Optional)
EXT_DATA_ROOT = "/kaggle/input/140k-real-and-fake-faces/real_vs_fake/real-vs-fake"
EXT_TEST_CSV  = "/kaggle/input/140k-real-and-fake-faces/test.csv"

os.makedirs(CKPT_DIR, exist_ok=True)
print(f"üöÄ Device: {DEVICE}")

#### Cell 3: Data Loading & Transformations
#### Defines the Dataset class and transformations.

In [None]:
# --- Cell 3: Data Loading ---

# 1. Transformations
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.0),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
    transforms.RandomAdjustSharpness(1.5, p=0.3),
    transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.05,0.05))], p=0.4),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

# 2. Dataset Class
class GravexDataset(Dataset):
    def __init__(self, csv_file, img_root, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_root = img_root
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_root, row['filename'])
        
        # Fallback for subfolders
        if not os.path.exists(img_path):
            for sub in ["real", "ai_images"]:
                alt_path = os.path.join(self.img_root, sub, row['filename'])
                if os.path.exists(alt_path):
                    img_path = alt_path
                    break
        
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception:
             # Return dummy for broken images
             return torch.zeros((3, 224, 224)), torch.tensor(-1.0, dtype=torch.float32)

        label = torch.tensor(int(row['label']), dtype=torch.float32)
        if self.transform:
            image = self.transform(image)
        return image, label

# 3. Initialize Loaders
train_dataset = GravexDataset(TRAIN_CSV, DATA_ROOT, transform=train_transform)
val_dataset   = GravexDataset(VAL_CSV, DATA_ROOT, transform=val_transform)
test_dataset  = GravexDataset(TEST_CSV, DATA_ROOT, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Data Ready: Train({len(train_dataset)}), Val({len(val_dataset)}), Test({len(test_dataset)})")

#### Cell 4: Model Architecture
#### Defines the ResNet+Mamba model.

In [None]:
# --- Cell 4: Model Architecture ---

class ResNetMambaDetector(nn.Module):
    def __init__(self, embed_dim=128, d_state=64, device=None):
        super().__init__()
        self.backbone = ResNetBackbone(out_dim=embed_dim)
        # Pass device to Mamba if provided (needed for CPU/GPU switching in efficiency tests)
        self.mamba = Mamba(d_model=embed_dim, d_state=d_state, expand=1, d_conv=4, device=device)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 1)
        )
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x):
        emb = self.backbone(x)           # (B, 128)
        seq = emb.unsqueeze(1)           # (B, 1, 128)
        out = self.mamba(seq).squeeze(1) # (B, 128)
        return self.classifier(out).squeeze(-1)

# Instantiate Model
model = ResNetMambaDetector(embed_dim=128, d_state=64, device=DEVICE).to(DEVICE)
print("‚úÖ Model Architecture Defined.")

#### Cell 5: Visualize Architecture
#### Generates the Graphviz diagram.

In [None]:
# --- Cell 5: Visualize Architecture ---
def generate_architecture_diagram():
    dot = graphviz.Digraph(comment='ResNet-Mamba Detector')
    dot.attr(rankdir='LR', label='ResNet-Mamba Architecture', fontsize='20')
    
    # Styles
    styles = {
        'input': {'shape': 'box', 'style': 'filled', 'fillcolor': '#a9def9'},
        'backbone': {'shape': 'box', 'style': 'filled', 'fillcolor': '#e4c1f9'},
        'mamba': {'shape': 'box', 'style': 'filled', 'fillcolor': '#f694c1'},
        'head': {'shape': 'box', 'style': 'filled', 'fillcolor': '#c3f73a'},
        'op': {'shape': 'oval', 'style': 'filled', 'fillcolor': 'lightgrey'}
    }

    dot.node('In', 'Input\n(B, 3, 224, 224)', **styles['input'])
    dot.node('RN', 'ResNet-18\nBackbone', **styles['backbone'])
    dot.node('Pr', 'Linear Proj\n(512->128)', **styles['backbone'])
    
    with dot.subgraph(name='cluster_mamba') as c:
        c.attr(label='Mamba Block', style='dashed')
        c.node('Unsq', 'Unsqueeze', **styles['op'])
        c.node('Mamba', 'Mamba Layer', **styles['mamba'])
        c.node('Sq', 'Squeeze', **styles['op'])
        c.edge('Unsq', 'Mamba')
        c.edge('Mamba', 'Sq')

    with dot.subgraph(name='cluster_head') as c:
        c.attr(label='Classifier', style='dashed')
        c.node('FC', 'Linear+ReLU', **styles['head'])
        c.node('Out', 'Logit', **styles['head'])
        c.edge('FC', 'Out')

    dot.edge('In', 'RN')
    dot.edge('RN', 'Pr')
    dot.edge('Pr', 'Unsq')
    dot.edge('Sq', 'FC')

    try:
        dot.render('model_arch', format='png', view=False)
        display(IPImage(filename='model_arch.png'))
    except Exception as e:
        print(f"Graphviz Error: {e}")

generate_architecture_diagram()

# PHASE 2: TRAINING

#### Cell 6: Training Loop & Helpers
#### Runs the training process and saves checkpoints.

In [None]:
# --- Cell 6: Training Loop ---

# Setup
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

history = {'train_loss': [], 'train_auc': [], 'val_loss': [], 'val_auc': [], 'val_acc': []}
best_val_auc = 0.0

def save_checkpoint(epoch, model, val_auc, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'val_auc': val_auc
    }, path)

print(f"üöÄ Starting Training for {EPOCHS} Epochs...")

for epoch in range(1, EPOCHS + 1):
    # Train
    model.train()
    running_loss = 0.0
    preds, trues = [], []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False)
    for xb, yb in pbar:
        xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
            out = model(xb)
            loss = criterion(out, yb)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * xb.size(0)
        preds.extend(torch.sigmoid(out).detach().cpu().numpy().tolist())
        trues.extend(yb.detach().cpu().numpy().tolist())
    
    train_loss = running_loss / len(train_loader.dataset)
    train_auc = roc_auc_score(trues, preds)

    # Validate
    model.eval()
    val_loss = 0.0
    v_preds, v_trues = [], []
    
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
                out = model(xb)
                loss = criterion(out, yb)
            val_loss += loss.item() * xb.size(0)
            v_preds.extend(torch.sigmoid(out).cpu().numpy().tolist())
            v_trues.extend(yb.cpu().numpy().tolist())

    val_loss /= len(val_loader.dataset)
    val_auc = roc_auc_score(v_trues, v_preds)
    val_acc = accuracy_score(v_trues, (np.array(v_preds) > 0.5).astype(int))

    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, AUC={train_auc:.4f} | Val Loss={val_loss:.4f}, AUC={val_auc:.4f}")

    # Log History
    history['train_loss'].append(train_loss)
    history['train_auc'].append(train_auc)
    history['val_loss'].append(val_loss)
    history['val_auc'].append(val_auc)
    history['val_acc'].append(val_acc)

    # Save Best
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        save_checkpoint(epoch, model, val_auc, os.path.join(CKPT_DIR, f"best_model_auc{val_auc:.4f}.pth"))

print("‚úÖ Training Complete.")

#### Cell 7: Plot Training History
#### Interpolates history to handle potential gaps and plots curves.

In [None]:
# --- Cell 7: Training History ---
if len(history['train_loss']) > 0:
    df_hist = pd.DataFrame(history)
    df_hist = df_hist.interpolate(method='linear', limit_direction='both') # Fill gaps
    
    plt.figure(figsize=(14, 5))
    
    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(df_hist.index + 1, df_hist['train_loss'], 'b-', label='Train Loss')
    plt.plot(df_hist.index + 1, df_hist['val_loss'], 'r--', label='Val Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True, linestyle=':')

    # Metrics
    plt.subplot(1, 2, 2)
    plt.plot(df_hist.index + 1, df_hist['train_auc'], 'b-', label='Train AUC')
    plt.plot(df_hist.index + 1, df_hist['val_auc'], 'r--', label='Val AUC')
    plt.plot(df_hist.index + 1, df_hist['val_acc'], 'g:', label='Val Acc')
    plt.title('Metric Curves')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True, linestyle=':')
    
    plt.show()
else:
    print("No history to plot yet.")

# PHASE 3: MASTER INFERENCE

#### Cell 8: Master Inference (Run Once, Use Everywhere)
#### Loads the best model and runs inference on the Test Set. Stores results in memory for all subsequent plots.

In [None]:
# --- Cell 8: MASTER INFERENCE ---
# We run this ONCE so we don't have to re-predict for every plot.

# 1. Load Best Model
model_files = glob.glob(os.path.join(CKPT_DIR, "best_epoch*_auc*.pth"))
if model_files:
    best_path = max(model_files, key=lambda x: float(x.split('auc')[1].split('.pth')[0].replace('_', '.')))
    print(f"üì• Loading Best Model: {best_path}")
    
    checkpoint = torch.load(best_path, map_location=DEVICE, weights_only=False)
    state = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
    model.load_state_dict(state)
    model.to(DEVICE).eval()
else:
    print("‚ö†Ô∏è No checkpoint found. Using current model state.")

# 2. Run Inference on 200k Test Set
print("üß™ Running Inference on 200k Test Set...")
all_preds_200k, all_labels_200k = [], []

with torch.no_grad():
    for imgs, lbls in tqdm(test_loader, desc="Inference"):
        valid = (lbls != -1)
        if not valid.any(): continue
        
        imgs, lbls = imgs[valid].to(DEVICE), lbls[valid].to(DEVICE)
        
        with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
            out = model(imgs)
            
        probs = torch.sigmoid(out).squeeze().cpu().numpy().tolist()
        # Handle single-item batch case where tolist() returns float
        if isinstance(probs, float): probs = [probs]
            
        all_preds_200k.extend(probs)
        all_labels_200k.extend(lbls.cpu().numpy().tolist())

# Convert to numpy for easy plotting later
all_labels_200k = np.array(all_labels_200k)
all_preds_probs_200k = np.array(all_preds_200k)
all_preds_classes_200k = (all_preds_probs_200k > 0.5).astype(int)

print(f"\n‚úÖ Inference Done. Loaded {len(all_labels_200k)} predictions into memory.")

#### Cell 9: External Inference (Optional)
#### Runs inference on the secondary 140k dataset for generalization checks.

In [None]:
# --- Cell 9: External Inference (140k Dataset) ---
if os.path.exists(EXT_TEST_CSV):
    print("üß™ Running Inference on External 140k Dataset...")
    
    # Helper Dataset class for external data
    class ExternalDataset(Dataset):
        def __init__(self, csv, root, transform):
            self.df = pd.read_csv(csv)
            self.root = root
            self.transform = transform
            # Adjust path in CSV
            self.df['path'] = self.df['path'].apply(lambda x: os.path.join(root, x))
        def __len__(self): return len(self.df)
        def __getitem__(self, idx):
            path = self.df.iloc[idx]['path']
            lbl = self.df.iloc[idx]['label']
            try:
                img = Image.open(path).convert('RGB')
                if self.transform: img = self.transform(img)
                return img, torch.tensor(lbl, dtype=torch.float32)
            except:
                return torch.zeros((3,224,224)), torch.tensor(-1.0)

    ext_dataset = ExternalDataset(EXT_TEST_CSV, EXT_DATA_ROOT, val_transform)
    ext_loader = DataLoader(ext_dataset, batch_size=32, shuffle=False, num_workers=2)
    
    all_preds_ext, all_labels_ext = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(ext_loader, desc="Ext Inference"):
            valid = (lbls != -1)
            if not valid.any(): continue
            imgs = imgs[valid].to(DEVICE)
            with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
                out = model(imgs)
            probs = torch.sigmoid(out).squeeze().cpu().numpy().tolist()
            if isinstance(probs, float): probs = [probs]
            all_preds_ext.extend(probs)
            all_labels_ext.extend(lbls[valid].numpy().tolist())
            
    all_labels_ext = np.array(all_labels_ext)
    all_preds_probs_ext = np.array(all_preds_ext)
    print(f"‚úÖ External Inference Done: {len(all_labels_ext)} samples.")
else:
    print("‚ö†Ô∏è External dataset not found. Skipping.")

# PHASE 4: ANALYSIS & VISUALIZATION

#### Cell 10: Standard Performance Metrics
#### ROC, Confusion Matrix, Classification Report.

In [None]:
# --- Cell 10: Standard Metrics ---
print("--- Test Results (200k) ---")
acc = accuracy_score(all_labels_200k, all_preds_classes_200k)
auc_score = roc_auc_score(all_labels_200k, all_preds_probs_200k)
print(f"Accuracy: {acc:.4f} | AUC: {auc_score:.4f}")
print(classification_report(all_labels_200k, all_preds_classes_200k, target_names=['Fake', 'Real']))

plt.figure(figsize=(16, 6))

# ROC
fpr, tpr, _ = roc_curve(all_labels_200k, all_preds_probs_200k)
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='orange', lw=2, label=f'AUC={auc_score:.4f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.title('ROC Curve')
plt.legend()

# Confusion Matrix
cm = confusion_matrix(all_labels_200k, all_preds_classes_200k)
plt.subplot(1, 2, 2)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Fake', 'Real'], yticklabels=['Fake', 'Real'])
plt.title('Confusion Matrix')

plt.show()

#### Cell 11: Advanced Metric Visualizations
#### PR Curve, Score Distributions, Pie Chart, Threshold Analysis.

# --- Cell 11: Advanced Metrics ---

# 1. PR Curve
prec, rec, _ = precision_recall_curve(all_labels_200k, all_preds_probs_200k)
ap = average_precision_score(all_labels_200k, all_preds_probs_200k)

plt.figure(figsize=(18, 5))

# PR Curve
plt.subplot(1, 3, 1)
plt.plot(rec, prec, 'b-', label=f'AP={ap:.4f}')
plt.title('Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.grid(True, ls=':')

# Score Distribution
plt.subplot(1, 3, 2)
sns.histplot(all_preds_probs_200k[all_labels_200k==0], color='red', label='Fake', kde=True, stat='density')
sns.histplot(all_preds_probs_200k[all_labels_200k==1], color='green', label='Real', kde=True, stat='density')
plt.title('Score Distribution')
plt.legend()

# Pie Chart
plt.subplot(1, 3, 3)
tn, fp, fn, tp = cm.ravel()
plt.pie([tn, tp, fp, fn], labels=['TN (Fake)', 'TP (Real)', 'FP (Fake)', 'FN (Real)'], 
        colors=['#457B9D', '#A8DADC', '#E63946', '#F4A261'], explode=(0,0,0.1,0.1), autopct='%1.1f%%')
plt.title('Outcome Breakdown')

plt.tight_layout()
plt.show()

# 4. Box Plot
plt.figure(figsize=(8, 4))
df_box = pd.DataFrame({'Score': all_preds_probs_200k, 'Label': all_labels_200k})
df_box['Label'] = df_box['Label'].map({0:'Fake', 1:'Real'})
sns.boxplot(x='Label', y='Score', data=df_box, palette={'Fake':'#E63946', 'Real':'#457B9D'})
plt.title('Score Spread by Class')
plt.grid(True, axis='y', ls=':')
plt.show()

#### Cell 12: Qualitative Analysis (Errors & Uncertainty)
#### Top 16 Errors and Most Uncertain Images.

In [None]:
# --- Cell 12: Qualitative Analysis ---
def get_img(idx):
    t_img, lbl = test_dataset[idx]
    img = t_img.permute(1,2,0).cpu().numpy()
    img = np.clip((img * 0.5) + 0.5, 0, 1)
    return img, lbl.item()

errors = np.abs(all_labels_200k - all_preds_probs_200k)
indices = np.arange(len(all_labels_200k))

# 1. Worst Mistakes (High Confidence Errors)
fp_mask = (all_labels_200k == 0)
top_fp = indices[fp_mask][np.argsort(all_preds_probs_200k[fp_mask])[::-1]][:8] # Fake called Real

fn_mask = (all_labels_200k == 1)
top_fn = indices[fn_mask][np.argsort(all_preds_probs_200k[fn_mask])][:8] # Real called Fake

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Worst Errors: Top Row=False Positives, Bottom Row=False Negatives")

for i, ax in enumerate(axes.flatten()):
    if i < 8: idx = top_fp[i]; type_err = "FP (Fake->Real)"
    else:     idx = top_fn[i-8]; type_err = "FN (Real->Fake)"
    
    img, lbl = get_img(idx)
    score = all_preds_probs_200k[idx]
    ax.imshow(img)
    ax.set_title(f"{type_err}\nScore: {score:.4f}", color='red')
    ax.axis('off')
plt.show()

# 2. Uncertainty (Scores near 0.5)
uncertain_mask = (all_preds_probs_200k > 0.45) & (all_preds_probs_200k < 0.55)
uncertain_idx = indices[uncertain_mask][:8]

if len(uncertain_idx) > 0:
    fig, axes = plt.subplots(1, len(uncertain_idx), figsize=(16, 3))
    fig.suptitle("Most Uncertain Images (Score ~ 0.5)")
    for i, idx in enumerate(uncertain_idx):
        img, lbl = get_img(idx)
        axes[i].imshow(img)
        axes[i].set_title(f"True: {lbl}\nScore: {all_preds_probs_200k[idx]:.4f}")
        axes[i].axis('off')
    plt.show()

#### Cell 13: Explainability (Saliency & FFT)
#### Uses Captum for heatmaps and standard FFT analysis.

In [None]:
# --- Cell 13: Explainability ---
from captum.attr import IntegratedGradients, visualization as viz

# 1. Saliency Maps
# Wrapper to handle shape mismatch in Captum
class ModelWrapper(nn.Module):
    def __init__(self, m): super().__init__(); self.m = m
    def forward(self, x): 
        out = self.m(x)
        return out.unsqueeze(-1) if out.dim()==1 else out

print("Generating Saliency Maps...")
model.eval()
wrapped = ModelWrapper(model).to(DEVICE)
ig = IntegratedGradients(wrapped)

viz_indices = random.sample(range(len(test_dataset)), 4)
fig, axes = plt.subplots(4, 2, figsize=(8, 16))

for i, idx in enumerate(viz_indices):
    t_img, lbl = test_dataset[idx]
    inp = t_img.unsqueeze(0).to(DEVICE)
    attr = ig.attribute(inp, target=0, n_steps=50)
    
    orig = t_img.permute(1,2,0).numpy()
    orig = np.clip((orig * 0.5) + 0.5, 0, 1)
    hm = attr.cpu().squeeze().permute(1,2,0).detach().numpy()
    
    axes[i][0].imshow(orig)
    axes[i][0].set_title(f"Label: {lbl.item()}")
    axes[i][0].axis('off')
    
    viz.visualize_image_attr(hm, orig, method="blended_heat_map", sign="all", 
                             show_colorbar=True, plt_fig_axis=(fig, axes[i][1]), use_pyplot=False)
plt.tight_layout(); plt.show()

# 2. FFT Analysis
print("Generating FFT Spectrum...")
def get_avg_fft(label_cls, n=300):
    accum = None
    count = 0
    for i in range(len(test_dataset)):
        if count >= n: break
        t_img, lbl = test_dataset[i]
        if lbl.item() != label_cls: continue
        
        img = (t_img.cpu().numpy() * 0.5) + 0.5
        img = (np.clip(img,0,1).transpose(1,2,0) * 255).astype(np.uint8)
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        
        f = np.fft.fft2(gray)
        fshift = np.fft.fftshift(f)
        mag = 20 * np.log(np.abs(fshift) + 1)
        
        accum = mag if accum is None else accum + mag
        count += 1
    return accum / count

plt.figure(figsize=(10, 5))
plt.subplot(1,2,1); plt.imshow(get_avg_fft(0), cmap='magma'); plt.title('Avg FFT (Fake)')
plt.subplot(1,2,2); plt.imshow(get_avg_fft(1), cmap='magma'); plt.title('Avg FFT (Real)')
plt.show()

#### Cell 14: Robustness & Benchmarking
#### Corruptions check and Ablation comparison.

In [None]:
# --- Cell 14.0: Baseline Training & Inference ---

print("‚è≥ STARTING BASELINE TRAINING (ResNet-Only) for Comparison...")

# 1. Define Baseline Architecture (Exact same, just without Mamba)
class ResNetBaseline(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        # Use exact same backbone setup
        self.net = models.resnet18(weights=None) 
        self.net.fc = nn.Identity()
        self.proj = nn.Linear(512, out_dim)
        
        # Init backbone weights same as main model
        nn.init.kaiming_normal_(self.proj.weight, nonlinearity='relu')
        if self.proj.bias is not None: nn.init.zeros_(self.proj.bias)

        # Same Classifier Head
        self.classifier = nn.Sequential(
            nn.Linear(out_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
        # Init classifier weights same as main model
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x):
        feat = self.proj(self.net(x)) # (B, 128)
        # Skip Mamba Block entirely
        return self.classifier(feat).squeeze(-1)

# 2. Setup Baseline Training
model_baseline = ResNetBaseline(out_dim=128).to(DEVICE)
opt_baseline = optim.AdamW(model_baseline.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
crit_baseline = nn.BCEWithLogitsLoss()
scaler_bl = torch.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

# 3. Training Loop (Exact duplicate of main loop for fairness)
# We train for the same EPOCHS to ensure fair comparison
for epoch in range(1, EPOCHS + 1):
    model_baseline.train()
    for xb, yb in tqdm(train_loader, desc=f"Baseline Epoch {epoch}", leave=False):
        xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
        opt_baseline.zero_grad(set_to_none=True)
        
        with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
            out = model_baseline(xb)
            loss = crit_baseline(out, yb)
        
        scaler_bl.scale(loss).backward()
        scaler_bl.step(opt_baseline)
        scaler_bl.update()

print("‚úÖ Baseline Training Complete.")

# 4. Run Inference for Baseline
print("üß™ Evaluating Baseline on Test Set...")
bl_preds, bl_trues = [], []
model_baseline.eval()

with torch.no_grad():
    for xb, yb in tqdm(test_loader, desc="Baseline Inference"):
        valid = (yb != -1)
        if not valid.any(): continue
        xb = xb[valid].to(DEVICE)
        
        with torch.amp.autocast(device_type='cuda', enabled=(DEVICE.type == "cuda")):
            out = model_baseline(xb)
        
        probs = torch.sigmoid(out).squeeze().cpu().numpy().tolist()
        if isinstance(probs, float): probs = [probs]
        
        bl_preds.extend(probs)
        bl_trues.extend(yb[valid].cpu().numpy().tolist())

# 5. Calculate Metrics
bl_trues = np.array(bl_trues)
bl_probs = np.array(bl_preds)
bl_classes = (bl_probs > 0.5).astype(int)

bl_auc = roc_auc_score(bl_trues, bl_probs)
bl_ap  = average_precision_score(bl_trues, bl_probs)
bl_f1  = f1_score(bl_trues, bl_classes)
bl_acc = accuracy_score(bl_trues, bl_classes)

# 6. Store in the variable needed for Cell 14.1
baseline_metrics = [bl_auc, bl_ap, bl_f1, bl_acc]

print(f"\n--- Baseline Results (ResNet-Only) ---")
print(f"AUC: {bl_auc:.4f} | AP: {bl_ap:.4f} | F1: {bl_f1:.4f} | Acc: {bl_acc:.4f}")
print("Values saved to 'baseline_metrics' for Plot 14.")

# Save the Baseline Model just in case
torch.save(model_baseline.state_dict(), os.path.join(CKPT_DIR, "baseline_resnet_final.pth"))
print("üíæ Baseline model saved to disk.")

In [None]:
# --- Cell 14.1: Robustness Analysis & Ablation Study ---

# 1. Robustness (Blur/Noise)
levels = [0, 1, 3, 5]
blur_aucs = []

print("Running Robustness Check (Blur)...")
for k in levels:
    if k == 0: 
        blur_aucs.append(auc_score)
        continue
        
    preds, trues = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(test_loader, desc=f"Blur k={k}", leave=False):
            # Apply blur on tensor directly
            imgs = transforms.functional.gaussian_blur(imgs, kernel_size=(k*2+1, k*2+1))
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            out = model(imgs)
            preds.extend(torch.sigmoid(out).cpu().numpy().tolist())
            trues.extend(lbls.cpu().numpy().tolist())
    blur_aucs.append(roc_auc_score(trues, preds))

plt.figure(figsize=(6, 4))
plt.plot(levels, blur_aucs, 'o-', color='purple')
plt.title('Robustness to Blur')
plt.xlabel('Blur Intensity')
plt.ylabel('AUC')
plt.grid(True)
plt.show()

# 2. Ablation Study
our_metrics = [auc_score, ap, f1_score(all_labels_200k, all_preds_classes_200k), acc]
labels = ['AUC', 'AP', 'F1', 'Acc']

x = np.arange(len(labels))
width = 0.35

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, our_metrics, width, label='ResNet+Mamba (Ours)', color='navy')
plt.bar(x + width/2, baseline_metrics, width, label='ResNet Baseline', color='skyblue')
plt.xticks(x, labels)
plt.title('Ablation Study: Ours vs Baseline')
plt.ylim(0.8, 1.0)
plt.legend()
plt.grid(axis='y', ls=':')
plt.show()

#### Cell 15: Efficiency Analysis
#### Calculates Latency and FLOPs.

In [None]:
# --- Cell 15: Efficiency Analysis ---
def measure_efficiency(model, device):
    dummy = torch.randn(1, 3, 224, 224).to(device)
    
    # 1. Latency
    model.eval()
    for _ in range(50): _ = model(dummy) # Warmup
    
    start = time.time()
    for _ in range(100): _ = model(dummy)
    torch.cuda.synchronize()
    latency = ((time.time() - start) / 100) * 1000 # ms
    
    # 2. FLOPs (on CPU to allow counting)
    try:
        model_cpu = model.__class__(embed_dim=128, d_state=64, device='cpu')
        flops, params = profile(model_cpu, inputs=(torch.randn(1, 3, 224, 224),), verbose=False)
    except:
        flops, params = 0, 0
        
    return latency, flops/1e9, params/1e6

lat, gflops, params = measure_efficiency(model, DEVICE)

print(f"--- Efficiency Stats ---")
print(f"Latency: {lat:.2f} ms")
print(f"GFLOPs:  {gflops:.2f}")
print(f"Params:  {params:.2f} M")

df_eff = pd.DataFrame([
    {'Metric': 'Latency (ms)', 'Value': lat},
    {'Metric': 'GFLOPs', 'Value': gflops},
    {'Metric': 'Params (M)', 'Value': params}
])
sns.barplot(x='Metric', y='Value', data=df_eff)
plt.title('Model Efficiency')
plt.show()