In [None]:
"""
fetal_brain_pipeline_multiclass_with_crossattn_gradcam.py

Full pipeline (multi-class) — updated to include:
- Cross-Attention Fusion (CNN <-> ViT)
- Grad-CAM explanations (replaces LIME example at end)

Other pipeline parts are unchanged: Dataset, CBAM, DenseNet + Swin, training loop, metrics, saving.

Author: ChatGPT (integrated into your existing script)
"""

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm

from sklearn.metrics import roc_auc_score, f1_score, precision_recall_fscore_support
from sklearn.preprocessing import label_binarize

# -------------------------
# Config / hyperparameters
# -------------------------
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DATA_ROOT = "Classification_Dataset"
BATCH_SIZE = 2
IMAGE_SIZE = 224
NUM_EPOCHS = 10
LR = 1e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SAVE_PATH = "best_model_multiclass.pth"
NUM_CLASSES = 11  # will be overridden after reading CSV if different

# -------------------------
# Helper: read CSV, get classes
# -------------------------
def read_classes_from_csv(csv_path: str):
    # Use sep=None to auto-detect delimiter, engine='python' for flexibility
    df = pd.read_csv(csv_path, sep=None, engine="python")
    class_cols = list(df.columns[1:])
    return class_cols

# -------------------------
# Dataset
# -------------------------
class FetalUSDataset(Dataset):
    def __init__(self, split_dir: str, image_size=224, transform=None):
        csv_path = os.path.join(split_dir, "_classes.csv")
        assert os.path.exists(csv_path), f"CSV not found: {csv_path}"
        self.df = pd.read_csv(csv_path, sep=None, engine="python")
        self.dir = split_dir
        self.filenames = self.df.iloc[:, 0].astype(str).values
        self.labels = self.df.iloc[:, 1:].astype(int).values
        self.transform = transform if transform else self.default_transform(image_size)
        self.image_size = image_size

    def default_transform(self, image_size):
        return T.Compose([
            T.Resize((image_size, image_size)),
            T.Grayscale(num_output_channels=3),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fn = self.filenames[idx]
        img_path = os.path.join(self.dir, fn)
        if not os.path.exists(img_path):
            alt = os.path.join(self.dir, "images", fn)
            if os.path.exists(alt):
                img_path = alt
            else:
                raise FileNotFoundError(f"File {fn} not found in {self.dir} or {os.path.join(self.dir, 'images')}")
        img = Image.open(img_path).convert("RGB")
        x = self.transform(img)
        y = torch.tensor(np.argmax(self.labels[idx]), dtype=torch.long)  # one-hot -> index
        return x, y, fn

# -------------------------
# CBAM implementation
# -------------------------
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, bias=False):
        super().__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, max(in_planes // ratio, 1), 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(max(in_planes // ratio, 1), in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = 3 if kernel_size==7 else 1
        self.conv = BasicConv(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x_cat))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(in_planes, ratio)
        self.spatial_att = SpatialAttention(kernel_size)
    def forward(self, x):
        x_out = x * self.channel_att(x)
        return x_out * self.spatial_att(x_out)

# -------------------------
# Cross-Attention Fusion module
# -------------------------
class CrossAttentionFusion(nn.Module):
    """
    Bidirectional cross-attention between CNN spatial tokens and ViT tokens.
    - Projects both feature maps to same embed dim
    - Flattens spatial dims to token sequences
    - Applies MultiheadAttention in both directions
    - Pools refined token sequences to vectors for classification
    """
    def __init__(self, cnn_channels, vit_channels, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.cnn_proj = nn.Conv2d(cnn_channels, embed_dim, kernel_size=1)
        self.vit_proj = nn.Conv2d(vit_channels, embed_dim, kernel_size=1)
        self.attn_c2v = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.attn_v2c = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm_c = nn.LayerNorm(embed_dim)
        self.norm_v = nn.LayerNorm(embed_dim)

    def forward(self, cnn_feat_map, vit_feat_map):
        """
        cnn_feat_map: [B, Cc, Hc, Wc]
        vit_feat_map: [B, Cv, Hv, Wv]
        returns: cnn_vec [B, E], vit_vec [B, E]
        """
        B = cnn_feat_map.shape[0]
        # project
        c = self.cnn_proj(cnn_feat_map)   # [B, E, Hc, Wc]
        v = self.vit_proj(vit_feat_map)   # [B, E, Hv, Wv]

        # flatten spatial dims -> sequences: [B, N, E]
        c_seq = c.flatten(2).permute(0, 2, 1).contiguous()  # [B, Hc*Wc, E]
        v_seq = v.flatten(2).permute(0, 2, 1).contiguous()  # [B, Hv*Wv, E]

        # cross-attention: CNN queries Vit (so keys/values from vit)
        # MultiheadAttention batch_first=True accepts (B, N, E)
        c_refined, _ = self.attn_c2v(query=c_seq, key=v_seq, value=v_seq)
        v_refined, _ = self.attn_v2c(query=v_seq, key=c_seq, value=c_seq)

        # LayerNorm
        c_refined = self.norm_c(c_refined)
        v_refined = self.norm_v(v_refined)

        # Pool along token dimension -> vectors
        c_vec = c_refined.mean(dim=1)  # [B, E]
        v_vec = v_refined.mean(dim=1)  # [B, E]

        return c_vec, v_vec

# -------------------------
# Grad-CAM utility
# -------------------------
class GradCAM:
    """
    Simple Grad-CAM implementation hooking into the CNN feature map output (cnn_features).
    Use generate_cam(input_tensor, target_class) to obtain a heatmap (H_in x W_in).
    """
    def __init__(self, model: nn.Module, target_module_name="cnn_features", device="cpu"):
        self.model = model
        self.model.eval()
        self.device = device
        self.gradients = None
        self.activations = None
        
        # Find the target module
        target_module = None
        for name, module in self.model.named_modules():
            if name == target_module_name:
                target_module = module
                break
        if target_module is None:
             raise ValueError(f"Target module '{target_module_name}' not found in model.")
             
        # forward hook to save activations
        def forward_hook(module, input, output):
            self.activations = output.detach()
        def backward_hook(module, grad_in, grad_out):
            # grad_out is a tuple with gradients w.r.t. outputs
            self.gradients = grad_out[0].detach()
        # register
        self.f_hook = target_module.register_forward_hook(forward_hook)
        self.b_hook = target_module.register_backward_hook(backward_hook)

    def generate_cam(self, input_tensor: torch.Tensor, target_index: int = None):
        """
        input_tensor: single image tensor [1,3,H,W] normalized as model expects
        target_index: class index for which to compute Grad-CAM. If None, uses top predicted class.
        returns: heatmap numpy array scaled to input HxW (values 0..1)
        """
        self.model.zero_grad()
        input_tensor = input_tensor.to(self.device)
        logits = self.model(input_tensor)  # [1, C]
        if target_index is None:
            target_index = torch.argmax(logits, dim=1).item()
        score = logits[:, target_index]
        score.backward(retain_graph=True)

        # get gradients and activations
        grads = self.gradients  # [B, Cc, Hc, Wc]
        acts = self.activations  # [B, Cc, Hc, Wc]
        if grads is None or acts is None:
            raise RuntimeError("Gradients or activations not captured. Make sure forward() was called on model and hooks are registered.")

        # global-average-pool gradients -> weights [B, Cc]
        weights = torch.mean(grads.view(grads.shape[0], grads.shape[1], -1), dim=2)  # [B, Cc]

        # weighted combination of activations
        # We only operate on the first item in batch (B=1)
        cam = torch.zeros(acts.shape[2], acts.shape[3], device=self.device) # Hc, Wc
        for i in range(weights.shape[1]): # Cc
            cam += weights[0, i] * acts[0, i, :, :]

        cam = F.relu(cam)
        # normalize cam to [0,1]
        cam_min = cam.min()
        cam_max = cam.max()
        if cam_max > cam_min and cam_max > 0:
            cam = (cam - cam_min) / (cam_max - cam_min)
        elif cam_max > 0:
             cam = cam / cam_max

        # upsample to input size
        cam = cam.unsqueeze(0).unsqueeze(0)  # [1,1,Hc,Wc]
        cam_up = F.interpolate(cam, size=(input_tensor.shape[2], input_tensor.shape[3]), mode='bilinear', align_corners=False)
        cam_up = cam_up.squeeze().cpu().numpy()
        return cam_up

    def remove_hooks(self):
        self.f_hook.remove()
        self.b_hook.remove()

# -------------------------
# HybridNet: DenseNet + CBAM + Swin + Cross-Attn Fusion
# -------------------------
class HybridNet(nn.Module):
    def __init__(self, densenet_model_name="densenet121", swin_model_name="swin_base_patch4_window7_224", num_classes=11, image_size=224, device="cpu", cross_embed_dim=256, cross_heads=8):
        super().__init__()
        self.device = device

        # CNN branch (DenseNet)
        dnet = timm.create_model(densenet_model_name, pretrained=True)
        self.cnn_features = dnet.features            # will output [B, Cc, Hc, Wc]
        cnn_feat_channels = dnet.num_features

        # CBAM on CNN feature map
        self.cbam = CBAM(cnn_feat_channels)

        # global pool for fallback vector (not used for cross-attn fusion)
        self.cnn_gap = nn.AdaptiveAvgPool2d((1,1))

        # Swin branch
        # FIX: global_pool="" to output feature map
        self.swin = timm.create_model(swin_model_name, pretrained=True, num_classes=0, global_pool="")
        swin_feat_channels = self.swin.num_features
        self.swin_pool = nn.AdaptiveAvgPool2d((1,1)) # This is now unused, but harmless

        # Cross-attention fusion module
        self.cross_fusion = CrossAttentionFusion(cnn_channels=cnn_feat_channels, vit_channels=swin_feat_channels,
                                                  embed_dim=cross_embed_dim, num_heads=cross_heads)

        # Classifier - input dimension = 2 * cross_embed_dim (cnn_vec + vit_vec)
        fusion_in_features = cross_embed_dim * 2
        self.classifier = nn.Sequential(
            nn.Linear(fusion_in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # CNN branch
        cnn_feat = self.cnn_features(x)    # [B, Cc, Hc, Wc]
        cnn_feat = self.cbam(cnn_feat)     # apply CBAM

        # Swin branch
        # This outputs [B, H, W, C], e.g. [B, 7, 7, 1024] (channels-last)
        swin_feat = self.swin.forward_features(x)

        # --- FIX: Permute from channels-last [B, H, W, C] to channels-first [B, C, H, W] ---
        # Conv2d layers in cross_fusion expect [B, C, H, W]
        swin_feat = swin_feat.permute(0, 3, 1, 2).contiguous() # Shape becomes [B, 1024, 7, 7]
        
        # Cross-attention fusion -> returns pooled vectors
        cnn_vec, vit_vec = self.cross_fusion(cnn_feat, swin_feat)  # each [B, E]

        # fuse and classify
        fused = torch.cat([cnn_vec, vit_vec], dim=1)  # [B, 2*E]
        logits = self.classifier(fused)
        return logits

# -------------------------
# Training / validation (unchanged)
# -------------------------
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for imgs, labels, _ in tqdm(loader, desc="Train batches"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(loader.dataset)

@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    all_logits, all_labels, filenames = [], [], []
    for imgs, labels, fns in tqdm(loader, desc="Validation batches"):
        imgs = imgs.to(device)
        logits = model(imgs)
        all_logits.append(logits.detach().cpu().numpy())
        all_labels.append(labels.numpy())
        filenames.extend(fns)
    all_logits = np.vstack(all_logits)
    all_labels = np.hstack(all_labels)
    return all_labels, all_logits, filenames

# -------------------------
# Main
# -------------------------
def main():
    # splits
    train_dir, valid_dir, test_dir = [os.path.join(DATA_ROOT, s) for s in ["train","valid","test"]]
    
    # Ensure train directory and its _classes.csv exist before proceeding
    train_csv_path = os.path.join(train_dir, "_classes.csv")
    if not os.path.exists(train_csv_path):
        print(f"Error: Training CSV not found at {train_csv_path}")
        print("Please ensure DATA_ROOT is set correctly and your 'train' folder contains '_classes.csv'")
        return

    class_names = read_classes_from_csv(train_csv_path)
    global NUM_CLASSES
    NUM_CLASSES = len(class_names)
    print("Classes:", class_names)
    print(f"Found {NUM_CLASSES} classes.")

    # transforms
    train_transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ColorJitter(0.1,0.1),
        T.Grayscale(3),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    val_transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.Grayscale(3),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_ds = FetalUSDataset(train_dir, IMAGE_SIZE, train_transform)
    val_ds = FetalUSDataset(valid_dir, IMAGE_SIZE, val_transform)
    test_ds = FetalUSDataset(test_dir, IMAGE_SIZE, val_transform)

    print(f"Dataset loaded: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test examples.")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    # model, criterion, optimizer
    model = HybridNet(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE, device=DEVICE).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    print(f"Starting training for {NUM_EPOCHS} epochs on {DEVICE}...")
    best_val_f1 = -1.0
    for epoch in range(1, NUM_EPOCHS+1):
        print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        print(f"Train loss: {train_loss:.4f}")

        val_labels, val_logits, _ = validate(model, val_loader, DEVICE)
        val_preds = np.argmax(val_logits, axis=1)
        # Add zero_division=0 to f1_score to prevent errors if a class has no predictions
        f1_macro = f1_score(val_labels, val_preds, average="macro", zero_division=0)
        f1_micro = f1_score(val_labels, val_preds, average="micro", zero_division=0)
        print(f"Val F1 macro: {f1_macro:.4f} | F1 micro: {f1_micro:.4f}")

        if f1_macro > best_val_f1:
            best_val_f1 = f1_macro
            torch.save({"model_state": model.state_dict(), "class_names": class_names, "epoch": epoch}, MODEL_SAVE_PATH)
            print(f"Saved best model to {MODEL_SAVE_PATH} (F1 macro: {best_val_f1:.4f})")

        scheduler.step()

    # Test evaluation
    print(f"\nLoading best model from {MODEL_SAVE_PATH} for test evaluation...")
    if not os.path.exists(MODEL_SAVE_PATH):
        print("Error: Best model was not saved. Skipping test evaluation.")
        return
        
    ckpt = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    test_labels, test_logits, test_fns = validate(model, test_loader, DEVICE)
    test_preds = np.argmax(test_logits, axis=1)
    f1_macro = f1_score(test_labels, test_preds, average="macro", zero_division=0)
    f1_micro = f1_score(test_labels, test_preds, average="micro", zero_division=0)
    print(f"Test F1 macro: {f1_macro:.4f} | F1 micro: {f1_micro:.4f}")

    # Save test predictions
    test_probs = torch.softmax(torch.tensor(test_logits), dim=1).numpy()
    out_df = pd.DataFrame(test_probs, columns=class_names)
    out_df["filename"] = test_fns
    out_df = out_df[["filename"] + class_names]
    out_df.to_csv("test_predictions_multiclass.csv", index=False)
    print("Saved test_predictions_multiclass.csv")

    # -------------------------
    # Example Grad-CAM explanation (replaces previous LIME example)
    # -------------------------
    print("\nGenerating Grad-CAM example...")
    # We'll produce one Grad-CAM image for a test example matching the first class in class_names (if present).
    target_class_name = class_names[0]
    target_idx = 0

    # Create GradCAM with hook on cnn_features
    gradcam = GradCAM(model, target_module_name="cnn_features", device=DEVICE)

    # Find a test example of the chosen class (based on ground truth)
    found_example = False
    if len(test_ds) == 0:
        print("Test dataset is empty, cannot generate Grad-CAM.")
    else:
        for i in range(len(test_ds)):
            img_tensor, label_idx, fn = test_ds[i]
            if label_idx == target_idx:
                # Find the correct full path to the original image
                img_path = os.path.join(test_dir, fn)
                if not os.path.exists(img_path):
                    img_path = os.path.join(test_dir, "images", fn)
                
                if not os.path.exists(img_path):
                    print(f"Warning: Could not find original image {fn} for Grad-CAM. Skipping.")
                    continue

                # prepare input batch
                input_tensor = img_tensor.unsqueeze(0).to(DEVICE)  # [1,3,H,W]
                # generate cam for the true class
                cam = gradcam.generate_cam(input_tensor, target_index=target_idx)  # HxW numpy
                # read original image for overlay (un-normalize)
                orig_pil = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
                orig = np.array(orig_pil).astype(np.uint8)

                # create heatmap overlay
                cmap = plt.get_cmap("jet")
                heatmap = cmap(cam)[:, :, :3]  # HxWx3
                heatmap = (heatmap * 255).astype(np.uint8)

                overlay = (0.6 * orig.astype(float) + 0.4 * heatmap.astype(float)).astype(np.uint8)

                out_img = f"gradcam_{os.path.basename(fn)}_{target_class_name.strip()}.png"
                plt.imsave(out_img, overlay)
                print("Saved Grad-CAM visualization:", out_img)
                found_example = True
                break
        
        if not found_example:
            print(f"Could not find a test example for class '{target_class_name}' to generate Grad-CAM.")

    # remove hooks
    gradcam.remove_hooks()
    print("Pipeline complete.")

if __name__ == "__main__":
    main()