In [None]:
!pip install scikit-learn==1.2.2 imbalanced-learn==0.10.1

In [None]:
#!/usr/bin/env python3
"""
important
DTCA-Net: Dual-Transformer Cross Attention Network
Complete pipeline for AD/FTD detection from EEG signals
FIXED VERSION - Ready for full dataset with NaN handling
"""

import os
import re
import glob
import random
import math
from pathlib import Path
from collections import Counter, defaultdict
from typing import Tuple, List

import numpy as np
import matplotlib.pyplot as plt

import mne
import pywt
from scipy.signal import hilbert

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import (
    f1_score, accuracy_score, precision_score, 
    recall_score, roc_auc_score, roc_curve
)
from sklearn.utils import check_random_state
from imblearn.over_sampling import SMOTE

import numba as nb

# ═══════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════

# Paths
DATA_DIR = "/kaggle/input/eye-open-eeg-alzheimers/eye-open-dataset"
FEATURES_DIR = "./features"
RESULTS_DIR = "./results"

# Create directories
os.makedirs(FEATURES_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# DWT Configuration
MAX_LVL = 8
WAVELET = 'db4'
band2levels = {
    'delta': [1, 2, 3],
    'theta': [4],
    'alpha': [5],
    'beta': [6],
    'gamma': [7]
}
band_list = list(band2levels.keys())

# Window Configuration - FIXED TO MATCH
MINUTE_LEN = 60
SFREQ = 256
MINUTE_SAMPLES = int(MINUTE_LEN * SFREQ)  # 15360 samples per minute
N_SUBWINS_PER_MINUTE = 11  # Fixed number of sub-windows per minute

# Model Configuration
SELECTED_CHANNELS = ['O1', 'O2', 'T4', 'T5', 'F7', 'F8']
BATCH_SIZE = 32
N_SPLITS = 10
N_REPETITIONS = 10
NUM_EPOCHS = 100
LEARNING_RATE = 0.0001

# ═══════════════════════════════════════════════════════════════════════════
# UTILITY FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════

def set_seed(seed: int):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    _ = check_random_state(seed)
    print(f"Seed set to: {seed}")

def get_subject_id(filepath: str) -> int:
    """Extract subject ID from filepath."""
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return int(part.replace('sub-', '').strip())
    return None

# ═══════════════════════════════════════════════════════════════════════════
# FEATURE EXTRACTION: PTE
# ═══════════════════════════════════════════════════════════════════════════

@nb.njit(fastmath=True, cache=True)
def _entropy(counts, length):
    """Calculate entropy."""
    H = 0.0
    for c in counts:
        if c > 0:
            p = c / length
            H -= p * np.log2(p)
    return H

@nb.njit(fastmath=True, cache=True)
def compute_PTE_numba(phase, delay):
    """Compute Phase Transfer Entropy using Numba JIT."""
    m, n = phase.shape
    raw = np.zeros((m, m), np.float64)
    L = n - delay
    
    for i in range(m):
        x = phase[i, :L]
        for j in range(m):
            y = phase[j, :L]
            ypr = phase[j, delay:]
            vmax = int(max(x.max(), y.max(), ypr.max()) + 1)
            
            cnt_y = np.bincount(y, minlength=vmax)
            idx_ypr_y = ypr + vmax * y
            cnt_ypr_y = np.bincount(idx_ypr_y, minlength=vmax * vmax)
            idx_y_x = y + vmax * x
            cnt_y_x = np.bincount(idx_y_x, minlength=vmax * vmax)
            idx_3d = ypr + vmax * (y + vmax * x)
            cnt_3d = np.bincount(idx_3d, minlength=vmax * vmax * vmax)
            
            Hy = _entropy(cnt_y, L)
            Hypr = _entropy(cnt_ypr_y, L)
            Hyx = _entropy(cnt_y_x, L)
            Hyprx = _entropy(cnt_3d, L)
            
            raw[i, j] = Hypr + Hyx - Hy - Hyprx
    
    return raw

@nb.njit(fastmath=True, cache=True)
def dPTE_from_raw(raw):
    """Compute directed PTE from raw PTE."""
    sym = raw + raw.T
    # Add small epsilon to avoid division by zero
    eps = 1e-10
    result = np.zeros_like(raw)
    for i in range(raw.shape[0]):
        for j in range(raw.shape[1]):
            if sym[i, j] > eps:
                result[i, j] = raw[i, j] / sym[i, j]
            else:
                result[i, j] = 0.0
    return np.triu(result, 1) + np.tril(result.T, -1)

def reconstruct_band_dwt(data: np.ndarray, levels: List[int]) -> np.ndarray:
    """Reconstruct signal from specific DWT levels with NaN handling."""
    try:
        coeffs = pywt.wavedec(data, WAVELET, axis=1, level=MAX_LVL)
        kept = [np.zeros_like(c) for c in coeffs]
        for lv in levels:
            kept[lv] = coeffs[lv]
        reconstructed = pywt.waverec(kept, WAVELET, axis=1)
        
        # Handle any NaN or Inf values
        reconstructed = np.nan_to_num(reconstructed, nan=0.0, posinf=0.0, neginf=0.0)
        
        return reconstructed
    except Exception as e:
        print(f"Warning: DWT reconstruction failed: {e}. Returning zeros.")
        return np.zeros_like(data)

def get_delay(phase: np.ndarray) -> int:
    """Estimate optimal delay for PTE."""
    m, n = phase.shape
    c1 = m * n
    c2 = (phase * np.roll(phase, 1, axis=1) < 0).sum()
    if c2 == 0:
        return 1
    return max(1, int(round(c1 / c2)))

def get_binsize(phase: np.ndarray, c: float = 3.49) -> float:
    """Calculate bin size for phase discretization."""
    m, n = phase.shape
    std_vals = np.std(phase, axis=1, ddof=1)
    mean_std = np.mean(std_vals)
    if mean_std == 0 or np.isnan(mean_std):
        return 0.1  # Default small bin size
    binsz = c * mean_std * n ** (-1 / 3)
    return max(binsz, 0.01)  # Ensure minimum bin size

def discretize_phase(phase: np.ndarray, binsz: float) -> np.ndarray:
    """Discretize phase values."""
    return np.ceil(phase / binsz).astype(np.int32)

def process_pte_subject(filepath: str, label: str):
    """Process one subject for PTE feature extraction."""
    print(f"Processing PTE: {filepath}")
    raw = mne.io.read_raw_eeglab(filepath, preload=True, verbose='ERROR')
    raw.resample(SFREQ)
    
    data_full = raw.get_data()
    n_ch = data_full.shape[0]
    total_samples = data_full.shape[1]
    
    n_minutes = total_samples // MINUTE_SAMPLES
    n_bands = len(band_list)
    subwin_samples = MINUTE_SAMPLES // N_SUBWINS_PER_MINUTE
    
    # Shape: (n_minutes, n_subwins, n_bands, n_ch, n_ch)
    dp_subject = np.zeros((n_minutes, N_SUBWINS_PER_MINUTE, n_bands, n_ch, n_ch), dtype=np.float64)
    
    for mi in range(n_minutes):
        seg60 = data_full[:, mi * MINUTE_SAMPLES:(mi + 1) * MINUTE_SAMPLES]
        
        for bi, band in enumerate(band_list):
            levels = band2levels[band]
            band_data = reconstruct_band_dwt(seg60, levels)
            phase = np.angle(hilbert(band_data, axis=1))
            
            # Handle NaN in phase
            phase = np.nan_to_num(phase, nan=0.0, posinf=0.0, neginf=0.0)
            
            delay = get_delay(phase)
            binsz = get_binsize(phase)
            dph = discretize_phase(phase + np.pi, binsz)
            
            for wi in range(N_SUBWINS_PER_MINUTE):
                start = wi * subwin_samples
                end = start + subwin_samples
                blk = dph[:, start:end]
                rawP = compute_PTE_numba(blk, delay)
                dp = dPTE_from_raw(rawP)
                
                # Ensure no NaN or Inf
                dp = np.nan_to_num(dp, nan=0.0, posinf=0.0, neginf=0.0)
                
                dp_subject[mi, wi, bi, :, :] = dp
    
    subj_id = get_subject_id(filepath)
    return subj_id, dp_subject, label

# ═══════════════════════════════════════════════════════════════════════════
# FEATURE EXTRACTION: DIFFERENTIAL ENTROPY
# ═══════════════════════════════════════════════════════════════════════════

def compute_DE(signal: np.ndarray) -> float:
    """Compute differential entropy with robust NaN handling."""
    # Remove any NaN or Inf values
    signal = signal[np.isfinite(signal)]
    
    if len(signal) < 2:
        return 0.0
    
    var = np.var(signal, ddof=1)
    
    # Handle zero or negative variance
    if var <= 1e-10 or not np.isfinite(var):
        return 0.0
    
    de = 0.5 * math.log(2 * math.pi * math.e * var)
    
    # Ensure result is finite
    if not np.isfinite(de):
        return 0.0
    
    return de

def process_de_subject(filepath: str, label: str):
    """Process one subject for DE feature extraction - FIXED TO MATCH PTE."""
    print(f"Processing DE: {filepath}")
    
    raw = mne.io.read_raw_eeglab(filepath, preload=True, verbose='ERROR')
    raw.resample(SFREQ)
    
    data = raw.get_data() * 1e6
    
    # Clean data - remove NaN/Inf
    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    
    n_ch = data.shape[0]
    n_samp = data.shape[1]
    
    n_minutes = n_samp // MINUTE_SAMPLES
    subwin_samples = MINUTE_SAMPLES // N_SUBWINS_PER_MINUTE
    
    # Shape: (n_minutes * n_subwins, n_ch, n_bands)
    total_windows = n_minutes * N_SUBWINS_PER_MINUTE
    DE_values = np.zeros((total_windows, n_ch, len(band_list)), dtype=float)
    
    win_idx = 0
    for mi in range(n_minutes):
        seg60 = data[:, mi * MINUTE_SAMPLES:(mi + 1) * MINUTE_SAMPLES]
        
        # Extract band signals for the entire minute
        band_sigs = {}
        for band in band_list:
            band_sig = reconstruct_band_dwt(seg60, band2levels[band])
            # Ensure no NaN
            band_sig = np.nan_to_num(band_sig, nan=0.0, posinf=0.0, neginf=0.0)
            band_sigs[band] = band_sig
        
        # Divide into sub-windows
        for wi in range(N_SUBWINS_PER_MINUTE):
            start = wi * subwin_samples
            end = start + subwin_samples
            
            for bi, band in enumerate(band_list):
                sig_window = band_sigs[band][:, start:end]
                for ch in range(n_ch):
                    de_val = compute_DE(sig_window[ch])
                    DE_values[win_idx, ch, bi] = de_val
            
            win_idx += 1
    
    # Final check for NaN values
    DE_values = np.nan_to_num(DE_values, nan=0.0, posinf=0.0, neginf=0.0)
    
    subj_id = get_subject_id(filepath)
    return subj_id, DE_values, label

# ═══════════════════════════════════════════════════════════════════════════
# DATA LOADING
# ═══════════════════════════════════════════════════════════════════════════

def load_data_stratified_kfold(
    pte_directory: str,
    DE_directory: str,
    batch_size: int,
    selected_classes=("alz", "ctrl"),
    selected_channels=None,
    n_splits: int = 10,
    n_repetitions: int = 5,
):
    """Load and prepare data with stratified k-fold cross-validation."""
    ch_names = [
        "Fp1", "Fp2", "F3", "F4", "C3", "C4", "P3", "P4", "O1", "O2",
        "F7", "F8", "T3", "T4", "T5", "T6", "Fz", "Cz", "Pz",
    ]
    
    if selected_channels is None:
        selected_channels = ch_names
    
    sel_idx = [ch_names.index(ch) for ch in selected_channels]
    label_map = {c: i for i, c in enumerate(selected_classes)}
    
    def parse_info(fname):
        m = re.match(r"sub-(\d+)_.*_(\w+)\.npz", fname)
        if not m:
            return None
        sid, lbl = int(m.group(1)), m.group(2).lower()
        if lbl not in selected_classes:
            return None
        return sid, lbl
    
    def collect_files(directory, file_type='PTE'):
        """Collect files of specific type (PTE or DE)."""
        all_files = sorted(
            [f for f in os.listdir(directory) if f.endswith(".npz") and f"_{file_type}_" in f],
            key=lambda f: int(re.search(r"sub-(\d+)_", f).group(1)),
        )
        info = [parse_info(f) + (f,) for f in all_files if parse_info(f) is not None]
        
        # Drop first 5 subjects from each class
        drop_ids = {}
        for cls in selected_classes:
            ids = sorted({sid for sid, lbl, _ in info if lbl == cls})
            drop_ids[cls] = set(ids[:5])
        
        return [
            fname
            for sid, lbl, fname in info
            if sid not in drop_ids[lbl]
        ]
    
    pte_files = collect_files(pte_directory, file_type='PTE')
    psd_files = collect_files(DE_directory, file_type='DE')
    
    pte_list, psd_list, labels_list, pid_list = [], [], [], []
    
    for fname in pte_files:
        sid, lbl = parse_info(fname)
        lbl_int = label_map[lbl]
        arr = np.load(Path(pte_directory) / fname, allow_pickle=True)
        
        pte = arr["pte_data"]
        # Clean PTE data
        pte = np.nan_to_num(pte, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Reshape from (n_minutes, 11, 5, 19, 19) to (n_minutes*11, 5, 19, 19)
        n_minutes = pte.shape[0]
        pte = pte.reshape(n_minutes * N_SUBWINS_PER_MINUTE, *pte.shape[2:])
        # Select channels
        pte = pte[:, :, sel_idx, :][:, :, :, sel_idx]
        
        N = pte.shape[0]
        pte_list.append(pte)
        labels_list.append(np.full(N, lbl_int, dtype=int))
        pid_list.extend([sid] * N)
    
    for fname in psd_files:
        sid, _ = parse_info(fname)
        arr = np.load(Path(DE_directory) / fname, allow_pickle=True)
        
        psd = arr["DE_features"]
        # Clean DE data
        psd = np.nan_to_num(psd, nan=0.0, posinf=0.0, neginf=0.0)
        
        psd = psd[:, sel_idx, :]
        
        psd_list.append(psd)
    
    X_pte = np.concatenate(pte_list, axis=0)
    X_psd = np.concatenate(psd_list, axis=0)
    y = np.concatenate(labels_list, axis=0)
    pid = np.asarray(pid_list, dtype=int)
    
    # Final safety check
    X_pte = np.nan_to_num(X_pte, nan=0.0, posinf=0.0, neginf=0.0)
    X_psd = np.nan_to_num(X_psd, nan=0.0, posinf=0.0, neginf=0.0)
    
    print(f"Data shapes - PTE: {X_pte.shape}, DE: {X_psd.shape}, Labels: {y.shape}")
    print(f"NaN check - PTE: {np.isnan(X_pte).sum()}, DE: {np.isnan(X_psd).sum()}")
    
    assert X_pte.shape[0] == X_psd.shape[0] == y.shape[0] == pid.shape[0], \
        f"Shape mismatch! PTE: {X_pte.shape[0]}, DE: {X_psd.shape[0]}, Labels: {y.shape[0]}, PID: {pid.shape[0]}"
    
    unique_pids = np.unique(pid)
    subj_labels = np.array(
        [Counter(y[pid == sid]).most_common(1)[0][0] for sid in unique_pids]
    )
    
    all_reps = []
    for rep in range(n_repetitions):
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=rep)
        rep_folds = []
        
        for subj_tr_idx, subj_va_idx in skf.split(unique_pids, subj_labels):
            train_pids = unique_pids[subj_tr_idx]
            val_pids = unique_pids[subj_va_idx]
            
            tr_mask = np.isin(pid, train_pids)
            va_mask = np.isin(pid, val_pids)
            
            Xp_tr, Xp_va = X_pte[tr_mask], X_pte[va_mask]
            Xs_tr, Xs_va = X_psd[tr_mask], X_psd[va_mask]
            y_tr, y_va = y[tr_mask], y[va_mask]
            pid_tr, pid_va = pid[tr_mask], pid[va_mask]
            
            flat_pte_tr = Xp_tr.reshape(len(y_tr), -1)
            flat_psd_tr = Xs_tr.reshape(len(y_tr), -1)
            X_train_flat = np.hstack([flat_pte_tr, flat_psd_tr])
            
            # Additional NaN check before SMOTE
            X_train_flat = np.nan_to_num(X_train_flat, nan=0.0, posinf=0.0, neginf=0.0)
            
            sm = SMOTE(random_state=rep)
            X_bal, y_bal = sm.fit_resample(X_train_flat, y_tr)
            
            if hasattr(sm, "sample_indices_"):
                res_idx = sm.sample_indices_
            elif hasattr(sm, "_sample_indices"):
                res_idx = sm._sample_indices
            else:
                idx = np.arange(len(y_tr)).reshape(-1, 1)
                idx_bal, _ = SMOTE(random_state=rep).fit_resample(idx, y_tr)
                res_idx = idx_bal.ravel()
            
            pid_bal = pid_tr[res_idx]
            
            split_at = flat_pte_tr.shape[1]
            flat_pte_bal = X_bal[:, :split_at]
            flat_psd_bal = X_bal[:, split_at:]
            
            scaler_pte = MinMaxScaler()
            scaler_psd = MinMaxScaler()
            
            flat_pte_bal = scaler_pte.fit_transform(flat_pte_bal)
            flat_pte_val = scaler_pte.transform(Xp_va.reshape(len(y_va), -1))
            
            flat_psd_bal = scaler_psd.fit_transform(flat_psd_bal)
            flat_psd_val = scaler_psd.transform(Xs_va.reshape(len(y_va), -1))
            
            Xp_tr_bal = flat_pte_bal.reshape(-1, *Xp_tr.shape[1:])
            Xs_tr_bal = flat_psd_bal.reshape(-1, *Xs_tr.shape[1:])
            Xp_va = flat_pte_val.reshape(Xp_va.shape)
            Xs_va = flat_psd_val.reshape(Xs_va.shape)
            
            def make_loader(x1, x2, y_, p_, shuffle):
                t1 = torch.from_numpy(x1).float()
                t2 = torch.from_numpy(x2).float()
                ty = torch.from_numpy(y_).long()
                tp = torch.from_numpy(p_).long()
                ds = TensorDataset(t1, t2, ty, tp)
                return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=False)
            
            train_loader = make_loader(Xp_tr_bal, Xs_tr_bal, y_bal, pid_bal, shuffle=True)
            val_loader = make_loader(Xp_va, Xs_va, y_va, pid_va, shuffle=False)
            
            rep_folds.append((train_loader, val_loader))
        
        all_reps.append(rep_folds)
    
    return all_reps

