In [None]:
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import RandomOverSampler
from scipy.signal import resample
import torch
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from imblearn.under_sampling import RandomUnderSampler
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")
np.random.seed(42)

# === Load EEG Data ===
with open("/content/drive/MyDrive/output_eeg_data/processed_eeg_data.pkl", "rb") as f:
    eeg_data_dict = pickle.load(f)
eeg_data = eeg_data_dict["data"]

In [None]:
# === Signal Utilities ===
def clean_signals(X):
    X = np.nan_to_num(X)
    return (X - np.mean(X, axis=1, keepdims=True)) / (np.std(X, axis=1, keepdims=True) + 1e-8)

def downsample_signals(X, target_len=2000):
    return np.array([resample(sig, target_len) for sig in X])

# === Custom Dataset ===
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

# === Model 1: 1D CNN ===
class ECG1DCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 44, 4, padding=2), nn.ReLU(), nn.BatchNorm1d(44), nn.MaxPool1d(2),
            nn.Conv1d(44, 57, 7, padding=3), nn.ReLU(), nn.BatchNorm1d(57), nn.MaxPool1d(2),
            nn.Conv1d(57, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm1d(128), nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(128, num_classes))
    def forward(self, x): return self.fc(self.cnn(x))

# === Model 2: BiLSTM ===
class ECG_BiLSTM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=128, num_layers=2, bidirectional=True,
                            batch_first=True, dropout=0.3)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 2, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        out, _ = self.lstm(x.transpose(1, 2))
        out = F.adaptive_avg_pool1d(out.transpose(1, 2), 1).squeeze(-1)
        return self.fc(out)

class ECG_Transformer(nn.Module):
    def __init__(self, num_classes, input_dim=1, seq_len=2000, d_model=64, nhead=4, num_layers=2, dropout=0.2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, x):
        # Input x: [batch_size, 1, seq_len]
        x = x.squeeze(1)                  # [batch, seq_len]
        x = x.unsqueeze(-1)              # [batch, seq_len, 1]
        x = self.embedding(x)
        x = self.transformer(x)          # [batch, seq_len, d_model]
        x = x.transpose(1, 2)            # [batch, d_model, seq_len] for pooling
        x = self.global_pool(x)          # [batch, d_model, 1]
        return self.classifier(x)        # [batch, num_classes]


