In [16]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import mne
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import time
import logging
import warnings
from tqdm import tqdm
from pytorch_metric_learning import losses
import pandas as pd

In [3]:
logging.disable(logging.CRITICAL)
warnings.filterwarnings("ignore")

# Load Data

### Sliding Window

In [4]:
def get_subject_id(file_path):
    parts = file_path.split(os.sep)
    subject = parts[-2]      # e.g., "S003"
    subject_num = int(subject[1:])  # converts "003" to integer 3
    return subject_num

def load_eeg_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload = True, verbose = False)
    raw.pick(['Oz..', 'T7..', 'Cz..'])    
    raw.filter(1., 40., fir_design = 'firwin', verbose = False)
    # raw.crop(tmin=0, tmax=30) #remove cropping
    return raw

# def make_epochs_with_offset(raw, duration, overlap, offset=0.0): #try loading without this function
#     if offset != 0.0:
#         raw_offset = raw.copy()
#         raw_offset.crop(tmin=offset)
#         epochs = mne.make_fixed_length_epochs(raw_offset, duration=duration, overlap=overlap, preload=True, verbose=False)
#     else:
#         epochs = mne.make_fixed_length_epochs(raw, duration=duration, overlap=overlap, preload=True, verbose=False)
#     return epochs

eeg_file_path = "./files/S003/S003R03.edf"
raw = load_eeg_data(eeg_file_path)
print(raw.annotations)
# sfreq = raw.info['sfreq']
# T_samples = 160
# stride_samples = 4
# offset_samples = 8

# T_sec = T_samples / sfreq
# stride_sec = stride_samples / sfreq
# overlap = T_sec - stride_sec

# epochs1 = make_epochs_with_offset(raw, duration=T_sec, overlap=overlap, offset=0.0)
# epochs2 = make_epochs_with_offset(raw, duration=T_sec, overlap=overlap, offset=offset_samples/sfreq)
# epochs_combined = mne.concatenate_epochs([epochs1, epochs2])


<Annotations | 30 segments: T0 (15), T1 (7), T2 (8)>


## Generate Dataset

In [5]:
class EEGMotorImageryDataset(Dataset):
    def __init__(self, epochs, subject_id):
        data = epochs.get_data()
        valid_idx = [i for i, epoch in enumerate(data) if epoch.shape[-1] == 160]
    
        self.data = data[valid_idx]
        self.labels = np.full((self.data.shape[0],), subject_id) 

        self.data = (self.data - np.mean(self.data, axis = 2, keepdims = True)) / np.std(self.data, axis = 2, keepdims = True)
        self.data = self.data.astype(np.float32)
        self.indices_by_class = {subject_id: list(range(len(self.data)))}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        sample = sample[np.newaxis,:,:]
        return sample, label



In [6]:
def load_file_as_dataset(file_path):

    raw = load_eeg_data(file_path)
    # epochs1 = make_epochs_with_offset(raw, duration=1.0, overlap=0.975, offset=0.0) #try without
    # epochs2 = make_epochs_with_offset(raw, duration=1.0, overlap=0.975, offset=0.05) #try without
    # epochs = mne.concatenate_epochs([epochs1, epochs2]) #try without
    sfreq = raw.info['sfreq']
    expected_length = int(round(sfreq * 1.0))  # for 1 second epochs
    print(f"Sampling frequency: {sfreq}, expected epoch length: {expected_length}")

    epochs = mne.make_fixed_length_epochs(raw, duration=1.0, preload=True, verbose=False)


    data = epochs.get_data() 
    valid_idx = [i for i, epoch in enumerate(data) if epoch.shape[-1] == expected_length]
    filtered_epochs = epochs[valid_idx]

    subject_id = get_subject_id(file_path)
    return EEGMotorImageryDataset(filtered_epochs, subject_id=subject_id)

In [7]:
def load_all_data(folder_path):
    edf_files = []
    for root, dir, files in os.walk(folder_path):
        for f in files:
            if f.lower().endswith('.edf'):
                edf_files.append(os.path.join(root, f))

    datasets = []
    for file_path in edf_files:
        print("Processing file:", file_path)
        ds = load_file_as_dataset(file_path)
        datasets.append(ds)
    if datasets:
        combined_dataset = ConcatDataset(datasets)
        return combined_dataset
    else:
        return None

In [8]:
folder_path = "./files"
dataset = load_all_data(folder_path)

if dataset is None:
    print('No EDF files found')
else: 
    print("Total samples in combined_dataset:", len(dataset))

Processing file: ./files\S001\S001R01.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R02.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R03.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R04.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R05.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R06.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R07.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R08.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R09.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R10.edf
Sampling frequency: 160.0, expected epoch length: 160
Processing file: ./files\S001\S001R11.ed

## Siamse Dataset

