In [18]:
import sys , os

# -----------------------------------------------------------------------------
# Project Path Setup
#    - Add project root to Python path for consistent imports
# -----------------------------------------------------------------------------



sys.path.append("..")




from model.MTCformerV3 import MTCFormer
import torch
import glob
import numpy as np
import pandas as pd
from torch.optim import Adam
from utils.augmentation import augment_data
from utils.training import  train_model , predict_optimized
from utils.CustomDataset import EEGDataset
from torch.utils.data import DataLoader
from utils.loader import Loader
from utils.preprocessing import SignalPreprocessor

print("extracting data for futher preprocessing...",end = "\n\n")

mapping_mi = {
    "Left":0,
    "Right":1
}
inverse_mapping_mi = {
    v:k for k , v in mapping_mi.items()
} 





from utils.preprocessing import SignalPreprocessor
train_mi_new = pd.read_csv("../data_my_head/MI/1/EEGdata.csv")
train_labels_new = pd.read_csv("../data_my_head/MI/1/trial_labels.csv")
val_mi = pd.read_csv("../data_my_head/MI/2/EEGdata.csv")
val_labels = pd.read_csv("../data_my_head/MI/2/trial_labels.csv")


mapping_mi = {
    "Left":0,
    "Right":1
}
trial_df = pd.concat([
    train_mi_new , val_mi
]).reset_index(drop=True)

labels_trial_df = pd.concat([
    train_labels_new , val_labels
]).reset_index(drop=True)








extracting data for futher preprocessing...



In [19]:
def preprocess_fast(
        trial_df,
        labels_df,
        mapping=None,
        task=None,
        signal_processer = None
        ):
    assert task in ["MI" , "SSVEP"]
    if task=="MI":
        eeg_col = ['C3', 'C4', 'CZ', 'FZ' ]
        time_len_per_trial = 2250
    else:
        eeg_col =  ['OZ', 'PO7', 'PO8', 'PZ']
        time_len_per_trial = 1750


    input_array = trial_df.drop(columns = ["Time" , "Battery" , "Counter"])[eeg_col+['AccX',
        'AccY', 'AccZ', 'Gyro1', 'Gyro2', 'Gyro3' , 'Validation']].to_numpy().T
    
    input_array = input_array.reshape(-1 , 11 , time_len_per_trial)
    input_trials = labels_df.direction.map(mapping).values
    subjects = np.full(input_trials.shape, "S13", dtype=object)
    acc_channel = np.linalg.norm(input_array[:,4:7,:],axis = 1)
    gyro_channel = np.linalg.norm(input_array[:,7:10,:],axis = 1)
    Validation_Channel = input_array[:,10,:]
    input_array[:,4,:] = acc_channel
    input_array[:,5,:] = gyro_channel
    input_array[:,6,:] = Validation_Channel

    input_array = input_array[:,:7,:]



    preprocessed_test_data , preprocessed_test_labels , preprocessed_test_subject_ids , weights_test = signal_processer.apply_preprocessing(input_array, input_trials , subjects)
    
    num_windows_per_trial = signal_processer.num_windows_per_trial

    
    return (
        torch.from_numpy(preprocessed_test_data).to(torch.float32),
        torch.from_numpy(preprocessed_test_labels).to(torch.long),
        torch.full(preprocessed_test_subject_ids.shape , 12).to(torch.long),
        torch.from_numpy(weights_test).to(torch.float32),
        num_windows_per_trial
        )

In [20]:
device = torch.device("cuda")

from torchinfo import summary
model_mi_1  = MTCFormer(depth=3,
                    kernel_size=50,
                    modulator_kernel_size=30,
                    n_times=600,
                    chs_num=7,
                    eeg_ch_nums=4,
                    class_num=2,
                    class_num_domain=30,
                    modulator_dropout=0.48929137963218305,
                    mid_dropout=0.5,
                    output_dropout=0.42685917257840517,
                    k=100,
                    projection_dimention=2,
                    seed=4224
                    ).to(device)

for i , param in enumerate(model_mi_1.parameters()):
    if i >80:
        break
    param.requires_grad = False



optimizer = Adam(model_mi_1.parameters(), lr=0.002)

checkpoint_path = os.path.join(
    "..",
    "checkpoints",
    "model_1_mi_checkpoint",
    "best_model_.pth"
    )

checkpoint = torch.load(checkpoint_path, weights_only=False)

model_mi_1.load_state_dict(checkpoint['model_state_dict'] , strict=True)

summary(model_mi_1 , input_size=(1 , 7 , 600))


Layer (type:depth-idx)                             Output Shape              Param #
MTCFormer                                          [1, 2]                    --
├─TemporalModulator: 1-1                           [1, 4, 600]               --
│    └─Conv1d: 2-1                                 [1, 4, 600]               (364)
│    └─ChannelWiseLayerNorm: 2-2                   [1, 4, 600]               (8)
│    └─Dropout: 2-3                                [1, 4, 600]               --
│    └─Sigmoid: 2-4                                [1, 4, 600]               --
├─Conv1d: 1-2                                      [1, 8, 600]               (40)
├─ChannelWiseLayerNorm: 1-3                        [1, 8, 600]               (16)
├─GELU: 1-4                                        [1, 8, 600]               --
├─Dropout: 1-5                                     [1, 8, 600]               --
├─ConvolutionalAttention: 1-6                      [1, 8, 600]               --
│    └─ModuleList: 2-5     

