<a href="https://colab.research.google.com/github/AyeshaAnzerBCIT/Multisource/blob/main/MLP%2C_GMU%2C_crossmodel%2C_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# models/mlp_fusion.py (Optimized for Google Colab Pro+ GPU)
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import re
import gcsfs
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Speed tweaks for Colab
torch.backends.cudnn.benchmark = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"

KEY_PATH = "Key.json"
fs = gcsfs.GCSFileSystem(token=KEY_PATH)

# Define features
eeg_features = ["Mobility", "Complexity", "Spectral_Entropy"]
eye_features = [
    "mean_pupil_size", "std_pupil_size",
    "mean_latency", "std_latency",
    "mean_gaze_vector", "std_gaze_vector"
]
beh_features = ["mean_rt", "accuracy"]

# Data preparation pipeline
def extract_patient_id(name):
    match = re.findall(r'(A\d{5,})', str(name))
    return match[0] if match else None

def load_csv(fs, path, id_col):
    with fs.open(path, 'r') as f:
        df = pd.read_csv(f)
    df["patient_id"] = df[id_col].apply(extract_patient_id)
    return df

def group_features(df, features):
    return df.groupby("patient_id")[features].mean().reset_index()

def impute_missing_rows(modality_df, features, missing_ids):
    avg_values = modality_df[features].mean()
    imputed = pd.DataFrame([{**{"patient_id": pid}, **avg_values.to_dict()} for pid in missing_ids])
    return pd.concat([modality_df, imputed], ignore_index=True)

eeg_df = load_csv(fs, "gs://eegchild/processed_features/merged_features.csv", "file_name")
eye_df = load_csv(fs, "gs://eegchild/processed_asd_features.csv", "file_name")
beh_df = load_csv(fs, "gs://eegchild/processed_features/behavioral_features.csv", "file")
label_df = load_csv(fs, "gs://eegchild/MIPDB_PublicFile.csv", "ID")

label_df = label_df.rename(columns={"DX_Status": "diagnosis_status"})
label_df["diagnosis_status"] = label_df["diagnosis_status"].replace({2: 1})

for df in [eeg_df, eye_df, beh_df, label_df]:
    df["patient_id"] = df["patient_id"].astype(str)

grouped_eeg = group_features(eeg_df, eeg_features)
grouped_eye = group_features(eye_df, eye_features)
grouped_beh = group_features(beh_df, beh_features)

expected_ids = set(label_df["patient_id"])
grouped_eeg = impute_missing_rows(grouped_eeg, eeg_features, expected_ids - set(grouped_eeg["patient_id"]))
grouped_eye = impute_missing_rows(grouped_eye, eye_features, expected_ids - set(grouped_eye["patient_id"]))
grouped_beh = impute_missing_rows(grouped_beh, beh_features, expected_ids - set(grouped_beh["patient_id"]))

eeg_merged = grouped_eeg.merge(label_df, on="patient_id")
eye_merged = grouped_eye.merge(label_df, on="patient_id")
beh_merged = grouped_beh.merge(label_df, on="patient_id")

