In [1]:
import sys , os

# -----------------------------------------------------------------------------
# Project Path Setup
#    - Add project root to Python path for consistent imports
# -----------------------------------------------------------------------------



sys.path.append( "..")


# -----------------------------------------------------------------------------
#  Core Imports
#   - Model.MTCFormer
#   - data extractors: extract_data,extract_subject_labels
#   - training utils: predict
#   - augmentation: augment_data
#   - datasets: The custom EEGDataset
# -----------------------------------------------------------------------------



from model.MTCformerV3 import MTCFormer
import torch
import glob
import numpy as np
import pandas as pd
from torch.optim import Adam
from utils.extractors import extract_data,extract_subject_labels
from utils.preprocessing import preprocess_data,preprocess_one_file
from utils.augmentation import augment_data
from utils.training import  predict
from utils.CustomDataset import EEGDataset
from torch.utils.data import DataLoader


In [4]:
DATA_FIF_DIR = "../data_fif"


test_file_paths_mi = glob.glob(os.path.join(DATA_FIF_DIR, "test/MI/*.fif"))
train_file_paths_mi = glob.glob(os.path.join(DATA_FIF_DIR, "train/MI/*.fif"))
val_file_paths_mi = glob.glob(os.path.join(DATA_FIF_DIR, "validation/MI/*.fif"))





mapping_mi = {
    "Left":0,
    "Right":1
}   

test_data_mi , test_labels_mi , ids_mi = extract_data(test_file_paths_mi , return_id = True)
train_data_mi , train_labels_mi , ids = extract_data(train_file_paths_mi , return_id = True)
val_data_mi , val_labels_mi , ids = extract_data(val_file_paths_mi , return_id = True)




train_labels_mi_mapped = np.array([mapping_mi[x] for x in train_labels_mi])
val_labels_mi_mapped = np.array([mapping_mi[x] for x in val_labels_mi])


train_subject_labels = extract_subject_labels(train_data_mi)
val_subject_labels = extract_subject_labels(val_data_mi)
test_subject_labels = extract_subject_labels(test_data_mi)


# -----------------------------------------------------------------------------
# [Optional] Test Pipeline Mode: --test_pipeline
#
#    - When this flag is passed, the script enables a lightweight debug mode
#      intended for fast testing and development.
#
#    - Instead of using the full dataset, it truncates the number of samples
#      from each split train to only the **first 50 trials**.
#
#    - This affects:
#        - Raw EEG data (`train_data_mi`, etc.)
#        - Corresponding labels (`train_labels_mi_mapped`, etc.)
#        - Subject ID labels (`train_subject_labels`, etc.)
#
#    - This mode is useful for verifying pipeline correctness quickly, without
#      waiting for full data to load or process.
#
# -----------------------------------------------------------------------------



from utils.preprocessing import preprocess_data,preprocess_one_file
print("Preprocessing data, This may take a while... ",end = "\n\n")


cols_to_pick = [
        'C3',
        'C4',
        'CZ',
        'FZ',
        'Acc_norm',
        'gyro_norm',
        'Validation'
          ]

params = {
    "cols_to_pick":cols_to_pick,
    "l_freq": 6,
    "h_freq": 30,
    "notch_freqs": [50, 100],
    "notch_width": 1.0,
    "window_size": 600,
    "window_stride": 35
}
train_data,weights_train,windowed_train_labels,subject_label_train_, WINDOW_LEN = preprocess_data(
    train_data_mi,
    labels =train_labels_mi_mapped,
    subject_labels = train_subject_labels,
    preprocess_func=preprocess_one_file,
    params = params,
    n_jobs=4
)


val_data,weights_val,windowed_val_labels,subject_label_val_, WINDOW_LEN = preprocess_data(
    val_data_mi,
    labels = val_labels_mi_mapped,
    subject_labels = val_subject_labels,
    preprocess_func=preprocess_one_file,
    params = params,
    n_jobs=4
)
test_data,weights_test, _ ,subject_label_test_, WINDOW_LEN= preprocess_data(
    test_data_mi,
    labels = test_labels_mi,
    subject_labels = test_subject_labels,
    preprocess_func=preprocess_one_file,
    params = params,
    n_jobs=4
    )




