In [1]:
import numpy as np
from torch.utils.data import DataLoader, Dataset

%matplotlib inline

In [2]:
from bci_aic3.data import BCIDataset, load_raw_data
from bci_aic3.preprocess import apply_all_preprocessing_steps
from bci_aic3.config import load_processing_config
from bci_aic3.paths import (
    CONFIG_DIR,
    LABEL_MAPPING_PATH,
    RAW_DATA_DIR,
    PROCESSED_DATA_DIR,
)
from bci_aic3.util import read_json_to_dict

In [4]:
label_mapping = read_json_to_dict(LABEL_MAPPING_PATH)

In [5]:
train_mi = BCIDataset(
    "train.csv",
    base_path=RAW_DATA_DIR,
    task_type="MI",
    split="train",
    label_mapping=label_mapping,
)

val_mi = BCIDataset(
    "validation.csv",
    base_path=RAW_DATA_DIR,
    task_type="MI",
    split="validation",
    label_mapping=label_mapping,
)

2400it [02:25, 16.51it/s]
50it [00:02, 16.88it/s]


In [6]:
from bci_aic3.config import ProcessingConfig


def preprocessing_pipeline(dataset: BCIDataset,
                           task_type: str,
                           split: str,
                           processing_config: ProcessingConfig):
    
    data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    data_batch, labels = next(iter(data_loader))
    
    data = data_batch.numpy()
    labels = labels.numpy()
    
    processed_data = apply_all_preprocessing_steps(data=data, settings=processing_config)
    
    processed_data_path = PROCESSED_DATA_DIR/ task_type.upper() / f"{split}_data.npy"
    processed_labels_path = PROCESSED_DATA_DIR/ task_type.upper() / f"{split}_labels.npy" 
    
    np.save(processed_data_path, processed_data)
    print(f"Processed data successfully saved at: {processed_data_path}")
    
    np.save(processed_labels_path, labels)
    print(f"Processed labels successfully saved at: {processed_labels_path}")
    

In [7]:
from bci_aic3.paths import MI_CONFIG_PATH


processing_config = load_processing_config(MI_CONFIG_PATH)

In [8]:
preprocessing_pipeline(val_mi, 
                       task_type="MI", 
                       split="validation",
                       processing_config=processing_config)

Processed data successfully saved at: P:\Programming\AIC3\repo\bci_aic3\data\processed\MI\validation_data.npy
Processed labels successfully saved at: P:\Programming\AIC3\repo\bci_aic3\data\processed\MI\validation_labels.npy