common_ids = set(eeg_merged["patient_id"]) & set(eye_merged["patient_id"]) & set(beh_merged["patient_id"])
eeg_final = eeg_merged[eeg_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
eye_final = eye_merged[eye_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
beh_final = beh_merged[beh_merged["patient_id"].isin(common_ids)].reset_index(drop=True)

merged_df = pd.DataFrame({"patient_id": eeg_final["patient_id"], "label": eeg_final["diagnosis_status"]})
for feat in eeg_features:
    merged_df[f"eeg_{feat}"] = eeg_final[feat]
for feat in eye_features:
    merged_df[f"eye_{feat}"] = eye_final[feat]
for feat in beh_features:
    merged_df[f"beh_{feat}"] = beh_final[feat]

class_0 = merged_df[merged_df["label"] == 0]
class_1 = merged_df[merged_df["label"] == 1]
balanced_df = pd.concat([
    class_0.sample(n=63, replace=True, random_state=42),
    class_1.sample(n=63, replace=True, random_state=42)
]).sample(frac=1, random_state=42).reset_index(drop=True)

all_feature_cols = [c for c in balanced_df.columns if c.startswith("eeg_") or c.startswith("eye_") or c.startswith("beh_")]
for col in all_feature_cols:
    balanced_df[col] = pd.to_numeric(balanced_df[col], errors='coerce').fillna(0).astype(np.float32)

class MultimodalDataset(Dataset):
    def __init__(self, df):
        self.df = df.copy()
        self.eeg_cols = [c for c in self.df.columns if c.startswith("eeg_")]
        self.eye_cols = [c for c in self.df.columns if c.startswith("eye_")]
        self.beh_cols = [c for c in self.df.columns if c.startswith("beh_")]
        for c in self.eeg_cols + self.eye_cols + self.beh_cols:
            self.df[c] = pd.to_numeric(self.df[c], errors="coerce").fillna(0).astype(np.float32)
        self.df["label"] = self.df["label"].astype(int)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        eeg = torch.tensor(row[self.eeg_cols].values.astype(np.float32), dtype=torch.float32)
        eye = torch.tensor(row[self.eye_cols].values.astype(np.float32), dtype=torch.float32)
        beh = torch.tensor(row[self.beh_cols].values.astype(np.float32), dtype=torch.float32)
        label = torch.tensor(row["label"], dtype=torch.long)
        return eeg, eye, beh, label

def create_loaders(df, train_ratio=0.7, val_ratio=0.15, batch_size=16):
    train_df, temp_df = train_test_split(df, test_size=1 - train_ratio, stratify=df["label"], random_state=42)
    val_size = val_ratio / (1 - train_ratio)
    val_df, test_df = train_test_split(temp_df, test_size=1 - val_size, stratify=temp_df["label"], random_state=42)

    train_ds = MultimodalDataset(train_df)
    val_ds = MultimodalDataset(val_df)
    test_ds = MultimodalDataset(test_df)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, pin_memory=True)
    return train_loader, val_loader, test_loader

# [model + training + evaluation remains unchanged]


  df = pd.read_csv(f)


In [None]:
# Add MLP model + training + evaluation (continued from previous)

class MLPFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim):
        super().__init__()
        self.eeg_fc = nn.Sequential(
            nn.Linear(eeg_dim, 64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.eye_fc = nn.Sequential(
            nn.Linear(eye_dim, 64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.beh_fc = nn.Sequential(
            nn.Linear(beh_dim, 64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 3, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_out = self.eeg_fc(eeg)
        eye_out = self.eye_fc(eye)
        beh_out = self.beh_fc(beh)
        fused = torch.cat([eeg_out, eye_out, beh_out], dim=1)
        return self.classifier(fused)

def train_model(model, train_loader, val_loader, device, epochs=30, lr=1e-3, patience=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for eeg, eye, beh, labels in train_loader:
            eeg, eye, beh, labels = eeg.to(device, non_blocking=True), eye.to(device, non_blocking=True), beh.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(eeg, eye, beh)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, device, return_acc=True)
        print(f"[Epoch {epoch}] Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "results/best_mlp_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break
    print("Best Val Accuracy:", best_val_acc)

def evaluate_model(model, loader, device, return_acc=False):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for eeg, eye, beh, labels in loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            logits = model(eeg, eye, beh)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    if return_acc:
        return acc
    print(classification_report(all_labels, all_preds, zero_division=0))
    cm = confusion_matrix(all_labels, all_preds)
    ConfusionMatrixDisplay(cm).plot()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/mlp_confmat.png")
    plt.close()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    model = MLPFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features)
    ).to(device)

    train_model(model, train_loader, val_loader, device)

    model.load_state_dict(torch.load("results/best_mlp_model.pt", map_location=device))
    evaluate_model(model, test_loader, device)


 Using device: cuda
[Epoch 1] Loss: 69.8826, Train Acc: 0.5455, Val Acc: 0.5000
[Epoch 2] Loss: 45.3643, Train Acc: 0.5341, Val Acc: 0.5000
[Epoch 3] Loss: 45.7964, Train Acc: 0.5227, Val Acc: 0.5556
[Epoch 4] Loss: 42.8756, Train Acc: 0.5000, Val Acc: 0.5000
[Epoch 5] Loss: 37.3129, Train Acc: 0.5227, Val Acc: 0.5000
[Epoch 6] Loss: 33.1963, Train Acc: 0.5682, Val Acc: 0.5000
Early stopping.
Best Val Accuracy: 0.5555555555555556
              precision    recall  f1-score   support

           0       0.50      1.00      0.67        10
           1       0.00      0.00      0.00        10

    accuracy                           0.50        20
   macro avg       0.25      0.50      0.33        20
weighted avg       0.25      0.50      0.33        20



  model.load_state_dict(torch.load("results/best_mlp_model.pt", map_location=device))


In [None]:
# Add MLP model + training + evaluation (continued from previous)

from sklearn.preprocessing import StandardScaler

# Feature normalization before DataLoader creation
scaler = StandardScaler()
balanced_df[all_feature_cols] = scaler.fit_transform(balanced_df[all_feature_cols])

# Define weighted loss based on label distribution
label_counts = balanced_df['label'].value_counts().to_dict()
total = sum(label_counts.values())
weights = [total / label_counts.get(i, 1) for i in range(2)]
class_weights = torch.tensor(weights, dtype=torch.float32).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

class MLPFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim):
        super().__init__()
        self.eeg_fc = nn.Sequential(
            nn.Linear(eeg_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.eye_fc = nn.Sequential(
            nn.Linear(eye_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.beh_fc = nn.Sequential(
            nn.Linear(beh_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 3, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_out = self.eeg_fc(eeg)
        eye_out = self.eye_fc(eye)
        beh_out = self.beh_fc(beh)
        fused = torch.cat([eeg_out, eye_out, beh_out], dim=1)
        return self.classifier(fused)

def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-3, patience=6):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for eeg, eye, beh, labels in train_loader:
            eeg, eye, beh, labels = eeg.to(device, non_blocking=True), eye.to(device, non_blocking=True), beh.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(eeg, eye, beh)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, device, return_acc=True)
        print(f"[Epoch {epoch}] Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "results/best_mlp_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break
    print("Best Val Accuracy:", best_val_acc)

def evaluate_model(model, loader, device, return_acc=False):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for eeg, eye, beh, labels in loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            logits = model(eeg, eye, beh)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    if return_acc:
        return acc
    print(classification_report(all_labels, all_preds, zero_division=0))
    cm = confusion_matrix(all_labels, all_preds)
    ConfusionMatrixDisplay(cm).plot()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/mlp_confmat.png")
    plt.close()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    model = MLPFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features)
    ).to(device)

    train_model(model, train_loader, val_loader, device)

    model.load_state_dict(torch.load("results/best_mlp_model.pt", map_location=device))
    evaluate_model(model, test_loader, device)


 Using device: cuda
[Epoch 1] Loss: 0.7296, Train Acc: 0.5227, Val Acc: 0.5556
[Epoch 2] Loss: 0.6187, Train Acc: 0.6364, Val Acc: 0.6111
[Epoch 3] Loss: 0.5669, Train Acc: 0.7159, Val Acc: 0.6111
[Epoch 4] Loss: 0.5593, Train Acc: 0.6477, Val Acc: 0.7222
[Epoch 5] Loss: 0.4962, Train Acc: 0.7045, Val Acc: 0.7222
[Epoch 6] Loss: 0.5151, Train Acc: 0.7386, Val Acc: 0.6111
[Epoch 7] Loss: 0.5884, Train Acc: 0.6477, Val Acc: 0.6111
[Epoch 8] Loss: 0.5019, Train Acc: 0.7386, Val Acc: 0.6111
[Epoch 9] Loss: 0.4879, Train Acc: 0.7614, Val Acc: 0.6111
[Epoch 10] Loss: 0.5145, Train Acc: 0.7159, Val Acc: 0.6111
Early stopping.
Best Val Accuracy: 0.7222222222222222
              precision    recall  f1-score   support

           0       1.00      0.50      0.67        10
           1       0.67      1.00      0.80        10

    accuracy                           0.75        20
   macro avg       0.83      0.75      0.73        20
weighted avg       0.83      0.75      0.73        20



  model.load_state_dict(torch.load("results/best_mlp_model.pt", map_location=device))


In [None]:
# Add Gated Multimodal Unit Fusion (GMUFusion)

from sklearn.preprocessing import StandardScaler

# Feature normalization before DataLoader creation
scaler = StandardScaler()
balanced_df[all_feature_cols] = scaler.fit_transform(balanced_df[all_feature_cols])

# Define weighted loss based on label distribution
label_counts = balanced_df['label'].value_counts().to_dict()
total = sum(label_counts.values())
weights = [total / label_counts.get(i, 1) for i in range(2)]
class_weights = torch.tensor(weights, dtype=torch.float32).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

class GMUFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, 64)
        self.eye_fc = nn.Linear(eye_dim, 64)
        self.beh_fc = nn.Linear(beh_dim, 64)

        self.z_gate = nn.Sequential(
            nn.Linear(64 * 3, 64),
            nn.Sigmoid()
        )

        self.fusion_fc = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_h = self.eeg_fc(eeg)
        eye_h = self.eye_fc(eye)
        beh_h = self.beh_fc(beh)
        concat = torch.cat([eeg_h, eye_h, beh_h], dim=1)
        z = self.z_gate(concat)
        gated = z * eeg_h + (1 - z) * eye_h + beh_h * 0.5
        return self.fusion_fc(gated)

def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-3, patience=6):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for eeg, eye, beh, labels in train_loader:
            eeg, eye, beh, labels = eeg.to(device, non_blocking=True), eye.to(device, non_blocking=True), beh.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(eeg, eye, beh)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, device, return_acc=True)
        print(f"[Epoch {epoch}] Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "results/best_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break
    print("Best Val Accuracy:", best_val_acc)

def evaluate_model(model, loader, device, return_acc=False):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for eeg, eye, beh, labels in loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            logits = model(eeg, eye, beh)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    if return_acc:
        return acc
    print(classification_report(all_labels, all_preds, zero_division=0))
    cm = confusion_matrix(all_labels, all_preds)
    ConfusionMatrixDisplay(cm).plot()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/test_confmat.png")
    plt.close()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    model = GMUFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features)
    ).to(device)

    train_model(model, train_loader, val_loader, device)

    model.load_state_dict(torch.load("results/best_model.pt", map_location=device))
    evaluate_model(model, test_loader, device)

 Using device: cuda
[Epoch 1] Loss: 0.6801, Train Acc: 0.5795, Val Acc: 0.4444
[Epoch 2] Loss: 0.6341, Train Acc: 0.6023, Val Acc: 0.3889
[Epoch 3] Loss: 0.5941, Train Acc: 0.7159, Val Acc: 0.4444
[Epoch 4] Loss: 0.5881, Train Acc: 0.6136, Val Acc: 0.4444
[Epoch 5] Loss: 0.5650, Train Acc: 0.6932, Val Acc: 0.4444
[Epoch 6] Loss: 0.5445, Train Acc: 0.6705, Val Acc: 0.4444
[Epoch 7] Loss: 0.5392, Train Acc: 0.6818, Val Acc: 0.5000
[Epoch 8] Loss: 0.5205, Train Acc: 0.6932, Val Acc: 0.5000
[Epoch 9] Loss: 0.5238, Train Acc: 0.6591, Val Acc: 0.5000
[Epoch 10] Loss: 0.5140, Train Acc: 0.6818, Val Acc: 0.5000
[Epoch 11] Loss: 0.4920, Train Acc: 0.7159, Val Acc: 0.4444
[Epoch 12] Loss: 0.4958, Train Acc: 0.7045, Val Acc: 0.4444
[Epoch 13] Loss: 0.4759, Train Acc: 0.7386, Val Acc: 0.5556
[Epoch 14] Loss: 0.4913, Train Acc: 0.6932, Val Acc: 0.5000
[Epoch 15] Loss: 0.4601, Train Acc: 0.8068, Val Acc: 0.5000
[Epoch 16] Loss: 0.4708, Train Acc: 0.7614, Val Acc: 0.6111
[Epoch 17] Loss: 0.4531, Trai

  model.load_state_dict(torch.load("results/best_model.pt", map_location=device))


In [None]:
# Add CrossModal Attention Fusion

from sklearn.preprocessing import StandardScaler

# Feature normalization before DataLoader creation
scaler = StandardScaler()
balanced_df[all_feature_cols] = scaler.fit_transform(balanced_df[all_feature_cols])

# Define weighted loss based on label distribution
label_counts = balanced_df['label'].value_counts().to_dict()
total = sum(label_counts.values())
weights = [total / label_counts.get(i, 1) for i in range(2)]
class_weights = torch.tensor(weights, dtype=torch.float32).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

class CrossModalAttentionFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden=64):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, hidden)
        self.eye_fc = nn.Linear(eye_dim, hidden)
        self.beh_fc = nn.Linear(beh_dim, hidden)

        self.attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=4, batch_first=True)

        self.classifier = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_h = self.eeg_fc(eeg)
        eye_h = self.eye_fc(eye)
        beh_h = self.beh_fc(beh)
        x = torch.stack([eeg_h, eye_h, beh_h], dim=1)  # shape: (B, 3, H)
        attn_output, _ = self.attn(x, x, x)
        pooled = attn_output.mean(dim=1)  # shape: (B, H)
        return self.classifier(pooled)