# -----------------------------------------------------------------------------
# 3. Convert preprocessed EEG windows into PyTorch-ready format
#
#    This stage converts numpy arrays from preprocessing into torch Tensors
#    with correct dtypes and wraps them into PyTorch-compatible Datasets and
#    DataLoaders. This makes the data pipeline ready for training.
#
#    Steps:
#    -------------------------------------------------------------------------
#    1. **Torch Conversion**:
#       - Input data (EEG windows) is cast to `torch.float32`.
#       - Labels and subject IDs are cast to `torch.long` (required for loss).
#       - Sample weights are cast to `torch.float32` for possible loss weighting.
#
#    2. **Test Data Handling**:
#       - Since test labels are unknown, placeholder zeros are used for compatibility.
#
#    3. **Custom Dataset**:
#       - We use a custom `EEGDataset` class that wraps:
#           - EEG windows
#           - Sample weights
#           - Class labels
#           - Subject labels (for subject-level analysis)
#       - Optional online data augmentation (enabled for training only).
#
#    4. **DataLoaders**:
#       - Train loader uses batching and shuffling for training.
#       - Val and Test loaders load the full set in one batch for deterministic evaluation.
#
#    5. **Device Setup**:
#       - Automatically selects GPU (`cuda`) if available, otherwise falls back to CPU.
# -----------------------------------------------------------------------------



from utils.CustomDataset import EEGDataset
from utils.augmentation import augment_data


print("Data Preparation.... Wrapping preprocessed data inside tensor datasets....",end = "\n\n")



batch_size=100

# Convert numpy arrays to PyTorch tensors with correct dtypes
orig_labels_val_torch = torch.from_numpy(val_labels_mi_mapped).to(torch.long) # Original labels for validation aggregation

train_mi_torch = torch.from_numpy(train_data).to(torch.float32)
train_mi_labels_torch = torch.from_numpy(windowed_train_labels).to(torch.long)
weights_train_torch = torch.from_numpy(weights_train).to(torch.float32) # Ensure float32
train_mi_torch_subject = torch.from_numpy(subject_label_train_).to(torch.long)


val_mi_torch = torch.from_numpy(val_data).to(torch.float32)
val_mi_labels_torch = torch.from_numpy(windowed_val_labels).to(torch.long)
weights_val_torch = torch.from_numpy(weights_val).to(torch.float32) # Ensure float32
val_mi_torch_subject = torch.from_numpy(subject_label_val_).to(torch.long)

test_mi_torch = torch.from_numpy(test_data).to(torch.float32)
weights_test_torch = torch.from_numpy(weights_test).to(torch.float32)
test_labels_placeholder = torch.zeros(test_mi_torch.shape[0], dtype=torch.long)
test_mi_torch_subject = torch.from_numpy(subject_label_test_).to(torch.long)


# Create TensorDatasets
train_dataset = EEGDataset(train_mi_torch, weights_train_torch, train_mi_labels_torch , train_mi_torch_subject,augment=True,augmentation_func=augment_data)
val_dataset = EEGDataset(val_mi_torch, weights_val_torch, val_mi_labels_torch , val_mi_torch_subject)
test_dataset = EEGDataset(test_mi_torch, weights_test_torch, test_labels_placeholder , test_mi_torch_subject)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False) # Full batch for validation
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False) # Full batch for test


device_to_work_on = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_to_work_on)


Preprocessing data, This may take a while... 



2217it [00:15, 144.49it/s]
47it [00:00, 219.61it/s]
50it [00:00, 240.96it/s]


Data Preparation.... Wrapping preprocessed data inside tensor datasets....



In [7]:
os.getcwd()

'/home/mohammed_ahmed/MTC_REPO/train'

In [29]:
model_mi_1 = MTCFormer(depth=2,
                    kernel_size=5,
                    n_times=600,
                    chs_num=7,
                    eeg_ch_nums=4,
                    class_num=2,
                    class_num_domain=30,
                    modulator_dropout=0.3,
                    mid_dropout=0.5,
                    output_dropout=0.5,
                    weight_init_mean=0,
                    weight_init_std=0.5,
                    ).to(device)


checkpoint_dicetory = "../"
optimizer = Adam(model_mi_1.parameters(), lr=0.002)

checkpoint_path = os.path.join(
    checkpoint_dicetory,
    "checkpoints",
    "model_1_mi_checkpoint",
    "best_model_.pth"
    )

checkpoint = torch.load(checkpoint_path, weights_only=False)

model_mi_1.load_state_dict(checkpoint['model_state_dict'] , strict=False)
preds_mi_one = predict(
    model_mi_1,
    window_len=WINDOW_LEN,
    loader=train_loader,
    num_samples_to_predict=2217,
    device = device,
    probability=True
    )
preds_mi_one

