<a href="https://colab.research.google.com/github/AyeshaAnzerBCIT/Multisource/blob/main/GATmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import gcsfs
import pandas as pd
import numpy as np

# GCS authentication
KEY_PATH = "Key.json"
fs = gcsfs.GCSFileSystem(token=KEY_PATH)

# Feature columns
eeg_features = ["Mobility", "Complexity", "Spectral_Entropy"]
eye_features = [
    "mean_pupil_size", "std_pupil_size",
    "mean_latency", "std_latency",
    "mean_gaze_vector", "std_gaze_vector"
]
beh_features = ["mean_rt", "accuracy"]

def extract_patient_id(name):
    match = re.findall(r'(A\d{5,})', str(name))
    return match[0] if match else None

def load_csv(fs, path, id_col):
    with fs.open(path, 'r') as f:
        df = pd.read_csv(f)
    df["patient_id"] = df[id_col].apply(extract_patient_id)
    return df

def group_features(df, features):
    return df.groupby("patient_id")[features].mean().reset_index()

def impute_missing_rows(modality_df, features, missing_ids):
    avg_values = modality_df[features].mean()
    imputed = pd.DataFrame([{**{"patient_id": pid}, **avg_values.to_dict()} for pid in missing_ids])
    return pd.concat([modality_df, imputed], ignore_index=True)

# -----------------------------
# 1) Load Raw Modality Data from GCS
# -----------------------------
eeg_df = load_csv(fs, "gs://eegchild/processed_features/merged_features.csv", "file_name")
eye_df = load_csv(fs, "gs://eegchild/processed_asd_features.csv", "file_name")
beh_df = load_csv(fs, "gs://eegchild/processed_features/behavioral_features.csv", "file")
label_df = load_csv(fs, "gs://eegchild/MIPDB_PublicFile.csv", "ID")

# Binarize labels: 0 => no diagnosis, 1 => diagnosis (2 => 1)
label_df = label_df.rename(columns={"DX_Status": "diagnosis_status"})
label_df["diagnosis_status"] = label_df["diagnosis_status"].replace({2: 1})

# Convert patient IDs to string
for df in [eeg_df, eye_df, beh_df, label_df]:
    df["patient_id"] = df["patient_id"].astype(str)

# -----------------------------
# 2) Group & Impute Missing
# -----------------------------
grouped_eeg = group_features(eeg_df, eeg_features)
grouped_eye = group_features(eye_df, eye_features)
grouped_beh = group_features(beh_df, beh_features)

expected_ids = set(label_df["patient_id"])
grouped_eeg = impute_missing_rows(grouped_eeg, eeg_features, expected_ids - set(grouped_eeg["patient_id"]))
grouped_eye = impute_missing_rows(grouped_eye, eye_features, expected_ids - set(grouped_eye["patient_id"]))
grouped_beh = impute_missing_rows(grouped_beh, beh_features, expected_ids - set(grouped_beh["patient_id"]))

# Merge each with label info
eeg_merged = grouped_eeg.merge(label_df, on="patient_id")
eye_merged = grouped_eye.merge(label_df, on="patient_id")
beh_merged = grouped_beh.merge(label_df, on="patient_id")