def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-3, patience=6):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for eeg, eye, beh, labels in train_loader:
            eeg, eye, beh, labels = eeg.to(device, non_blocking=True), eye.to(device, non_blocking=True), beh.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(eeg, eye, beh)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, device, return_acc=True)
        print(f"[Epoch {epoch}] Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "results/best_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break
    print("Best Val Accuracy:", best_val_acc)

def evaluate_model(model, loader, device, return_acc=False):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for eeg, eye, beh, labels in loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            logits = model(eeg, eye, beh)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    if return_acc:
        return acc
    print(classification_report(all_labels, all_preds, zero_division=0))
    cm = confusion_matrix(all_labels, all_preds)
    ConfusionMatrixDisplay(cm).plot()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/test_confmat.png")
    plt.close()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    model = CrossModalAttentionFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features)
    ).to(device)

    train_model(model, train_loader, val_loader, device)

    model.load_state_dict(torch.load("results/best_model.pt", map_location=device))
    evaluate_model(model, test_loader, device)


 Using device: cuda
[Epoch 1] Loss: 0.6889, Train Acc: 0.4773, Val Acc: 0.5556
[Epoch 2] Loss: 0.6740, Train Acc: 0.4886, Val Acc: 0.5556
[Epoch 3] Loss: 0.6479, Train Acc: 0.5568, Val Acc: 0.7778
[Epoch 4] Loss: 0.6195, Train Acc: 0.6364, Val Acc: 0.6111
[Epoch 5] Loss: 0.5935, Train Acc: 0.6136, Val Acc: 0.5556
[Epoch 6] Loss: 0.5642, Train Acc: 0.6364, Val Acc: 0.4444
[Epoch 7] Loss: 0.5392, Train Acc: 0.6705, Val Acc: 0.5556
[Epoch 8] Loss: 0.5232, Train Acc: 0.6932, Val Acc: 0.6111
[Epoch 9] Loss: 0.5020, Train Acc: 0.7159, Val Acc: 0.6111
Early stopping.
Best Val Accuracy: 0.7777777777777778
              precision    recall  f1-score   support

           0       0.80      0.80      0.80        10
           1       0.80      0.80      0.80        10

    accuracy                           0.80        20
   macro avg       0.80      0.80      0.80        20
