In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import mne
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import time

In [None]:
def load_eeg_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload = True, verbose = False)
    raw.pick(['Oz..', 'T7..', 'Cz..'])    
    raw.filter(1., 40., fir_design = 'firwin', verbose = False)

    T_sec = 1
    stride_sec = 4 / 160
    offset_sec = 8/160
    overlap = T_sec - stride_sec

    epochs1 = mne.make_fixed_length_epochs(
        raw,
        duration=T_sec,
        overlap=overlap,
        preload=True,
        verbose=False
    )
    # epochs2 = mne.make_fixed_length_epochs(
    #     raw,
    #     duration=T_sec,
    #     overlap=overlap,
    #     preload=True,
    #     verbose=False,
    # )
    # combined = mne.concatenate_epochs([epochs1, epochs2])
    return epochs1

eeg_file_path = 'data/files/eegmmidb/1.0.0/S003/S003R01.edf'
epochs = load_eeg_data(eeg_file_path)
epochs

In [None]:
epochs.plot(n_epochs = 10, n_channels = 3, scalings = 'auto')

In [None]:
class EEGMotorImageryDataset(Dataset):
    def __init__(self, epochs):
        self.data = epochs.get_data()
        self.labels = epochs.events[:,-1]

        self.data = (self.data - np.mean(self.data, axis = 2, keepdims = True)) / np.std(self.data, axis = 2, keepdims = True)
        self.data = self.data.astype(np.float32)

        self.indices_by_class = {}
        for i, label in enumerate(self.labels):
            self.indices_by_class.setdefault(label, []).append(i)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        sample = sample[np.newaxis,:,:]
        return sample, label

base_dataset = EEGMotorImageryDataset(epochs)

In [None]:
for i in range(3):
    sample, label = base_dataset[i]
    # Remove the singleton dimension (1, n_channels, n_times) -> (n_channels, n_times)
    sample = sample.squeeze(0)
    n_channels, n_times = sample.shape
    plt.figure(figsize=(10, 4))
    for ch in range(n_channels):
        plt.plot(sample[ch], label=f"Channel {ch}")
    plt.title(f"Sample {i} - Label: {label}")
    plt.xlabel("Time points")
    plt.ylabel("Normalized Amplitude")
    plt.legend()
    plt.show()