# Find common patient IDs
common_ids = set(eeg_merged["patient_id"]) & set(eye_merged["patient_id"]) & set(beh_merged["patient_id"])
eeg_final = eeg_merged[eeg_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
eye_final = eye_merged[eye_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
beh_final = beh_merged[beh_merged["patient_id"].isin(common_ids)].reset_index(drop=True)

# -----------------------------
# 3) Merge into One DataFrame
# -----------------------------
merged_df = pd.DataFrame({
    "patient_id": eeg_final["patient_id"],
    "label": eeg_final["diagnosis_status"]
})

for feat in eeg_features:
    merged_df[f"eeg_{feat}"] = eeg_final[feat]
for feat in eye_features:
    merged_df[f"eye_{feat}"] = eye_final[feat]
for feat in beh_features:
    merged_df[f"beh_{feat}"] = beh_final[feat]

# -----------------------------
# 4) Balance Dataset (63 each)
# -----------------------------
class_0 = merged_df[merged_df["label"] == 0]
class_1 = merged_df[merged_df["label"] == 1]
balanced_df = pd.concat([
    class_0.sample(n=63, replace=True, random_state=42),
    class_1.sample(n=63, replace=True, random_state=42)
]).sample(frac=1, random_state=42).reset_index(drop=True)

print("Final balanced dataset shape:", balanced_df.shape)
print("Class distribution:", balanced_df['label'].value_counts())

# -----------------------------
# 5) Ensure All Feature Columns Are float32
# -----------------------------
all_feature_cols = [
    col for col in balanced_df.columns
    if col.startswith("eeg_") or col.startswith("eye_") or col.startswith("beh_")
]

for c in all_feature_cols:
    # Convert forcibly to float32, dropping/coercing non-numerics
    balanced_df[c] = pd.to_numeric(balanced_df[c], errors='coerce').fillna(0).astype(np.float32)

# (Optional) Drop any leftover non-feature columns if needed
extra_cols = [c for c in balanced_df.columns if c not in all_feature_cols + ["patient_id", "label"]]
if extra_cols:
    print("Dropping leftover columns:", extra_cols)
    balanced_df.drop(columns=extra_cols, inplace=True)

# Double-check
print("\nColumn dtypes after numeric conversion:")
print(balanced_df.dtypes)

print("\nSample rows:\n", balanced_df.head(3))


  df = pd.read_csv(f)


Final balanced dataset shape: (126, 13)
Class distribution: label
1    63
0    63
Name: count, dtype: int64

Column dtypes after numeric conversion:
patient_id               object
label                     int64
eeg_Mobility            float32
eeg_Complexity          float32
eeg_Spectral_Entropy    float32
eye_mean_pupil_size     float32
eye_std_pupil_size      float32
eye_mean_latency        float32
eye_std_latency         float32
eye_mean_gaze_vector    float32
eye_std_gaze_vector     float32
beh_mean_rt             float32
beh_accuracy            float32
dtype: object

Sample rows:
   patient_id  label  eeg_Mobility  eeg_Complexity  eeg_Spectral_Entropy  \
0  A00063558      1      0.192332        2.461543              2.219917   
1  A00055956      0      0.192332        2.461543              2.219917   
2  A00057599      1      0.192332        2.461543              2.219917   

   eye_mean_pupil_size  eye_std_pupil_size  eye_mean_latency  eye_std_latency  \
0            18.102869  

In [None]:
# Identify all feature columns from the balanced DataFrame
all_feature_cols = []
for col in balanced_df.columns:
    if col.startswith("eeg_") or col.startswith("eye_") or col.startswith("beh_"):
        all_feature_cols.append(col)

print("All feature columns:")
for c in all_feature_cols:
    print("-", c)

# Example numeric conversion step (optional) to guarantee float32
for c in all_feature_cols:
    balanced_df[c] = pd.to_numeric(balanced_df[c], errors='coerce').fillna(0).astype('float32')

# Now you have a list of your final numeric features
print("\nFinal check (dtypes):")
print(balanced_df[all_feature_cols].dtypes)


All feature columns:
- eeg_Mobility
- eeg_Complexity
- eeg_Spectral_Entropy
- eye_mean_pupil_size
- eye_std_pupil_size
- eye_mean_latency
- eye_std_latency
- eye_mean_gaze_vector
- eye_std_gaze_vector
- beh_mean_rt
- beh_accuracy

Final check (dtypes):
eeg_Mobility            float32
eeg_Complexity          float32
eeg_Spectral_Entropy    float32
eye_mean_pupil_size     float32
eye_std_pupil_size      float32
eye_mean_latency        float32
eye_std_latency         float32
eye_mean_gaze_vector    float32
eye_std_gaze_vector     float32
beh_mean_rt             float32
beh_accuracy            float32
dtype: object


In [None]:
# part2_gat_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    """
    Minimal GAT-like layer for a 3-node graph (EEG, Eye, Beh).
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=False)
        self.att_src = nn.Parameter(torch.zeros((1, 1, out_dim)))
        self.att_dst = nn.Parameter(torch.zeros((1, 1, out_dim)))
        nn.init.xavier_uniform_(self.lin.weight, gain=1.414)
        nn.init.xavier_uniform_(self.att_src, gain=1.414)
        nn.init.xavier_uniform_(self.att_dst, gain=1.414)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        """
        Args:
          x: shape (B, N=3, in_dim)
        Returns:
          shape (B, N=3, out_dim)
        """
        B, N, _ = x.shape
        h = self.lin(x)  # (B, N, out_dim)

        src = (h * self.att_src).sum(dim=-1).unsqueeze(2)  # (B, N, 1)
        dst = (h * self.att_dst).sum(dim=-1).unsqueeze(1)  # (B, 1, N)
        e = self.leaky_relu(src + dst)  # (B, N, N)

        alpha = F.softmax(e, dim=-1)     # attention across the N nodes
        out = torch.bmm(alpha, h)        # (B, N, out_dim)
        return F.elu(out)                # final activation

class GATFusion(nn.Module):
    """
    GAT-based fusion of EEG, Eye, and Behavioral modalities.
    Each modality is encoded to hidden_dim, then the GAT layer merges them.
    """
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden_dim=64, out_dim=64):
        super().__init__()
        # Per-modality FC encoders
        self.eeg_fc = nn.Linear(eeg_dim, hidden_dim)
        self.eye_fc = nn.Linear(eye_dim, hidden_dim)
        self.beh_fc = nn.Linear(beh_dim, hidden_dim)

        # Graph Attention
        self.gat = GATLayer(in_dim=hidden_dim, out_dim=out_dim)

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(out_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)  # binary classification
        )

    def forward(self, eeg, eye, beh):
        # (B, eeg_dim) => (B, hidden_dim)
        eeg_embed = F.relu(self.eeg_fc(eeg))
        eye_embed = F.relu(self.eye_fc(eye))
        beh_embed = F.relu(self.beh_fc(beh))

        # Stack: (B, 3, hidden_dim)
        stack = torch.stack([eeg_embed, eye_embed, beh_embed], dim=1)

        # GAT => (B, 3, out_dim)
        gat_out = self.gat(stack)
        fused = gat_out.mean(dim=1)  # (B, out_dim)

        return self.classifier(fused)