weighted avg       0.80      0.80      0.80        20



  model.load_state_dict(torch.load("results/best_model.pt", map_location=device))


In [None]:
# Add CrossModal Attention Fusion

from sklearn.preprocessing import StandardScaler

# Feature normalization before DataLoader creation
scaler = StandardScaler()
balanced_df[all_feature_cols] = scaler.fit_transform(balanced_df[all_feature_cols])

# Define weighted loss based on label distribution
label_counts = balanced_df['label'].value_counts().to_dict()
total = sum(label_counts.values())
weights = [total / label_counts.get(i, 1) for i in range(2)]
class_weights = torch.tensor(weights, dtype=torch.float32).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

class CrossModalAttentionFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden=64):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, hidden)
        self.eye_fc = nn.Linear(eye_dim, hidden)
        self.beh_fc = nn.Linear(beh_dim, hidden)

        self.attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=4, batch_first=True)

        self.classifier = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_h = self.eeg_fc(eeg)
        eye_h = self.eye_fc(eye)
        beh_h = self.beh_fc(beh)
        x = torch.stack([eeg_h, eye_h, beh_h], dim=1)  # shape: (B, 3, H)
        attn_output, _ = self.attn(x, x, x)
        pooled = attn_output.mean(dim=1)  # shape: (B, H)
        return self.classifier(pooled)