def train_model(X, y, label_names, model_class, model_name, num_epochs=15, batch_size=64, patience=5):
    print(f"\n⚙️ Training {model_name} with Early Stopping...")

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.25, random_state=42)
    X_train, y_train = RandomOverSampler().fit_resample(X_train, y_train)

    class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float32)

    # DataLoaders
    train_loader = DataLoader(ECGDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(ECGDataset(X_test, y_test), batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_class(num_classes=len(label_names)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    # Early stopping
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Validation loss
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                val_loss += criterion(model(X_batch), y_batch).item()
        avg_val_loss = val_loss / len(val_loader)

        print(f"🧠 Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # Check early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()  # Save best model weights
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping triggered.")
                break

    # Load best model
    model.load_state_dict(best_model_state)

    # Final Evaluation
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            pred = model(X_batch).argmax(dim=1).cpu().numpy()
            y_true.extend(y_batch.numpy())
            y_pred.extend(pred)

    print(f"\n✅ Accuracy: {accuracy_score(y_true, y_pred):.4f}")
    print("\n✅ Classification Report:\n", classification_report(y_true, y_pred, target_names=label_names))
    print("✅ Confusion Matrix:\n", confusion_matrix(y_true, y_pred))

In [None]:
# === Prepare EEG Memory Classifier Data ===
X_mem_eeg, y_mem_eeg = [], []
memory_label_map = { "Five-Memory": 0, "Nine-Memory": 1, "Thirteen-Memory": 2 }

for subj_data in eeg_data:
    for cond_idx, condition in enumerate(["Five", "Nine", "Thirteen"]):
        for subcond_idx in [1, 2]:  # MemoryCorrect, MemoryIncorrect
            for trial_idx in range(54):a
                signal = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(signal).all():
                    X_mem_eeg.append(signal)
                    y_mem_eeg.append(memory_label_map[f"{condition}-Memory"])

X_mem_eeg = downsample_signals(clean_signals(np.array(X_mem_eeg)))
label_names_mem = list(memory_label_map.keys())

In [None]:
# === Train Models on EEG (Memory Classifier) ===
train_model(X_mem_eeg, np.array(y_mem_eeg), label_names_mem, ECG1DCNN, "EEG_1D_CNN_Memory")
train_model(X_mem_eeg, np.array(y_mem_eeg), label_names_mem, ECG_BiLSTM, "EEG_BiLSTM_Memory")
train_model(X_mem_eeg, np.array(y_mem_eeg), label_names_mem, ECG_Transformer, "EEG_Transformer_Memory")


⚙️ Training EEG_1D_CNN_Memory with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.1161 | Val Loss: 0.2794
🧠 Epoch 2/15 | Train Loss: 0.0399 | Val Loss: 0.2900
🧠 Epoch 3/15 | Train Loss: 0.0345 | Val Loss: 0.2930
🧠 Epoch 4/15 | Train Loss: 0.0341 | Val Loss: 0.2950
🧠 Epoch 5/15 | Train Loss: 0.0313 | Val Loss: 0.2977
🧠 Epoch 6/15 | Train Loss: 0.0260 | Val Loss: 0.2988
⏹️ Early stopping triggered.

✅ Accuracy: 0.9940

✅ Classification Report:
                  precision    recall  f1-score   support

    Five-Memory       1.00      0.99      1.00       590
    Nine-Memory       1.00      0.99      0.99       560
Thirteen-Memory       0.98      1.00      0.99       515

       accuracy                           0.99      1665
      macro avg       0.99      0.99      0.99      1665
   weighted avg       0.99      0.99      0.99      1665

✅ Confusion Matrix:
 [[585   2   3]
 [  0 555   5]
 [  0   0 515]]

⚙️ Training EEG_BiLSTM_Memory with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.

In [None]:
# === Prepare Balanced EEG Binary Classifier Data (JustListen = 0, Memory = 1) ===
X_bin_eeg, y_bin_eeg = [], []
for subj_data in eeg_data:
    for cond_idx in range(3):  # Five, Nine, Thirteen
        for trial_idx in range(54):
            # JustListen
            sig_jl = subj_data[cond_idx, 0, trial_idx]
            if not np.isnan(sig_jl).all():
                X_bin_eeg.append(sig_jl)
                y_bin_eeg.append(0)
            # MemoryCorrect + MemoryIncorrect
            for subcond_idx in [1, 2]:
                sig_mem = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(sig_mem).all():
                    X_bin_eeg.append(sig_mem)
                    y_bin_eeg.append(1)

X_bin_eeg = downsample_signals(clean_signals(np.array(X_bin_eeg)))
label_names_bin = ["JustListen", "Memory"]


In [None]:
# === Train Models on EEG (Binary Classifier) ===
train_model(X_bin_eeg, np.array(y_bin_eeg), label_names_bin, ECG1DCNN, "EEG_1D_CNN_Binary")
train_model(X_bin_eeg, np.array(y_bin_eeg), label_names_bin, ECG_BiLSTM, "EEG_BiLSTM_Binary")
train_model(X_bin_eeg, np.array(y_bin_eeg), label_names_bin, ECG_Transformer, "EEG_Transformer_Binary")


⚙️ Training EEG_1D_CNN_Binary with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.7001 | Val Loss: 0.6821
🧠 Epoch 2/15 | Train Loss: 0.6961 | Val Loss: 0.6960
🧠 Epoch 3/15 | Train Loss: 0.6952 | Val Loss: 0.6941
🧠 Epoch 4/15 | Train Loss: 0.6946 | Val Loss: 0.6873
🧠 Epoch 5/15 | Train Loss: 0.6935 | Val Loss: 0.6816
🧠 Epoch 6/15 | Train Loss: 0.6932 | Val Loss: 0.6813
🧠 Epoch 7/15 | Train Loss: 0.6930 | Val Loss: 0.6739
🧠 Epoch 8/15 | Train Loss: 0.6918 | Val Loss: 0.6934
🧠 Epoch 9/15 | Train Loss: 0.6922 | Val Loss: 0.6817
🧠 Epoch 10/15 | Train Loss: 0.6919 | Val Loss: 0.6871
🧠 Epoch 11/15 | Train Loss: 0.6922 | Val Loss: 0.6908
🧠 Epoch 12/15 | Train Loss: 0.6916 | Val Loss: 0.6708
🧠 Epoch 13/15 | Train Loss: 0.6902 | Val Loss: 0.6886
🧠 Epoch 14/15 | Train Loss: 0.6914 | Val Loss: 0.6922
🧠 Epoch 15/15 | Train Loss: 0.6910 | Val Loss: 0.6816

✅ Accuracy: 0.5573

✅ Classification Report:
               precision    recall  f1-score   support

  JustListen       0.34      0.42      0.37

In [None]:
# === Prepare EEG Four-Class Classifier Data ===
X_4class_eeg, y_4class_eeg = [], []
four_class_label_map = {
    "JustListen": 0,
    "Five-Memory": 1,
    "Nine-Memory": 2,
    "Thirteen-Memory": 3
}

for subj_data in eeg_data:
    for cond_idx, condition in enumerate(["Five", "Nine", "Thirteen"]):
        for trial_idx in range(54):
            # JustListen
            sig_jl = subj_data[cond_idx, 0, trial_idx]
            if not np.isnan(sig_jl).all():
                X_4class_eeg.append(sig_jl)
                y_4class_eeg.append(0)

            # MemoryCorrect + MemoryIncorrect
            for subcond_idx in [1, 2]:
                sig_mem = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(sig_mem).all():
                    y_val = four_class_label_map[f"{condition}-Memory"]
                    X_4class_eeg.append(sig_mem)
                    y_4class_eeg.append(y_val)

X_4class_eeg = downsample_signals(clean_signals(np.array(X_4class_eeg)))
label_names_4class = ["JustListen", "Five-Memory", "Nine-Memory", "Thirteen-Memory"]


In [None]:
train_model(X_4class_eeg, np.array(y_4class_eeg), label_names_4class, ECG1DCNN, "EEG_1D_CNN_FourClass")
train_model(X_4class_eeg, np.array(y_4class_eeg), label_names_4class, ECG_BiLSTM, "EEG_BiLSTM_FourClass")
train_model(X_4class_eeg, np.array(y_4class_eeg), label_names_4class, ECG_Transformer, "EEG_Transformer_FourClass")



⚙️ Training EEG_1D_CNN_FourClass with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.7241 | Val Loss: 0.7144
🧠 Epoch 2/15 | Train Loss: 0.6568 | Val Loss: 0.6854
🧠 Epoch 3/15 | Train Loss: 0.6299 | Val Loss: 0.6688
🧠 Epoch 4/15 | Train Loss: 0.6171 | Val Loss: 0.6903
🧠 Epoch 5/15 | Train Loss: 0.6174 | Val Loss: 0.6654
🧠 Epoch 6/15 | Train Loss: 0.6181 | Val Loss: 0.6591
🧠 Epoch 7/15 | Train Loss: 0.6103 | Val Loss: 0.6765
🧠 Epoch 8/15 | Train Loss: 0.6071 | Val Loss: 0.6608
🧠 Epoch 9/15 | Train Loss: 0.6113 | Val Loss: 0.6623
🧠 Epoch 10/15 | Train Loss: 0.6052 | Val Loss: 0.6686
🧠 Epoch 11/15 | Train Loss: 0.6041 | Val Loss: 0.6536
🧠 Epoch 12/15 | Train Loss: 0.6004 | Val Loss: 0.6488
🧠 Epoch 13/15 | Train Loss: 0.6001 | Val Loss: 0.6628
🧠 Epoch 14/15 | Train Loss: 0.6016 | Val Loss: 0.6709
🧠 Epoch 15/15 | Train Loss: 0.6001 | Val Loss: 0.6629

✅ Accuracy: 0.6752

✅ Classification Report:
                  precision    recall  f1-score   support

     JustListen       0.30      0.01 