In [None]:
###############################################
# Part 1: Load & Balance Data from GCS
###############################################

import os
import re
import gcsfs
import pandas as pd
import numpy as np

KEY_PATH = "Key.json"
fs = gcsfs.GCSFileSystem(token=KEY_PATH)

eeg_features = ["Mobility", "Complexity", "Spectral_Entropy"]
eye_features = [
    "mean_pupil_size", "std_pupil_size",
    "mean_latency", "std_latency",
    "mean_gaze_vector", "std_gaze_vector"
]
beh_features = ["mean_rt", "accuracy"]

def extract_patient_id(name):
    match = re.findall(r'(A\d{5,})', str(name))
    return match[0] if match else None

def load_csv(fs, path, id_col):
    with fs.open(path, 'r') as f:
        df = pd.read_csv(f)
    df["patient_id"] = df[id_col].apply(extract_patient_id)
    return df

def group_features(df, features):
    return df.groupby("patient_id")[features].mean().reset_index()

def impute_missing_rows(modality_df, features, missing_ids):
    avg_values = modality_df[features].mean()
    imputed = pd.DataFrame([{**{"patient_id": pid}, **avg_values.to_dict()} for pid in missing_ids])
    return pd.concat([modality_df, imputed], ignore_index=True)

# 1) Load
eeg_df = load_csv(fs, "gs://eegchild/processed_features/merged_features.csv", "file_name")
eye_df = load_csv(fs, "gs://eegchild/processed_asd_features.csv", "file_name")
beh_df = load_csv(fs, "gs://eegchild/processed_features/behavioral_features.csv", "file")
label_df = load_csv(fs, "gs://eegchild/MIPDB_PublicFile.csv", "ID")

label_df = label_df.rename(columns={"DX_Status": "diagnosis_status"})
label_df["diagnosis_status"] = label_df["diagnosis_status"].replace({2: 1})

for df in [eeg_df, eye_df, beh_df, label_df]:
    df["patient_id"] = df["patient_id"].astype(str)

# 2) Group & Impute
grouped_eeg = group_features(eeg_df, eeg_features)
grouped_eye = group_features(eye_df, eye_features)
grouped_beh = group_features(beh_df, beh_features)

expected_ids = set(label_df["patient_id"])
grouped_eeg = impute_missing_rows(grouped_eeg, eeg_features, expected_ids - set(grouped_eeg["patient_id"]))
grouped_eye = impute_missing_rows(grouped_eye, eye_features, expected_ids - set(grouped_eye["patient_id"]))
grouped_beh = impute_missing_rows(grouped_beh, beh_features, expected_ids - set(grouped_beh["patient_id"]))

# Merge label
eeg_merged = grouped_eeg.merge(label_df, on="patient_id")
eye_merged = grouped_eye.merge(label_df, on="patient_id")
beh_merged = grouped_beh.merge(label_df, on="patient_id")

