In [1]:
from torcheeg.datasets import SEEDDataset
from torcheeg import transforms

raw_dataset = SEEDDataset(
    root_path='./SEED/SEED_EEG/Preprocessed_EEG',
    io_path = 'E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS',
    online_transform=None,  # Disable transforms
    label_transform=None,
    num_worker=4
)

raw_sample = raw_dataset[0]
print(f"Raw EEG data shape: {raw_sample[0].shape}")  # Should be [62, ...] for SEED
print(f"Label: {raw_sample[1]}")  # Should be a number between 0 and 3

[2024-12-12 01:04:12] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS.


Raw EEG data shape: (62, 200)
Label: {'start_at': 0, 'end_at': 200, 'clip_id': '10_20131130.mat_0', 'subject_id': 10, 'trial_id': 'ww_eeg1', 'emotion': 1, 'date': 20131130, '_record_id': '_record_0'}


In [2]:
from collections import Counter

# Collect all emotion labels
emotion_counts = Counter()

for sample in raw_dataset:
    emotion = sample[1]['emotion']
    emotion_counts[emotion] += 1

# Display the counts for each emotion
print("Emotion Distribution:")
for emotion, count in sorted(emotion_counts.items()):
    label = "Negative" if emotion == -1 else "Neutral" if emotion == 0 else "Positive"
    print(f"{label} ({emotion}): {count}")


Emotion Distribution:
Negative (-1): 50400
Neutral (0): 49680
Positive (1): 52650


In [2]:
import numpy as np
from scipy.signal import butter, filtfilt

def bandpass_filter(data, lowcut=4, highcut=47, fs=200, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data, axis=1)

In [3]:
filtered_data_list = []
labels_list = []

for idx in range(len(raw_dataset)):
    eeg_data, label = raw_dataset[idx]
    
    # Apply the bandpass filter
    filtered_data = bandpass_filter(eeg_data, lowcut=4, highcut=47, fs=200)
    
    # Store the filtered data and corresponding label
    filtered_data_list.append(filtered_data)
    labels_list.append(label)

    # Print progress
    if idx % 5 == 0:
        print(f"Processed sample {idx}/{len(raw_dataset)}")

print("Bandpass filtering completed for all samples.")

# Verify the shape of the filtered data
print(f"Filtered EEG Data Shape: {filtered_data_list[0].shape}")
print(f"Label: {labels_list[0]}")


Processed sample 0/152730
Processed sample 5/152730
Processed sample 10/152730
Processed sample 15/152730
Processed sample 20/152730
Processed sample 25/152730
Processed sample 30/152730
Processed sample 35/152730
Processed sample 40/152730
Processed sample 45/152730
Processed sample 50/152730
Processed sample 55/152730
Processed sample 60/152730
Processed sample 65/152730
Processed sample 70/152730
Processed sample 75/152730
Processed sample 80/152730
Processed sample 85/152730
Processed sample 90/152730
Processed sample 95/152730
Processed sample 100/152730
Processed sample 105/152730
Processed sample 110/152730
Processed sample 115/152730
Processed sample 120/152730
Processed sample 125/152730
Processed sample 130/152730
Processed sample 135/152730
Processed sample 140/152730
Processed sample 145/152730
Processed sample 150/152730
Processed sample 155/152730
Processed sample 160/152730
Processed sample 165/152730
Processed sample 170/152730
Processed sample 175/152730
Processed samp

In [36]:
import numpy as np
import random
from collections import defaultdict