In [21]:


print("Preprocessing data for model 1 ..... ")


preprocessor = SignalPreprocessor(
    fs=250,                                             
    bandpass_low=6.0,                     
    bandpass_high=24.0,                  
    n_cols_to_filter=4,                   
    window_size=600,                      
    window_stride=70,                    
    idx_to_ignore_normalization=-1,        
    crop_range=(2.5 , 7)              
)



input_array , labels , subject_ids , weigths  , num_windows_per_trial = preprocess_fast(
        trial_df,
        labels_trial_df,
        mapping=mapping_mi,
        task="MI",
        signal_processer = preprocessor
        )

Preprocessing data for model 1 ..... 


In [22]:
from utils.CustomDataset import EEGDataset
from utils.augmentation import augment_data


print("Data Preparation.... Wrapping preprocessed data inside tensor datasets....",end = "\n\n")


from torch.utils.data import DataLoader
from utils.CustomDataset import EEGDataset
from torch.nn import *
from torch.optim.lr_scheduler import *
n_trials_to_choose = 12
n_windows_to_choose = num_windows_per_trial*n_trials_to_choose
training_dataset = EEGDataset(
    data_tensor=input_array[:n_windows_to_choose, : , :],
    weigths=weigths[:n_windows_to_choose],
    label_tensor = labels[:n_windows_to_choose],
    subject_labels=subject_ids[:n_windows_to_choose]
    
    )

val_dataset = EEGDataset(
    data_tensor=input_array[n_windows_to_choose:, : , :],
    weigths=weigths[n_windows_to_choose:],
    label_tensor = labels[n_windows_to_choose:],
    subject_labels=subject_ids[n_windows_to_choose:]
)

trial_level_labels = labels[n_windows_to_choose:].clone()

train_loader = DataLoader(training_dataset , batch_size=3 , shuffle = True)
val_loader = DataLoader(val_dataset , batch_size=val_dataset.__len__() , shuffle = False)

Data Preparation.... Wrapping preprocessed data inside tensor datasets....



In [23]:
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import *
save_path=os.path.join("tuned_checkpoints","model_1_mi")
training_loop_params = {
    "criterion": CrossEntropyLoss(reduction="none"),


    "optimizer_class": Adam,
    "optimizer_config": {
        "lr": 0.002,
        "weight_decay":1.6581884226239174e-05
    },


    "scheduler_class": CosineAnnealingLR,
    "scheduler_config":  {
            "T_max": 121,
            "eta_min": 0.00019632120929158656
        },

    "window_len": num_windows_per_trial,
    "n_epochs": 600,
    "patience": 100 ,
    "domain_lambda": 0.05,
    "lambda_scheduler_fn": None,
    "adversarial_steps": 1,
    "adversarial_epsilon": 0.1,
    "adversarial_alpha": 0.005,
    "adversarial_training": True,
    "save_best_only": True,
    "save_path": save_path,
    "n_classes": 2,
    "device": device,
    "update_loader": (10, 100),
    "scheduler_fn": None ,
    "tensorboard":True,
}

train_model(model_mi_1,
        train_loader=train_loader,
        val_loader=val_loader,
        original_val_labels=trial_level_labels,
        **training_loop_params
    )


Launching TensorBoard at http://localhost:6006/ ...
Path Exists. Contents of this folder will be modified save_path is :  tuned_checkpoints/model_1_mi
--- Starting Training Loop ---
✅ Best checkpoint updated (save_best_only=True). at  tuned_checkpoints/model_1_mi
🟢 Validation Balanced Accuracy improved. Saving best model state...
Epoch [1/600] - Train Loss: 138168.0279, Train F1: 0.4353 | adversarial F1 : 0.4103 | Val Loss: 0.6832, Val F1: 0.3826 | Time: 2.3s | Balanced Accuracy Val: 0.5750 | - Domain Loss: 0.2827
✅ Best checkpoint updated (save_best_only=True). at  tuned_checkpoints/model_1_mi
🟢 Validation Balanced Accuracy improved. Saving best model state...
Epoch [2/600] - Train Loss: 120262.3506, Train F1: 0.3615 | adversarial F1 : 0.3407 | Val Loss: 0.6711, Val F1: 0.4333 | Time: 2.5s | Balanced Accuracy Val: 0.6000 | - Domain Loss: 0.0057
🟡 No improvement for 1 epochs.
Epoch [3/600] - Train Loss: 114511.6159, Train F1: 0.3939 | adversarial F1 : 0.4534 | Val Loss: 0.6800, Val F1:

np.float64(0.5166666666666666)

In [24]:
print("extracting data for futher preprocessing...",end = "\n\n")