In [None]:
class SiameseEEGMotorImageryDataset(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.data = base_dataset.data
        self.labels = base_dataset.labels
        self.indices_by_class = base_dataset.indices_by_class

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        sample1, label1 = self.basedataset[idx]

        similar = np.random.randit(0,2)

        if similar:
            indices = self.indices_by_class[label1]
            idx2 = idx
            while idx2 == idx:
                idx2 = np.random.choice(indices)
        else:
            other_classes = list(self.indices_by_class.keys())
            other_classes.remove(label1)
            chosen_label = np.random.choice(other_classes)
            idx2 = np.random.choice(self.indices_by_class[chosen_label])

        sample2, label2 = self.base_dataset[idx2] 
        similarity = 1 if label1 == label2 else 0
        return sample1, sample2, np.array([similarity], dtype = np.float32)

siamese_dataset = SiameseEEGMotorImageryDataset(base_dataset)

In [None]:
batch_size = 8
loader = DataLoader(siamese_dataset, batch_size = batch_size, shuffle = True)

In [None]:
num_input = 1
F1 = 8
D = 2
F2 = 16

kernel_size_1 = (1, 64)  
kernel_padding_1 = (0, 32)
kernel_size_2 = (2, 32)  
kernel_avgpool_1 = (1, 8)
dropout_rate = 0.5
kernel_size_3 = (1, 16)  
kernel_padding_3 = (0, 8)
kernel_size_4 = (1, 1)   
kernel_avgpool_2 = (1, 4)
signal_length = 256      
num_class = 4 

In [None]:
class EEGNetFeature(nn.Module): 
    def __init__(self):
        super(EEGNetFeature, self).__init__()
        # Layer 1
        self.conv2d = nn.Conv2d(num_input, F1, kernel_size=kernel_size_1, padding=kernel_padding_1)
        self.Batch_normalization_1 = nn.BatchNorm2d(F1)
        # Layer 2
        self.Depthwise_conv2D = nn.Conv2d(F1, D * F1, kernel_size=kernel_size_2, groups=F1)
        self.Batch_normalization_2 = nn.BatchNorm2d(D * F1)
        self.Elu = nn.ELU()
        self.Average_pooling2D_1 = nn.AvgPool2d(kernel_avgpool_1)
        self.Dropout = nn.Dropout2d(dropout_rate)
        # Layer 3
        self.Separable_conv2D_depth = nn.Conv2d(D * F1, D * F1, kernel_size=kernel_size_3,
                                                 padding=kernel_padding_3, groups=D * F1)
        self.Separable_conv2D_point = nn.Conv2d(D * F1, F2, kernel_size=kernel_size_4)
        self.Batch_normalization_3 = nn.BatchNorm2d(F2)
        self.Average_pooling2D_2 = nn.AvgPool2d(kernel_avgpool_2)
        # Layer 4
        self.Flatten = nn.Flatten()
        # Compute the flattened feature size. This depends on your input signal size.
        # Here we assume the pooling operations reduce the time dimension by a factor of 32.
        self.Dense = nn.Linear(F2 * (round(signal_length / 32)), num_class)
        # Note: we remove the Softmax to get raw embeddings (or logits) for the siamese branch

    def forward(self, x):
        # Layer 1
        y = self.conv2d(x)
        y = self.Batch_normalization_1(y)
        # Layer 2
        y = self.Depthwise_conv2D(y)
        y = self.Batch_normalization_2(y)
        y = self.Elu(y)
        y = self.Average_pooling2D_1(y)
        y = self.Dropout(y)
        # Layer 3
        y = self.Separable_conv2D_depth(y)
        y = self.Separable_conv2D_point(y)
        y = self.Batch_normalization_3(y)
        y = self.Elu(y)
        y = self.Average_pooling2D_2(y)
        y = self.Dropout(y)
        # Layer 4
        y = self.Flatten(y)
        y = self.Dense(y)
        return y  # These are your embeddings (or logits)

In [None]:
class SiameseEEGNet(nn.Module):
    def __init__(self):
        super(SiameseEEGNet, self).__init__()
        # Shared EEGNet feature extractor (weights will be shared for both inputs)
        self.feature_extractor = EEGNetFeature()
        
    def forward(self, x1, x2):
        # Get embeddings for both inputs
        embed1 = self.feature_extractor(x1)
        embed2 = self.feature_extractor(x2)
        cos_sim = F.cosine_similarity(embed1, embed2, dim=1, eps=1e-6)
        # Optionally, you can reshape it to (batch_size, 1) if needed.
        return embed1, embed2, cos_sim.unsqueeze(1)


# Example usage:
# Create the Siamese model instance
siamese_model = SiameseEEGNet()

In [None]:
# Assume you have two inputs (e.g., two EEG signals) with shape [batch_size, channels, height, width]
# For example, using random tensors (replace with your actual data)
batch_size = 8
# The input shape must match what EEGNetFeature expects. Here we assume a shape of (1, 1, 256) per sample.
# If your EEG data is 2D (channels x signal_length), you might need to adjust the dimensions.
input_shape = (num_input, 1, signal_length)  # example shape; adjust as needed

x1 = torch.randn(batch_size, *input_shape)
x2 = torch.randn(batch_size, *input_shape)

# Forward pass through the Siamese network
embed1, embed2, distance = siamese_model(x1, x2)
print("Embedding 1 shape:", embed1.shape)
print("Embedding 2 shape:", embed2.shape)
print("Distance shape:", distance.shape)

In [None]:
def cosine_contrastive_loss(y_true, cos_sim, margin = 0.275):
    loss_similar = y_true * torch.pow((1-cos_sim), 2)

    loss_dissimilar = (1-y_true) * torch.pow(torch.clamp(cos_sim - margin, min = 0.0), 2)
    loss = torch.mean(loss_similar + loss_dissimlar)
    return loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Instantiate your model and move it to device.
siamese_model = SiameseEEGNetCosine().to(device)

# Hyperparameters
batch_size = 8
learning_rate = 1e-3
num_epochs = 20

# Split the dataset into training and validation (80%/20% split)
dataset_size = len(siamese_dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(siamese_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define optimizer (using Adam here)
optimizer = optim.Adam(siamese_model.parameters(), lr=learning_rate)

# Optionally, define a scheduler (e.g., ReduceLROnPlateau) if desired:
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# For tracking best validation loss to save the best model
best_val_loss = float('inf')

# Training loop
for epoch in range(num_epochs):
    siamese_model.train()  # set model to training mode
    running_loss = 0.0
    epoch_start_time = time.time()
    
    # Training step
    for batch_idx, (x1, x2, labels) in enumerate(train_loader):
        # Move data to device
        x1 = x1.to(device)
        x2 = x2.to(device)
        labels = labels.to(device)  # Expected shape: (batch_size, 1)
        
        optimizer.zero_grad()
        
        # Forward pass: compute embeddings and cosine similarity
        _, _, cos_sim = siamese_model(x1, x2)
        
        # Compute loss using cosine-based contrastive loss
        loss = cosine_contrastive_loss(labels, cos_sim, margin=margin)
        
        # Backpropagation and optimization step
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * x1.size(0)
    
    # Compute average training loss for the epoch
    train_loss = running_loss / len(train_loader.dataset)
    
    # Validation step (without gradient computations)
    siamese_model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for x1, x2, labels in val_loader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            labels = labels.to(device)
            
            # Forward pass on validation data
            _, _, cos_sim = siamese_model(x1, x2)
            loss = cosine_contrastive_loss(labels, cos_sim, margin=margin)
            val_running_loss += loss.item() * x1.size(0)
    
    val_loss = val_running_loss / len(val_loader.dataset)
    
    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{num_epochs} | Time: {epoch_duration:.2f}s | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Optionally, step the scheduler:
    # scheduler.step(val_loss)
    
    # Save the model if validation loss decreases
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(siamese_model.state_dict(), 'best_siamese_model.pth')
        print("  --> Best model saved.")

print("Training complete.")
