In [None]:
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import RandomOverSampler
from scipy.signal import resample
import torch
from collections import Counter
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from imblearn.under_sampling import RandomUnderSampler
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")
np.random.seed(42)

# === Load ECG Data ===
with open("/content/drive/MyDrive/output_ecg_data/processed_ecg_data.pkl", "rb") as f:
    data_dict = pickle.load(f)
data = data_dict["data"]

In [None]:
# === Signal Utilities ===
def clean_signals(X):
    X = np.nan_to_num(X)
    return (X - np.mean(X, axis=1, keepdims=True)) / (np.std(X, axis=1, keepdims=True) + 1e-8)

def downsample_signals(X, target_len=2000):
    return np.array([resample(sig, target_len) for sig in X])

# === Custom Dataset ===
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

# === Model 1: 1D CNN ===
class ECG1DCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 44, 4, padding=2), nn.ReLU(), nn.BatchNorm1d(44), nn.MaxPool1d(2),
            nn.Conv1d(44, 57, 7, padding=3), nn.ReLU(), nn.BatchNorm1d(57), nn.MaxPool1d(2),
            nn.Conv1d(57, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm1d(128), nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Sequential(nn.Flatten(), nn.Dropout(0.3), nn.Linear(128, num_classes))
    def forward(self, x): return self.fc(self.cnn(x))

# === Model 2: BiLSTM ===
class ECG_BiLSTM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=128, num_layers=2, bidirectional=True,
                            batch_first=True, dropout=0.3)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 2, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        out, _ = self.lstm(x.transpose(1, 2))
        out = F.adaptive_avg_pool1d(out.transpose(1, 2), 1).squeeze(-1)
        return self.fc(out)

class ECG_Transformer(nn.Module):
    def __init__(self, num_classes, input_dim=1, seq_len=2000, d_model=64, nhead=4, num_layers=2, dropout=0.2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, x):
        # Input x: [batch_size, 1, seq_len]
        x = x.squeeze(1)                  # [batch, seq_len]
        x = x.unsqueeze(-1)              # [batch, seq_len, 1]
        x = self.embedding(x)
        x = self.transformer(x)          # [batch, seq_len, d_model]
        x = x.transpose(1, 2)            # [batch, d_model, seq_len] for pooling
        x = self.global_pool(x)          # [batch, d_model, 1]
        return self.classifier(x)        # [batch, num_classes]