# 3) Filter to common patients
common_ids = set(eeg_merged["patient_id"]) & set(eye_merged["patient_id"]) & set(beh_merged["patient_id"])
eeg_final = eeg_merged[eeg_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
eye_final = eye_merged[eye_merged["patient_id"].isin(common_ids)].reset_index(drop=True)
beh_final = beh_merged[beh_merged["patient_id"].isin(common_ids)].reset_index(drop=True)

# 4) Merge into one DataFrame
merged_df = pd.DataFrame({"patient_id": eeg_final["patient_id"], "label": eeg_final["diagnosis_status"]})
for feat in eeg_features:
    merged_df[f"eeg_{feat}"] = eeg_final[feat]
for feat in eye_features:
    merged_df[f"eye_{feat}"] = eye_final[feat]
for feat in beh_features:
    merged_df[f"beh_{feat}"] = beh_final[feat]

# 5) Balance classes to 63 each
class_0 = merged_df[merged_df["label"] == 0]
class_1 = merged_df[merged_df["label"] == 1]
balanced_df = pd.concat([
    class_0.sample(n=63, replace=True, random_state=42),
    class_1.sample(n=63, replace=True, random_state=42)
]).sample(frac=1, random_state=42).reset_index(drop=True)

print("Balanced shape:", balanced_df.shape)
print("Class distribution:", balanced_df["label"].value_counts())

# Force numeric float32 for all feature columns
all_feature_cols = []
for c in balanced_df.columns:
    if c.startswith("eeg_") or c.startswith("eye_") or c.startswith("beh_"):
        all_feature_cols.append(c)

for col in all_feature_cols:
    balanced_df[col] = pd.to_numeric(balanced_df[col], errors='coerce').fillna(0).astype(np.float32)

print("Check dtypes:\n", balanced_df[all_feature_cols].dtypes)


###############################################
# Part 2: GATFusion Model
###############################################
import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=False)
        self.att_src = nn.Parameter(torch.zeros((1, 1, out_dim)))
        self.att_dst = nn.Parameter(torch.zeros((1, 1, out_dim)))
        nn.init.xavier_uniform_(self.lin.weight, gain=1.414)
        nn.init.xavier_uniform_(self.att_src, gain=1.414)
        nn.init.xavier_uniform_(self.att_dst, gain=1.414)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        B, N, _ = x.shape
        h = self.lin(x)
        src = (h * self.att_src).sum(dim=-1).unsqueeze(2)
        dst = (h * self.att_dst).sum(dim=-1).unsqueeze(1)
        e = self.leaky_relu(src + dst)  # (B, N, N)

        alpha = F.softmax(e, dim=-1)
        out = torch.bmm(alpha, h)  # (B, N, out_dim)
        return F.elu(out)

class GATFusion(nn.Module):
    def __init__(self, eeg_dim, eye_dim, beh_dim, hidden_dim=64, out_dim=64):
        super().__init__()
        self.eeg_fc = nn.Linear(eeg_dim, hidden_dim)
        self.eye_fc = nn.Linear(eye_dim, hidden_dim)
        self.beh_fc = nn.Linear(beh_dim, hidden_dim)
        self.gat = GATLayer(hidden_dim, out_dim)
        self.classifier = nn.Sequential(
            nn.Linear(out_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, eeg, eye, beh):
        eeg_embed = F.relu(self.eeg_fc(eeg))
        eye_embed = F.relu(self.eye_fc(eye))
        beh_embed = F.relu(self.beh_fc(beh))
        stack = torch.stack([eeg_embed, eye_embed, beh_embed], dim=1)
        gat_out = self.gat(stack)
        fused = gat_out.mean(dim=1)
        return self.classifier(fused)











  df = pd.read_csv(f)


Balanced shape: (126, 13)
Class distribution: label
1    63
0    63
Name: count, dtype: int64
Check dtypes:
 eeg_Mobility            float32
eeg_Complexity          float32
eeg_Spectral_Entropy    float32
eye_mean_pupil_size     float32
eye_std_pupil_size      float32
eye_mean_latency        float32
eye_std_latency         float32
eye_mean_gaze_vector    float32
eye_std_gaze_vector     float32
beh_mean_rt             float32
beh_accuracy            float32
dtype: object
Train/Val/Test sizes: 88, 18, 20


KeyError: 84

In [None]:
###############################################
# Part 3: Dataset + DataLoader  (final patch)
###############################################
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import defaultdict


class MultimodalGATDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        # Keep original index; we'll access rows by position (iloc)
        self.df = df.copy()

        # Column groups for each modality
        self.eeg_cols = [c for c in self.df.columns if c.startswith("eeg_")]
        self.eye_cols = [c for c in self.df.columns if c.startswith("eye_")]
        self.beh_cols = [c for c in self.df.columns if c.startswith("beh_")]

        # Ensure every feature is numeric float32
        for col in self.eeg_cols + self.eye_cols + self.beh_cols:
            self.df[col] = (
                pd.to_numeric(self.df[col], errors="coerce")
                .fillna(0)
                .astype(np.float32)
            )

        self.df["label"] = self.df["label"].astype(int)
        if "patient_id" not in self.df.columns:
            self.df["patient_id"] = ["unknown"] * len(self.df)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]

        eeg   = torch.from_numpy(row[self.eeg_cols].values.astype(np.float32))
        eye   = torch.from_numpy(row[self.eye_cols].values.astype(np.float32))
        beh   = torch.from_numpy(row[self.beh_cols].values.astype(np.float32))
        label = torch.tensor(row["label"], dtype=torch.long)
        pid   = row["patient_id"]

        return eeg, eye, beh, label, pid


