In [1]:
import os
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from sklearn.model_selection import GroupKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset, random_split
from torch.utils.data import ConcatDataset

import pywt
from scipy.signal import welch
from scipy.stats import chi2_contingency, ttest_ind

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering, DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.metrics import (
    r2_score,
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    silhouette_score,
)

from sklearn.svm import SVC
from xgboost import XGBClassifier
from catboost import CatBoostClassifier

import shap
from imblearn.over_sampling import SMOTE
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show

from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from mpl_toolkits.mplot3d import Axes3D

In [2]:
class EEGUtil:
    @staticmethod
    def load_data(file_path, column_names=['spike hz']):
        """Load multi-channel EEG data from a CSV file and normalize it.

        Parameters:
        - file_path: Path to the CSV file.
        - column_names: List of EEG signal column names to load.

        Returns:
        - signal: 2D numpy array (channels × time), normalized.
        """
        data = pd.read_csv(file_path)

        available_columns = [col for col in column_names if col in data.columns]
        if not available_columns:
            raise ValueError(f"None of the specified columns {column_names} exist in {file_path}")

        signal = data[available_columns].values.T  # (channels, time)

        mean = np.mean(signal, axis=1, keepdims=True)
        std = np.std(signal, axis=1, keepdims=True) 
        std[std == 0] = 1e-8  
        
        return signal

    @staticmethod
    def padding(signal, target_length, mode="constant", constant_value=0):
        """Pad or truncate multi-channel signal to the target length."""
        signal = np.array(signal)  # Ensure it's a NumPy array
        channels, current_length = signal.shape
        
        if current_length >= target_length:
            return signal[:, :target_length]  # Truncate

        padding_size = target_length - current_length
        if mode == "constant":
            pad_values = np.full((channels, padding_size), constant_value)
        elif mode == "reflect":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='reflect')[:, -padding_size:]
        elif mode == "cyclic":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='wrap')[:, -padding_size:]
        elif mode == "edge":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='edge')[:, -padding_size:]
        elif mode == "random":
            pad_values = np.random.uniform(low=np.min(signal), high=np.max(signal), size=(channels, padding_size))
        else:
            raise ValueError(f"Unsupported padding mode: {mode}")

        return np.hstack((signal, pad_values))  # Concatenate along time axis
    