def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-3, patience=6):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for eeg, eye, beh, labels in train_loader:
            eeg, eye, beh, labels = eeg.to(device, non_blocking=True), eye.to(device, non_blocking=True), beh.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(eeg, eye, beh)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, device, return_acc=True)
        print(f"[Epoch {epoch}] Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "results/best_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break
    print("Best Val Accuracy:", best_val_acc)

def evaluate_model(model, loader, device, return_acc=False):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for eeg, eye, beh, labels in loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            logits = model(eeg, eye, beh)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    if return_acc:
        return acc
    print(classification_report(all_labels, all_preds, zero_division=0))
    cm = confusion_matrix(all_labels, all_preds)
    ConfusionMatrixDisplay(cm).plot()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/test_confmat.png")
    plt.close()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("✅ Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    model = CrossModalAttentionFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features)
    ).to(device)

    train_model(model, train_loader, val_loader, device)

    model.load_state_dict(torch.load("results/best_model.pt", map_location=device))
    evaluate_model(model, test_loader, device)


✅ Using device: cuda
[Epoch 1] Loss: 0.6845, Train Acc: 0.5568, Val Acc: 0.6111
[Epoch 2] Loss: 0.6495, Train Acc: 0.6705, Val Acc: 0.6111
[Epoch 3] Loss: 0.6284, Train Acc: 0.6477, Val Acc: 0.5556
[Epoch 4] Loss: 0.5997, Train Acc: 0.6477, Val Acc: 0.5556
[Epoch 5] Loss: 0.5750, Train Acc: 0.6364, Val Acc: 0.5556
[Epoch 6] Loss: 0.5490, Train Acc: 0.6591, Val Acc: 0.5000
[Epoch 7] Loss: 0.5292, Train Acc: 0.6591, Val Acc: 0.5000
Early stopping.
Best Val Accuracy: 0.6111111111111112
              precision    recall  f1-score   support

           0       0.86      0.60      0.71        10
           1       0.69      0.90      0.78        10

    accuracy                           0.75        20
   macro avg       0.77      0.75      0.74        20
weighted avg       0.77      0.75      0.74        20



  model.load_state_dict(torch.load("results/best_model.pt", map_location=device))


In [None]:
# TransformerFusion for training
class TransformerFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden_dim=64, nhead=4, num_layers=2):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, hidden_dim)
        self.eye_fc = nn.Linear(eye_dim, hidden_dim)
        self.beh_fc = nn.Linear(beh_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            dropout=0.1,
            activation='relu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_embed = F.relu(self.eeg_fc(eeg))
        eye_embed = F.relu(self.eye_fc(eye))
        beh_embed = F.relu(self.beh_fc(beh))

        x = torch.stack([eeg_embed, eye_embed, beh_embed], dim=1)
        x = self.transformer(x)
        fused = x.mean(dim=1)
        return self.classifier(fused)

# GMUFusion for training
class GMUFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden=64):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, hidden)
        self.eye_fc = nn.Linear(eye_dim, hidden)
        self.beh_fc = nn.Linear(beh_dim, hidden)

        self.gate = nn.Sequential(
            nn.Linear(hidden * 3, hidden),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_h = self.eeg_fc(eeg)
        eye_h = self.eye_fc(eye)
        beh_h = self.beh_fc(beh)
        concat = torch.cat([eeg_h, eye_h, beh_h], dim=1)
        z = self.gate(concat)
        fused = z * eeg_h + (1 - z) * eye_h + beh_h * 0.5
        return self.classifier(fused)

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)
    train_loader, val_loader, test_loader = create_loaders(balanced_df)

    print("\n Training TransformerFusion")
    model_t = TransformerFusion(len(eeg_features), len(eye_features), len(beh_features)).to(device)
    train_model(model_t, train_loader, val_loader, device)
    model_t.load_state_dict(torch.load("results/best_model.pt", map_location=device))
    evaluate_model(model_t, test_loader, device)

    print("\n Training GMUFusion")
    model_g = GMUFusion(len(eeg_features), len(eye_features), len(beh_features)).to(device)
    train_model(model_g, train_loader, val_loader, device)
    model_g.load_state_dict(torch.load("results/best_model.pt", map_location=device))
    evaluate_model(model_g, test_loader, device)


 Using device: cuda

 Training TransformerFusion