def generate_contrastive_minibatch(filtered_data_list, labels_list, batch_size=32, pairs_per_sample=4):
    """
    Generate a minibatch with multiple positive and negative pairs for each sample.
    Positive pairs are from trials with the same emotion.
    Negative pairs are from trials with different emotions.

    Args:
        filtered_data_list: List of filtered EEG data (one entry per segment).
        labels_list: List of corresponding labels (each label contains subject_id, trial_id, emotion).
        batch_size: Desired number of samples in each minibatch.
        pairs_per_sample: Number of pairs to generate per sample (both positive and negative).

    Returns:
        A tuple (minibatch, pair_labels), where:
        - minibatch: List of tuples (sample_A, sample_B)
        - pair_labels: List indicating if the pair is positive (1) or negative (0)
    """
  # Group data by emotion for efficient sampling
    emotion_map = defaultdict(list)
    for idx, label in enumerate(labels_list):
        emotion = label['emotion']
        emotion_map[emotion].append(idx)

    # List of unique emotions
    emotions = list(emotion_map.keys())

    # Select a random subset of samples for the minibatch
    all_indices = list(range(len(filtered_data_list)))
    minibatch_indices = random.sample(all_indices, batch_size)

    minibatch = []
    pair_labels = []

    for idx in minibatch_indices:
        sample_A = filtered_data_list[idx]
        label_A = labels_list[idx]
        emotion_A = label_A['emotion']

        # Generate positive pairs (same emotion)
        if len(emotion_map[emotion_A]) > 1:
            positive_indices = [i for i in emotion_map[emotion_A] if i != idx]
            pos_idx = random.choice(positive_indices)
            minibatch.append((sample_A, filtered_data_list[pos_idx]))
            pair_labels.append(1)

        # Generate negative pairs (different emotions)
        negative_emotions = [e for e in emotions if e != emotion_A]
        neg_emotion = random.choice(negative_emotions)
        neg_idx = random.choice(emotion_map[neg_emotion])
        minibatch.append((sample_A, filtered_data_list[neg_idx]))
        pair_labels.append(0)

    # Shuffle the pairs to mix positive and negative pairs
    combined = list(zip(minibatch, pair_labels))
    random.shuffle(combined)
    minibatch, pair_labels = zip(*combined)

    return minibatch, pair_labels

In [42]:

minibatch, pair_labels = generate_contrastive_minibatch(filtered_data_list, labels_list, batch_size=512, pairs_per_sample=4)

# Count positive and negative pairs
num_positive = sum(pair_labels)
num_negative = len(pair_labels) - num_positive

print(f"Number of Positive Pairs: {num_positive}")
print(f"Number of Negative Pairs: {num_negative}")


Number of Positive Pairs: 512
Number of Negative Pairs: 512


In [43]:
import torch
import torch.nn.functional as F

def contrastive_loss(embeddings_A, embeddings_B, pair_labels, temperature=0.1):
    """
    Compute the contrastive loss for a minibatch.

    Args:
        embeddings_A: Tensor of embeddings for the first element in each pair.
        embeddings_B: Tensor of embeddings for the second element in each pair.
        pair_labels: Tensor indicating if the pair is positive (1) or negative (0).
        temperature: Temperature scaling parameter for contrastive loss.

    Returns:
        Computed contrastive loss.
    """
    # Compute cosine similarity
    cosine_sim = F.cosine_similarity(embeddings_A, embeddings_B)
    
    # Scale the similarities by the temperature
    logits = cosine_sim / temperature
    
    # Compute binary cross-entropy loss
    loss = F.binary_cross_entropy_with_logits(logits, pair_labels.float())
    
    return loss


In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BaseEncoder(nn.Module):
    def __init__(self, num_channels=62, temporal_filter_length=48, num_spatial_filters=16, num_temporal_filters=16):
        super(BaseEncoder, self).__init__()

        # Spatial Convolution: [batch_size, num_channels, timepoints] -> [batch_size, num_spatial_filters, timepoints]
        self.spatial_conv = nn.Conv1d(
            in_channels=num_channels,
            out_channels=num_spatial_filters,
            kernel_size=1,
            bias=False
        )
        
        # Temporal Convolution: [batch_size, num_spatial_filters, timepoints] -> [batch_size, num_temporal_filters, timepoints]
        self.temporal_conv = nn.Conv1d(
            in_channels=num_spatial_filters,
            out_channels=num_temporal_filters,
            kernel_size=temporal_filter_length,
            padding='same',  # Ensure output length matches input length
            bias=False
        )

        # Batch Normalization for stability
        self.bn_spatial = nn.BatchNorm1d(num_spatial_filters)
        self.bn_temporal = nn.BatchNorm1d(num_temporal_filters)

        # Activation Function
        self.activation = nn.ELU()

    def forward(self, x):
        # x shape: [batch_size, num_channels, timepoints]
        
        # Spatial Convolution + BatchNorm + ReLU
        x = self.spatial_conv(x)
        x = self.bn_spatial(x)
        x = self.activation(x)
        
        # Temporal Convolution + BatchNorm + ReLU
        x = self.temporal_conv(x)
        x = self.bn_temporal(x)
        x = self.activation(x)
        
        return x

# Initialize the Base Encoder
base_encoder = BaseEncoder()

# Verify the model with a sample input
sample_input = torch.randn(32, 62, 200)  # Batch size = 32, Channels = 62, Timepoints = 200
output = base_encoder(sample_input)