class EEGDataset(Dataset):
    def __init__(self, data_folder, outcome_file=None, ssd_file=None,start_time=16,target_length=600, strategy='padding', 
                 padding_mode="constant", use_labels=False, augment=False,num_good=1,num_bad=1,column_names=['spike hz']):
        
        """
        EEG signal datasets with support for data enhancement (random fill). 

        Parameters: 
        - data_folder: path of the EEG data folder 
        - outcome_file: CSV file with patient ID and outcome (optional) 
        - target_length: indicates the padding length of the target 
        - strategy: wavelet ('padding', 'wavelet', 'psd') 
        - padding_mode: indicates the fill mode ('constant', 'reflect',...). 
        - use_labels: indicates whether to use labels 
        - augment: Whether data enhancement is enabled (randomly fill different lengths) 
        """

        self.use_labels = use_labels
        self.data_folder = data_folder
        self.target_length = target_length
        self.strategy = strategy
        self.padding_mode = padding_mode
        self.augment = augment 
        self.column_names = column_names #Feature Chosen
        
        self.num_good = num_good
        self.num_bad = num_bad
        
        self.start_time=start_time
        self.end_time=int((target_length/12)+self.start_time)

        self.file_list = [f for f in os.listdir(data_folder) if f.endswith('.csv')]
       
        # Read 'rosc sec' start time
        ssd_df = pd.read_csv(ssd_file)
        # Get `pat_ID`
        ssd_df["pat_ID"] = ssd_df["fn"].str.extract(r"(ICARE_\d+)")
        #  Take the smallest 'rosc sec' of each 'pat ID' 
        self.rosc_dict = ssd_df.groupby("pat_ID")["rosc_sec"].min().to_dict()
        
        # LOAD Labels
        self.outcome_dict = {}
        if use_labels and outcome_file:
            self.outcome_data = pd.read_csv(outcome_file)
            self.outcome_dict = self.outcome_data.set_index('pat_ID')['outcome'].to_dict()
            self.file_list = [f for f in self.file_list if f.split('.')[0] in self.outcome_dict]
        
        self.valid_files = []
        # Filter EEG data conforming to 16h-68h rules
        for f in self.file_list:
            pat_id = f.split('.')[0]

            if pat_id in self.rosc_dict:
                file_path = os.path.join(self.data_folder, f)
                signal = EEGUtil.load_data(file_path, column_names=self.column_names)  #Load Multiple Channel
                rosc_sec = float(self.rosc_dict[pat_id])  # Ensure `rosc_sec` is a float

                # **EEG recording time range**
                start_time = rosc_sec
                end_time = start_time + signal.shape[1] * 300  # Each point represents 5 minutes (=300 seconds)
                # **Skip if the data is completely outside the 16h-68h observation window**
                if end_time < self.start_time * 3600 or start_time > self.end_time * 3600:
                    #print(f"❌ Skipping {pat_id}: EEG data is out of 16h-68h range ({start_time/3600:.1f}h - {end_time/3600:.1f}h)")
                    continue  

                # **Align to the 16h-68h window**
                aligned_signal = self.align_signal(signal, rosc_sec)
                self.valid_files.append((f, aligned_signal))

        print(f"✅ Loaded {len(self.valid_files)} valid EEG files (filtered from {len(self.file_list)} total)")

        # **Count Good/Bad Outcome samples**
        self.good_outcome_count = sum(1 for f, _ in self.valid_files if self.get_label(f.split('.')[0]) == 1)
        self.bad_outcome_count = len(self.valid_files) - self.good_outcome_count

        print(f"Good Outcome: {self.good_outcome_count}, Bad Outcome: {self.bad_outcome_count}")

        # **Data Augmentation: Expanding indices**
        self.expanded_indices = []
        for idx, (filename, signal) in enumerate(self.valid_files):
            patient_id = filename.split('.')[0]
            label = self.get_label(patient_id) if self.use_labels else -1

            if self.augment:
                if self.use_labels:
                    # Good Outcome ×10, Bad Outcome ×2
                    if label == 1:
                        repeat_times = self.num_good
                    else:
                        repeat_times = self.num_bad
                else:
                    repeat_times = 1  # Data augmentation for unlabeled data
            else:
                repeat_times = 1  

            for _ in range(repeat_times):
                self.expanded_indices.append((idx, label))  # ✅ Store index & label
    
    def __len__(self):
        # print(f"📏 Dataset __len__: {len(self.expanded_indices)}")  # Ensure `expanded_indices` length is correct
        return len(self.expanded_indices)  # ✅ Must return the number of samples after data augmentation
    
    def __getitem__(self, idx):
        original_idx, label = self.expanded_indices[idx]
        filename, signal = self.valid_files[original_idx]  # Directly retrieve the **aligned** signal
        patient_id = filename.split('.')[0]

        # Get label
        label = -1
        if self.use_labels:
            label = self.get_label(patient_id)

        # Perform data augmentation (varies each time)
        if self.augment:
            augmented_signal = self.augment_signal(signal)  # ✅ Apply augmentation directly to the **aligned signal**
        else:
            augmented_signal = signal  # ✅ Use the aligned signal directly

        return torch.tensor(augmented_signal, dtype=torch.float32), label
    
    def align_signal(self, signal, rosc_sec):
        """ Align EEG data to the 16h-68h observation period """

        target_length = self.target_length  # Number of `5min` windows for 52 hours (624)
        total_signal_length = signal.shape[1]  # Total length of the EEG recording
        
        rosc_sec = float(rosc_sec)  # ✅ Ensure `rosc_sec` is a float
        # print(f"🔍 Processing patient data: rosc_sec={rosc_sec}, total_signal_length={total_signal_length}")

        # **Calculate the starting position of '16h' in the EEG recording**
        start_sec = (self.start_time * 3600) - rosc_sec  
        if start_sec < 0:
            pad_size = abs(start_sec) / 300  # Calculate the number of windows to pad
            start_index = 0  # Start extracting data from the beginning of the EEG recording
        else:
            pad_size = 0  # No padding needed
            start_index = int(start_sec // 300)  # ✅ Convert to integer

        # **Calculate the endpoint index for '68h'**
        end_index = int(min(start_index + target_length, total_signal_length))  # ✅ Convert to integer

        # **Extract EEG data for the 16h-68h observation period**
        aligned_signal = signal[:, start_index:end_index]

        # **Pre-padding (if `rosc_sec > 16h`)**
        if pad_size > 0:
            aligned_signal = EEGDataset.pad_signal(aligned_signal, target_length, self.padding_mode, padding_position="pre")

        # **Post-padding (if data is less than 52 hours)**
        aligned_signal = EEGDataset.pad_signal(aligned_signal, target_length, self.padding_mode, padding_position="post")

        # print(f"✅ Aligned signal length: {len(aligned_signal)}")
        return aligned_signal
    
    def pad_signal(signal, target_length, mode="constant", constant_value=0, padding_position="post"):
        """ Pad EEG signal to ensure it reaches `target_length`.

        Parameters:
        - signal: Original EEG signal (numpy array)
        - target_length: Target length (52h = 624 `5min` windows)
        - mode: Padding mode:
            - `constant`: Fill with a fixed value (`constant_value`)
            - `reflect`: Mirror padding
            - `cyclic`: Cyclic padding
            - `edge`: Edge padding
            - `random`: Fill with random values between [min, max]
        - padding_position: `"pre"` (pad at the beginning) or `"post"` (pad at the end)

        Returns:
        - Padded EEG signal (numpy array)
        """

        channels, current_length = signal.shape
        
        if current_length >= target_length:
            return signal[:, :target_length]  # Truncate if already long enough

        padding_size = target_length - current_length  # Number of elements to pad

        if mode == "constant":
            pad_values = np.full((channels, padding_size), constant_value)  # Make sure padding shape matches
        elif mode == "reflect":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='reflect')[:, -padding_size:]
        elif mode == "cyclic":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='wrap')[:, -padding_size:]
        elif mode == "edge":
            pad_values = np.pad(signal, ((0, 0), (0, padding_size)), mode='edge')[:, -padding_size:]
        elif mode == "random":
            pad_values = np.random.uniform(low=np.min(signal), high=np.max(signal), size=(channels, padding_size))
        else:
            raise ValueError(f"Unsupported padding mode: {mode}")

        # Ensure proper concatenation along the time axis
        if padding_position == "pre":
            padded_signal = np.hstack((pad_values, signal))  # Pad at the beginning
        else:
            padded_signal = np.hstack((signal, pad_values))  # Pad at the end

        return padded_signal[:, :target_length]  # Ensure exact target length
    
    def augment_signal(self, signal):
        """ Data augmentation: Shift EEG data within the 16h-68h observation period """

        target_length = self.target_length  # Number of `5min` windows for 52 hours (624)
        channels, current_length = signal.shape  # Current EEG recording length

        # **Check if already aligned to 16h-68h before augmentation**
        if current_length != target_length:
            raise ValueError(f"Before augmentation, signal length should be {target_length}, but received {current_length}")

        # **Augmentation Strategy 1: Random time shift within ±60min**
        max_shift = 12  # `5min` windows, 60 minutes = 12 data points
        shift = np.random.randint(-max_shift, max_shift + 1)  # Random shift in [-12, 12]

        # **Compute new starting index and ensure it remains within bounds**
        start_index = max(0, min(current_length - target_length, shift))
        end_index = min(start_index + target_length, current_length)

        # **Extract EEG data after shifting**
        augmented_signal = signal[:, start_index:end_index]

        # **Augmentation Strategy 2: Apply padding based on `shift` direction**
        if augmented_signal.shape[1] < target_length:
            padding_size = target_length - augmented_signal.shape[1]

            if shift > 0:
                pad_values = self.pad_signal(np.zeros((channels, padding_size)), target_length, self.padding_mode)
                augmented_signal = np.hstack((pad_values, augmented_signal))  # **确保形状正确**
            elif shift < 0:
                pad_values = self.pad_signal(np.zeros((channels, padding_size)), target_length, self.padding_mode)
                augmented_signal = np.hstack((augmented_signal, pad_values))
            else:
                augmented_signal = self.pad_signal(augmented_signal, target_length, self.padding_mode)

        return augmented_signal

    def get_label(self, patient_id):
        """ Get sample label (1 = Good Outcome, 0 = Bad Outcome) """
        return 1 if self.outcome_dict.get(patient_id, 'Bad Outcome') == 'Good Outcome' else 0

    def compare_data_augmentation(self):
        """ Compare the number of samples before and after data augmentation. """
        original_count = len(self.valid_files)  # Count only files that meet the 16h condition
        augmented_count = len(self.expanded_indices)  # Count the number of augmented samples

        print(f"Data count before augmentation: {original_count}")
        print(f"Data count after augmentation: {augmented_count}")
        print(f"Augmentation ratio: {augmented_count / original_count:.2f}x")

        if self.use_labels:
            # Count Good Outcome and Bad Outcome samples in original data (filtered_files)
            good_outcome_original = sum(1 for f, _ in self.valid_files if self.get_label(f.split('.')[0]) == 1)
            bad_outcome_original = original_count - good_outcome_original  # Remaining are Bad Outcome samples

            # Count Good Outcome and Bad Outcome samples after augmentation
            good_outcome_augmented = sum(
                1 for (idx, _) in self.expanded_indices  # ✅ Use only idx, ignore label
                if self.get_label(self.valid_files[idx][0].split('.')[0]) == 1
            )
            bad_outcome_augmented = augmented_count - good_outcome_augmented  # Remaining are Bad Outcome samples

            print(f"Good Outcome before augmentation: {good_outcome_original}, after augmentation: {good_outcome_augmented}")
            print(f"Bad Outcome before augmentation: {bad_outcome_original}, after augmentation: {bad_outcome_augmented}")

        return original_count, augmented_count
    
    