def standardise_split(
    train_df: pd.DataFrame,
    *other_dfs: pd.DataFrame,
    feature_cols: list[str],
):
    """
    Z‑score every feature using statistics computed **only** on train_df.
    Returns the transformed DataFrames and a dict of means/stds.
    """
    stats = defaultdict(dict)
    for col in feature_cols:
        mu  = train_df[col].mean()
        sig = train_df[col].std() if train_df[col].std() > 0 else 1.0  # avoid /0
        stats[col]["mean"] = mu
        stats[col]["std"]  = sig

        train_df[col] = (train_df[col] - mu) / sig
        for df in other_dfs:
            df[col] = (df[col] - mu) / sig

    return (train_df, *other_dfs), stats


def create_gat_loaders(
    df: pd.DataFrame,
    train_ratio: float = 0.7,
    val_ratio: float   = 0.15,
    batch_size: int    = 16,
):
    # Stratified train/val/test split
    train_df, temp_df = train_test_split(
        df,
        test_size=1 - train_ratio,
        stratify=df["label"],
        random_state=42,
    )

    val_size = val_ratio / (1 - train_ratio)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=1 - val_size,
        stratify=temp_df["label"],
        random_state=42,
    )

    # ── Feature standardisation (train stats only) ──
    feature_cols = [c for c in df.columns if c.startswith(("eeg_", "eye_", "beh_"))]
    (train_df, val_df, test_df), _ = standardise_split(
        train_df, val_df, test_df, feature_cols=feature_cols
    )

    # Build datasets & loaders
    train_ds = MultimodalGATDataset(train_df)
    val_ds   = MultimodalGATDataset(val_df)
    test_ds  = MultimodalGATDataset(test_df)

    print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size),
        DataLoader(test_ds,  batch_size=batch_size),
    )


In [None]:
###############################################
# Part 4: Training + Evaluation  (quick‑fix v2)
###############################################
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt


class FocalLoss(nn.Module):
    """
    Binary / multi‑class focal loss.
    γ: focusing parameter (γ = 2 is common)
    α: weighting tensor for classes, shape = [C]
    """
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha  # torch tensor or None
        self.reduction = reduction
        self.ce = nn.CrossEntropyLoss(weight=alpha, reduction="none")

    def forward(self, logits, targets):
        ce_loss = self.ce(logits, targets)
        pt = torch.exp(-ce_loss)            # prob of the true class
        focal = (1 - pt) ** self.gamma * ce_loss
        if self.reduction == "mean":
            return focal.mean()
        elif self.reduction == "sum":
            return focal.sum()
        return focal


def train_gat_model(
        model,
        train_loader,
        val_loader,
        device,
        *,
        epochs: int = 50,
        lr: float = 5e-4,
        patience: int = 5,
        use_focal: bool = False,        # ← toggle focal loss
        class0_weight: float = 1.1,     # ← gentler bias toward class 0
):
    # ── 1. Build loss function ───────────────────────────────────────────────
    alpha = torch.tensor([class0_weight, 1.5], dtype=torch.float32).to(device)
    criterion = (
        FocalLoss(alpha=alpha, gamma=1.0)
        if use_focal
        else nn.CrossEntropyLoss(weight=alpha)
    )

    # ── 2. Optimiser ─────────────────────────────────────────────────────────
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val = 0.0
    patience_cnt = 0

    for epoch in range(1, epochs + 1):
        # ── train one epoch ──────────────────────────────────────────────────
        model.train()
        run_loss = 0.0
        correct = 0
        total = 0

        for eeg, eye, beh, labels, _ in train_loader:
            eeg, eye, beh, labels = (
                eeg.to(device), eye.to(device), beh.to(device), labels.to(device)
            )

            optimizer.zero_grad()
            out = model(eeg, eye, beh)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()

            run_loss += loss.item() * labels.size(0)
            preds = out.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = run_loss / total
        train_acc = correct / total
        val_acc = evaluate_gat_model(model, val_loader, device, return_acc=True)

        print(f"[Ep {epoch:2d}/{epochs}] "
              f"Loss {train_loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f}")

        # ── early stopping ───────────────────────────────────────────────────
        if val_acc > best_val:
            best_val = val_acc
            patience_cnt = 0
            os.makedirs("results", exist_ok=True)
            torch.save(model.state_dict(), "results/gat_best_model.pt")
            print("  Best model updated!")
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("  Early stopping triggered.")
                break

    print(f"\nBest Validation Accuracy: {best_val:.4f}")