mapping_ssvep = {
    "Backward":0,
    "Forward":1,
    "Left":2,
    "Right":3
}
inverse_mapping_mi = {
    v:k for k , v in mapping_mi.items()
} 





from utils.preprocessing import SignalPreprocessor
train_mi_new = pd.read_csv("../data_my_head/SSVEP/1/EEGdata.csv")
train_labels_new = pd.read_csv("../data_my_head/SSVEP/1/trial_labels.csv")
val_mi = pd.read_csv("../data_my_head/SSVEP/2/EEGdata.csv")
val_labels = pd.read_csv("../data_my_head/SSVEP/2/trial_labels.csv")


mapping_ssvep = {
    "Backward":0,
    "Forward":1,
    "Left":2,
    "Right":3
}
trial_df = pd.concat([
    train_mi_new , val_mi
]).reset_index(drop=True)

labels_trial_df = pd.concat([
    train_labels_new , val_labels
]).reset_index(drop=True)


preprocessor = SignalPreprocessor(
    fs=250,                                                 
    bandpass_low=8,                     
    bandpass_high=14,                  
    n_cols_to_filter=4,                   
    window_size=500,                      
    window_stride=50,                    
    idx_to_ignore_normalization=-1,        
    crop_range=(1.5 , 6)            
)


input_array , labels , subject_ids , weigths  , num_windows_per_trial = preprocess_fast(
        trial_df,
        labels_trial_df,
        mapping=mapping_ssvep,
        task="SSVEP",
        signal_processer = preprocessor
        )


extracting data for futher preprocessing...



array([2, 0, 3, 3, 3, 2, 3, 3, 2, 1])

In [109]:
from utils.CustomDataset import EEGDataset
from utils.augmentation import augment_data


print("Data Preparation.... Wrapping preprocessed data inside tensor datasets....",end = "\n\n")


from torch.utils.data import DataLoader
from utils.CustomDataset import EEGDataset
from torch.nn import *
from torch.optim.lr_scheduler import *
n_trials_to_choose = 12
n_windows_to_choose = num_windows_per_trial*n_trials_to_choose
training_dataset = EEGDataset(
    data_tensor=input_array[:n_windows_to_choose, : , :],
    weigths=weigths[:n_windows_to_choose],
    label_tensor = labels[:n_windows_to_choose],
    subject_labels=subject_ids[:n_windows_to_choose]
    
    )

val_dataset = EEGDataset(
    data_tensor=input_array[n_windows_to_choose:, : , :],
    weigths=weigths[n_windows_to_choose:],
    label_tensor = labels[n_windows_to_choose:],
    subject_labels=subject_ids[n_windows_to_choose:]
)

trial_level_labels = val_labels.direction.map(mapping_ssvep).values

train_loader = DataLoader(training_dataset , batch_size=3 , shuffle = True)
val_loader = DataLoader(val_dataset , batch_size=val_dataset.__len__() , shuffle = False)

Data Preparation.... Wrapping preprocessed data inside tensor datasets....



In [59]:
labels[:n_windows_to_choose:]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [63]:
from model.MTCformerV2 import MTCFormer
from torch.optim.lr_scheduler import *
from utils.training import train_model 
model_ssvep = MTCFormer(
    depth=5,
    kernel_size=50,
    n_times=500,
    chs_num=7,
    eeg_ch_nums=4,
    class_num=4,
    class_num_domain=30,
    modulator_kernel_size=30,
    domain_dropout=0.4,
    modulator_dropout=0.4,
    mid_dropout=0.4,
    output_dropout=0.6,
    k=100,
    projection_dimention=2,
    seed = 5445
).to(device)
optimizer = Adam(model_ssvep.parameters(), lr=0.002)

checkpoint_path = os.path.join(
    "..",
    "checkpoints",
    "model_ssvep_checkpoint",
    "best_model_.pth"
    )

checkpoint = torch.load(checkpoint_path, weights_only=False)

model_ssvep.load_state_dict(checkpoint['model_state_dict'] , strict=False)


for i , param in enumerate(model_mi_1.parameters()):
    if i >50:
        break
    param.requires_grad = False

In [64]:
save_path = os.path.join("","tuned_checkpoints","model_ssvep")

training_loop_params = {
    "criterion": CrossEntropyLoss(reduction="none"),


    "optimizer_class": Adam,
    "optimizer_config": {
        "lr": 0.001,
        "weight_decay":0
    },


    "scheduler_class": None,
    "scheduler_config":  {
            "T_max": 121,
            "eta_min": 0.00019632120929158656
        },

    "window_len": num_windows_per_trial,
    "n_epochs": 600,
    "patience": 600 ,
    "domain_lambda": 0.2,
    "lambda_scheduler_fn": None,
    "adversarial_steps": 1,
    "adversarial_epsilon": 0.05,
    "adversarial_alpha": 0.005,
    "adversarial_training": True,
    "save_best_only": True,
    "save_path": save_path,
    "n_classes": 4,
    "device": device,
    "update_loader": (20, 100),
    "scheduler_fn": None  # Optional: your own dynamic LR adjustment function
}