[Epoch 1] Loss: 0.7206, Train Acc: 0.4318, Val Acc: 0.5556
[Epoch 2] Loss: 0.6445, Train Acc: 0.6591, Val Acc: 0.5000
[Epoch 3] Loss: 0.6329, Train Acc: 0.6477, Val Acc: 0.5556
[Epoch 4] Loss: 0.5654, Train Acc: 0.7386, Val Acc: 0.6111
[Epoch 5] Loss: 0.5597, Train Acc: 0.7273, Val Acc: 0.6111
[Epoch 6] Loss: 0.4783, Train Acc: 0.7841, Val Acc: 0.5556
[Epoch 7] Loss: 0.4845, Train Acc: 0.8068, Val Acc: 0.5556
[Epoch 8] Loss: 0.4448, Train Acc: 0.7955, Val Acc: 0.6667
[Epoch 9] Loss: 0.4135, Train Acc: 0.8295, Val Acc: 0.6667
[Epoch 10] Loss: 0.3760, Train Acc: 0.8409, Val Acc: 0.7222
[Epoch 11] Loss: 0.3644, Train Acc: 0.8636, Val Acc: 0.7778
[Epoch 12] Loss: 0.3397, Train Acc: 0.8636, Val Acc: 0.8333
[Epoch 13] Loss: 0.3149, Train Acc: 0.8864, Val Acc: 0.7778
[Epoch 14] Loss: 0.3013, Train Acc: 0.8864, Val Acc: 0.7778
[Epoch 15] Loss: 0.2642, Train Acc: 0.9091, Val Acc: 0.7778
[Epoch 16] Loss: 0.2563, Train Acc: 0.8977, Val Acc: 0.7778


  model_t.load_state_dict(torch.load("results/best_model.pt", map_location=device))