# AugmentedEEGDataset 
class AugmentedEEGDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, eeg_dataset_instance, augment=True, num_good=10, num_bad=2):
        self.base_dataset = base_dataset
        self.eeg_dataset_instance = eeg_dataset_instance 
        self.augment = augment
        self.num_good = num_good
        self.num_bad = num_bad
        self.expanded_data = []

        for i in range(len(base_dataset)):
            signal, label = base_dataset[i]
            repeat = num_good if label == 1 else num_bad

            for _ in range(repeat):
                signal_np = signal.numpy()
                if signal_np.ndim == 2:
                    augmented_signal = self.eeg_dataset_instance.augment_signal(signal_np)
                else:
                    augmented_signal = self.eeg_dataset_instance.augment_signal(signal_np[np.newaxis, :])  # **单通道兼容**

                self.expanded_data.append((augmented_signal, label))

    def __len__(self):
        return len(self.expanded_data)

    def __getitem__(self, idx):
        signal, label = self.expanded_data[idx]
        return torch.tensor(signal, dtype=torch.float32), label

def count_labels(dataset):
    labels = [dataset[i][1] for i in range(len(dataset))] 
    unique, counts = np.unique(labels, return_counts=True)
    return dict(zip(unique, counts))

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameter settings
latent_dim = 10  # Dimension of the latent space
initial_channels = 128  # Initial number of channels for CNN
start_time=16 #Set initial hour
seq_length = (68 - start_time) * 12  # Number of `5min` windows for 52 hours (624)
batch_size = 32
epochs = 100
alpha = 1e-12  # Adjust KL divergence weight
patience = 10  # Define the number of epochs without improvement before stopping training
padding_method = 'constant'  # Use constant padding
is_augment = False  # Do not use
is_VAEtrain_aug=False # Do not use
num_good_train=1
num_bad_train=1
num_good_test=1
num_bad_test=1
column_names=['ssd','BCI','avgspectent', 'lv_l5']
num_eeg_channels=len(column_names)