# ═══════════════════════════════════════════════════════════════════════════
# MODEL ARCHITECTURE
# ═══════════════════════════════════════════════════════════════════════════

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadCrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=num_heads, 
            dropout=dropout, 
            batch_first=True
        )
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        attn_output, attn_weights = self.multihead_attn(
            query, key, value, 
            attn_mask=attn_mask, 
            key_padding_mask=key_padding_mask
        )
        attn_output = self.dropout(attn_output)
        output = self.layer_norm(query + attn_output)
        return output, attn_weights

class PteTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, output_dim, dropout):
        super(PteTransformer, self).__init__()
        # PTE input: (batch, 5_bands, 6_channels, 6_channels) = (batch, 5, 6, 6)
        # Flatten to: (batch, 5, 36) - treat bands as sequence
        self.flatten_spatial = nn.Flatten(start_dim=2)  # Flatten spatial dimensions
        spatial_dim = 36  # 6 * 6
        
        self.position_encoding = nn.Parameter(torch.randn(1, 5, spatial_dim), requires_grad=True)
        
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=spatial_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True,
            activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(spatial_dim, output_dim)
    
    def forward(self, x):
        # x: (batch, 5, 6, 6)
        b = x.shape[0]
        x = self.flatten_spatial(x)  # (batch, 5, 36)
        x = self.position_encoding + x
        x = self.transformer(x)  # (batch, 5, 36)
        x = self.output_layer(x)  # (batch, 5, 128)
        return x

class PsdTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, output_dim, dropout):
        super(PsdTransformer, self).__init__()
        # PSD/DE input: (batch, 6_channels, 5_bands)
        # Transpose to: (batch, 6, 5) - treat channels as sequence
        
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,  # 5 bands
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True,
            activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        # x: (batch, 6_channels, 5_bands)
        x = self.transformer(x)  # (batch, 6, 5)
        x = self.output_layer(x)  # (batch, 6, 128)
        return x

class FinalModel(nn.Module):
    def __init__(self, 
                 pte_input_dim, pte_hidden_dim, pte_num_layers, pte_num_heads, pte_output_dim, pte_dropout,
                 psd_input_dim, psd_hidden_dim, psd_num_layers, psd_num_heads, psd_output_dim, psd_dropout,
                 cross_d_model, cross_num_heads):
        super(FinalModel, self).__init__()
        
        self.pte_transformer = PteTransformer(
            input_dim=pte_input_dim,
            hidden_dim=pte_hidden_dim,
            num_layers=pte_num_layers,
            num_heads=pte_num_heads,
            output_dim=pte_output_dim,
            dropout=pte_dropout
        )
        
        self.psd_transformer = PsdTransformer(
            input_dim=psd_input_dim,
            hidden_dim=psd_hidden_dim,
            num_layers=psd_num_layers,
            num_heads=psd_num_heads,
            output_dim=psd_output_dim,
            dropout=psd_dropout
        )
        
        self.cross_attention = MultiHeadCrossAttention(
            d_model=cross_d_model,
            num_heads=cross_num_heads,
            dropout=0.1
        )
        
        # After cross attention: (batch, 5, 128)
        # Flatten for classification: (batch, 5*128) = (batch, 640)
        self.final_classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(5 * 128, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )
    
    def forward(self, pte_input, psd_input):
        # pte_input: (batch, 5, 6, 6)
        # psd_input: (batch, 6, 5)
        
        pte_encoded = self.pte_transformer(pte_input)  # (batch, 5, 128)
        psd_encoded = self.psd_transformer(psd_input)  # (batch, 6, 128)
        
        # Cross attention: query from PTE (5 band features), key/value from PSD (6 channel features)
        cross_attn_output, attn_weights = self.cross_attention(
            query=pte_encoded,  # (batch, 5, 128)
            key=psd_encoded,    # (batch, 6, 128)
            value=psd_encoded   # (batch, 6, 128)
        )
        # Output: (batch, 5, 128) - maintains query sequence length
        
        label_pred = self.final_classifier(cross_attn_output)  # (batch, 2)
        
        return label_pred, attn_weights