In [None]:

###############################################
# USAGE:
###############################################
if __name__ == "__main__":
    import torch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Now that balanced_df is produced, create loaders
    train_loader, val_loader, test_loader = create_gat_loaders(balanced_df)

    # Initialize model
    model = GATFusion(
        eeg_dim=len(eeg_features),
        eye_dim=len(eye_features),
        beh_dim=len(beh_features),
        hidden_dim=64,
        out_dim=64
    ).to(device)

    # Train
    train_gat_model(model, train_loader, val_loader, device, epochs=50, lr=5e-5, patience=5, use_focal = True, class0_weight=1.0)

    # Evaluate best model on test set
    model.load_state_dict(torch.load("results/gat_best_model.pt", map_location=device))
    evaluate_gat_model(model, test_loader, device, return_acc=False)

Train/Val/Test sizes: 88, 18, 20
[Ep  1/50] Loss 0.4728 | Train Acc 0.5795 | Val Acc 0.4444
  Best model updated!
[Ep  2/50] Loss 0.4810 | Train Acc 0.5568 | Val Acc 0.5000
  Best model updated!
[Ep  3/50] Loss 0.4799 | Train Acc 0.6136 | Val Acc 0.5000
[Ep  4/50] Loss 0.4769 | Train Acc 0.6023 | Val Acc 0.5000
[Ep  5/50] Loss 0.4762 | Train Acc 0.6023 | Val Acc 0.5000
[Ep  6/50] Loss 0.4761 | Train Acc 0.5455 | Val Acc 0.5000
[Ep  7/50] Loss 0.4569 | Train Acc 0.6364 | Val Acc 0.5000
  Early stopping triggered.

Best Validation Accuracy: 0.5000

--- Evaluation Metrics ---
              precision    recall  f1-score   support

           0       1.00      0.10      0.18        10
           1       0.53      1.00      0.69        10

    accuracy                           0.55        20
   macro avg       0.76      0.55      0.44        20
weighted avg       0.76      0.55      0.44        20



  model.load_state_dict(torch.load("results/gat_best_model.pt", map_location=device))


In [None]:
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
import numpy as np
import torch
import os

# ← make sure train_gat_model, evaluate_gat_model,
#    MultimodalGATDataset and GATFusion are imported/defined above

def cross_validate_gat(
    df: pd.DataFrame,
    n_splits: int = 5,
    epochs: int = 50,
    lr: float = 5e-4,
    patience: int = 5,
    use_focal: bool = False,
    class0_weight: float = 1.1,
    batch_size: int = 16,
):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_accs = []
    fold_cms  = []

    X = df.index.values
    y = df["label"].values

    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y), 1):
        print(f"\n=== Fold {fold}/{n_splits} ===")
        # Build train+val vs test
        df_train_full = df.iloc[train_idx].reset_index(drop=True)
        df_test       = df.iloc[test_idx].reset_index(drop=True)

        # Inner split: train vs val for early stopping
        train_df, val_df = train_test_split(
            df_train_full,
            test_size=0.15,
            stratify=df_train_full["label"],
            random_state=42,
        )

        # DataLoaders
        train_loader = DataLoader(
            MultimodalGATDataset(train_df),
            batch_size=batch_size,
            shuffle=True,
        )
        val_loader = DataLoader(
            MultimodalGATDataset(val_df),
            batch_size=batch_size,
        )
        test_loader = DataLoader(
            MultimodalGATDataset(df_test),
            batch_size=batch_size,
        )

        # Init a fresh model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = GATFusion(
            eeg_dim=len(eeg_features),
            eye_dim=len(eye_features),
            beh_dim=len(beh_features),
            hidden_dim=64,
            out_dim=64
        ).to(device)

        # Train + early‑stop on val
        train_gat_model(
            model,
            train_loader,
            val_loader,
            device,
            epochs=epochs,
            lr=lr,
            patience=patience,
            use_focal=use_focal,
            class0_weight=class0_weight,
        )

        # Load best weights & evaluate on this fold’s test
        chk = "results/gat_best_model.pt"
        model.load_state_dict(torch.load(chk, map_location=device))

        # Collect preds / labels
        all_preds, all_labels = [], []
        model.eval()
        with torch.no_grad():
            for eeg, eye, beh, labels, _ in test_loader:
                eeg, eye, beh = eeg.to(device), eye.to(device), beh.to(device)
                out = model(eeg, eye, beh)
                all_preds.extend(out.argmax(1).cpu().numpy())
                all_labels.extend(labels.numpy())

        acc = np.mean(np.array(all_preds) == np.array(all_labels))
        cm  = confusion_matrix(all_labels, all_preds, labels=[0,1])

        print(f"Fold {fold}  —  Test Acc: {acc:.3f}")
        print(" Confusion Matrix:\n", cm)

        fold_accs.append(acc)
        fold_cms.append(cm)

    # Summary
    mean_acc = np.mean(fold_accs)
    std_acc  = np.std(fold_accs)
    print(f"\n=== CV Results ({n_splits}-fold) ===")
    print(f"Accuracy: {mean_acc:.3f} ± {std_acc:.3f}")

    # aggregated confusion matrix (sum over folds)
    agg_cm = sum(fold_cms)
    print("Aggregated Confusion Matrix (summing folds):\n", agg_cm)