print(f"Input Shape: {sample_input.shape}")
print(f"Output Shape: {output.shape}")  # Expected: [32, 16, 200]


Input Shape: torch.Size([32, 62, 200])
Output Shape: torch.Size([32, 16, 200])


In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Projector(nn.Module):
    def __init__(self, num_input_filters=16, spatial_filters=32, temporal_filters=64, temporal_filter_length=4, pooling_kernel_length=24):
        super(Projector, self).__init__()

        # Spatial Convolution: [batch_size, num_input_filters, timepoints] -> [batch_size, spatial_filters, timepoints]
        self.spatial_conv = nn.Conv1d(
            in_channels=num_input_filters,
            out_channels=spatial_filters,
            kernel_size=1,
            bias=False
        )

        # Temporal Convolution: [batch_size, spatial_filters, timepoints] -> [batch_size, temporal_filters, timepoints]
        # Padding to maintain the same output length as input length
        self.temporal_conv = nn.Conv1d(
            in_channels=spatial_filters,
            out_channels=temporal_filters,
            kernel_size=temporal_filter_length,
            padding=temporal_filter_length // 2,
            bias=False
        )

        # Average Pooling: [batch_size, temporal_filters, timepoints] -> [batch_size, temporal_filters, reduced_timepoints]
        self.avg_pool = nn.AvgPool1d(kernel_size=pooling_kernel_length)

        # Batch Normalization
        self.bn_spatial = nn.BatchNorm1d(spatial_filters)
        self.bn_temporal = nn.BatchNorm1d(temporal_filters)

        # Activation Function
        self.activation = nn.ELU()

    def forward(self, x):
        # x shape: [batch_size, num_input_filters, timepoints]
        
        # Spatial Convolution + BatchNorm + ReLU
        x = self.spatial_conv(x)
        x = self.bn_spatial(x)
        x = self.activation(x)

        # Temporal Convolution + BatchNorm + ReLU
        x = self.temporal_conv(x)
        x = self.bn_temporal(x)
        x = self.activation(x)

        # Average Pooling
        x = self.avg_pool(x)

        return x

# Initialize the Projector with SEED hyperparameters
projector = Projector(
    num_input_filters=16,         # Output from the base encoder
    spatial_filters=32,           # CK2 = 32
    temporal_filters=64,          # C^2K2 = 64
    temporal_filter_length=4,     # P2 = 4
    pooling_kernel_length=24      # S = 24
)

# Verify the model with a sample input
sample_input = torch.randn(32, 16, 200)  # Batch size = 32, Input filters = 16, Timepoints = 200
output = projector(sample_input)

print(f"Input Shape: {sample_input.shape}")
print(f"Output Shape: {output.shape}")  # Expected: [32, 64, reduced_timepoints]


Input Shape: torch.Size([32, 16, 200])
Output Shape: torch.Size([32, 64, 8])


In [55]:
def generate_embeddings(minibatch, encoder, projector, device):
    embeddings_A = []
    embeddings_B = []

    for sample_A, sample_B in minibatch:
        tensor_A = torch.tensor(sample_A, dtype=torch.float32).unsqueeze(0).to(device)
        tensor_B = torch.tensor(sample_B, dtype=torch.float32).unsqueeze(0).to(device)

        # Get encoded features and projected embeddings
        embedding_A = projector(encoder(tensor_A))
        embedding_B = projector(encoder(tensor_B))

        embeddings_A.append(embedding_A)
        embeddings_B.append(embedding_B)

    # Stack embeddings
    embeddings_A = torch.cat(embeddings_A, dim=0)
    embeddings_B = torch.cat(embeddings_B, dim=0)

    return embeddings_A, embeddings_B


In [56]:
import torch.nn.functional as F

def contrastive_loss(embeddings_A, embeddings_B, pair_labels, temperature=0.1):
    """
    Compute the contrastive loss for a minibatch.

    Args:
        embeddings_A: Tensor of embeddings for the first element in each pair.
        embeddings_B: Tensor of embeddings for the second element in each pair.
        pair_labels: Tensor indicating if the pair is positive (1) or negative (0).
        temperature: Temperature scaling parameter for contrastive loss.

    Returns:
        Computed contrastive loss.
    """
    embeddings_A = F.normalize(embeddings_A, p=2, dim=1)
    embeddings_B = F.normalize(embeddings_B, p=2, dim=1)
    
    # Compute cosine similarity
    cosine_sim = F.cosine_similarity(embeddings_A, embeddings_B)
    
    # Scale the similarities by the temperature
    logits = cosine_sim / temperature
    
    # pair_labels = pair_labels.squeeze()
    
    # Compute binary cross-entropy loss
    loss = F.binary_cross_entropy_with_logits(logits, pair_labels.float())
    
    return loss


