In [8]:
import random
import pennylane as qml
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import copy
from sklearn.metrics import roc_auc_score
import copy
import time
from typing import Any, Optional, Tuple, Callable
import mne
from sklearn.model_selection import train_test_split
import math

print('Pennylane Version :', qml.__version__)
print('Pytorch Version :', torch.__version__)

Pennylane Version : 0.41.1
Pytorch Version : 2.7.1+cu126


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print("Running on ", device)

Running on  cuda


# Prepare PhysioNet EEG Dataset

In [None]:
def load_eeg_ts(seed, device, batch_size, sampling_freq):
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        
    # Load and preprocess the PhysioNet EEG Motor Imagery data
    N_SUBJECT = 50
    IMAGINE_OPEN_CLOSE_LEFT_RIGHT_FIST = [4, 8, 12]

    # Load data from PhysioNet (example assumes data is downloaded locally)
    physionet_paths = [
        mne.datasets.eegbci.load_data(
            subjects=subj_id,
            runs=IMAGINE_OPEN_CLOSE_LEFT_RIGHT_FIST,
            path="PhysioNet_EEG",
        ) for subj_id in range(1, N_SUBJECT+1)
    ]
    physionet_paths = np.concatenate(physionet_paths)

    # Ensuring that all subjects share same sampling frequency
    # TARGET_SFREQ = 160  # 160 Hz sampling rate
    TARGET_SFREQ = sampling_freq

    # Concatenate all loaded raw data
    parts = []
    for path in physionet_paths:
        raw = mne.io.read_raw_edf(
            path,
            preload=True,
            stim_channel='auto',
            verbose='WARNING',
        )
        # Resample raw data to ensure consistent sfreq
        raw.resample(TARGET_SFREQ, npad="auto")
        parts.append(raw)
        
    # Concatenate resampled raw data
    raw = mne.concatenate_raws(parts)

    # Pick EEG channels and extract events
    eeg_channel_inds = mne.pick_types(
        raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads'
    )
    events, _ = mne.events_from_annotations(raw)

    # Epoch the data
    epoched = mne.Epochs(
        raw, events, dict(left=2, right=3), tmin=1, tmax=4.1,
        proj=False, picks=eeg_channel_inds, baseline=None, preload=True
    )

    # Convert data to NumPy arrays
    X = (epoched.get_data() * 1e3).astype(np.float32)  # Convert to millivolts
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 0: left, 1: right
    
    # Train-validation-test split
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=seed)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=seed)
    
    def MakeTensorDataset(X, y):
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        y_tensor = torch.tensor(y, dtype=torch.float32).to(device)
        tensordataset = TensorDataset(X_tensor, y_tensor)
        return tensordataset
    
    # Create datasets and dataloaders
    train_dataset = MakeTensorDataset(X_train, y_train)
    val_dataset = MakeTensorDataset(X_val, y_val)
    test_dataset = MakeTensorDataset(X_test, y_test)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    input_dim = X_train.shape
    
    return train_loader, val_loader, test_loader, input_dim

In [10]:
train_loader, val_loader, test_loader, input_dim = load_eeg_ts(seed=2025, device=device, batch_size=32, sampling_freq=4)

Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Not setting metadata
2250 matching events found
No baseline correction applied
Using data from preloaded Raw for 2250 events and 13 original time points ...
116 bad epochs dropped


In [20]:
next(iter(test_loader))