# Define dataset folder paths
data_folder = '5min_smoothed_data/'  # Replace with the actual data folder path
valid_outcome_data = 'valid_patients_outcome.csv'
ssd_file = 'files_art_ssd_fts_predictions.csv'

# Create EEG dataset (labels are not used when training VAE)
eeg_dataset = EEGDataset(
    data_folder='5min_smoothed_data/',
    outcome_file='valid_patients_outcome.csv',
    ssd_file=ssd_file,
    start_time=start_time,
    target_length=seq_length,
    strategy='padding',
    padding_mode=padding_method,
    use_labels=True,  # Use labels to determine class-based augmentation
    augment=is_augment,  
    num_good=1,
    num_bad=1,
    column_names=column_names
)

In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import GroupKFold

# —— 0. 设备设置 —— 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 1. 超参数与数据集初始化 —— 
start_time   = 16
seq_length   = (68 - start_time) * 12  # 624 个 5min 窗口
column_names = ['ssd', 'BCI', 'avgspectent', 'lv_l5']
batch_size   = 32
epochs       = 100
n_splits     = 5
learning_rate= 1e-3

# 创建 EEGDataset
eeg_dataset = EEGDataset(
    data_folder='5min_smoothed_data/',
    outcome_file='valid_patients_outcome.csv',
    ssd_file='files_art_ssd_fts_predictions.csv',
    start_time=start_time,
    target_length=seq_length,
    strategy='padding',
    padding_mode='constant',
    use_labels=True,
    augment=False,
    num_good=1,
    num_bad=1,
    column_names=column_names
)