array([[0.46184656, 0.53815347],
       [0.5045754 , 0.4954247 ],
       [0.543059  , 0.45694107],
       ...,
       [0.49263743, 0.50736266],
       [0.4024053 , 0.59759474],
       [0.46077135, 0.5392287 ]], dtype=float32)

In [31]:

model_mi_two = MTCFormer(depth=2,
                    kernel_size=5,
                    n_times=600,
                    chs_num=7,
                    eeg_ch_nums=4,
                    class_num=2,
                    class_num_domain=30,
                    modulator_dropout=0.3,
                    mid_dropout=0.5,
                    output_dropout=0.5,
                    weight_init_mean=0,
                    weight_init_std=0.5,
                    ).to(device)


checkpoint_dicetory = "../"
optimizer = Adam(model_mi_two.parameters(), lr=0.002)

checkpoint_path = os.path.join(
    checkpoint_dicetory,
    "checkpoints",
    "model_2_mi_checkpoint",
    "best_model_.pth"
    )

checkpoint = torch.load(checkpoint_path, weights_only=False)

model_mi_two.load_state_dict(checkpoint['model_state_dict'] , strict=False)

preds_mi_two = predict(
    model_mi_two,
    window_len=WINDOW_LEN,
    loader=train_loader,
    num_samples_to_predict=2217,
    device = device,
    probability=True
    )
preds_mi_two

array([[0.47386113, 0.52613884],
       [0.51480097, 0.485199  ],
       [0.49435315, 0.5056469 ],
       ...,
       [0.5179802 , 0.4820198 ],
       [0.48719618, 0.51280373],
       [0.5185388 , 0.48146117]], dtype=float32)

In [56]:
test_preds_1 = np.array([
[0.53246975 ,0.46753022],
 [0.48687148, 0.51312846],
 [0.5387039  ,0.46129608],
 [0.5021071,  0.49789286],
 [0.47704634, 0.5229537 ],
 [0.46221194 ,0.53778803],
 [0.52266014 ,0.4773399 ],
 [0.42057058, 0.57942945],
 [0.50525504, 0.494745  ],
 [0.45534346, 0.5446566 ],
 [0.47970426 ,0.5202958 ],
 [0.4733704 , 0.5266297 ],
 [0.5423661 , 0.45763397],
 [0.49935672 ,0.50064325],
 [0.4460962 , 0.5539039 ],
 [0.5064874,  0.49351254],
 [0.5220652 , 0.4779348 ],
 [0.48356584, 0.51643413],
 [0.49317035 ,0.5068297 ],
 [0.55157775 ,0.44842222],
 [0.5027857 , 0.49721432],
 [0.44819015 ,0.55180985],
 [0.45454434 ,0.54545563],
 [0.46888864 ,0.53111136],
 [0.4830007 , 0.5169993 ],
 [0.47482154, 0.5251785 ],
 [0.42132378, 0.5786762 ],
 [0.48184648 ,0.5181535 ],
 [0.530921 ,  0.469079  ],
 [0.4922824 , 0.50771755],
 [0.47214556, 0.52785444],
 [0.5168218  ,0.48317823],
 [0.4571302 , 0.5428698 ],
 [0.4934551 , 0.5065449 ],
 [0.52833664 ,0.47166333],
 [0.4886873 , 0.5113127 ],
 [0.444473 ,  0.555527  ],
 [0.4575885 , 0.54241145],
 [0.51342475 ,0.48657525],
 [0.46953878 ,0.5304612 ],
 [0.5031684 , 0.4968316 ],
 [0.45879114 ,0.54120886],
 [0.5187844  ,0.4812156 ],
 [0.5151648 , 0.48483512],
 [0.54893726, 0.4510627 ],
 [0.4846798 , 0.5153202 ],
 [0.51223046, 0.4877695 ],
 [0.47033954, 0.52966046],
 [0.48824993, 0.5117501 ],
 [0.49464935, 0.5053506 ]])
import numpy as np
from sklearn.preprocessing import QuantileTransformer

from sklearn.preprocessing import QuantileTransformer
import numpy as np


# Fit separate transformers per column
qt_neg = QuantileTransformer(output_distribution='normal').fit(preds_mi_one[:, 0].reshape(-1, 1))
qt_pos = QuantileTransformer(output_distribution='normal').fit(preds_mi_one[:, 1].reshape(-1, 1))

