# Load SEED Dataset

### Load RAW EEG

In [None]:
from dataset_processing.seed_dataset_loader import SeedDatasetLoader

_loader = SeedDatasetLoader()

In [None]:
labels = _loader.get_labels()
labels

In [None]:
channel_order = _loader.get_channel_order()
channel_order

In [None]:
_eeg_data_df = _loader.get_eeg_data_df()

In [None]:
_loader.plot_random_eeg()

In [None]:
del _loader

### Data Augmentation

In [None]:
from dataset_processing.eeg_augmentation import EEGAugmentation

_augmentor = EEGAugmentation(_eeg_data_df)
_augmented_df = _augmentor.augment_data()
del _augmentor, _eeg_data_df

### Dataset Loader

In [None]:
from torch.utils.data import DataLoader
from dataset_processing.eeg_dataset import EEGDataset

_dataset = EEGDataset(_augmented_df)
data_loader = DataLoader(_dataset, batch_size=32, shuffle=True)
del _augmented_df, _dataset

# Training

In [None]:
from model.encoders import TimeFrequencyEncoder, CrossSpaceProjector
from model.loss import NTXentLoss
import torch.optim as optim
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epochs = 10
lr = 0.001

alpha = 0.2
beta = 1.0
temperature = 0.05

In [None]:
# Initialize models
ET = TimeFrequencyEncoder().to(device)
EF = TimeFrequencyEncoder().to(device)
PT = CrossSpaceProjector().to(device)
PF = CrossSpaceProjector().to(device)

# Define optimizers
optimizer_ET = optim.Adam(ET.parameters(), lr=lr)
optimizer_EF = optim.Adam(EF.parameters(), lr=lr)
optimizer_PT = optim.Adam(PT.parameters(), lr=lr)
optimizer_PF = optim.Adam(PF.parameters(), lr=lr)

nt_xent_calculator = NTXentLoss(temperature=temperature)

# Pre-training loop
for epoch in range(epochs):
    for xT, xT_augmented, xF, xF_augmented in data_loader:
        # Move batches of data to the `device`
        xT = xT.to(device)
        xT_augmented = xT_augmented.to(device)
        xF = xF.to(device)
        xF_augmented = xF_augmented.to(device)

        # Reset the optimizers
        optimizer_ET.zero_grad()
        optimizer_EF.zero_grad()
        optimizer_PT.zero_grad()
        optimizer_PF.zero_grad()

        # Time Domain Contrastive Learning
        hT = ET(xT)  # Encode time data
        hT_augmented = ET(xT_augmented)  # Encode augmented time data
        LT = nt_xent_calculator.calculate_loss(  # Calculate the time-based contrastive loss LT in Eq. 1
            hT,
            hT_augmented
        )

        # Frequency Domain Contrastive Learning
        hF = EF(xF)  # Encode frequency data
        hF_augmented = EF(xF_augmented)  # Encode augmented frequency data
        LF = nt_xent_calculator.calculate_loss(  # Calculate the frequency-based contrastive loss LF in Eq. 2
            hF,
            hF_augmented
        )

        # Time-Frequency Domain Contrastive Learning
        zT = PT(hT)  # Project into shared latent space
        zF = PF(hF)  # Project into shared latent space
        LA = nt_xent_calculator.calculate_loss(  # Calculate the alignment loss LA in Eq. 3
            zT,
            zF
        )

        # Compute total loss
        L = alpha * (LT + LF) + beta * LA

        # Backpropagation
        L.backward()
        optimizer_ET.step()
        optimizer_EF.step()
        optimizer_PT.step()
        optimizer_PF.step()

        print(f"Epoch {epoch}, Loss: {L.item()}")

Do a correlation matrix between the channels of the EEG signals.
Then when doing the joint whatever model, use the "distances" between the channels (like the hamming distance but not really), as a "weight" for training the joining etc.