In [14]:
class SiameseEEGMotorImageryDataset(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.indices_by_subject = {}
        for idx in range(len(self.base_dataset)):
            _, label = self.base_dataset[idx]
            if label not in self.indices_by_subject:
                self.indices_by_subject[label] = []
            self.indices_by_subject[label].append(idx)
        self.all_subjects = list(self.indices_by_subject.keys())


    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        sample1, label1 = self.base_dataset[idx]

        if np.random.rand() < 0.5:
            indices = self.indices_by_subject[label1]
            canidate = np.random.choice([i for i in indices if i != idx])
            sample2, _ = self.base_dataset[canidate]
            pair_label = 1
        else:
            other_subjects = [sub for sub in self.all_subjects if sub != label1]
            negative_label = np.random.choice(other_subjects)
            indices = self.indices_by_subject[negative_label]
            canidate = np.random.choice(indices)
            sample2, _ = self.base_dataset[canidate]
            pair_label = 0

        return (sample1, sample2), pair_label

siamese_dataset = SiameseEEGMotorImageryDataset(dataset)

In [30]:
batch_size = 32
siamese_loader = DataLoader(siamese_dataset, batch_size = batch_size, shuffle = True)

## EEGNet

In [31]:
num_input = 1
F1 = 8
D = 2
F2 = 16

kernel_size_1 = (1, 64)  
kernel_padding_1 = (0, 32)
kernel_size_2 = (2, 32)  
kernel_avgpool_1 = (1, 8)
dropout_rate = 0.5
kernel_size_3 = (1, 16)  
kernel_padding_3 = (0, 8)
kernel_size_4 = (1, 1)   
kernel_avgpool_2 = (1, 4)
signal_length = 160      
embedding_dim = 128 

In [32]:
class EEGNetFeature(nn.Module): 
    def __init__(self):
        super(EEGNetFeature, self).__init__()
        # Layer 1
        self.conv2d = nn.Conv2d(num_input, F1, kernel_size=kernel_size_1, padding=kernel_padding_1)
        self.Batch_normalization_1 = nn.BatchNorm2d(F1)
        # Layer 2
        self.Depthwise_conv2D = nn.Conv2d(F1, D * F1, kernel_size=kernel_size_2, groups=F1)
        self.Batch_normalization_2 = nn.BatchNorm2d(D * F1)
        self.Elu = nn.ELU()
        self.Average_pooling2D_1 = nn.AvgPool2d(kernel_avgpool_1)
        self.Dropout = nn.Dropout2d(dropout_rate)
        # Layer 3
        self.Separable_conv2D_depth = nn.Conv2d(D * F1, D * F1, kernel_size=kernel_size_3,
                                                 padding=kernel_padding_3, groups=D * F1)
        self.Separable_conv2D_point = nn.Conv2d(D * F1, F2, kernel_size=kernel_size_4)
        self.Batch_normalization_3 = nn.BatchNorm2d(F2)
        self.Average_pooling2D_2 = nn.AvgPool2d(kernel_avgpool_2)
        # Layer 4
        self.Flatten = nn.Flatten()
        # Compute the flattened feature size. This depends on your input signal size.
        # Here we assume the pooling operations reduce the time dimension by a factor of 32.
        self.Dense = nn.Linear(128, embedding_dim)
        # Note: we remove the Softmax to get raw embeddings (or logits) for the siamese branch

    def forward(self, x):
        # Layer 1
        y = self.conv2d(x)
        y = self.Batch_normalization_1(y)
        # Layer 2
        y = self.Depthwise_conv2D(y)
        y = self.Batch_normalization_2(y)
        y = self.Elu(y)
        y = self.Average_pooling2D_1(y)
        y = self.Dropout(y)
        # Layer 3
        y = self.Separable_conv2D_depth(y)
        y = self.Separable_conv2D_point(y)
        y = self.Batch_normalization_3(y)
        y = self.Elu(y)
        y = self.Average_pooling2D_2(y)
        y = self.Dropout(y)
        # Layer 4
        y = self.Flatten(y)
        y = self.Dense(y)
        return y  # These are your embeddings (or logits)

## SiameseEEGNet

In [33]:
class SiameseEEGNet(nn.Module):
    def __init__(self):
        super(SiameseEEGNet, self).__init__()
        # Shared EEGNet feature extractor (weights will be shared for both inputs)
        self.feature_extractor = EEGNetFeature()
        
    def forward(self, x1, x2):
        # Get embeddings for both inputs
        embed1 = self.feature_extractor(x1)
        embed2 = self.feature_extractor(x2)
        return embed1, embed2


# Example usage:
# Create the Siamese model instance
siamese_model = SiameseEEGNet()

### Loss Function

In [None]:
def contrastive_loss(y_true, cos_sim, margin = 1.0):
    loss_similar = y_true * torch.pow((1-cos_sim), 2)

    loss_dissimilar = (1-y_true) * torch.pow(torch.clamp(cos_sim - margin, min = 0.0), 2)
    loss = torch.mean(loss_similar + loss_dissimilar)
    return loss

### Training Loss and Validation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Instantiate your model and move it to device.
siamese_model = SiameseEEGNet().to(device)

# Hyperparameters
batch_size = 32
learning_rate = 1e-3
num_epochs = 10
threshold = 0.725
# Split the dataset into training and validation (80%/20% split)
dataset_size = len(siamese_dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(siamese_dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define optimizer (using Adam here)
optimizer = torch.optim.Adam(siamese_model.parameters(), lr=learning_rate)

# Optionally, define a scheduler (e.g., ReduceLROnPlateau) if desired:
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# For tracking best validation loss to save the best model
best_val_loss = float('inf')

# Training loop
for epoch in tqdm(range(num_epochs)):
    siamese_model.train()  # set model to training mode
    running_loss = 0.0
    epoch_start_time = time.time()
    
    # Training step
    for batch_idx, (x1, x2, labels) in enumerate(train_loader):
        # Move data to device
        x1 = x1.to(device)
        x2 = x2.to(device)
        labels = labels.to(device)  # Expected shape: (batch_size, 1)
        
        optimizer.zero_grad()
        
        # Forward pass: compute embeddings and cosine similarity
        embed1, embed2 = siamese_model(x1, x2)
        
        # Compute loss using cosine-based contrastive loss
        cos_sim = F.cosine_similarity(embed1, embed2, dim=1, eps=1e-6).unsqueeze(1)
        loss = contrastive_loss(labels, cos_sim, margin=0.9)
        
        # Backpropagation and optimization step
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * x1.size(0)
    
    # Compute average training loss for the epoch
    train_loss = running_loss / len(train_loader.dataset)
    
    # Validation step (without gradient computations)
    siamese_model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for x1, x2, labels in val_loader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            labels = labels.to(device).float().unsqueeze(1)
            
            # Forward pass on validation data
            embed1, embed2 = siamese_model(x1, x2)
            cos_sim = F.cosine_similarity(embed1, embed2, dim=1, eps=1e-6).unsqueeze(1)
            loss = contrastive_loss(labels, cos_sim, margin=0.9)
            val_running_loss += loss.item() * x1.size(0)

            preds = (cos_sim > threshold).float()
            correct += (preds == labels).float().sum().item()
            total += labels.size(0)
    
    val_loss = val_running_loss / len(val_loader.dataset)
    val_accuracy = correct / total
    
    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs} | Time: {epoch_duration:.2f}s | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}")
    
    # Optionally, step the scheduler:
    # scheduler.step(val_loss)
    
    # Save the model if validation loss decreases
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(siamese_model.state_dict(), 'best_siamese_model.pth')
        print("  --> Best model saved.")

print("Training complete.")


Using device: cpu


  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [01:53<17:00, 113.43s/it]

Epoch 1/10 | Time: 113.43s | Train Loss: 0.0042 | Val Loss: 0.0036 | Val Acc: 0.5012
  --> Best model saved.


 20%|██        | 2/10 [03:44<14:58, 112.29s/it]

Epoch 2/10 | Time: 111.49s | Train Loss: 0.0026 | Val Loss: 0.0034 | Val Acc: 0.5014
  --> Best model saved.


 30%|███       | 3/10 [05:39<13:12, 113.22s/it]

Epoch 3/10 | Time: 114.32s | Train Loss: 0.0026 | Val Loss: 0.0033 | Val Acc: 0.5025
  --> Best model saved.


 40%|████      | 4/10 [07:32<11:20, 113.37s/it]

Epoch 4/10 | Time: 113.59s | Train Loss: 0.0026 | Val Loss: 0.0032 | Val Acc: 0.5010
  --> Best model saved.


 50%|█████     | 5/10 [09:26<09:28, 113.63s/it]

Epoch 5/10 | Time: 114.09s | Train Loss: 0.0026 | Val Loss: 0.0034 | Val Acc: 0.5034


 60%|██████    | 6/10 [11:19<07:33, 113.29s/it]

Epoch 6/10 | Time: 112.64s | Train Loss: 0.0026 | Val Loss: 0.0033 | Val Acc: 0.4970


 70%|███████   | 7/10 [13:11<05:38, 112.85s/it]

Epoch 7/10 | Time: 111.96s | Train Loss: 0.0026 | Val Loss: 0.0034 | Val Acc: 0.4992


 80%|████████  | 8/10 [15:06<03:47, 113.54s/it]

Epoch 8/10 | Time: 115.01s | Train Loss: 0.0026 | Val Loss: 0.0034 | Val Acc: 0.4990


 90%|█████████ | 9/10 [16:59<01:53, 113.48s/it]

Epoch 9/10 | Time: 113.36s | Train Loss: 0.0026 | Val Loss: 0.0033 | Val Acc: 0.5016


100%|██████████| 10/10 [18:53<00:00, 113.38s/it]

Epoch 10/10 | Time: 113.92s | Train Loss: 0.0026 | Val Loss: 0.0034 | Val Acc: 0.4978
Training complete.