# Transform new data similarly:
X_transformed_1 = np.zeros_like(test_preds_1)
X_transformed_1[:, 0] = qt_neg.transform(test_preds_1[:, 0].reshape(-1, 1)).flatten()
X_transformed_1[:, 1] = qt_pos.transform(test_preds_1[:, 1].reshape(-1, 1)).flatten()
X_transformed_1

array([[ 0.70018483, -0.70018475],
       [-0.23019393,  0.23019185],
       [ 0.80687607, -0.80687655],
       [ 0.09439401, -0.0943943 ],
       [-0.41423813,  0.41423826],
       [-0.715663  ,  0.71566212],
       [ 0.49167994, -0.4916792 ],
       [-1.60755996,  1.60756591],
       [ 0.14777895, -0.14777832],
       [-0.84310176,  0.84310219],
       [-0.36593847,  0.3659389 ],
       [-0.48097119,  0.48097256],
       [ 0.88314872, -0.88314814],
       [ 0.03737948, -0.03738001],
       [-1.02151726,  1.02151898],
       [ 0.1684108 , -0.16841391],
       [ 0.48514855, -0.48514831],
       [-0.2830423 ,  0.28304226],
       [-0.08771035,  0.08771143],
       [ 1.06888014, -1.0688804 ],
       [ 0.10846223, -0.10845999],
       [-0.96520314,  0.96520359],
       [-0.85641167,  0.85641142],
       [-0.55816974,  0.55816944],
       [-0.29796649,  0.29796667],
       [-0.45401928,  0.45402072],
       [-1.59632829,  1.59632855],
       [-0.32450565,  0.32450518],
       [ 0.65894676,

In [None]:
test_preds_2 = np.array([
[0.5036396  ,0.49636036],
 [0.4728651 , 0.52713484],
 [0.5006888  ,0.49931118],
 [0.5313396 , 0.46866038],
 [0.48304752 ,0.5169525 ],
 [0.49555388 ,0.504446  ],
 [0.50834763, 0.49165243],
 [0.4771689 , 0.5228311 ],
 [0.5203398 , 0.47966024],
 [0.495307  , 0.50469303],
 [0.5008955 , 0.49910453],
 [0.48937204 ,0.5106279 ],
 [0.5252463 , 0.47475374],
 [0.5483076 , 0.45169237],
 [0.45992485, 0.5400751 ],
 [0.5295104 , 0.47048965],
 [0.54884744, 0.45115265],
 [0.5056256 , 0.4943744 ],
 [0.5074791  ,0.49252093],
 [0.54100263, 0.4589974 ],
 [0.49924478, 0.50075525],
 [0.46759537 ,0.53240466],
 [0.491207  , 0.50879306],
 [0.46828505 ,0.5317149 ],
 [0.5076705 , 0.49232942],
 [0.4747552,  0.5252448 ],
 [0.46640074, 0.5335993 ],
 [0.49615887, 0.50384116],
 [0.5209354 , 0.47906458],
 [0.48120984 ,0.5187902 ],
 [0.4874002  ,0.51259977],
 [0.54170877, 0.45829126],
 [0.46296924, 0.53703076],
 [0.525115 ,  0.47488502],
 [0.47401354 ,0.5259865 ],
 [0.50283843, 0.49716154],
 [0.4943745 , 0.50562555],
 [0.49801394 ,0.501986  ],
 [0.52885187, 0.47114813],
 [0.52234006, 0.47766   ],
 [0.4989587 , 0.50104123],
 [0.52421755, 0.47578242],
 [0.4901222 , 0.50987786],
 [0.5432616  ,0.45673832],
 [0.51199526, 0.48800474],
 [0.5419697  ,0.45803037],
 [0.5202589 , 0.4797411 ],
 [0.4716455 , 0.52835447],
 [0.5039931 , 0.49600688],
 [0.49536097 ,0.504639  ]])
# Fit separate transformers per column
qt_neg = QuantileTransformer(output_distribution='normal').fit(preds_mi_two[:, 0].reshape(-1, 1))
qt_pos = QuantileTransformer(output_distribution='normal').fit(preds_mi_two[:, 1].reshape(-1, 1))

# Transform new data similarly:
X_transformed_2 = np.zeros_like(test_preds_2)
X_transformed_2[:, 0] = qt_neg.transform(test_preds_2[:, 0].reshape(-1, 1)).flatten()
X_transformed_2[:, 1] = qt_pos.transform(test_preds_2[:, 1].reshape(-1, 1)).flatten()

X_transformed_2

array([[0.51177919, 0.48822046],
       [0.1125957 , 0.88740253],
       [0.4727341 , 0.52726561],
       [0.84920063, 0.15079939],
       [0.22328272, 0.77672153],
       [0.39736489, 0.60263307],
       [0.58940958, 0.41059441],
       [0.15306612, 0.84693386],
       [0.74182581, 0.25817431],
       [0.39337131, 0.60662904],
       [0.47508796, 0.52491217],
       [0.31232081, 0.68767892],
       [0.79642854, 0.20357154],
       [0.94961244, 0.05038734],
       [0.02834251, 0.97165711],
       [0.83202696, 0.16797312],
       [0.95059065, 0.04940945],
       [0.54627801, 0.45372281],
       [0.57872756, 0.42127252],
       [0.91493001, 0.08507044],
       [0.45662696, 0.5433736 ],
       [0.0680805 , 0.93191968],
       [0.33676808, 0.66323315],
       [0.07097424, 0.92902552],
       [0.58180376, 0.41819575],
       [0.12896415, 0.8710358 ],
       [0.06047154, 0.9395287 ],
       [0.40651265, 0.59348773],
       [0.74809399, 0.25190468],
       [0.19249875, 0.80750215],
       [0.

In [60]:
q_transform = QuantileTransformer(output_distribution='uniform').fit(preds_mi_one.reshape(-1, 1))
transformed_data_1 = q_transform.transform(test_preds_1.reshape(-1,1)).reshape(-1,2)
transformed_data_1

array([[0.73979391, 0.26020595],
       [0.39557187, 0.60442731],
       [0.7788764 , 0.2211235 ],
       [0.51689798, 0.48310194],
       [0.32462526, 0.67537494],
       [0.22617865, 0.77382069],
       [0.67370283, 0.32629757],
       [0.04913542, 0.95086468],
       [0.53959609, 0.46040388],
       [0.18878428, 0.81121621],
       [0.34123828, 0.65876206],
       [0.29952245, 0.70047809],
       [0.79959989, 0.20040027],
       [0.49433106, 0.50566873],
       [0.14160796, 0.85839237],
       [0.5498274 , 0.45017107],
       [0.67054742, 0.32945233],
       [0.37047389, 0.62952592],
       [0.44698653, 0.55301371],
       [0.84460897, 0.15539092],
       [0.52271737, 0.47728276],
       [0.15379278, 0.84620743],
       [0.18254559, 0.81745394],
       [0.2712957 , 0.72870435],
       [0.36513338, 0.63486676],
       [0.30870504, 0.69129529],
       [0.05050475, 0.94949523],
       [0.35621574, 0.64378371],
       [0.72793551, 0.27206441],
       [0.44061884, 0.55938074],
       [0.

In [61]:
q_transform = QuantileTransformer(output_distribution='uniform').fit(preds_mi_two.reshape(-1, 1))
transformed_data_2 = q_transform.transform(test_preds_2.reshape(-1,1)).reshape(-1,2)
transformed_data_2

array([[0.55198052, 0.44801877],
       [0.14954949, 0.85044856],
       [0.50777743, 0.49222223],
       [0.88725688, 0.11274308],
       [0.25980571, 0.74019442],
       [0.43628793, 0.56370995],
       [0.62277653, 0.37722481],
       [0.18950199, 0.8104978 ],
       [0.78384999, 0.21614999],
       [0.43234563, 0.5676541 ],
       [0.51092002, 0.48907981],
       [0.34520355, 0.65479507],
       [0.83365697, 0.16634303],
       [0.97066866, 0.02933115],
       [0.05894386, 0.94105598],
       [0.87197925, 0.12802068],
       [0.97152461, 0.02847547],
       [0.58294703, 0.41705482],
       [0.61115529, 0.3888454 ],
       [0.94505385, 0.05494662],
       [0.49141833, 0.50858189],
       [0.10728984, 0.89271019],
       [0.37096731, 0.62903336],
       [0.11012818, 0.88987159],
       [0.61339668, 0.38660211],
       [0.16635202, 0.83364774],
       [0.09749454, 0.90250535],
       [0.44495322, 0.55504616],
       [0.78870371, 0.21129598],
       [0.23378939, 0.76621144],
       [0.

In [68]:
ensemble = (X_transformed_2+X_transformed_1)/2

preds_ensemble = np.argmax(ensemble,axis=1)
preds_ensemble

array([0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,
       0, 0, 0, 1, 1, 1])

In [69]:
(preds_ensemble==old).mean()

np.float64(0.86)

In [64]:
old = np.array([0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0,
       0, 0, 0, 1, 1, 1])

In [70]:
inv_mapping_mi = {v:k for k,v in mapping_mi.items()}
preds_mi_csv = pd.DataFrame({
    "id":ids_mi,
    "label": pd.Series(preds_ensemble).map(inv_mapping_mi).values
})

preds_mi_csv


Unnamed: 0,id,label
0,4939,Left
1,4948,Right
2,4945,Left
3,4920,Left
4,4921,Right
5,4937,Right
6,4906,Left
7,4919,Right
8,4922,Left
9,4912,Right


In [72]:
test_file_paths_ssvep = glob.glob(os.path.join(DATA_FIF_DIR, "test/SSVEP/*.fif"))

test_data_ssvep, test_labels_ssvep , ids_ssvep = extract_data(test_file_paths_ssvep , return_id = True)
test_subject_labels = extract_subject_labels(test_data_ssvep)


mapping_ssvep = {
    "Backward":0,
    "Forward":1,
    "Left":2,
    "Right":3
}

inv_mapping_ssvep = {
    v:k for k,v in mapping_ssvep.items()
}


model_ssvep = MTCFormer(
    depth=1,
    kernel_size=10,
    n_times=500,
    chs_num=7,
    eeg_ch_nums=4,
    class_num=4,
    class_num_domain=30,
    modulator_kernel_size=10,
    domain_dropout=0.7,
    modulator_dropout=0.7,
    mid_dropout=0.7,
    output_dropout=0.7,
    weight_init_std=0.05,
    weight_init_mean=0.0,
).to(device)

optimizer = Adam(model_ssvep.parameters(), lr=0.002)

checkpoint_path = os.path.join(
    "../",
    "checkpoints",
    "model_ssvep_checkpoint",
    "best_model_.pth"
    )

checkpoint = torch.load(checkpoint_path, weights_only=False)

model_ssvep.load_state_dict(checkpoint['model_state_dict'] , strict=False)


cols_to_pick = [
        'OZ',
        'PO7',
        'PO8',
        'PZ',
        'Acc_norm',
        'gyro_norm',
        'Validation'
          ]


params = {
    "cols_to_pick":cols_to_pick,
    "l_freq": 8,
    "h_freq": 14,
    "notch_freqs": [50, 100],
    "notch_width": 1.0,
    "window_size": 500,
    "window_stride": 50
}

test_data,weights_test, _ ,subject_label_test_, WINDOW_LEN= preprocess_data(
    test_data_ssvep,
    labels = test_labels_ssvep,
    subject_labels = test_subject_labels,
    preprocess_func=preprocess_one_file,
    params = params,
    n_jobs=4
    )



weights_test_torch = torch.from_numpy(weights_test).to(torch.float32)
test_ssvep_torch = torch.from_numpy(test_data).to(torch.float32)
test_labels_placeholder = torch.zeros(test_ssvep_torch.shape[0], dtype=torch.long)
test_ssvep_torch_subject = torch.from_numpy(subject_label_test_).to(torch.long)

test_dataset = EEGDataset(
    data_tensor=test_ssvep_torch,
    weigths=weights_test_torch,
    label_tensor=test_labels_placeholder,
    subject_labels=test_ssvep_torch_subject
    )

test_loader   = DataLoader(
    test_dataset,
    batch_size=len(test_dataset),
    shuffle=False
    )



final_preds_ssvep= predict(
    model_ssvep,
    window_len=WINDOW_LEN,
    loader=test_loader,
    num_samples_to_predict=50,
    device = device,
    probability=False,
    num_classes=4
    )


# -----------------------------------------------------------------------------
#  Final Submission Assembly
#
# - Convert SSVEP predictions to a DataFrame and map class indices to labels
# - Concatenate MI and SSVEP prediction DataFrames
# - Sort by ID and save as the final submission CSV
# -----------------------------------------------------------------------------

preds_ssvep_csv = pd.DataFrame({
    "id":ids_ssvep,
    "label": pd.Series(final_preds_ssvep).map(inv_mapping_ssvep).values
})



submission = pd.concat([
    preds_mi_csv,
    preds_ssvep_csv
]).sort_values(
    by="id"
).reset_index(
    drop=True
)
submission

50it [00:00, 437.45it/s]


Unnamed: 0,id,label
0,4901,Left
1,4902,Right
2,4903,Right
3,4904,Left
4,4905,Right
...,...,...
95,4996,Left
96,4997,Forward
97,4998,Right
98,4999,Backward


In [73]:
submission.to_csv("quantile_ensemble_uniform_2.csv" , index=False)

submission