[Epoch 1] Loss: 0.7377, Train Acc: 0.5114, Val Acc: 0.6111
[Epoch 2] Loss: 0.6575, Train Acc: 0.6136, Val Acc: 0.5000
[Epoch 3] Loss: 0.6054, Train Acc: 0.6364, Val Acc: 0.5000
[Epoch 4] Loss: 0.5779, Train Acc: 0.6477, Val Acc: 0.5000
[Epoch 5] Loss: 0.5611, Train Acc: 0.6477, Val Acc: 0.5000
[Epoch 6] Loss: 0.5451, Train Acc: 0.6591, Val Acc: 0.5000
[Epoch 7] Loss: 0.5323, Train Acc: 0.6591, Val Acc: 0.5000
Early stopping.
Best Val Accuracy: 0.6111111111111112
              precision    recall  f1-score   support

           0       1.00      0.50      0.67        10
           1       0.67      1.00      0.80        10

    accuracy                           0.75        20
   macro avg       0.83      0.75      0.73        20
weighted avg       0.83      0.75      0.73        20



  model_g.load_state_dict(torch.load("results/best_model.pt", map_location=device))


In [None]:
# Further Analysis for TransformerFusion and GATFusion
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve

# Self-Attention Heatmap Visualizer (Transformer)
@torch.no_grad()
def visualize_attention_weights(model, eeg, eye, beh):
    model.eval()
    device = next(model.parameters()).device
    eeg = eeg.to(device).unsqueeze(0)
    eye = eye.to(device).unsqueeze(0)
    beh = beh.to(device).unsqueeze(0)

    eeg_h = F.relu(model.eeg_fc(eeg))
    eye_h = F.relu(model.eye_fc(eye))
    beh_h = F.relu(model.beh_fc(beh))
    x = torch.stack([eeg_h, eye_h, beh_h], dim=1)

    attn_layer = model.transformer.layers[0].self_attn
    attn_output, attn_weights = attn_layer(x, x, x, need_weights=True)
    weights = attn_weights.squeeze(0).cpu().detach().numpy()

    for h in range(weights.shape[0]):
        visualize_connectogram(weights[h], title=f"Transformer Head {h+1} Connectogram", save_path=f"results/transformer_connectogram_{h+1}.png")

# GAT attention visualizer
@torch.no_grad()
def visualize_gat_attention(model, eeg, eye, beh):
    model.eval()
    device = next(model.parameters()).device
    eeg = eeg.to(device).unsqueeze(0)
    eye = eye.to(device).unsqueeze(0)
    beh = beh.to(device).unsqueeze(0)

    eeg_h = F.relu(model.eeg_fc(eeg))
    eye_h = F.relu(model.eye_fc(eye))
    beh_h = F.relu(model.beh_fc(beh))
    x = torch.stack([eeg_h, eye_h, beh_h], dim=1)

    h = model.gat.lin(x)
    att_src = model.gat.att_src.view(1, 1, -1)
    att_dst = model.gat.att_dst.view(1, 1, -1)

    src = (h * att_src).sum(dim=-1)
    dst = (h * att_dst).sum(dim=-1)
    e = model.gat.leaky_relu(src.unsqueeze(2) + dst.unsqueeze(1))
    alpha = F.softmax(e, dim=-1).squeeze(0).cpu().numpy()

    visualize_connectogram(alpha, title="GAT Connectogram", save_path="results/gat_connectogram.png")

# Connectogram-Style Visualization
def visualize_connectogram(matrix, title="Connectogram", save_path="connectogram.png"):
    matrix = np.asarray(matrix)
    if matrix.ndim == 1:
        matrix = np.expand_dims(matrix, axis=0)
    if matrix.shape != (3, 3):
        print(f"Invalid shape for matrix: {matrix.shape}, skipping plot.")
        return

    labels = ["EEG", "Eye", "Beh"]
    pos = {
        "EEG": (0, 1),
        "Eye": (-1, -1),
        "Beh": (1, -1)
    }
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_aspect('equal')
    ax.axis('off')

    for label in labels:
        x, y = pos[label]
        circle = Circle((x, y), 0.15, color='skyblue', ec='black', lw=1.5)
        ax.add_patch(circle)
        ax.text(x, y, label, fontsize=12, ha='center', va='center', weight='bold')

    for i, src in enumerate(labels):
        for j, tgt in enumerate(labels):
            if i != j:
                x1, y1 = pos[src]
                x2, y2 = pos[tgt]
                if matrix.ndim == 2:
                    weight = matrix[i][j]
                    arrow = FancyArrowPatch((x1, y1), (x2, y2), connectionstyle="arc3,rad=0.2",
                                            arrowstyle='-|>', mutation_scale=15,
                                            lw=weight * 5, color=plt.cm.viridis(weight))
                    ax.add_patch(arrow)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# Per-Modality Ablation Evaluation