train_model(model_ssvep,
            train_loader=train_loader,
            val_loader=val_loader,
            original_val_labels=trial_level_labels,
            **training_loop_params
    )

Path Exists. Contents of this folder will be modified save_path is :  tuned_checkpoints/model_ssvep
--- Starting Training Loop ---




✅ Best checkpoint updated (save_best_only=True). at  tuned_checkpoints/model_ssvep
🟢 Validation Balanced Accuracy improved. Saving best model state...
Epoch [1/600] - Train Loss: 202359.8548, Train F1: 0.3177 | adversarial F1 : 0.3061 | Val Loss: 1.4450, Val F1: 0.0000 | Time: 2.3s | Balanced Accuracy Val: 0.0000 | - Domain Loss: 0.1665




✅ Best checkpoint updated (save_best_only=True). at  tuned_checkpoints/model_ssvep
🟢 Validation Balanced Accuracy improved. Saving best model state...
Epoch [2/600] - Train Loss: 128734.8483, Train F1: 0.3544 | adversarial F1 : 0.3563 | Val Loss: 1.4602, Val F1: 0.0373 | Time: 1.9s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0291




✅ Best checkpoint updated (save_best_only=True). at  tuned_checkpoints/model_ssvep
🟢 Validation Balanced Accuracy improved. Saving best model state...
Epoch [3/600] - Train Loss: 97142.1109, Train F1: 0.4125 | adversarial F1 : 0.4509 | Val Loss: 1.4292, Val F1: 0.0725 | Time: 2.1s | Balanced Accuracy Val: 0.0205 | - Domain Loss: 0.0135




🟡 No improvement for 1 epochs.
Epoch [4/600] - Train Loss: 67687.9149, Train F1: 0.4591 | adversarial F1 : 0.4684 | Val Loss: 1.5307, Val F1: 0.0189 | Time: 2.0s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0304




🟡 No improvement for 2 epochs.
Epoch [5/600] - Train Loss: 37785.4634, Train F1: 0.5075 | adversarial F1 : 0.4889 | Val Loss: 1.5702, Val F1: 0.0189 | Time: 2.0s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0369




🟡 No improvement for 3 epochs.
Epoch [6/600] - Train Loss: 32222.4058, Train F1: 0.4875 | adversarial F1 : 0.4766 | Val Loss: 1.5354, Val F1: 0.0189 | Time: 1.9s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0048




🟡 No improvement for 4 epochs.
Epoch [7/600] - Train Loss: 34035.0912, Train F1: 0.5040 | adversarial F1 : 0.4899 | Val Loss: 1.5369, Val F1: 0.0373 | Time: 2.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0005




🟡 No improvement for 5 epochs.
Epoch [8/600] - Train Loss: 11665.4046, Train F1: 0.5324 | adversarial F1 : 0.4802 | Val Loss: 1.5499, Val F1: 0.0373 | Time: 2.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0034




🟡 No improvement for 6 epochs.
Epoch [9/600] - Train Loss: 10454.8047, Train F1: 0.5356 | adversarial F1 : 0.5203 | Val Loss: 1.5754, Val F1: 0.0373 | Time: 2.2s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0012




🟡 No improvement for 7 epochs.
Epoch [10/600] - Train Loss: 18607.9756, Train F1: 0.5245 | adversarial F1 : 0.5327 | Val Loss: 1.6062, Val F1: 0.0189 | Time: 2.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0012




🟡 No improvement for 8 epochs.
Epoch [11/600] - Train Loss: 13839.9081, Train F1: 0.5336 | adversarial F1 : 0.4988 | Val Loss: 1.6185, Val F1: 0.0189 | Time: 1.9s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0015




🟡 No improvement for 9 epochs.
Epoch [12/600] - Train Loss: 3212.2289, Train F1: 0.5327 | adversarial F1 : 0.5459 | Val Loss: 1.5497, Val F1: 0.0725 | Time: 1.9s | Balanced Accuracy Val: 0.0205 | - Domain Loss: 0.0003




🟡 No improvement for 10 epochs.
Epoch [13/600] - Train Loss: 7170.4608, Train F1: 0.5446 | adversarial F1 : 0.5249 | Val Loss: 1.5592, Val F1: 0.0725 | Time: 1.7s | Balanced Accuracy Val: 0.0205 | - Domain Loss: 0.0009




🟡 No improvement for 11 epochs.
Epoch [14/600] - Train Loss: 17949.6081, Train F1: 0.5505 | adversarial F1 : 0.5229 | Val Loss: 1.6644, Val F1: 0.0189 | Time: 2.2s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0046




🟡 No improvement for 12 epochs.
Epoch [15/600] - Train Loss: 14983.4355, Train F1: 0.5124 | adversarial F1 : 0.5247 | Val Loss: 1.6401, Val F1: 0.0189 | Time: 2.0s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 13 epochs.
Epoch [16/600] - Train Loss: 3774.4271, Train F1: 0.5311 | adversarial F1 : 0.5349 | Val Loss: 1.6341, Val F1: 0.0189 | Time: 1.7s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0001




