# Load SEED Dataset

In [None]:
from utils.seed_dataset_loader import SeedDatasetLoader

loader = SeedDatasetLoader()

In [None]:
labels = loader.get_labels()
labels

In [None]:
channel_order = loader.get_channel_order()
channel_order

In [None]:
eeg_data_df = loader.get_eeg_data_df()

In [None]:
loader.plot_random_eeg()

In [None]:
del loader

# Data Augmentation

In [None]:
import torch
import torch.nn.functional as F
from utils.eeg_augmentation import EEGAugmentation

In [None]:
augmentor = EEGAugmentation(eeg_data_df)
augmented_df = augmentor.augment_data()

In [None]:
augmented_df.iloc[0]

In [None]:
def nt_xent_loss(z, z_augmented, temperature=0.05):
    """
    Calculates the NT-Xent loss for a batch of embeddings and their augmented versions, where the positive pair
    consists of each embedding and its augmentation, and negative pairs are computed between the embedding and all
    other non-augmented embeddings in the batch.
    
    Parameters:
    - z (torch.Tensor): Embeddings from the original EEG signals.
    - z_augmented (torch.Tensor): Corresponding embeddings from the augmented EEG signals.
    - temperature (float): Temperature scaling factor for the softmax.
    
    Returns:
    - torch.Tensor: The average NT-Xent loss for the batch.
    """
    device = z.device
    batch_size = z.size(0)

    # Normalize the embeddings to use cosine similarity
    z = F.normalize(z, p=2, dim=1).to(device)
    z_augmented = F.normalize(z_augmented, p=2, dim=1).to(device)

    # Calculate the cosine similarity between each original and its augmented version (positive pairs)
    # Already normalized, so no need to divide by anything
    positive_sim = torch.sum(z * z_augmented, dim=1) / temperature

    # Calculate cosine similarity between each original and all other originals (for negatives)
    negative_sim_matrix = torch.mm(z, z.t()) / temperature
    # Mask out self-similarities (diagonal elements)
    mask = torch.eye(batch_size, device=device)
    negative_sim_matrix = negative_sim_matrix.masked_fill(mask == 1, float('-inf'))

    # Use log-sum-exp trick to calculate the denominator of the softmax function
    # Ref: https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
    # max_negative_sim = torch.max(negative_sim_matrix, dim=1, keepdim=True)[0]
    # exp_negative_sim = torch.exp(negative_sim_matrix - max_negative_sim)
    # sum_exp_negative_sim = torch.sum(exp_negative_sim, dim=1, keepdim=True)
    # logsumexp_negatives = torch.log(sum_exp_negative_sim + 1e-6) + max_negative_sim.squeeze()
    logsumexp_negatives = torch.logsumexp(negative_sim_matrix, dim=1)

    # Calculate log probabilities for the positives in relation to the log of the sum of exponentiated negative similarities
    log_prob = positive_sim - logsumexp_negatives

    # Mean loss across all samples
    loss = -torch.mean(log_prob)

    return loss

In [None]:
eeg_augmentation = EEGAugmentation(None)

In [None]:
# Dummy EEG data
torch.manual_seed(seed=1)
x = torch.randn(256, 128)  # 256 samples, 128 features per sample

# Augment the time-domain data
xe = eeg_augmentation._time_gaussian_noise(x)

# Simulate encoding process to generate embeddings
time_encoder = torch.nn.Linear(128, 64)  # Dummy encoder
h = time_encoder(x)
he = time_encoder(xe)

# Compute the loss
loss_time = nt_xent_loss(h, he)
print(f"Time Domain - Contrastive Loss: {loss_time.item()}")

In [None]:
# Assume x is the input EEG signal
torch.manual_seed(seed=1)
x = torch.randn(256, 128)  # Dummy EEG data, 256 samples, 128 features per sample

# Convert to frequency domain
xF = eeg_augmentation._freq_fourier_transform(x)

# Apply spectral perturbation
xeF = eeg_augmentation._freq_spectral_perturbation(xF)

# Let's assume you have an encoder for frequency data
frequency_encoder = torch.nn.Linear(128, 64)  # Dummy encoder for frequency domain

# Generate embeddings
hF = frequency_encoder(xF)
heF = frequency_encoder(xeF)

# Compute the contrastive loss
loss_frequency = nt_xent_loss(hF, heF)
print(f"Frequency Domain - Contrastive Loss: {loss_frequency.item()}")

Do a correlation matrix between the channels of the EEG signals.
Then when doing the joint whatever model, use the "distances" between the channels (like the hamming distance but not really), as a "weight" for training the joining etc.