# —— 2. 从 Dataset 中提取 X, y, groups —— 
signals, labels, groups = [], [], []
for idx, label in eeg_dataset.expanded_indices:
    filename, signal = eeg_dataset.valid_files[idx]
    signals.append(signal.T)               # (time_steps, features)
    labels.append(label)                   
    groups.append(filename.split('.')[0])  # pat_ID

X = np.stack(signals)    # shape: (n_samples, time_steps, features)
y = np.array(labels)     # shape: (n_samples,)

# —— 3. 定义 PyTorch LSTM 模型 —— 
input_size = X.shape[2]
hidden1    = 64
hidden2    = 12

class LSTMClassifier(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size, hidden1, batch_first=True)
        self.drop1 = nn.Dropout(0.2)
        self.lstm2 = nn.LSTM(hidden1, hidden2, batch_first=True)
        self.drop2 = nn.Dropout(0.2)
        self.fc    = nn.Linear(hidden2, 1)
        # 注意：不在这里加 Sigmoid

    def forward(self, x):
        # x: (batch, time_steps, features)
        out, _ = self.lstm1(x)
        out     = self.drop1(out)
        out, _  = self.lstm2(out)
        out     = self.drop2(out)
        out     = out[:, -1, :]       # 取最后一个时间步，shape (batch, hidden2)
        logits  = self.fc(out)        # shape (batch, 1)
        return logits.squeeze(1)      # shape (batch,)

# —— 4. 分组交叉验证训练和评估 —— 
kf       = GroupKFold(n_splits=n_splits)
loss_fn  = nn.BCEWithLogitsLoss()
acc_list = []

for fold, (train_idx, test_idx) in enumerate(kf.split(X, y, groups=groups), 1):
    # 构造训练/测试张量
    X_train = torch.tensor(X[train_idx], dtype=torch.float32, device=device)
    y_train = torch.tensor(y[train_idx], dtype=torch.float32, device=device)
    X_test  = torch.tensor(X[test_idx],  dtype=torch.float32, device=device)
    y_test  = torch.tensor(y[test_idx],  dtype=torch.float32, device=device)

    # 初始化模型、优化器
    model = LSTMClassifier(input_size).to(device)
    opt   = optim.Adam(model.parameters(), lr=learning_rate)

    # 训练
    model.train()
    for epoch in range(1, epochs+1):
        opt.zero_grad()
        logits = model(X_train)         # shape (n_train,)
        loss   = loss_fn(logits, y_train)
        loss.backward()
        opt.step()

    # 测试
    model.eval()
    with torch.no_grad():
        logits    = model(X_test)
        preds     = (torch.sigmoid(logits) > 0.5).float()
        accuracy  = (preds == y_test).float().mean().item()
    acc_list.append(accuracy * 100)
    print(f"Fold {fold}: Accuracy = {accuracy*100:.2f}%")

# —— 5. 输出整体结果 —— 
mean_acc = np.mean(acc_list)
std_acc  = np.std(acc_list)
print(f"Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")

✅ Loaded 244 valid EEG files (filtered from 244 total)
Good Outcome: 44, Bad Outcome: 200
Fold 1: Accuracy = 77.55%
Fold 2: Accuracy = 79.59%
Fold 3: Accuracy = 79.59%
Fold 4: Accuracy = 87.76%
Fold 5: Accuracy = 85.42%
Mean Accuracy: 81.98% ± 3.90%


In [15]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import TensorDataset, DataLoader
from collections import Counter

# —— 0. 设备设置 —— 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 1. 超参数与数据集初始化 —— 
start_time    = 16
seq_length    = (68 - start_time) * 12   # 624 个 5min 窗口
column_names  = ['ssd', 'BCI', 'avgspectent', 'lv_l5']
batch_size    = 8    # 对应原代码 batch_size=8
epochs        = 1    # 对应原 model.fit 默认的 1 个 epoch
n_splits      = 5
learning_rate = 1e-3

# 创建 EEGDataset
eeg_dataset = EEGDataset(
    data_folder='5min_smoothed_data/',
    outcome_file='valid_patients_outcome.csv',
    ssd_file='files_art_ssd_fts_predictions.csv',
    start_time=start_time,
    target_length=seq_length,
    strategy='padding',
    padding_mode='constant',
    use_labels=True,
    augment=False,
    num_good=1,
    num_bad=1,
    column_names=column_names
)