🟡 No improvement for 14 epochs.
Epoch [17/600] - Train Loss: 2939.9902, Train F1: 0.5393 | adversarial F1 : 0.5311 | Val Loss: 1.6292, Val F1: 0.0189 | Time: 1.8s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0012




🟡 No improvement for 15 epochs.
Epoch [18/600] - Train Loss: 2455.7423, Train F1: 0.5299 | adversarial F1 : 0.5274 | Val Loss: 1.6308, Val F1: 0.0189 | Time: 2.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 16 epochs.
Epoch [19/600] - Train Loss: 1976.1682, Train F1: 0.5320 | adversarial F1 : 0.5342 | Val Loss: 1.6428, Val F1: 0.0189 | Time: 2.8s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 17 epochs.
Epoch [20/600] - Train Loss: 1598.5415, Train F1: 0.5205 | adversarial F1 : 0.5252 | Val Loss: 1.6553, Val F1: 0.0189 | Time: 2.2s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 18 epochs.
Epoch [21/600] - Train Loss: 981.0979, Train F1: 0.5251 | adversarial F1 : 0.5357 | Val Loss: 1.6554, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005




🟡 No improvement for 19 epochs.
Epoch [22/600] - Train Loss: 2193.4669, Train F1: 0.5350 | adversarial F1 : 0.5204 | Val Loss: 1.6555, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 20 epochs.
Epoch [23/600] - Train Loss: 2708.1291, Train F1: 0.5376 | adversarial F1 : 0.5241 | Val Loss: 1.6556, Val F1: 0.0189 | Time: 0.2s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0009




🟡 No improvement for 21 epochs.
Epoch [24/600] - Train Loss: 2128.9395, Train F1: 0.5420 | adversarial F1 : 0.5331 | Val Loss: 1.6556, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0006
🟡 No improvement for 22 epochs.
Epoch [25/600] - Train Loss: 1751.2299, Train F1: 0.5280 | adversarial F1 : 0.5435 | Val Loss: 1.6556, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005
🟡 No improvement for 23 epochs.
Epoch [26/600] - Train Loss: 2002.3940, Train F1: 0.5547 | adversarial F1 : 0.5226 | Val Loss: 1.6556, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 24 epochs.
Epoch [27/600] - Train Loss: 3127.1494, Train F1: 0.5293 | adversarial F1 : 0.5278 | Val Loss: 1.6557, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0007
🟡 No improvement for 25 epochs.
Epoch [28/600] - Train Loss: 2683.9096, Train F1: 0.5199 | adversarial F1 : 0.5317 | Val Loss: 1.6549, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 26 epochs.
Epoch [29/600] - Train Loss: 5997.8944, Train F1: 0.5350 | adversarial F1 : 0.5319 | Val Loss: 1.6538, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 27 epochs.
Epoch [30/600] - Train Loss: 1892.2319, Train F1: 0.5561 | adversarial F1 : 0.5448 | Val Loss: 1.6528, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 28 epochs.
Epoch [31/600] - Train Loss: 3312.1884, Train F1: 0.5325 | adversarial F1 : 0.5313 | Val Loss: 1.6518, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 29 epochs.
Epoch [32/600] - Train Loss: 2762.7992, Train F1: 0.5315 | adversarial F1 : 0.5414 | Val Loss: 1.6511, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 30 epochs.
Epoch [33/600] - Train Loss: 1360.0338, Train F1: 0.5480 | adversarial F1 : 0.5264 | Val Loss: 1.6505, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0006
🟡 No improvement for 31 epochs.
Epoch [34/600] - Train Loss: 1596.6107, Train F1: 0.5339 | adversarial F1 : 0.5232 | Val Loss: 1.6499, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0007