def train_model(X, y, label_names, model_class, model_name, num_epochs=15, batch_size=64, patience=5):
    print(f"\n⚙️ Training {model_name} with Early Stopping...")

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.25, random_state=42)
    X_train, y_train = RandomOverSampler().fit_resample(X_train, y_train)

    class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float32)

    # DataLoaders
    train_loader = DataLoader(ECGDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(ECGDataset(X_test, y_test), batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_class(num_classes=len(label_names)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    # Early stopping
    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Validation loss
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                val_loss += criterion(model(X_batch), y_batch).item()
        avg_val_loss = val_loss / len(val_loader)

        print(f"🧠 Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # Check early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()  # Save best model weights
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping triggered.")
                break

    # Load best model
    model.load_state_dict(best_model_state)


    # Final Evaluation
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            pred = model(X_batch).argmax(dim=1).cpu().numpy()
            y_true.extend(y_batch.numpy())
            y_pred.extend(pred)

    print(f"\n✅ Accuracy: {accuracy_score(y_true, y_pred):.4f}")
    print("\n✅ Classification Report:\n", classification_report(y_true, y_pred, target_names=label_names))
    print("✅ Confusion Matrix:\n", confusion_matrix(y_true, y_pred))


    # Confusion matrix visualization
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap='Blues', xticks_rotation=45)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.show()


In [None]:
# === Prepare Memory Classifier Data ===
X_mem, y_mem = [], []
memory_label_map = { "Five-Memory": 0, "Nine-Memory": 1, "Thirteen-Memory": 2 }
for subj_data in data:
    for cond_idx, condition in enumerate(["Five", "Nine", "Thirteen"]):
        for subcond_idx in [1, 2]:
            for trial_idx in range(54):
                signal = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(signal).all():
                    X_mem.append(signal)
                    y_mem.append(memory_label_map[f"{condition}-Memory"])

X_mem = downsample_signals(clean_signals(np.array(X_mem)))
label_names_mem = list(memory_label_map.keys())
# === Train all 3 Models for Memory Classification ===
train_model(X_mem, np.array(y_mem), label_names_mem, ECG1DCNN, "1D_CNN_Memory")
train_model(X_mem, np.array(y_mem), label_names_mem, ECG_BiLSTM, "BiLSTM_Memory")
train_model(X_mem, np.array(y_mem), label_names_mem, ECG_Transformer, "Transformer_Memory")


⚙️ Training 1D_CNN_Memory with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.2702 | Val Loss: 0.1833
🧠 Epoch 2/15 | Train Loss: 0.1652 | Val Loss: 0.1315
🧠 Epoch 3/15 | Train Loss: 0.1322 | Val Loss: 0.1283
🧠 Epoch 4/15 | Train Loss: 0.1102 | Val Loss: 0.0856
🧠 Epoch 5/15 | Train Loss: 0.1035 | Val Loss: 0.1019
🧠 Epoch 6/15 | Train Loss: 0.0866 | Val Loss: 0.0730
🧠 Epoch 7/15 | Train Loss: 0.0762 | Val Loss: 0.0878
🧠 Epoch 8/15 | Train Loss: 0.0772 | Val Loss: 0.0748
🧠 Epoch 9/15 | Train Loss: 0.0770 | Val Loss: 0.1181
🧠 Epoch 10/15 | Train Loss: 0.0855 | Val Loss: 0.0707
🧠 Epoch 11/15 | Train Loss: 0.0631 | Val Loss: 0.0984
🧠 Epoch 12/15 | Train Loss: 0.0611 | Val Loss: 0.0669
🧠 Epoch 13/15 | Train Loss: 0.0599 | Val Loss: 0.0492
🧠 Epoch 14/15 | Train Loss: 0.0624 | Val Loss: 0.0977
🧠 Epoch 15/15 | Train Loss: 0.0595 | Val Loss: 0.0533

✅ Accuracy: 0.9834

✅ Classification Report:
                  precision    recall  f1-score   support

    Five-Memory       0.99      1.00      0.

In [None]:
# === Prepare Balanced Binary Classifier Data (JustListen = 0, Memory = 1) ===
X_bin, y_bin = [], []
for subj_data in data:
    for cond_idx in range(3):  # Conditions: Five, Nine, Thirteen
        for trial_idx in range(54):
            # JustListen = 0
            sig_jl = subj_data[cond_idx, 0, trial_idx]
            if not np.isnan(sig_jl).all():
                X_bin.append(sig_jl)
                y_bin.append(0)
            # MemoryCorrect + MemoryIncorrect = 1
            for subcond_idx in [1, 2]:
                sig_mem = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(sig_mem).all():
                    X_bin.append(sig_mem)
                    y_bin.append(1)

X_bin = downsample_signals(clean_signals(np.array(X_bin)))
label_names_bin = ["JustListen", "Memory"]

# === Train Models ===
train_model(X_bin, np.array(y_bin), label_names_bin, ECG1DCNN, "1D_CNN_Binary")
train_model(X_bin, np.array(y_bin), label_names_bin, ECG_BiLSTM, "BiLSTM_Binary")
train_model(X_bin, np.array(y_bin), label_names_bin, ECG_Transformer, "Transformer_Binary")


⚙️ Training 1D_CNN_Binary with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.6992 | Val Loss: 0.7023
🧠 Epoch 2/15 | Train Loss: 0.6940 | Val Loss: 0.6973
🧠 Epoch 3/15 | Train Loss: 0.6935 | Val Loss: 0.7061
🧠 Epoch 4/15 | Train Loss: 0.6923 | Val Loss: 0.6852
🧠 Epoch 5/15 | Train Loss: 0.6924 | Val Loss: 0.7030
🧠 Epoch 6/15 | Train Loss: 0.6925 | Val Loss: 0.6983
🧠 Epoch 7/15 | Train Loss: 0.6925 | Val Loss: 0.6869
🧠 Epoch 8/15 | Train Loss: 0.6923 | Val Loss: 0.6863
🧠 Epoch 9/15 | Train Loss: 0.6919 | Val Loss: 0.6940
⏹️ Early stopping triggered.

✅ Accuracy: 0.5029

✅ Classification Report:
               precision    recall  f1-score   support

  JustListen       0.33      0.56      0.42       813
      Memory       0.70      0.47      0.57      1748

    accuracy                           0.50      2561
   macro avg       0.52      0.52      0.49      2561
weighted avg       0.58      0.50      0.52      2561

✅ Confusion Matrix:
 [[459 354]
 [919 829]]

⚙️ Training BiLSTM_Binary

In [None]:
# === Prepare Combined 4-Class Classifier Data ===
X_combined, y_combined = [], []
combined_label_map = {
    "JustListen": 0,
    "Five-Memory": 1,
    "Nine-Memory": 2,
    "Thirteen-Memory": 3
}

for subj_data in data:
    for cond_idx, condition in enumerate(["Five", "Nine", "Thirteen"]):
        for trial_idx in range(54):
            # JustListen case (subcond_idx = 0)
            sig_jl = subj_data[cond_idx, 0, trial_idx]
            if not np.isnan(sig_jl).all():
                X_combined.append(sig_jl)
                y_combined.append(0)

            # Memory cases (subcond_idx in [1, 2])
            for subcond_idx in [1, 2]:
                sig_mem = subj_data[cond_idx, subcond_idx, trial_idx]
                if not np.isnan(sig_mem).all():
                    label = combined_label_map[f"{condition}-Memory"]
                    X_combined.append(sig_mem)
                    y_combined.append(label)

# Preprocess
X_combined = downsample_signals(clean_signals(np.array(X_combined)))
label_names_combined = list(combined_label_map.keys())

# === Train All 3 Models for Combined Classifier ===
train_model(X_combined, np.array(y_combined), label_names_combined, ECG1DCNN, "1D_CNN_Combined")
train_model(X_combined, np.array(y_combined), label_names_combined, ECG_BiLSTM, "BiLSTM_Combined")
train_model(X_combined, np.array(y_combined), label_names_combined, ECG_Transformer, "Transformer_Combined")



⚙️ Training 1D_CNN_Combined with Early Stopping...
🧠 Epoch 1/15 | Train Loss: 0.8164 | Val Loss: 0.8010
🧠 Epoch 2/15 | Train Loss: 0.7266 | Val Loss: 0.7408
🧠 Epoch 3/15 | Train Loss: 0.7014 | Val Loss: 0.7110
🧠 Epoch 4/15 | Train Loss: 0.6891 | Val Loss: 0.7277
🧠 Epoch 5/15 | Train Loss: 0.6674 | Val Loss: 0.7206
🧠 Epoch 6/15 | Train Loss: 0.6715 | Val Loss: 0.7034
🧠 Epoch 7/15 | Train Loss: 0.6658 | Val Loss: 0.7303
🧠 Epoch 8/15 | Train Loss: 0.6553 | Val Loss: 0.7066
🧠 Epoch 9/15 | Train Loss: 0.6517 | Val Loss: 0.6822
🧠 Epoch 10/15 | Train Loss: 0.6463 | Val Loss: 0.7005
🧠 Epoch 11/15 | Train Loss: 0.6440 | Val Loss: 0.6924
🧠 Epoch 12/15 | Train Loss: 0.6449 | Val Loss: 0.6954
🧠 Epoch 13/15 | Train Loss: 0.6399 | Val Loss: 0.6847
🧠 Epoch 14/15 | Train Loss: 0.6368 | Val Loss: 0.6880
⏹️ Early stopping triggered.

✅ Accuracy: 0.6669

✅ Classification Report:
                  precision    recall  f1-score   support

     JustListen       0.34      0.04      0.07       813
    Five-M