# —— 2. 从 Dataset 中提取 X, y, groups —— 
signals, labels, groups = [], [], []
for idx, label in eeg_dataset.expanded_indices:
    fname, signal = eeg_dataset.valid_files[idx]
    signals.append(signal.T)               # (time_steps, features)
    labels.append(label)
    groups.append(fname.split('.')[0])     # pat_ID

X = np.stack(signals)    # shape: (n_samples, time_steps, features)
y = np.array(labels)     # shape: (n_samples,)

# 标签沿时间步复制，匹配模型输出维度 (n_samples, seq_length)
y_seq = np.tile(y[:, None], (1, seq_length))

# —— 3. 定义 PyTorch LSTM 模型 —— 
input_size = X.shape[2]
hidden1, hidden2 = 64, 12

class LSTMClassifier(nn.Module):
    def __init__(self, input_size, out_steps):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size, hidden1, batch_first=True)
        self.drop1 = nn.Dropout(0.2)
        self.lstm2 = nn.LSTM(hidden1, hidden2, batch_first=True)
        self.drop2 = nn.Dropout(0.2)
        self.fc    = nn.Linear(hidden2, out_steps)
    
    def forward(self, x):
        out, _ = self.lstm1(x)
        out     = self.drop1(out)
        out, _  = self.lstm2(out)
        out     = self.drop2(out)
        out     = out[:, -1, :]
        logits  = self.fc(out)      # (batch, seq_length)
        return logits               # 不再 sigmoid   # (batch, seq_length)

# —— 4. 分组交叉验证训练和评估 —— 
kf            = GroupKFold(n_splits=n_splits)
criterion = nn.BCEWithLogitsLoss()
cvscores      = []
auc_scores    = []
prediction_arr = np.array([])

print("Original class balance:", Counter(y))

for fold, (train_idx, test_idx) in enumerate(kf.split(X, y, groups=groups), 1):
    print(f"----- Fold {fold} -----")
    y_train, y_test = y[train_idx], y[test_idx]
    print("Train balance:", Counter(y_train), "Test balance:", Counter(y_test))
    
    # 准备训练/测试数据
    X_tr = torch.tensor(X[train_idx], dtype=torch.float32, device=device)
    y_tr = torch.tensor(y_seq[train_idx], dtype=torch.float32, device=device)
    X_te = torch.tensor(X[test_idx],  dtype=torch.float32, device=device)
    y_te = torch.tensor(y_seq[test_idx], dtype=torch.float32, device=device)
    
    train_ds     = TensorDataset(X_tr, y_tr)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    
    # 模型、优化器初始化
    model     = LSTMClassifier(input_size, seq_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 训练（1 个 epoch，对应原 .fit 的默认行为）
    model.train()
    for xb, yb in train_loader:
        optimizer.zero_grad()
        outs = model(xb)
        loss = criterion(outs, yb)
        loss.backward()
        optimizer.step()
    
    # 测试及评估
    # —— 在测试及评估里替换 AUC 计算 —— 
    model.eval()
    with torch.no_grad():
        logits = model(X_te)                                 # (n_test, seq_length)
        probs  = torch.sigmoid(logits).cpu().numpy()         # (n_test, seq_length)

    # 1) 窗口级 Accuracy（保持原样）
    acc = accuracy_score(
        y_seq[test_idx].flatten(),
        (probs > 0.5).astype(int).flatten()
    )
    cvscores.append(acc * 100)

    # 2) 患者级 AUC
   # a) 按时间步平均每个病人的预测概率，NaN 替为 0.5
    preds_patient = np.nan_to_num(preds_patient, nan=0.5)
    labels_patient = y[test_idx]                    # (n_test,)

    # c) 计算 AUC
    auc = roc_auc_score(labels_patient, preds_patient)
    auc_scores.append(auc * 100)

    print(f"Fold {fold}: Accuracy = {acc*100:.2f}%, AUC = {auc*100:.2f}%")
# —— 5. 输出整体结果 —— 
print(f"Mean Accuracy: {np.mean(cvscores):.2f}% ± {np.std(cvscores):.2f}%")
print(f"Mean AUC:      {np.mean(auc_scores):.2f}% ± {np.std(auc_scores):.2f}%")

✅ Loaded 244 valid EEG files (filtered from 244 total)
Good Outcome: 44, Bad Outcome: 200
Original class balance: Counter({0: 200, 1: 44})
----- Fold 1 -----
Train balance: Counter({0: 162, 1: 33}) Test balance: Counter({0: 38, 1: 11})


ValueError: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.