🟡 No improvement for 32 epochs.
Epoch [35/600] - Train Loss: 1853.7323, Train F1: 0.5255 | adversarial F1 : 0.5246 | Val Loss: 1.6493, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 33 epochs.
Epoch [36/600] - Train Loss: 1324.5178, Train F1: 0.5425 | adversarial F1 : 0.5241 | Val Loss: 1.6488, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 34 epochs.
Epoch [37/600] - Train Loss: 4391.3284, Train F1: 0.5333 | adversarial F1 : 0.5336 | Val Loss: 1.6486, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005
🟡 No improvement for 35 epochs.
Epoch [38/600] - Train Loss: 1581.1160, Train F1: 0.5352 | adversarial F1 : 0.5380 | Val Loss: 1.6484, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 36 epochs.
Epoch [39/600] - Train Loss: 2234.2411, Train F1: 0.5423 | adversarial F1 : 0.5326 | Val Loss: 1.6480, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 37 epochs.
Epoch [40/600] - Train Loss: 2137.2474, Train F1: 0.5308 | adversarial F1 : 0.5434 | Val Loss: 1.6478, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 38 epochs.
Epoch [41/600] - Train Loss: 2718.6392, Train F1: 0.5375 | adversarial F1 : 0.5324 | Val Loss: 1.6476, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0006
🟡 No improvement for 39 epochs.
Epoch [42/600] - Train Loss: 1718.5934, Train F1: 0.5400 | adversarial F1 : 0.5280 | Val Loss: 1.6474, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 40 epochs.
Epoch [43/600] - Train Loss: 1471.4797, Train F1: 0.5274 | adversarial F1 : 0.5426 | Val Loss: 1.6472, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002
🟡 No improvement for 41 epochs.
Epoch [44/600] - Train Loss: 2790.3940, Train F1: 0.5297 | adversarial F1 : 0.5370 | Val Loss: 1.6474, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 42 epochs.
Epoch [45/600] - Train Loss: 2678.1515, Train F1: 0.5350 | adversarial F1 : 0.5251 | Val Loss: 1.6475, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005
🟡 No improvement for 43 epochs.
Epoch [46/600] - Train Loss: 3248.0278, Train F1: 0.5289 | adversarial F1 : 0.5268 | Val Loss: 1.6473, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 44 epochs.
Epoch [47/600] - Train Loss: 3444.5768, Train F1: 0.5399 | adversarial F1 : 0.5236 | Val Loss: 1.6472, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 45 epochs.
Epoch [48/600] - Train Loss: 3540.9200, Train F1: 0.5337 | adversarial F1 : 0.5383 | Val Loss: 1.6469, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 46 epochs.
Epoch [49/600] - Train Loss: 1845.3062, Train F1: 0.5523 | adversarial F1 : 0.5264 | Val Loss: 1.6468, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 47 epochs.
Epoch [50/600] - Train Loss: 1626.9333, Train F1: 0.5255 | adversarial F1 : 0.5232 | Val Loss: 1.6468, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005




🟡 No improvement for 48 epochs.
Epoch [51/600] - Train Loss: 3016.0574, Train F1: 0.5359 | adversarial F1 : 0.5381 | Val Loss: 1.6474, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005
🟡 No improvement for 49 epochs.
Epoch [52/600] - Train Loss: 1348.7395, Train F1: 0.5366 | adversarial F1 : 0.5236 | Val Loss: 1.6479, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0005




🟡 No improvement for 50 epochs.
Epoch [53/600] - Train Loss: 2328.3300, Train F1: 0.5336 | adversarial F1 : 0.5184 | Val Loss: 1.6486, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 51 epochs.
Epoch [54/600] - Train Loss: 2680.2032, Train F1: 0.5434 | adversarial F1 : 0.5229 | Val Loss: 1.6492, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 52 epochs.
Epoch [55/600] - Train Loss: 366.5923, Train F1: 0.5320 | adversarial F1 : 0.5281 | Val Loss: 1.6497, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0006
🟡 No improvement for 53 epochs.
Epoch [56/600] - Train Loss: 1591.4360, Train F1: 0.5347 | adversarial F1 : 0.5414 | Val Loss: 1.6502, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 54 epochs.
Epoch [57/600] - Train Loss: 2197.3213, Train F1: 0.5377 | adversarial F1 : 0.5309 | Val Loss: 1.6506, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 55 epochs.
Epoch [58/600] - Train Loss: 1924.0615, Train F1: 0.5315 | adversarial F1 : 0.5268 | Val Loss: 1.6510, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 56 epochs.
Epoch [59/600] - Train Loss: 2090.0186, Train F1: 0.5268 | adversarial F1 : 0.5293 | Val Loss: 1.6514, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 57 epochs.
Epoch [60/600] - Train Loss: 1934.6904, Train F1: 0.5401 | adversarial F1 : 0.5212 | Val Loss: 1.6521, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 58 epochs.
Epoch [61/600] - Train Loss: 1534.5809, Train F1: 0.5247 | adversarial F1 : 0.5300 | Val Loss: 1.6527, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 59 epochs.
Epoch [62/600] - Train Loss: 955.5557, Train F1: 0.5458 | adversarial F1 : 0.5359 | Val Loss: 1.6532, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 60 epochs.
Epoch [63/600] - Train Loss: 2106.1999, Train F1: 0.5386 | adversarial F1 : 0.5281 | Val Loss: 1.6536, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 61 epochs.
Epoch [64/600] - Train Loss: 1512.0547, Train F1: 0.5458 | adversarial F1 : 0.5397 | Val Loss: 1.6539, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 62 epochs.
Epoch [65/600] - Train Loss: 1472.6157, Train F1: 0.5379 | adversarial F1 : 0.5212 | Val Loss: 1.6543, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008
🟡 No improvement for 63 epochs.
Epoch [66/600] - Train Loss: 3661.6154, Train F1: 0.5199 | adversarial F1 : 0.5437 | Val Loss: 1.6547, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 64 epochs.
Epoch [67/600] - Train Loss: 2312.8200, Train F1: 0.5509 | adversarial F1 : 0.5260 | Val Loss: 1.6552, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 65 epochs.
Epoch [68/600] - Train Loss: 2503.5032, Train F1: 0.5255 | adversarial F1 : 0.5395 | Val Loss: 1.6555, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 66 epochs.
Epoch [69/600] - Train Loss: 1625.7704, Train F1: 0.5403 | adversarial F1 : 0.5268 | Val Loss: 1.6559, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 67 epochs.
Epoch [70/600] - Train Loss: 2268.3118, Train F1: 0.5278 | adversarial F1 : 0.5251 | Val Loss: 1.6559, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 68 epochs.
Epoch [71/600] - Train Loss: 1893.4830, Train F1: 0.5370 | adversarial F1 : 0.5296 | Val Loss: 1.6560, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 69 epochs.
Epoch [72/600] - Train Loss: 1342.2848, Train F1: 0.5348 | adversarial F1 : 0.5236 | Val Loss: 1.6561, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 70 epochs.
Epoch [73/600] - Train Loss: 1896.5152, Train F1: 0.5292 | adversarial F1 : 0.5275 | Val Loss: 1.6563, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 71 epochs.
Epoch [74/600] - Train Loss: 3432.8021, Train F1: 0.5352 | adversarial F1 : 0.5464 | Val Loss: 1.6562, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 72 epochs.
Epoch [75/600] - Train Loss: 4231.6480, Train F1: 0.5325 | adversarial F1 : 0.5354 | Val Loss: 1.6561, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 73 epochs.
Epoch [76/600] - Train Loss: 832.3725, Train F1: 0.5299 | adversarial F1 : 0.5288 | Val Loss: 1.6559, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 74 epochs.
Epoch [77/600] - Train Loss: 3006.5588, Train F1: 0.5303 | adversarial F1 : 0.5232 | Val Loss: 1.6558, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 75 epochs.
Epoch [78/600] - Train Loss: 1296.2320, Train F1: 0.5504 | adversarial F1 : 0.5209 | Val Loss: 1.6558, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 76 epochs.
Epoch [79/600] - Train Loss: 1256.5370, Train F1: 0.5281 | adversarial F1 : 0.5303 | Val Loss: 1.6559, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 77 epochs.
Epoch [80/600] - Train Loss: 1779.6318, Train F1: 0.5470 | adversarial F1 : 0.5222 | Val Loss: 1.6559, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 78 epochs.
Epoch [81/600] - Train Loss: 3132.1408, Train F1: 0.5423 | adversarial F1 : 0.5368 | Val Loss: 1.6560, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 79 epochs.
Epoch [82/600] - Train Loss: 3677.2286, Train F1: 0.5350 | adversarial F1 : 0.5315 | Val Loss: 1.6565, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 80 epochs.
Epoch [83/600] - Train Loss: 3605.5646, Train F1: 0.5247 | adversarial F1 : 0.5288 | Val Loss: 1.6560, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002
🟡 No improvement for 81 epochs.
Epoch [84/600] - Train Loss: 2498.5545, Train F1: 0.5307 | adversarial F1 : 0.5271 | Val Loss: 1.6556, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0008