[tensor([[[-9.3418e-03, -2.0892e-02, -2.5427e-02,  ...,  1.9003e-02,
            1.0521e-02,  6.7763e-03],
          [-1.8186e-02, -2.4984e-02, -3.0667e-02,  ...,  2.5125e-02,
            1.6405e-02,  2.9495e-02],
          [-1.0434e-02, -1.6612e-02, -2.2492e-02,  ...,  2.6761e-02,
            1.4599e-02,  2.1806e-02],
          ...,
          [ 1.9377e-03,  1.8254e-03,  5.9484e-03,  ..., -3.2511e-03,
           -1.0333e-02,  4.2859e-03],
          [-3.2623e-03, -1.1324e-02, -1.3641e-02,  ...,  1.9414e-02,
            1.0725e-02,  1.4776e-02],
          [ 2.0441e-03,  1.9807e-03, -2.6451e-03,  ...,  5.6664e-03,
           -1.1818e-02,  2.9063e-03]],
 
         [[-2.8618e-04,  9.8959e-03, -9.4629e-03,  ...,  4.9505e-03,
            5.0498e-03, -8.0845e-03],
          [ 1.4872e-03,  1.1819e-02, -7.5468e-03,  ...,  6.8991e-03,
            7.0454e-03, -6.1632e-03],
          [ 6.6202e-03,  1.6407e-02, -1.0013e-02,  ...,  2.6930e-03,
            1.9717e-03,  2.0611e-03],
          ...,
    

In [11]:
input_dim

(1493, 64, 13)

In [55]:
from tcn import TemporalConvNet
import torch.nn.functional as F

class EEGTCNClassifier(nn.Module):
    def __init__(self, in_ch, tcn_ch, num_classes, k=2, dropout=0.2):
        super().__init__()
        self.tcn = TemporalConvNet(in_ch, tcn_ch, k, dropout)
        self.linear = nn.Linear(tcn_ch[-1], num_classes)
    def forward(self, x):
        # x: (B, C, L)
        y = self.tcn(x)            # → (B, C_out, L)
        logit = self.linear(y[:, :, -1])
        return logit

In [56]:
################################# Calculate Running Time ########################################
def epoch_time(start_time: float, end_time: float) -> Tuple[float, float]:
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


################################# Performance & Density Matrices ################################
# Training loop
def train_perf(model, dataloader, optimizer, criterion):
    model.train()
    train_loss = 0.0
    all_labels = []
    all_outputs = []
    for inputs, labels in tqdm(dataloader):
        assert not torch.isnan(inputs).any(), "Inputs contain NaN!"
        assert not torch.isinf(inputs).any(), "Inputs contain Inf!"
        assert not torch.isnan(labels).any(), "Labels contain NaN!"
        inputs, labels = inputs.to(device), labels.to(device)  # Ensure that data is on the same device (GPU or CPU)
        labels = labels.float()   # Ensure labels are of type float for BCEWithLogitsLoss
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs.squeeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        # Collect labels and outputs for AUROC
        all_labels.append(labels.cpu().numpy())
        all_outputs.append(outputs.detach().cpu().numpy())       
    # Calculate train AUROC
    all_labels = np.concatenate(all_labels)
    all_outputs = np.concatenate(all_outputs)
    print("labels NaNs:", np.isnan(all_labels).sum(), "/", all_labels.shape[0])
    print("outputs NaNs:", np.isnan(all_outputs).sum(), "/", all_outputs.shape[0])
    train_auroc = roc_auc_score(all_labels, all_outputs)
    
    return train_loss / len(dataloader), train_auroc


# Validation/Test loop
def evaluate_perf(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_outputs = []
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)  # Ensure that data is on the same device (GPU or CPU)
            labels = labels.float()   # Ensure labels are of type float for BCEWithLogitsLoss
            outputs = model(inputs)
            outputs = outputs.squeeze(1)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Collect labels and outputs for AUROC
            all_labels.append(labels.cpu().numpy())
            all_outputs.append(outputs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_outputs = np.concatenate(all_outputs)
    auroc = roc_auc_score(all_labels, all_outputs)
    
    return running_loss / len(dataloader), auroc

In [57]:
num_epochs = 20  # Set the number of epochs for training

model = EEGTCNClassifier(in_ch=64, tcn_ch=[64, 64, 64], num_classes=1, k=2, dropout=0.2).to(device)
criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss for binary classification
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-4, eps=1e-8)

# Training process
train_metrics, valid_metrics, test_metrics = [], [], []

for epoch in range(num_epochs):
    start_time = time.time()
    
    train_loss, train_auc = train_perf(model, train_loader, optimizer, criterion)
    train_metrics.append({'epoch': epoch + 1, 'train_loss': train_loss, 'train_auc': train_auc})    

    valid_loss, valid_auc = evaluate_perf(model, val_loader, criterion)
    valid_metrics.append({'epoch': epoch + 1, 'valid_loss': valid_loss, 'valid_auc': valid_auc})

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f"Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s")
    print(f"Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f} | Validation Loss: {valid_loss:.4f}, AUC: {valid_auc:.4f}")

# Final evaluation on the test set
test_loss, test_auc = evaluate_perf(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}, AUC: {test_auc:.4f}")
test_metrics.append({'epoch': num_epochs, 'test_loss': test_loss, 'test_auc': test_auc}) 

# Combine all metrics into a pandas DataFrame
metrics = []
for epoch in range(num_epochs):
    metrics.append({
        'epoch': epoch + 1,
        'train_loss': train_metrics[epoch]['train_loss'],
        'train_auc': train_metrics[epoch]['train_auc'],
        'valid_loss': valid_metrics[epoch]['valid_loss'],
        'valid_auc': valid_metrics[epoch]['valid_auc'],
        'test_loss': test_metrics[0]['test_loss'],
        'test_auc': test_metrics[0]['test_auc'],
    })
# Convert to DataFrame
metrics_df = pd.DataFrame(metrics)
# Save to CSV
csv_filename = f"TCN_performance.csv"
metrics_df.to_csv(csv_filename, index=False)
print(f"Metrics saved to {csv_filename}")

  WeightNorm.apply(module, name, dim)
 32%|███▏      | 15/47 [00:00<00:00, 142.70it/s]

100%|██████████| 47/47 [00:00<00:00, 150.12it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 461.07it/s]


Epoch: 01 | Time: 0m 0s
Train Loss: 0.6938, AUC: 0.4821 | Validation Loss: 0.6921, AUC: 0.6155


100%|██████████| 47/47 [00:00<00:00, 155.90it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 442.29it/s]


Epoch: 02 | Time: 0m 0s
Train Loss: 0.6899, AUC: 0.5609 | Validation Loss: 0.6892, AUC: 0.6330


100%|██████████| 47/47 [00:00<00:00, 142.03it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 480.54it/s]


Epoch: 03 | Time: 0m 0s
Train Loss: 0.6664, AUC: 0.6529 | Validation Loss: 0.6381, AUC: 0.7458


100%|██████████| 47/47 [00:00<00:00, 136.40it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 365.73it/s]


Epoch: 04 | Time: 0m 0s
Train Loss: 0.6149, AUC: 0.7297 | Validation Loss: 0.6070, AUC: 0.7530


100%|██████████| 47/47 [00:00<00:00, 149.66it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 365.73it/s]


Epoch: 05 | Time: 0m 0s
Train Loss: 0.5778, AUC: 0.7761 | Validation Loss: 0.6118, AUC: 0.7523


100%|██████████| 47/47 [00:00<00:00, 129.17it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 276.00it/s]


Epoch: 06 | Time: 0m 0s
Train Loss: 0.5578, AUC: 0.7940 | Validation Loss: 0.5926, AUC: 0.7654


100%|██████████| 47/47 [00:00<00:00, 113.20it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 321.75it/s]


Epoch: 07 | Time: 0m 0s
Train Loss: 0.5507, AUC: 0.7940 | Validation Loss: 0.5927, AUC: 0.7615


100%|██████████| 47/47 [00:00<00:00, 113.71it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 325.83it/s]


Epoch: 08 | Time: 0m 0s
Train Loss: 0.5227, AUC: 0.8197 | Validation Loss: 0.5907, AUC: 0.7632


100%|██████████| 47/47 [00:00<00:00, 119.65it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 398.00it/s]


Epoch: 09 | Time: 0m 0s
Train Loss: 0.5137, AUC: 0.8262 | Validation Loss: 0.5962, AUC: 0.7657


100%|██████████| 47/47 [00:00<00:00, 140.22it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 427.89it/s]


Epoch: 10 | Time: 0m 0s
Train Loss: 0.5035, AUC: 0.8338 | Validation Loss: 0.6153, AUC: 0.7651


100%|██████████| 47/47 [00:00<00:00, 159.31it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 440.39it/s]


Epoch: 11 | Time: 0m 0s
Train Loss: 0.4889, AUC: 0.8449 | Validation Loss: 0.5756, AUC: 0.7816


100%|██████████| 47/47 [00:00<00:00, 154.75it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 387.43it/s]


Epoch: 12 | Time: 0m 0s
Train Loss: 0.4688, AUC: 0.8589 | Validation Loss: 0.6116, AUC: 0.7643


100%|██████████| 47/47 [00:00<00:00, 137.93it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 385.86it/s]


Epoch: 13 | Time: 0m 0s
Train Loss: 0.4682, AUC: 0.8604 | Validation Loss: 0.6044, AUC: 0.7727


100%|██████████| 47/47 [00:00<00:00, 140.92it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 420.20it/s]


Epoch: 14 | Time: 0m 0s
Train Loss: 0.4563, AUC: 0.8667 | Validation Loss: 0.6131, AUC: 0.7750


100%|██████████| 47/47 [00:00<00:00, 140.75it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 372.13it/s]


Epoch: 15 | Time: 0m 0s
Train Loss: 0.4777, AUC: 0.8521 | Validation Loss: 0.5999, AUC: 0.7715


100%|██████████| 47/47 [00:00<00:00, 143.14it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 414.41it/s]


Epoch: 16 | Time: 0m 0s
Train Loss: 0.4457, AUC: 0.8745 | Validation Loss: 0.6261, AUC: 0.7779


100%|██████████| 47/47 [00:00<00:00, 152.16it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 446.72it/s]


Epoch: 17 | Time: 0m 0s
Train Loss: 0.4442, AUC: 0.8743 | Validation Loss: 0.6161, AUC: 0.7678


100%|██████████| 47/47 [00:00<00:00, 146.52it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 409.89it/s]


Epoch: 18 | Time: 0m 0s
Train Loss: 0.4312, AUC: 0.8827 | Validation Loss: 0.6315, AUC: 0.7714


100%|██████████| 47/47 [00:00<00:00, 146.39it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 420.73it/s]


Epoch: 19 | Time: 0m 0s
Train Loss: 0.4316, AUC: 0.8830 | Validation Loss: 0.6272, AUC: 0.7641


100%|██████████| 47/47 [00:00<00:00, 135.81it/s]


labels NaNs: 0 / 1493
outputs NaNs: 0 / 1493


100%|██████████| 10/10 [00:00<00:00, 380.37it/s]


Epoch: 20 | Time: 0m 0s
Train Loss: 0.4135, AUC: 0.8931 | Validation Loss: 0.6320, AUC: 0.7743


100%|██████████| 11/11 [00:00<00:00, 392.61it/s]


Test Loss: 0.5062, AUC: 0.8192
Metrics saved to TCN_performance.csv