# ── USAGE ──
if __name__ == "__main__":
    cross_validate_gat(
        balanced_df,
        n_splits=5,
        epochs=50,
        lr=5e-4,
        patience=5,
        use_focal=False,      # or True
        class0_weight=1.0,    # tune as you like
        batch_size=16
    )



=== Fold 1/5 ===
[Ep  1/50] Loss 167.1951 | Train Acc 0.4471 | Val Acc 0.4000
  Best model updated!
[Ep  2/50] Loss 90.0487 | Train Acc 0.4353 | Val Acc 0.4667
  Best model updated!
[Ep  3/50] Loss 61.4607 | Train Acc 0.5882 | Val Acc 0.2667
[Ep  4/50] Loss 78.7587 | Train Acc 0.3882 | Val Acc 0.5333
  Best model updated!
[Ep  5/50] Loss 57.8758 | Train Acc 0.4706 | Val Acc 0.4667
[Ep  6/50] Loss 54.4799 | Train Acc 0.4353 | Val Acc 0.4667
[Ep  7/50] Loss 54.7765 | Train Acc 0.4824 | Val Acc 0.8000
  Best model updated!
[Ep  8/50] Loss 43.4903 | Train Acc 0.4706 | Val Acc 0.6000
[Ep  9/50] Loss 41.5417 | Train Acc 0.4706 | Val Acc 0.7333
[Ep 10/50] Loss 30.3839 | Train Acc 0.5294 | Val Acc 0.6000
[Ep 11/50] Loss 24.4924 | Train Acc 0.5529 | Val Acc 0.8000
[Ep 12/50] Loss 28.6295 | Train Acc 0.5647 | Val Acc 0.6000
  Early stopping triggered.

Best Validation Accuracy: 0.8000
Fold 1  —  Test Acc: 0.500
 Confusion Matrix:
 [[4 9]
 [4 9]]

=== Fold 2/5 ===


  model.load_state_dict(torch.load(chk, map_location=device))


[Ep  1/50] Loss 52.4649 | Train Acc 0.4471 | Val Acc 0.5000
  Best model updated!
[Ep  2/50] Loss 0.6817 | Train Acc 0.5059 | Val Acc 0.5000
[Ep  3/50] Loss 0.7035 | Train Acc 0.4706 | Val Acc 0.5000
[Ep  4/50] Loss 0.6741 | Train Acc 0.4941 | Val Acc 0.5000
[Ep  5/50] Loss 0.6745 | Train Acc 0.5176 | Val Acc 0.5000
[Ep  6/50] Loss 0.6785 | Train Acc 0.4706 | Val Acc 0.5000
  Early stopping triggered.

Best Validation Accuracy: 0.5000
Fold 2  —  Test Acc: 0.520
 Confusion Matrix:
 [[ 0 12]
 [ 0 13]]

=== Fold 3/5 ===


  model.load_state_dict(torch.load(chk, map_location=device))


[Ep  1/50] Loss 0.6933 | Train Acc 0.4471 | Val Acc 0.5000
  Best model updated!
[Ep  2/50] Loss 0.6768 | Train Acc 0.5059 | Val Acc 0.5000
[Ep  3/50] Loss 0.6820 | Train Acc 0.4824 | Val Acc 0.5000
[Ep  4/50] Loss 0.6814 | Train Acc 0.4941 | Val Acc 0.5000
[Ep  5/50] Loss 0.6882 | Train Acc 0.4941 | Val Acc 0.5000
[Ep  6/50] Loss 0.6820 | Train Acc 0.4941 | Val Acc 0.5000
  Early stopping triggered.

Best Validation Accuracy: 0.5000
Fold 3  —  Test Acc: 0.520
 Confusion Matrix:
 [[ 0 12]
 [ 0 13]]

=== Fold 4/5 ===


  model.load_state_dict(torch.load(chk, map_location=device))


[Ep  1/50] Loss 1.2772 | Train Acc 0.5059 | Val Acc 0.5000
  Best model updated!