@torch.no_grad()
def ablation_test(model_class, base_model, device):
    test_results = {}

    for combo in ["eeg", "eye", "beh", "eeg+eye", "eeg+beh", "eye+beh", "all"]:
        model = model_class(len(eeg_features), len(eye_features), len(beh_features)).to(device)
        model.load_state_dict(base_model.state_dict())
        model.eval()

        all_preds, all_labels = [], []
        for eeg, eye, beh, labels in test_loader:
            eeg, eye, beh, labels = eeg.to(device), eye.to(device), beh.to(device), labels.to(device)

            if combo == "all":
                pass
            elif combo == "eeg":
                eye = torch.zeros_like(eye)
                beh = torch.zeros_like(beh)
            elif combo == "eye":
                eeg = torch.zeros_like(eeg)
                beh = torch.zeros_like(beh)
            elif combo == "beh":
                eeg = torch.zeros_like(eeg)
                eye = torch.zeros_like(eye)
            elif combo == "eeg+eye":
                beh = torch.zeros_like(beh)
            elif combo == "eeg+beh":
                eye = torch.zeros_like(eye)
            elif combo == "eye+beh":
                eeg = torch.zeros_like(eeg)

            outputs = model(eeg, eye, beh)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        cm = confusion_matrix(all_labels, all_preds)
        acc = np.mean(np.array(all_preds) == np.array(all_labels))
        report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
        test_results[combo] = {"acc": acc, "f1": report["macro avg"]["f1-score"]}

    return test_results

# Additional visualizations: ROC and PR Curve
@torch.no_grad()
def plot_roc_pr_curves(model, loader, device, title_prefix="model"):
    model.eval()
    all_labels = []
    all_probs = []

    for eeg, eye, beh, labels in loader:
        eeg, eye, beh = eeg.to(device), eye.to(device), beh.to(device)
        outputs = model(eeg, eye, beh)
        probs = torch.softmax(outputs, dim=1)[:, 1]
        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(labels.numpy())

    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    prec, rec, _ = precision_recall_curve(all_labels, all_probs)

    roc_auc = auc(fpr, tpr)
    pr_auc = auc(rec, prec)

    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'{title_prefix} ROC Curve')
    plt.legend(loc='lower right')
    plt.savefig(f"results/{title_prefix.lower()}_roc.png")
    plt.close()

    plt.figure()
    plt.plot(rec, prec, label=f'PR Curve (AUC = {pr_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'{title_prefix} Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.savefig(f"results/{title_prefix.lower()}_pr.png")
    plt.close()

if __name__ == "__main__":
    eeg, eye, beh, label = next(iter(test_loader))
    visualize_attention_weights(model_t, eeg[0], eye[0], beh[0])
    ablation_scores = ablation_test(TransformerFusion, model_t, device)
    print("\n▶ Transformer Ablation Study Results:")
    for k, v in ablation_scores.items():
        print(f"{k.upper():9} — Accuracy: {v['acc']:.2f}, F1: {v['f1']:.2f}")

    if 'model_gat' in globals():
        visualize_gat_attention(model_gat, eeg[0], eye[0], beh[0])
        plot_roc_pr_curves(model_gat, test_loader, device, title_prefix="GAT")

    plot_roc_pr_curves(model_t, test_loader, device, title_prefix="Transformer")


Invalid shape for matrix: (1, 3), skipping plot.
Invalid shape for matrix: (1, 3), skipping plot.
Invalid shape for matrix: (1, 3), skipping plot.

▶ Transformer Ablation Study Results:
EEG       — Accuracy: 0.60, F1: 0.52
EYE       — Accuracy: 0.55, F1: 0.54
BEH       — Accuracy: 0.70, F1: 0.69
EEG+EYE   — Accuracy: 0.60, F1: 0.60
EEG+BEH   — Accuracy: 0.65, F1: 0.63
EYE+BEH   — Accuracy: 0.80, F1: 0.80
ALL       — Accuracy: 0.80, F1: 0.80


In [None]:
import pandas as pd

model_results = {
    "MLP": {"acc": 0.75, "f1": 0.73},
    "GMU": {"acc": 0.70, "f1": 0.67},
    "Transformer": {"acc": 0.80, "f1": 0.80},
    "GAT": {"acc": 0.75, "f1": 0.73},
    "CrossModal": {"acc": 0.75, "f1": 0.73}
}

df_results = pd.DataFrame(model_results).T
df_results.to_csv("results/model_comparison.csv", float_format="%.2f")
