In [1]:
from torcheeg.datasets import SEEDDataset
from torcheeg import transforms

raw_dataset = SEEDDataset(
    root_path='./SEED/SEED_EEG/Preprocessed_EEG',
    io_path = 'E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS',
    online_transform=None,  # Disable transforms
    label_transform=None,
    num_worker=4
)

raw_sample = raw_dataset[0]
print(f"Raw EEG data shape: {raw_sample[0].shape}")  # Should be [62, ...] for SEED

[2024-12-08 00:00:08] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS.


Raw EEG data shape: (62, 200)


In [2]:
import numpy as np
from scipy.signal import butter, lfilter

# Define a bandpass filter (4-47 Hz for SEED dataset)
def bandpass_filter(data, lowcut=4, highcut=47, fs=200, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return lfilter(b, a, data, axis=-1)

In [3]:
def process_dataset(dataset):
    processed_segments = []
    for sample in dataset:
        eeg_data, metadata = sample

        # Bandpass filter
        filtered_data = bandpass_filter(eeg_data)

        # Append metadata for reference
        processed_segments.append((filtered_data, metadata))

    return processed_segments
# Process the raw dataset
processed_data = process_dataset(raw_dataset)

In [4]:
import numpy as np

def incremental_normalization(process_dataset):
    normalized_segments = []
    subject_data_stats = {}
    
    #First pass : computing mean and variance incrementally for each subject
    for eeg_data, metadata in process_dataset:
        subject_id = metadata['subject_id']
        if subject_id not in subject_data_stats:
            subject_data_stats[subject_id] = {'sum':0, 'sum_sq':0, "count":0}
            
        stats = subject_data_stats[subject_id]
        stats['sum'] += np.sum(eeg_data, axis=-1, keepdims=True)
        stats['sum_sq'] += np.sum(eeg_data**2, axis=-1, keepdims=True)
        stats['count'] += eeg_data.shape[1]
        
    #Compute mean and standard deviation for each subject
    for subject_id, stats in subject_data_stats.items():
        stats['mean'] = stats['sum'] / stats['count']
        stats['std'] = np.sqrt(stats['sum_sq'] / stats['count'] - stats['mean']**2)
        
        
    #Second pass : Normalizing each segment using computed stats
    for eeg_data, metadata in process_dataset:
        subject_id = metadata['subject_id']
        stats = subject_data_stats[subject_id]
        mean = stats['mean']
        std = stats['std']
        normalized_data = (eeg_data - mean) / std
        normalized_segments.append((normalized_data, metadata))
        
    return normalized_segments

# Normalize the processed dataset
normalized_data = incremental_normalization(processed_data)

        

In [5]:
# Function to sample data into fixed length segments
def sample_data(normalized_data, time_length=30, step_size=15):
    sampled_segments = []
    
    for eeg_data, metadata in normalized_data:
        trial_length = eeg_data.shape[1]
        
        for start in range(0 , trial_length - time_length + 1, step_size):
            segment = eeg_data[:, start:start+time_length]
            new_metadata = metadata.copy()
            new_metadata['segment_start'] = start
            new_metadata['segment_end'] = start + time_length
            sampled_segments.append((segment, new_metadata))
    
    return sampled_segments

# Sample the normalized data
time_length = 30
step_size = 15
sampled_data = sample_data(normalized_data, time_length=time_length, step_size=step_size)
  

In [6]:
# Function to compute DE features for each segment
def compute_de_features(eeg_segments):
    # eeg_segment shape: [channels, timepoints]
    variance = np.var(eeg_segments, axis=-1)
    de_features = 0.5 * np.log2(2 * np.pi * np.e * variance)
    return de_features

# Extract DE features for the sampled data
de_features_data = []
for segment, metadata in sampled_data:
    de_features = compute_de_features(segment)
    de_features_data.append((de_features, metadata))
    

In [7]:
import torch.nn as nn

class BaseEncoder(nn.Module):
    def __init__(self, input_channels=62, temporal_filter_length=48, spatial_filters=16, temporal_filters=16):
        super(BaseEncoder, self).__init__()
        # Spatial Convolution
        self.spatial_conv = nn.Conv1d(input_channels, spatial_filters, kernel_size=1)
        # Temporal Convolution
        self.temporal_conv = nn.Conv1d(spatial_filters, temporal_filters, kernel_size=temporal_filter_length, padding=temporal_filter_length // 2)

    def forward(self, x):
        # x shape: [batch_size, input_channels, time_points]
        print("Entered BaseEncoder")
        x = self.spatial_conv(x)
        x = nn.ReLU()(x)
        x = self.temporal_conv(x)
        x = nn.ReLU()(x)
        print("Exit BaseEncoder")
        return x


In [8]:
# Define the Projector
class Projector(nn.Module):
    def __init__(self, spatial_filters=16, pooling_kernel=24, temporal_filter_size=4 , c=2):
        super(Projector, self).__init__()
        #Average Pooling
        
        self.avg_pool = nn.AvgPool1d(kernel_size=pooling_kernel , stride=pooling_kernel)
        #spatial convolution        
        self.spatial_conv = nn.Conv1d(in_channels=spatial_filters, out_channels=spatial_filters*c,kernel_size=1)
        #temporal convolution
        self.temporal_conv = nn.Conv1d(in_channels=spatial_filters*c,out_channels=(spatial_filters*c)*c, kernel_size=temporal_filter_size, padding=temporal_filter_size//2)
        #Activation
        self.relu = nn.ReLU()
        
    def forward(self, x):
        print("Entered Projector")
        #Applying avg pooling
        x = self.avg_pool(x)
        #Applying spatial convolution
        x = self.spatial_conv(x)
        x = self.relu(x)
        #Applying temporal convolution
        x = self.temporal_conv(x)
        x = self.relu(x)
        print("Exit Projector")
        return x
    


# Initialize the projector
projector = Projector(spatial_filters=16, pooling_kernel=24, temporal_filter_size=4,c=2)

# # Test on the encoded output
# projected_output = projector(encoded_output)

# print(f"Encoded output shape: {encoded_output.shape}")
# print(f"Projected output shape: {projected_output.shape}")

In [52]:
import torch
import torch.nn as nn

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.2):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        print("Entered ContrastiveLoss")

        # Normalize embeddings
        z_i = nn.functional.normalize(z_i, dim=1)
        z_j = nn.functional.normalize(z_j, dim=1)

        # Compute similarity matrix
        similarity_matrix = torch.matmul(z_i, z_j.T) / self.temperature

        # Clamp similarities to prevent overflow
        similarity_matrix = torch.clamp(similarity_matrix, min=-10, max=10)

        # Debugging similarities and labels
        print(f"Similarity Matrix: {similarity_matrix.shape}")
        print(f"Similarity Matrix (First 5 Rows & Columns):\n{similarity_matrix[:5, :5]}")

        # Create labels for the batch
        batch_size = z_i.size(0)
        labels = torch.arange(batch_size).to(z_i.device)
        print(f"Labels: {labels}")

        # Compute contrastive loss
        loss = nn.CrossEntropyLoss()(similarity_matrix, labels)
        print(f"Contrastive Loss: {loss.item()}")

        return loss


In [53]:
# Check unique subject IDs and their sample counts
from collections import Counter

subject_ids = [item[1]['subject_id'] for item in sampled_data]
subject_distribution = Counter(subject_ids)

print(f"Unique Subject IDs: {list(subject_distribution.keys())}")
print(f"Subject Sample Counts: {subject_distribution}")

# Create a mapping for subject IDs
unique_subject_ids = sorted(set(subject_ids))
subject_id_map = {old_id: new_id for new_id, old_id in enumerate(unique_subject_ids)}

# Update subject IDs in metadata
for _, metadata in sampled_data:
    metadata['subject_id'] = subject_id_map[metadata['subject_id']]

# Verify remapping
updated_subject_ids = [item[1]['subject_id'] for item in sampled_data]
updated_subject_distribution = Counter(updated_subject_ids)

print(f"Updated Subject IDs: {list(updated_subject_distribution.keys())}")
print(f"Updated Subject Sample Counts: {updated_subject_distribution}")


Unique Subject IDs: [9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8]
Subject Sample Counts: Counter({9: 122184, 10: 122184, 11: 122184, 12: 122184, 13: 122184, 14: 122184, 0: 122184, 1: 122184, 2: 122184, 3: 122184, 4: 122184, 5: 122184, 6: 122184, 7: 122184, 8: 122184})
Updated Subject IDs: [9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8]
Updated Subject Sample Counts: Counter({9: 122184, 10: 122184, 11: 122184, 12: 122184, 13: 122184, 14: 122184, 0: 122184, 1: 122184, 2: 122184, 3: 122184, 4: 122184, 5: 122184, 6: 122184, 7: 122184, 8: 122184})


In [11]:
def loso_split(dataset, num_subjects):
    """
    Create training and testing splits for Leave-One-Subject-Out (LOSO) Cross-Validation.
    Args:
        dataset: list of tuples (EEG_data, metadata)
        num_subjects: total number of unique subjects in the dataset
    Returns:
        List of splits [(train_data, test_data), ...]
    """
    splits = []
    for test_subject in range(num_subjects):
        # Split dataset into training and testing based on subject ID
        train_data = [item for item in dataset if item[1]['subject_id'] != test_subject]
        test_data = [item for item in dataset if item[1]['subject_id'] == test_subject]
        splits.append((train_data, test_data))
    return splits

# Determine the number of unique subjects
num_subjects = len(set(item[1]['subject_id'] for item in sampled_data))

# Perform LOSO split
loso_splits = loso_split(sampled_data, num_subjects)

# Example: Use the first split
train_data, test_data = loso_splits[0]
print(f"Training samples: {len(train_data)}, Testing samples: {len(test_data)}")


Training samples: 1710576, Testing samples: 122184


In [12]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
import torch

class EEGDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # print("Entered EEGDataset")
        eeg_data, metadata = self.data[idx]
        if not isinstance(metadata, dict):
            raise ValueError(f"Metadata must be a dictionary, but got {type(metadata)} at index {idx}.")
        eeg_data_tensor = torch.tensor(eeg_data, dtype=torch.float32)
        emotion_label = metadata.get('emotion', -1) # Replace 'emotion' with the actual key for emotion labels
        # print("Exit EEGDataset")
        return eeg_data_tensor, metadata

In [13]:
def stratified_normalization_minibatch(x, metadata):
    """
    Perform stratified normalization of the minibatch based on subject and emotion information.
    x: Tensor of shape [batch_size, channels, time_points].
    metadata: Dictionary of metadata with keys like 'subject_id', 'emotion', etc.
    """
    print("Entered stratified_normalization_minibatch")
    print(f"x.shape: {x.shape}")
    print(f"Metadata keys: {list(metadata.keys())}")

    # Extract subject_id and emotion as tensors
    subject_ids = metadata['subject_id']
    emotion_labels = metadata['emotion']

    # Initialize normalized tensor
    normalized_x = torch.zeros_like(x)

    # Group data by (subject_id, emotion)
    group_to_indices = {}
    for idx, (subject, emotion) in enumerate(zip(subject_ids, emotion_labels)):
        group = (subject.item(), emotion.item())
        if group not in group_to_indices:
            group_to_indices[group] = []
        group_to_indices[group].append(idx)

    for group, indices in group_to_indices.items():
        group_data = x[indices]  # Extract data for the group
        mean = group_data.mean(dim=-1, keepdim=True)  # Mean over time points
        std = group_data.std(dim=-1, keepdim=True)    # Std over time points
        normalized_x[indices] = (group_data - mean) / (std + 1e-8)  # Normalize

    print("Exit stratified_normalization_minibatch")
    return normalized_x


In [14]:
import os
import torch

def save_checkpoint(epoch, model, optimizer, loss, path="./Checkpoints/"):
    # Ensure the directory exists
    if not os.path.exists(path):
        os.makedirs(path)
    
    save_path = os.path.join(path, f"checkpoint_epoch_{epoch}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, save_path)
    print(f"Checkpoint saved at: {save_path}")


In [15]:
def load_checkpoint(filename, model, optimizer):
    """
    Load a training checkpoint.
    Args:
        filename: Name of the checkpoint file.
        model: The model to load the state into.
        optimizer: The optimizer to load the state into.
    Returns:
        epoch: The epoch at which the checkpoint was saved.
        loss: The loss value at the time of saving.
    """
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded: {filename} (epoch {epoch})")
    return epoch, loss


In [54]:
import os
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

def train_model_with_stratified_normalization(train_data, base_encoder, projector, contrastive_loss, epochs=100, batch_size=256, lr=0.0007, patience=30, checkpoint_path="./Checkpoints/"):
   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter(log_dir="E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/runs/ContrastiveLearning")
    
    checkpoint_dir = "./Checkpoints/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    train_dataset = EEGDataset(train_data)
    print(f"Total samples in Dataset: {len(train_dataset)}")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, timeout=0)
    print(f"Total batches in DataLoader: {len(train_loader)}")

    # Optimizer and scheduler
    model_params = list(base_encoder.parameters()) + list(projector.parameters())
    optimizer = optim.Adam(model_params, lr=lr, weight_decay=0.015)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    base_encoder.train()
    projector.train()

    best_loss = float("inf")
    epochs_no_improve = 0
    
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        start_epoch, best_loss = load_checkpoint(checkpoint_path, base_encoder, optimizer)
        

    for epoch in range(start_epoch,epochs):
        print(f"Entered Epoch {epoch + 1}")
        total_loss = 0.0
        print(f"Starting Training Loop with {len(train_loader)} batches.")


        for batch_idx, (x, metadata) in enumerate(train_loader):
            print(f"Processing Batch {batch_idx + 1}/{len(train_loader)}")
            print(f"Batch {batch_idx + 1} loaded: x.shape={x.shape}, metadata keys = {list(metadata.keys())}")
           
            print("Entering try block")
            try:
                print("Inside try block")
                x = x.float().to(device)  # Send input to device
              
                # Apply stratified normalization on the minibatch
                x = stratified_normalization_minibatch(x, metadata)
                print(f"Normalized Input Shape: {x.shape}")

                # Forward pass through Base Encoder and Projector
                encoded = base_encoder(x)
                print(f"Encoded Output Shape: {encoded.shape}")
                
                z_i = projector(encoded)
                print(f"Projected z_i Shape: {z_i.shape}")
                
                z_j = projector(encoded + torch.normal(mean=0, std=0.005, size=encoded.shape).to(device))  
                print(f"Projected z_j Shape: {z_j.shape}")
                
                z_i = z_i.view(z_i.size(0), -1)
                z_j = z_j.view(z_j.size(0), -1)
                z_i = z_i.to(device)
                z_j = z_j.to(device)

                # Compute Contrastive Loss
                loss = contrastive_loss(z_i,z_j)  # Replace second projected with positive pairs if applicable
                print(f"Batch {batch_idx + 1}, Loss: {loss.item()}")
                
                total_loss += loss.item()
                writer.add_scalar("Batch Loss", loss.item(), epoch * len(train_loader) + batch_idx)

                # Backpropagation and optimization step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            except Exception as e:
                print(f"Error in Batch {batch_idx + 1}: {e}")
                break
                      

        # Average loss over all batches
        avg_loss = total_loss / len(train_loader)
        writer.add_scalar("Avg Loss", avg_loss, epoch)
        
        if epoch % 5 == 0 or avg_loss < best_loss:
            save_checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
            save_checkpoint(epoch, base_encoder, optimizer, avg_loss, save_checkpoint_path)

        
            
        # Early stopping logic
        if avg_loss < best_loss:
            best_loss = avg_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

        # Step the scheduler
        scheduler.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

    writer.close()
    print("Training complete!")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize models
base_encoder = BaseEncoder(input_channels=62, temporal_filter_length=48, spatial_filters=16, temporal_filters=16).to(device)
projector = Projector(spatial_filters=16, pooling_kernel=24, temporal_filter_size=4, c=2).to(device)
contrastive_loss = ContrastiveLoss(temperature=0.1).to(device)

# Train with stratified normalization
train_model_with_stratified_normalization(
    train_data, base_encoder, projector, contrastive_loss, epochs=100, batch_size=256, lr=0.0007, patience=30 , checkpoint_path="./Checkpoints/check/"
)