🟡 No improvement for 82 epochs.
Epoch [85/600] - Train Loss: 2616.3123, Train F1: 0.5404 | adversarial F1 : 0.5202 | Val Loss: 1.6554, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 83 epochs.
Epoch [86/600] - Train Loss: 1216.7101, Train F1: 0.5202 | adversarial F1 : 0.5366 | Val Loss: 1.6553, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002




🟡 No improvement for 84 epochs.
Epoch [87/600] - Train Loss: 2308.1695, Train F1: 0.5504 | adversarial F1 : 0.5425 | Val Loss: 1.6552, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 85 epochs.
Epoch [88/600] - Train Loss: 714.1942, Train F1: 0.5410 | adversarial F1 : 0.5445 | Val Loss: 1.6551, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003




🟡 No improvement for 86 epochs.
Epoch [89/600] - Train Loss: 1162.6927, Train F1: 0.5326 | adversarial F1 : 0.5239 | Val Loss: 1.6550, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0003
🟡 No improvement for 87 epochs.
Epoch [90/600] - Train Loss: 1420.3616, Train F1: 0.5341 | adversarial F1 : 0.5330 | Val Loss: 1.6551, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004




🟡 No improvement for 88 epochs.
Epoch [91/600] - Train Loss: 1772.3360, Train F1: 0.5344 | adversarial F1 : 0.5292 | Val Loss: 1.6551, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0002
🟡 No improvement for 89 epochs.
Epoch [92/600] - Train Loss: 1212.4930, Train F1: 0.5296 | adversarial F1 : 0.5279 | Val Loss: 1.6552, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0006




🟡 No improvement for 90 epochs.
Epoch [93/600] - Train Loss: 2178.7081, Train F1: 0.5315 | adversarial F1 : 0.5229 | Val Loss: 1.6536, Val F1: 0.0189 | Time: 0.1s | Balanced Accuracy Val: 0.0051 | - Domain Loss: 0.0004
🟡 No improvement for 91 epochs.
Epoch [94/600] - Train Loss: 1363.5454, Train F1: 0.5256 | adversarial F1 : 0.5354 | Val Loss: 1.6512, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0003




