# Load SEED Dataset

### Load RAW EEG

In [None]:
from dataset_processing.seed_dataset_loader import SeedDatasetLoader

sampling_frequency = 200  # 200 Hz

_loader = SeedDatasetLoader(fs=sampling_frequency)

In [None]:
labels = _loader.get_labels()
labels

In [None]:
channel_order = _loader.get_channel_order()
channel_order

In [None]:
_eeg_data_df = _loader.get_eeg_data_df()

In [None]:
_loader.plot_random_eeg()

In [None]:
del _loader

### Data Augmentation

In [None]:
from dataset_processing.eeg_augmentation import EEGAugmentation

_augmentor = EEGAugmentation(_eeg_data_df)
_augmented_df = _augmentor.augment_data()
del _augmentor, _eeg_data_df

### Dataset Loader

In [None]:
from torch.utils.data import DataLoader
from dataset_processing.eeg_dataset import EEGDataset

# From the paper
pretraining_batch_size = 256

_dataset = EEGDataset(_augmented_df)
data_loader = DataLoader(_dataset, batch_size=pretraining_batch_size, shuffle=True)
del _augmented_df, _dataset

# Pre-Training

In [None]:
from model.encoders import TimeFrequencyEncoder, CrossSpaceProjector
from model.loss import NTXentLoss
import torch.optim as optim
import torch
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os

pretraining_model_save_dir = "model_params/pretraining"
finetuning_model_save_dir = "model_params/finetuning"
if not os.path.exists(pretraining_model_save_dir):
    os.makedirs(pretraining_model_save_dir)
if not os.path.exists(finetuning_model_save_dir):
    os.makedirs(finetuning_model_save_dir)

In [None]:
# Parameters from the paper
pretraining_epochs = 10  # Set to 10 for initial tests
# pretraining_epochs = 1000
pretraining_lr = 3e-4
l2_norm_penalty = 3e-4
alpha = 0.2
beta = 1.0
pretraining_temperature = 0.05

encoders_output_dim = 200
projectors_output_dim = 128

In [None]:
num_layers = 2
nhead = 8

# Initialize models
ET = TimeFrequencyEncoder(
    input_dim=sampling_frequency,
    output_dim=encoders_output_dim,
    num_layers=num_layers,
    nhead=nhead,
).to(device)
EF = TimeFrequencyEncoder(
    input_dim=sampling_frequency,
    output_dim=encoders_output_dim,
    num_layers=num_layers,
    nhead=nhead,
).to(device)
PT = CrossSpaceProjector(
    input_dim=encoders_output_dim,
    output_dim=projectors_output_dim,
).to(device)
PF = CrossSpaceProjector(
    input_dim=encoders_output_dim,
    output_dim=projectors_output_dim,
).to(device)

# Define optimizers with L2 penalty
optimizer = optim.Adam(
    list(ET.parameters()) + list(EF.parameters()) + list(PT.parameters()) + list(PF.parameters()),
    lr=pretraining_lr,
    weight_decay=l2_norm_penalty  # L2-norm penalty coefficient
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=10,
)

nt_xent_calculator = NTXentLoss(temperature=pretraining_temperature)

# Pre-training loop
for epoch in range(1, pretraining_epochs + 1):
    pbar = tqdm(data_loader, desc=f"Epoch {epoch}", leave=False)
    epoch_loss = 0
    for xT, xT_augmented, xF, xF_augmented in pbar:
        # Move batches of data to the `device`
        xT = xT.to(device)
        xT_augmented = xT_augmented.to(device)
        xF = xF.to(device)
        xF_augmented = xF_augmented.to(device)

        # Reset the optimizers
        optimizer.zero_grad()

        # Time Domain Contrastive Learning
        hT = ET(xT)  # Encode time data
        hT_augmented = ET(xT_augmented)  # Encode augmented time data
        LT = nt_xent_calculator.calculate_loss(  # Calculate the time-based contrastive loss LT in Eq. 1
            hT,
            hT_augmented
        )

        # Frequency Domain Contrastive Learning
        hF = EF(xF)  # Encode frequency data
        hF_augmented = EF(xF_augmented)  # Encode augmented frequency data
        LF = nt_xent_calculator.calculate_loss(  # Calculate the frequency-based contrastive loss LF in Eq. 2
            hF,
            hF_augmented
        )

        # Time-Frequency Domain Contrastive Learning
        zT = PT(hT)  # Project into shared latent space
        zF = PF(hF)  # Project into shared latent space
        LA = nt_xent_calculator.calculate_loss(  # Calculate the alignment loss LA in Eq. 3
            zT,
            zF
        )

        # Compute total loss
        L = alpha * (LT + LF) + beta * LA
        epoch_loss += L.item()

        # Backpropagation
        L.backward()
        optimizer.step()

        # Update tqdm progress bar with the current loss
        pbar.set_description_str(f"Epoch {epoch}, Loss: {L.item():.4f}")

    # Step the scheduler based on the epoch loss
    scheduler.step(epoch_loss)

    # Save the model every 10 epochs and at the last epoch
    if epoch % 10 == 0 or epoch == pretraining_epochs:
        model_dicts = {
            "ET_state_dict": ET.state_dict(),
            "EF_state_dict": EF.state_dict(),
            "PT_state_dict": PT.state_dict(),
            "PF_state_dict": PF.state_dict(),
        }
        torch.save(model_dicts, os.path.join(pretraining_model_save_dir, f"pretrained_model__epoch_{epoch}.pt"))
        print(f"Saved model at epoch {epoch}")

    print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(data_loader):.4f}, Learning Rate: {scheduler.get_last_lr():.6f}")

# Ideas

Do a correlation matrix between the channels of the EEG signals.
Then when doing the joint whatever model, use the "distances" between the channels (like the hamming distance but not really), as a "weight" for training the joining etc.

Or maybe just output something that could show each channel's contribution towards the final emotion prediction.