In [None]:
!pip install torch torcheeg
!pip install torch-scatter -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__.split('+')[0])").html
!pip install --upgrade --force-reinstall numpy torch torchvision torchaudio
!pip uninstall -y numpy scipy scikit-learn torcheeg torch torchvision torchaudio jax jaxlib
!pip install numpy==1.26.4 scipy==1.10.1 scikit-learn torch torchvision torchaudio torcheeg --force-reinstall


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip uninstall torch torchvision torchaudio torch-scatter -y
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
!pip install shap
!pip install captum

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torcheeg.models import EEGNet
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import numpy as np
import os
import matplotlib.pyplot as plt
import shap
from captum.attr import LRP

drive_path = "/content/drive/MyDrive/FileStore/"
preprocessed_folder = drive_path + "preprocessed_eeg/"



In [None]:
X_list, y_list = [], []


for file_name in sorted(os.listdir(preprocessed_folder)):
    if file_name.endswith(".npy") and not file_name.endswith("_label.npy"):  # Ignore label files
        label_file = file_name.replace(".npy", "_label.npy")
        label_path = os.path.join(preprocessed_folder, label_file)

        if os.path.exists(label_path):  # Ensure both X and y exist
            X = np.load(os.path.join(preprocessed_folder, file_name))  # EEG data
            y = np.load(label_path)  # Labels

            X_list.append(X)
            y_list.append(y)
        else:
            print(f"⚠️ Warning: Label file missing for {file_name}, skipping.")

# Merge all data into single NumPy arrays
X_all = np.concatenate(X_list, axis=0)  # Merge all trials into one dataset
y_all = np.concatenate(y_list, axis=0)

print(f"✅ Merged Data Shape: X = {X_all.shape}, y = {y_all.shape}")

X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.2, random_state=42, stratify=y_all)


# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
y_test_tensor = torch.tensor(y_test, dtype=torch.long)


# Create DataLoader
batch_size = 32
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



# Define the Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGNet(
    chunk_size=250,  # Timepoints per trial
    num_electrodes=5,  # EEG channels
    num_classes=2,
    F1=64,
    D=8,
    F2=512,
    dropout=0.4
).to(device)

# Define Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    verbose=True
)

# Training Loop
num_epochs = 200
best_val_loss = float('inf')
early_stop_counter = 0
patience = 15
train_losses = []
val_losses = []
print("\n🚀 Starting Training...")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_X, batch_y in train_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += batch_y.size(0)
        correct += (predicted == batch_y).sum().item()

    train_acc = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")

    # Validation check (using test set loss here as proxy)
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_test_tensor.to(device))
        val_loss = criterion(val_outputs, y_test_tensor.to(device)).item()

    train_losses.append(running_loss)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"⏹️ Early stopping triggered at epoch {epoch+1}!")
            break

# Evaluation on Test Set
print("\n🔍 Evaluating Model...")
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_X, batch_y in test_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        outputs = model(batch_X)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)

test_acc = 100 * correct / total
print(f"✅ Model Evaluation Complete! Test Accuracy: {test_acc:.2f}%")

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()