[Ep  2/50] Loss 0.6772 | Train Acc 0.5176 | Val Acc 0.5000
[Ep  3/50] Loss 0.6862 | Train Acc 0.5176 | Val Acc 0.5000
[Ep  4/50] Loss 0.6994 | Train Acc 0.4471 | Val Acc 0.5000
[Ep  5/50] Loss 0.6819 | Train Acc 0.5647 | Val Acc 0.5000
[Ep  6/50] Loss 0.6835 | Train Acc 0.5176 | Val Acc 0.5000
  Early stopping triggered.

Best Validation Accuracy: 0.5000
Fold 4  —  Test Acc: 0.480
 Confusion Matrix:
 [[ 0 13]
 [ 0 12]]

=== Fold 5/5 ===


  model.load_state_dict(torch.load(chk, map_location=device))


[Ep  1/50] Loss 0.6957 | Train Acc 0.4588 | Val Acc 0.5000
  Best model updated!
[Ep  2/50] Loss 0.6842 | Train Acc 0.5059 | Val Acc 0.5000
[Ep  3/50] Loss 0.6909 | Train Acc 0.5059 | Val Acc 0.5000
[Ep  4/50] Loss 0.6475 | Train Acc 0.5176 | Val Acc 0.5000
[Ep  5/50] Loss 0.6750 | Train Acc 0.5059 | Val Acc 0.5000
[Ep  6/50] Loss 0.6837 | Train Acc 0.5059 | Val Acc 0.5000
  Early stopping triggered.

Best Validation Accuracy: 0.5000
Fold 5  —  Test Acc: 0.480
 Confusion Matrix:
 [[ 0 13]
 [ 0 12]]

=== CV Results (5-fold) ===
Accuracy: 0.500 ± 0.018
Aggregated Confusion Matrix (summing folds):
 [[ 4 59]
 [ 4 59]]


  model.load_state_dict(torch.load(chk, map_location=device))


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
import numpy as np

# 1. Pull out features & labels from your balanced_df
feature_cols = [c for c in balanced_df.columns if c.startswith(("eeg_", "eye_", "beh_"))]
X = balanced_df[feature_cols].values
y = balanced_df["label"].values

# 2. Stratified 5‑fold splitter
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# 3. Logistic Regression baseline
lr = LogisticRegression(solver="liblinear")
lr_scores = cross_val_score(lr, X, y, cv=skf, scoring="accuracy")
print(f"LogReg 5‑fold Acc: {lr_scores.mean():.3f} ± {lr_scores.std():.3f}")

# 4. Random Forest baseline
rf = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
rf_scores = cross_val_score(rf, X, y, cv=skf, scoring="accuracy")
print(f"RandomForest 5‑fold Acc: {rf_scores.mean():.3f} ± {rf_scores.std():.3f}")


LogReg 5‑fold Acc: 0.555 ± 0.120
RandomForest 5‑fold Acc: 0.746 ± 0.037


In [None]:
"""
The GATFusion model you built first independently projects each modality’s 13 summary features (EEG, eye–tracking, behavioral) into a shared 64‑dim space, treats those three embeddings as nodes in a tiny fully connected graph, applies a single graph‐attention layer, and then classifies the fused output via a two‐layer MLP. Despite careful balancing, focal‐loss weighting, and both single‐split and 5‑fold cross‑validation, the network never learned to separate the classes: average CV accuracy hovered at 50 % (± 1.8 %), with the aggregated confusion matrix showing virtually random predictions. In contrast, a simple Random Forest trained on the same features achieved approximately 75 % accuracy under 5‑fold CV, demonstrating that the summary statistics do contain discriminative signal but that the GAT architecture and small sample size are a poor fit. To move forward, it’s best either to pivot to and tune the tree‐ensemble baseline or to engineer richer modality‐specific features before revisiting deep fusion models.





"""

'\nThe GATFusion model you built first independently projects each modality’s 13 summary features (EEG, eye–tracking, behavioral) into a shared 64‑dim space, treats those three embeddings as nodes in a tiny fully connected graph, applies a single graph‐attention layer, and then classifies the fused output via a two‐layer MLP. Despite careful balancing, focal‐loss weighting, and both single‐split and 5‑fold cross‑validation, the network never learned to separate the classes: average CV accuracy hovered at 50\xa0% (±\xa01.8\xa0%), with the aggregated confusion matrix showing virtually random predictions. In contrast, a simple Random Forest trained on the same features achieved approximately 75\xa0% accuracy under 5‑fold CV, demonstrating that the summary statistics do contain discriminative signal but that the GAT architecture and small sample size are a poor fit. To move forward, it’s best either to pivot to and tune the tree‐ensemble baseline or to engineer richer modality‐specific fea