# ═══════════════════════════════════════════════════════════════════════════
# TRAINING AND EVALUATION
# ═══════════════════════════════════════════════════════════════════════════

def train_model(
    model,
    source_dataloader,
    target_dataloader,
    criterion_label,
    optimizer,
    num_epochs=10,
    device="cuda",
    scheduler=None,
):
    model.to(device)
    model.train()
    
    accuracy_history = []
    
    for epoch in range(num_epochs):
        total_correct = 0
        total_samples = 0
        epoch_loss = 0.0
        
        for batch_src in source_dataloader:
            if len(batch_src) == 4:
                source_pte, source_psd, source_labels, _ = batch_src
            else:
                source_pte, source_psd, source_labels = batch_src[:3]
            
            source_pte = source_pte.to(device)
            source_psd = source_psd.to(device)
            source_labels = source_labels.to(device)
            
            label_preds, _ = model(source_pte, source_psd)
            loss_label = criterion_label(label_preds, source_labels)
            
            optimizer.zero_grad()
            loss_label.backward()
            optimizer.step()
            
            epoch_loss += loss_label.item()
            
            _, predicted = torch.max(label_preds, dim=1)
            correct = (predicted == source_labels).sum().item()
            total_correct += correct
            total_samples += source_labels.size(0)
        
        if scheduler is not None:
            scheduler.step()
        
        epoch_accuracy = 100.0 * total_correct / total_samples if total_samples > 0 else 0
        accuracy_history.append(epoch_accuracy)
        
        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}/{num_epochs}: Loss={epoch_loss:.4f}, Acc={epoch_accuracy:.2f}%")
    
    return accuracy_history