In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils import tensorboard

# Define the Encoder and Projector (assuming these are already defined)
encoder = BaseEncoder()    # Replace with your actual BaseEncoder class
projector = Projector()    # Replace with your actual Projector class
writer = tensorboard.SummaryWriter(log_dir='E:/FYP/Egg-Based Emotion Recognition/EEg-based-Emotion-Recognition/runs/contrastive_pakka')
# Move models to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder.to(device)
projector.to(device)

# Optimizer and Scheduler
optimizer = Adam(list(encoder.parameters()) + list(projector.parameters()), lr=0.0007, weight_decay=0.015)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# Early Stopping Parameters
best_loss = float('inf')
patience = 30
counter = 0

# Training Loop
num_epochs = 100
scaler = torch.amp.GradScaler()

for epoch in range(num_epochs):
    encoder.train()
    projector.train()

    # Generate a minibatch of pairs and labels
    minibatch, pair_labels = generate_contrastive_minibatch(filtered_data_list, labels_list, batch_size=256, pairs_per_sample=4)
    pair_labels = torch.tensor(pair_labels, dtype=torch.float32).to(device)

    # Generate embeddings
    embeddings_A, embeddings_B = [], []

    for sample_A, sample_B in minibatch:
        tensor_A = torch.tensor(sample_A.copy(), dtype=torch.float32).unsqueeze(0).to(device)
        tensor_B = torch.tensor(sample_B.copy(), dtype=torch.float32).unsqueeze(0).to(device)
        
        with torch.amp.autocast(device_type=device.type):
            embedding_A = projector(encoder(tensor_A))
            embedding_B = projector(encoder(tensor_B))

        embeddings_A.append(embedding_A.cpu())
        embeddings_B.append(embedding_B.cpu())
        
        del tensor_A, tensor_B, embedding_A, embedding_B
        torch.cuda.empty_cache()


    embeddings_A = torch.cat(embeddings_A, dim=0).to(device)
    embeddings_B = torch.cat(embeddings_B, dim=0).to(device)
    
    embeddings_A = torch.mean(embeddings_A, dim=2)  # Shape: [batch_size, channels]
    embeddings_B = torch.mean(embeddings_B, dim=2)  # Shape: [batch_size, channels]


    embeddings_A = embeddings_A.view(embeddings_A.size(0), -1)
    embeddings_B = embeddings_B.view(embeddings_B.size(0), -1)
    
    print(f"logits shape: {embeddings_A.shape}, pair_labels shape: {pair_labels.shape}")


    with torch.amp.autocast(device_type=device.type):
    # Compute contrastive loss
        loss = contrastive_loss(embeddings_A, embeddings_B, pair_labels)

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    writer.add_scalar('Loss/train', loss.item(), epoch)

    # Early stopping
    if loss < best_loss:
        best_loss = loss
        torch.save({'encoder': encoder.state_dict(), 'projector': projector.state_dict()}, 'best_model.pth')
        counter = 0
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered")
        break

print("Training Complete.")
writer.close()

logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [1/100], Loss: 4.9301
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [2/100], Loss: 4.9055
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [3/100], Loss: 4.8734
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [4/100], Loss: 4.8415
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [5/100], Loss: 4.7819
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [6/100], Loss: 4.7698
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [7/100], Loss: 4.7257
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [8/100], Loss: 4.7043
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [9/100], Loss: 4.6883
logits shape: torch.Size([512, 64]), pair_labels shape: torch.Size([512])
Epoch [1

KeyboardInterrupt: 

In [38]:
print("Checking sample pairs:")
for i in range(20):
    print(f"Pair {i + 1}: Label = {pair_labels[i]}")


Checking sample pairs:
Pair 1: Label = 0
Pair 2: Label = 0
Pair 3: Label = 1
Pair 4: Label = 1
Pair 5: Label = 1
Pair 6: Label = 1
Pair 7: Label = 0
Pair 8: Label = 1
Pair 9: Label = 0
Pair 10: Label = 0
Pair 11: Label = 0
Pair 12: Label = 0
Pair 13: Label = 0
Pair 14: Label = 1
Pair 15: Label = 0
Pair 16: Label = 0
Pair 17: Label = 0
Pair 18: Label = 1
Pair 19: Label = 1
Pair 20: Label = 0