🟡 No improvement for 92 epochs.
Epoch [95/600] - Train Loss: 2156.1006, Train F1: 0.5317 | adversarial F1 : 0.5222 | Val Loss: 1.6492, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0002
🟡 No improvement for 93 epochs.
Epoch [96/600] - Train Loss: 2338.0721, Train F1: 0.5430 | adversarial F1 : 0.5445 | Val Loss: 1.6477, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0002




🟡 No improvement for 94 epochs.
Epoch [97/600] - Train Loss: 1059.3993, Train F1: 0.5504 | adversarial F1 : 0.5466 | Val Loss: 1.6469, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0004
🟡 No improvement for 95 epochs.
Epoch [98/600] - Train Loss: 908.9614, Train F1: 0.5288 | adversarial F1 : 0.5339 | Val Loss: 1.6463, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0002




🟡 No improvement for 96 epochs.
Epoch [99/600] - Train Loss: 1138.6711, Train F1: 0.5237 | adversarial F1 : 0.5258 | Val Loss: 1.6458, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0012
🟡 No improvement for 97 epochs.
Epoch [100/600] - Train Loss: 1457.7355, Train F1: 0.5423 | adversarial F1 : 0.5322 | Val Loss: 1.6454, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0005
🟡 No improvement for 98 epochs.
Epoch [101/600] - Train Loss: 1341.6254, Train F1: 0.5420 | adversarial F1 : 0.5266 | Val Loss: 1.6451, Val F1: 0.0373 | Time: 0.1s | Balanced Accuracy Val: 0.0103 | - Domain Loss: 0.0002
---BREAKING TRAINING----




np.float64(0.04648829431438127)

In [101]:
np.set_printoptions(suppress=True)



In [103]:
trial_level_labels

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])

In [102]:
predict_optimized(
    model_ssvep,
    windows_per_trial=num_windows_per_trial,
    loader=val_loader,
    probability=True
)

array([[0.32958537, 0.03263313, 0.38126412, 0.25651738],
       [0.3255843 , 0.18979754, 0.31974265, 0.16487542],
       [0.4849823 , 0.00704174, 0.48517603, 0.02280002],
       [0.5536281 , 0.04972099, 0.36318922, 0.03346165],
       [0.03661281, 0.8781843 , 0.03940997, 0.04579278],
       [0.01608884, 0.95102745, 0.0132519 , 0.01963184],
       [0.00012568, 0.99934286, 0.0001578 , 0.0003736 ],
       [0.00014678, 0.00019971, 0.0002484 , 0.99940515]], dtype=float32)

In [83]:
labels_trial_df

Unnamed: 0,trial,direction
0,1,Forward
1,2,Right
2,3,Forward
3,4,Left
4,5,Right
5,6,Left
6,7,Forward
7,8,Backward
8,9,Right
9,10,Backward


In [92]:
n = 8
predict(model_ssvep.to(torch.device("cpu")) , trial_df[n*1750:(n+1)*1750])

'Right'

In [78]:
def predict(model, df):
    """
    Preprocess the raw MI data and make a prediction.

    Args:
        model: The loaded model object.
        df (pd.DataFrame): The raw trial data.

    Returns:
        str: The prediction, which must be "left", "right", or "?".
    """
    def preprocess_optimized(
            trial_df,
            signal_processer = None
            ):
        eeg_col =  ['OZ', 'PO7', 'PO8', 'PZ']


        input_array = trial_df.drop(columns = ["Time" , "Battery" , "Counter"])[eeg_col+['AccX',
            'AccY', 'AccZ', 'Gyro1', 'Gyro2', 'Gyro3' , 'Validation']].to_numpy().T
        
        acc_channel = np.linalg.norm(input_array[4:7,:],axis = 0)
        gyro_channel = np.linalg.norm(input_array[7:10,:],axis = 0)
        Validation_Channel = input_array[10,:]
        input_array[4,:] = acc_channel
        input_array[5,:] = gyro_channel
        input_array[6,:] = Validation_Channel

        input_array = input_array[:7,:]


        preprocessed_test_data , _ , _ , weights_test = signal_processer.apply_preprocessing(np.expand_dims(input_array,axis=0), np.array([None]) , np.array([1]))
        
        num_windows_per_trial = signal_processer.num_windows_per_trial

        
        return (
            torch.from_numpy(preprocessed_test_data).to(torch.float32),
           -1,
            -1,
            torch.from_numpy(weights_test).to(torch.float32),
            num_windows_per_trial
            )

    preprocessor = SignalPreprocessor(
    fs=250,                                                 
    bandpass_low=8,                     
    bandpass_high=14,                  
    n_cols_to_filter=4,                   
    window_size=500,                      
    window_stride=50,                    
    idx_to_ignore_normalization=-1,        
    crop_range=(1.5 , 6)            
)
    data , _ , _ , weights , windows_per_trial = preprocess_optimized(df,signal_processer=preprocessor)

    
    index_to_label = {
        0: "Backward",
        1: "Forward",
        2: "Left",
        3: "Right"
    }

    prediction = predict_optimized(
        model=model,
        windows_per_trial=windows_per_trial,
        loader= (data , weights)

    )

    return index_to_label[prediction[0]]