In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/EEG_analysis/cnn_lstm')
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import LSTMClassifier
from tqdm import tqdm
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import os

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_recall_curve, auc, classification_report




In [None]:
class SeizureDataset(Dataset):
    def __init__(self, data_path, label_path):
        # Load data and labels
        self.data = pd.read_csv(data_path, header=None).values
        self.labels = pd.read_csv(label_path, header=None).values

        # Only use features from index 934 to 1279
        self.feature_start = 768
        self.feature_end = 1280

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Extract the relevant features (934-1279)
        features = self.data[idx, self.feature_start:self.feature_end]
        # Convert to tensor and reshape to [sequence_length, num_features]
        # We'll treat the 346 features as 346 timesteps with 1 feature each
        features = torch.FloatTensor(features).unsqueeze(1)  # [346, 1]
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return features, label


def collate_fn(batch):
    data, labels = zip(*batch)
    data_padded = pad_sequence(data, batch_first=True, padding_value=0)
    labels = torch.stack(labels, dim=0)

    return data_padded, labels

In [None]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math

class LSTMClassifier1D(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout_rate=0.5):
        super(LSTMClassifier1D, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Conv1d: expecting input shape (batch, channels=1, seq_len)
        self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.conv2 = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(hidden_size)

        # LSTM
        self.lstm = nn.LSTM(hidden_size + input_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate)

        # FC layers
        self.dropout = nn.Dropout(dropout_rate)
        self.fc1 = nn.Linear(hidden_size, 256)
        self.bn3 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn4 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        batch_size, seq_len, input_size = x.size()

        # Permute to (batch, input_size, seq_len) for Conv1d
        x_conv = x.permute(0, 2, 1)
        x_conv = torch.relu(self.bn1(self.conv1(x_conv)))
        x_conv = torch.relu(self.bn2(self.conv2(x_conv)))  # Shape: (batch, hidden_size, seq_len)

        x_conv = x_conv.permute(0, 2, 1)  # Back to (batch, seq_len, hidden_size)

        # Residual connection (original input): (batch, seq_len, input_size)
        x_cat = torch.cat((x_conv, x), dim=2)  # (batch, seq_len, hidden_size + input_size)

        # Initialize hidden state
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)

        out, _ = self.lstm(x_cat, (h0, c0))
        out = out[:, -1, :]  # Take last time step

        out = self.dropout(out)
        out = torch.relu(self.bn3(self.fc1(out)))
        out = self.dropout(out)
        out = torch.relu(self.bn4(self.fc2(out)))
        out = self.fc3(out)
        return out

In [None]:

# Define paths
data_path = '/content/drive/MyDrive/Colab Notebooks/EEG_analysis/processed_data/'
train_features_path = data_path + "concatenated_train_final_data_epoch_final.csv"
train_labels_path = data_path + "train_labels.csv"
test_features_path = data_path + "concatenated_test_final_data_epoch_final.csv"
test_labels_path = data_path + "test_labels.csv"

# Initialize datasets
train_dataset = SeizureDataset(train_features_path, train_labels_path)
test_dataset = SeizureDataset(test_features_path, test_labels_path)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # or 64, 128, 256
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

criterion = nn.BCEWithLogitsLoss()

model = LSTMClassifier1D(
    input_size=1,      # 1 channel EEG
    hidden_size=128,   # Hidden layer size
    num_layers=4,      # Number of LSTM layers
    num_classes=1,     # Binary classification
    dropout_rate=0.2
)

model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


best_val_loss = float("inf")
tolerance_counter = 0

epochs = 150
tolerance = 5  # Number of epochs to tolerate non-decreasing val loss before switching dataloader

for epoch in range(epochs):
    model.train()
    t_loss = 0
    for data, ground_truth in tqdm(train_loader, leave=False):
        data, ground_truth = data.to(device), ground_truth.to(device)
        optimizer.zero_grad()
        pred = model(data)
        ground_truth = ground_truth.view(-1, 1)
        ground_truth = ground_truth.float()
        loss = criterion(pred, ground_truth)
        t_loss += loss.item()
        loss.backward()
        optimizer.step()

        # Print average loss for the epoch (moved outside the batch loop)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {t_loss/len(train_loader):.4f}")

    if epoch % 10 == 0:
        torch.save(model.state_dict(), f'ckpts/weight_lstm_{epoch}.pth')

print("Training complete!")
torch.save(model.state_dict(), 'ckpts/weight_lstm_final.pth')


# Evaluate

In [None]:
# Generate predictions
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device)
        preds = torch.sigmoid(model(data))  # Convert logits to probabilities
        preds = (preds > 0.5).int()  # Threshold at 0.5
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Non-Seizure", "Seizure"],
            yticklabels=["Non-Seizure", "Seizure"])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Get predicted probabilities (not thresholded)
all_probs = []
all_labels = []

with torch.no_grad():
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device)
        probs = torch.sigmoid(model(data)).cpu().numpy()
        all_probs.extend(probs)
        all_labels.extend(labels.cpu().numpy())

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
auc_score = roc_auc_score(all_labels, all_probs)

# Plot
plt.plot(fpr, tpr, label=f"AUC = {auc_score:.2f}")
plt.plot([0, 1], [0, 1], linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.show()

In [None]:
# Get some test samples
test_samples, test_labels = next(iter(test_loader))
test_samples = test_samples.to(device)

# Get predictions
model.eval()
with torch.no_grad():
    outputs = torch.sigmoid(model(test_samples))
    preds = (outputs > 0.5).float()

# Visualize first 5 samples
for i in range(5):
    plt.figure(figsize=(10, 4))
    plt.plot(test_samples[i].cpu().numpy())
    plt.title(f"True: {test_labels[i].item()}, Pred: {preds[i].item()}")
    plt.show()

In [None]:

precision, recall, _ = precision_recall_curve(all_labels, all_probs)
auprc = auc(recall, precision)

plt.plot(recall, precision, label=f"AUPRC = {auprc:.2f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.show()

In [None]:
# Select a sample from the test set
sample_idx = 0  # Change to visualize different samples
data, label = test_dataset[sample_idx]

# Generate prediction
with torch.no_grad():
    data = data.unsqueeze(0).to(device)  # Add batch dimension
    prob = torch.sigmoid(model(data)).item()
pred = "Seizure" if prob > 0.5 else "Non-Seizure"

# Plot the time-series data
plt.figure(figsize=(10, 4))
plt.plot(data.cpu().numpy()[:, 0], label="EEG Channel 1")  # Plot 1st feature
plt.title(f"True Label: {label.item()} | Predicted: {pred} (Prob: {prob:.2f})")
plt.xlabel("Time Step")
plt.ylabel("Amplitude")
plt.legend()
plt.show()

In [None]:

print(classification_report(all_labels, all_preds,
                            target_names=["Non-Seizure", "Seizure"]))