def test_model(
    model,
    test_dataloader,
    criterion_label,
    device="cuda",
    num_classes=2,
    alz_threshold=0.4
):
    model.to(device).eval()
    total_loss = 0.0
    
    all_preds = []
    all_labels = []
    all_probs = []
    all_pids = []
    
    with torch.no_grad():
        for batch in test_dataloader:
            if len(batch) == 4:
                pte, psd, labels, pids = batch
            else:
                pte, psd, labels = batch
                pids = torch.zeros_like(labels)
            
            pte, psd, labels = pte.to(device), psd.to(device), labels.to(device)
            logits, _ = model(pte, psd)
            loss = criterion_label(logits, labels)
            total_loss += loss.item()
            
            probs = F.softmax(logits, dim=1)
            preds = probs.argmax(dim=1)
            
            all_probs.append(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_pids.extend(pids.cpu().numpy())
    
    n_batches = len(test_dataloader)
    avg_loss = total_loss / n_batches if n_batches else 0.0
    
    all_probs = np.vstack(all_probs)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_pids = np.array(all_pids)
    
    part_ids = np.unique(all_pids)
    part_accs = []
    part_preds = []
    part_confs = np.zeros((num_classes, num_classes), dtype=int)
    part_ratios = []
    part_trues = []
    
    for pid in part_ids:
        mask = (all_pids == pid)
        labs = all_labels[mask]
        preds = all_preds[mask]
        
        true_lbl = labs[0]
        alz_ratio = (preds == 1).sum() / max(len(preds), 1)
        pred_lbl = 1 if alz_ratio >= alz_threshold else 0
        
        part_confs[true_lbl, pred_lbl] += 1
        part_accs.append(100.0 if pred_lbl == true_lbl else 0.0)
        
        part_preds.append(pred_lbl)
        part_ratios.append(alz_ratio)
        part_trues.append(true_lbl)
    
    mean_acc = float(np.mean(part_accs)) if part_accs else 0.0
    mean_f1 = f1_score(part_trues, part_preds, average='macro', zero_division=0) if part_trues else 0.0
    
    return (
        avg_loss,
        mean_acc,
        mean_f1,
        part_confs,
        all_probs,
        all_labels,
        np.array(part_ratios),
        np.array(part_trues)
    )

def tune_threshold_on_source(
    model,
    source_dataloader,
    device="cuda",
    thresholds=[0.1, 0.2, 0.3, 0.4, 0.5],
    num_classes=2
):
    model.eval()
    model.to(device)
    sample_preds = defaultdict(list)
    participant_label = {}
    
    with torch.no_grad():
        for batch in source_dataloader:
            if len(batch) == 4:
                pte_batch, psd_batch, labels, pid_batch = batch
            else:
                raise ValueError("Expected Dataloader to return (pte, psd, labels, pid).")
            
            pte_batch = pte_batch.to(device)
            psd_batch = psd_batch.to(device)
            labels = labels.to(device)
            pid_batch = pid_batch.to(device)
            
            label_preds, _ = model(pte_batch, psd_batch)
            softmax_output = F.softmax(label_preds, dim=1)
            _, predicted = torch.max(softmax_output, dim=1)
            
            predicted = predicted.cpu().numpy()
            labels = labels.cpu().numpy()
            pid_batch = pid_batch.cpu().numpy()
            
            for pred, true_lbl, pid in zip(predicted, labels, pid_batch):
                sample_preds[pid].append(pred)
                if pid not in participant_label:
                    participant_label[pid] = true_lbl
    
    best_threshold = None
    best_metric_val = -1.0
    
    for thr in thresholds:
        part_level_preds = []
        part_level_trues = []
        
        for pid, preds_list in sample_preds.items():
            true_lbl = participant_label[pid]
            n_alz = sum([p == 1 for p in preds_list])
            ratio = float(n_alz) / len(preds_list)
            participant_pred = 1 if ratio >= thr else 0
            
            part_level_preds.append(participant_pred)
            part_level_trues.append(true_lbl)
        
        f1 = f1_score(part_level_trues, part_level_preds, average='macro', zero_division=0)
        acc = accuracy_score(part_level_trues, part_level_preds)
        
        print(f"    [Threshold {thr}] -> F1={f1:.4f} | Acc={acc:.4f}")
        
        if f1 > best_metric_val:
            best_metric_val = f1
            best_threshold = thr
    
    print(f"    [Best Threshold] = {best_threshold} with F1={best_metric_val:.4f}")
    return best_threshold

# ═══════════════════════════════════════════════════════════════════════════
# MAIN PIPELINE
# ═══════════════════════════════════════════════════════════════════════════

def extract_features():
    """Extract PTE and DE features from raw EEG data."""
    print("=" * 80)
    print("FEATURE EXTRACTION")
    print("=" * 80)
    
    # Get file paths
    all_paths = glob.glob(f"{DATA_DIR}/sub-*/eeg/*.set")
    print(f"Found {len(all_paths)} EEG files")
    
    groups = {'alz': [], 'ctrl': [], 'ftd': []}
    for fp in all_paths:
        sid = get_subject_id(fp)
        if sid is None:
            continue
        if sid <= 36:
            groups['alz'].append(fp)
        elif sid <= 65:
            groups['ctrl'].append(fp)
        else:
            groups['ftd'].append(fp)
    
    print(f"ALZ: {len(groups['alz'])}, CTRL: {len(groups['ctrl'])}, FTD: {len(groups['ftd'])}")
    
    # Extract PTE features
    print("\n--- Extracting PTE features ---")
    for grp, paths in groups.items():
        for fp in paths:
            subj_id, dp, label = process_pte_subject(fp, grp)
            out_f = os.path.join(FEATURES_DIR, f"sub-{subj_id}_PTE_{grp}.npz")
            np.savez(out_f, pte_data=dp, subject_id=subj_id, label=label)
            print(f"  Saved {out_f}, shape={dp.shape}")
    
    # Extract DE features
    print("\n--- Extracting DE features ---")
    for grp, paths in groups.items():
        for fp in paths:
            subj_id, de_vals, label = process_de_subject(fp, grp)
            out_f = os.path.join(FEATURES_DIR, f"sub-{subj_id}_DE_{grp}.npz")
            np.savez_compressed(out_f, DE_features=de_vals, subject_id=subj_id, label=label)
            print(f"  Saved {out_f}, shape={de_vals.shape}")

def run_experiment(task='cn_ad'):
    """Run the complete experiment."""
    print("\n" + "=" * 80)
    print(f"RUNNING EXPERIMENT: {task.upper()}")
    print("=" * 80)
    
    # Set seed
    set_seed(0)
    
    # Configure task
    if task == 'cn_ad':
        selected_classes = ["ctrl", "alz"]
        class_weights = torch.tensor([1.0, 0.7])
        use_weights = True
    elif task == 'cn_ftd':
        selected_classes = ["ctrl", "ftd"]
        class_weights = None
        use_weights = False
    else:
        raise ValueError(f"Unknown task: {task}")
    
    # Model hyperparameters
    pte_input_dim = 36  # Spatial dimension after flattening (6*6)
    pte_hidden_dim = 512
    pte_num_layers = 2
    pte_num_heads = 4  # Must divide 36
    pte_output_dim = 128
    pte_dropout = 0.4
    
    psd_input_dim = 5  # Number of bands
    psd_hidden_dim = 512
    psd_num_layers = 2
    psd_num_heads = 5  # Must divide 5
    psd_output_dim = 128
    psd_dropout = 0.4
    
    cross_d_model = 128
    cross_num_heads = 8  # Must divide 128
    
    # Load data
    print("\n--- Loading data ---")
    all_folds = load_data_stratified_kfold(
        pte_directory=FEATURES_DIR,
        DE_directory=FEATURES_DIR,  # Both in same directory now
        batch_size=BATCH_SIZE,
        selected_classes=selected_classes,
        selected_channels=SELECTED_CHANNELS,
        n_splits=N_SPLITS,
        n_repetitions=N_REPETITIONS,
    )
    
    # Results storage
    all_acc_final = []
    all_f1_final = []
    all_conf_final = []
    global_probs_final = []
    global_labels_final = []
    best_thresholds_final = []
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Run experiments
    for rep_idx, folds in enumerate(all_folds):
        print(f"\n{'=' * 80}")
        print(f"REPETITION {rep_idx + 1}/{len(all_folds)}")
        print(f"{'=' * 80}")
        
        all_acc = []
        all_f1 = []
        all_conf = []
        global_probs = []
        global_labels = []
        best_thresholds = []
        
        for fold_idx, (train_loader, val_loader) in enumerate(folds, 1):
            print(f"\n--- Fold {fold_idx}/{len(folds)} ---")
            
            # Initialize model
            model = FinalModel(
                pte_input_dim=pte_input_dim,
                pte_hidden_dim=pte_hidden_dim,
                pte_num_layers=pte_num_layers,
                pte_num_heads=pte_num_heads,
                pte_output_dim=pte_output_dim,
                pte_dropout=pte_dropout,
                psd_input_dim=psd_input_dim,
                psd_hidden_dim=psd_hidden_dim,
                psd_num_layers=psd_num_layers,
                psd_num_heads=psd_num_heads,
                psd_output_dim=psd_output_dim,
                psd_dropout=psd_dropout,
                cross_d_model=cross_d_model,
                cross_num_heads=cross_num_heads
            )
            model.to(device)
            
            # Loss and optimizer
            if use_weights:
                criterion_label = nn.CrossEntropyLoss(class_weights.to(device))
            else:
                criterion_label = nn.CrossEntropyLoss()
            
            optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
            
            # Train
            print("Training...")
            label_acc_history = train_model(
                model=model,
                source_dataloader=train_loader,
                target_dataloader=val_loader,
                criterion_label=criterion_label,
                optimizer=optimizer,
                num_epochs=NUM_EPOCHS,
                device=device,
                scheduler=None,
            )
            print(f"  Final training accuracy: {label_acc_history[-1]:.2f}%")
            
            # Threshold tuning
            print("  Tuning threshold...")
            thresholds_to_try = [0.2, 0.3, 0.4, 0.5]
            best_thr = tune_threshold_on_source(
                model=model,
                source_dataloader=train_loader,
                device=device,
                thresholds=thresholds_to_try,
                num_classes=2
            )
            best_thresholds.append(best_thr)
            
            # Test
            print("  Testing...")
            test_loss, test_acc, test_f1, conf_mat, preds, labels, _, _ = test_model(
                model=model,
                test_dataloader=val_loader,
                criterion_label=criterion_label,
                device=device,
                num_classes=2,
                alz_threshold=best_thr
            )
            
            print(f"  Validation loss: {test_loss:.4f}")
            print(f"  Validation accuracy: {test_acc:.2f}%")
            print(f"  Validation F1: {test_f1:.4f}")
            
            # Store results
            all_acc.append(test_acc)
            all_f1.append(test_f1)
            all_conf.append(conf_mat)
            global_probs.append(preds)
            global_labels.append(labels)
        
        # Repetition results
        all_acc_final.append(all_acc)
        all_f1_final.append(all_f1)
        all_conf_final.append(all_conf)
        global_probs_final.append(global_probs)
        global_labels_final.append(global_labels)
        best_thresholds_final.append(best_thresholds)
        
        print(f"\n  Repetition {rep_idx + 1} Results:")
        print(f"  Mean accuracy: {np.mean(all_acc):.2f}% ± {np.std(all_acc):.2f}%")
        print(f"  Mean F1: {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
    
    # Save results
    final_results = {
        "all_acc": all_acc_final,
        "all_f1": all_f1_final,
        "all_conf": all_conf_final,
        "global_probs": global_probs_final,
        "global_labels": global_labels_final,
        "best_thresholds": best_thresholds_final
    }
    
    results_file = os.path.join(RESULTS_DIR, f"final_results_{task}_dtca.npz")
    np.savez(results_file, final_results=final_results)
    print(f"\n✓ Saved results to {results_file}")
    
    # Compute final metrics
    print("\n" + "=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    compute_final_metrics(final_results)
    
    return final_results

def compute_final_metrics(final_results):
    """Compute and print final performance metrics."""
    all_runs = final_results["all_conf"]
    
    acc_scores, precision_scores, recall_scores, f1_scores = [], [], [], []
    
    for run_idx, run_cms in enumerate(all_runs, start=1):
        for fold_idx, cm in enumerate(run_cms, start=1):
            cm = np.asarray(cm)
            if cm.shape != (2, 2):
                continue
            
            tn, fp, fn, tp = cm.ravel()
            
            y_true = np.array([0] * (tn + fp) + [1] * (fn + tp))
            y_pred = np.array([0] * tn + [1] * fp + [0] * fn + [1] * tp)
            
            acc_scores.append(accuracy_score(y_true, y_pred))
            precision_scores.append(precision_score(y_true, y_pred, zero_division=0))
            recall_scores.append(recall_score(y_true, y_pred, zero_division=0))
            f1_scores.append(f1_score(y_true, y_pred, average="macro", zero_division=0))
    
    metrics = {
        "Accuracy": (np.mean(acc_scores), np.std(acc_scores)),
        "Precision": (np.mean(precision_scores), np.std(precision_scores)),
        "Recall": (np.mean(recall_scores), np.std(recall_scores)),
        "F1-score": (np.mean(f1_scores), np.std(f1_scores)),
    }
    
    for name, (mean, std) in metrics.items():
        print(f"{name:12s}: {mean:.4f} ± {std:.4f}")
    
    # Compute global AUC
    gp = []
    for i in range(len(final_results["global_probs"])):
        gp.extend(final_results["global_probs"][i])
    global_probs = np.vstack(gp)
    
    gl = []
    for i in range(len(final_results["global_labels"])):
        gl.extend(final_results["global_labels"][i])
    global_labels = np.hstack(gl)
    
    global_auc = roc_auc_score(global_labels, global_probs[:, 1])
    print(f"\nGlobal AUC: {global_auc:.4f}")
    
    # Plot ROC curve
    fpr, tpr, _ = roc_curve(global_labels, global_probs[:, 1])
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"AUC = {global_auc:.4f}")
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(RESULTS_DIR, 'roc_curve.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved ROC curve to {RESULTS_DIR}/roc_curve.png")

# ═══════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION
# ═══════════════════════════════════════════════════════════════════════════

if __name__ == "__main__":
    print("=" * 80)
    print("DTCA-NET: DUAL-TRANSFORMER CROSS ATTENTION NETWORK")
    print("EEG-based Alzheimer's and Frontotemporal Dementia Detection")
    print("=" * 80)
    
    # Step 1: Extract features (comment out if already extracted)
    print("\n[1/3] Extracting features from raw EEG data...")
    extract_features()
    
    # Step 2: Run CN vs AD experiment
    print("\n[2/3] Running CN vs AD experiment...")
    results_cn_ad = run_experiment(task='cn_ad')
    
    # Step 3: Run CN vs FTD experiment
    print("\n[3/3] Running CN vs FTD experiment...")
    results_cn_ftd = run_experiment(task='cn_ftd')
    
    print("\n" + "=" * 80)
    print("EXPERIMENT COMPLETED SUCCESSFULLY")
    print("=" * 80)
    print(f"Results saved in: {RESULTS_DIR}")
    print("=" * 80)


In [None]:
# Add this code at the end of your notebook
import shutil
import os

# Create a zip file of all results
output_dir = '/kaggle/working'
zip_filename = '/kaggle/working/all_results'

shutil.make_archive(zip_filename, 'zip', output_dir)
print(f"Created: {zip_filename}.zip")
