In [31]:
   %load_ext autoreload
   %autoreload 2

In [32]:

from ErrP_decoder.data_io import select_eeg_files, load_selected_files

file_paths = select_eeg_files()

datasets = load_selected_files(file_paths)

for ds in datasets:
    print("Data shape:", ds.data.shape)
    print("Info:", ds.info)

In [33]:
from ErrP_decoder.data_io import set_params
for ds in datasets:
    cfg = set_params(ds.info,config_path="../configs/default.yaml")

In [34]:
from ErrP_decoder.data_io import preprocess_dataset
for ds in datasets:
    ds = preprocess_dataset(ds, cfg)

for ds in datasets:
    neu = sum(ds.trig['typ'] == 0)
    pos = sum(ds.trig['typ'] == 1)
    neg = sum(ds.trig['typ'] == 2)
    print(neu)
    print(pos)
    print(neg)
    


In [35]:
from ErrP_decoder.data_io import bandpass_filter
from ErrP_decoder.plots import plot_eeg_interactive

for ds in datasets:
    filtered_eeg = bandpass_filter(ds.eeg,cfg)
    plot_eeg_interactive(filtered_eeg, cfg)
    ds.eeg_filt = filtered_eeg

In [36]:
from ErrP_decoder.data_io import epoch_data
for ds in datasets:
    ds.epochs, ds.epoch_labels = epoch_data(ds.eeg_filt,cfg,ds.trig)
    ds.epochs = ds.epochs.transpose(1, 2, 0)
    # epochs (768,32,200)
    # epoch_labels (200,)

In [37]:
import numpy as np

all_epochs = []
all_labels = []
for ds in datasets:
    all_epochs.append(ds.epochs)
    all_labels.append(ds.epoch_labels)
# Concatenate along the epochs axis (axis=2 for epochs, axis=0 for labels)
epochs_combined = np.concatenate(all_epochs, axis=2)
labels_combined = np.concatenate(all_labels, axis=0)

In [38]:
from ErrP_decoder.plots import plot_erp
plot_erp(epochs_combined, labels_combined, cfg, 'CZ')
plot_erp(epochs_combined, labels_combined, cfg, 'FCZ')
plot_erp(epochs_combined, labels_combined, cfg, 'FZ')
plot_erp(epochs_combined, labels_combined, cfg, 'PZ')
plot_erp(epochs_combined, labels_combined, cfg, 'CPZ')

In [39]:
from ErrP_decoder.data_io import balance_runs
for ds in datasets:
    ds.epochs_balanced, ds.epoch_labels_balanced = balance_runs(ds.epochs,ds.epoch_labels)

In [40]:
for ds in datasets:
    print(ds.epochs_balanced.shape)


In [41]:

all_epochs = []
all_labels = []
run_labels = []

for i, ds in enumerate(datasets):
    all_epochs.append(ds.epochs_balanced)              # shape: (samples, channels, trials)
    all_labels.append(ds.epoch_labels_balanced)        # shape: (trials,)

    n_trials = ds.epoch_labels_balanced.shape[0]
    run_labels.append(np.full(n_trials, i))            # start from 0

training_epochs   = np.concatenate(all_epochs, axis=2)   # shape: (768,32,306)
training_labels = np.concatenate(all_labels, axis=0)   # shape: (306,)
training_run_labels = np.concatenate(run_labels, axis=0)   # shape: (306,)
train_data = {
    'data' : training_epochs,
    'labels' : training_labels,
    'run' : training_run_labels
}



In [42]:
all_epochs = []
all_labels = []
run_labels = []

for i, ds in enumerate(datasets):
    all_epochs.append(ds.epochs)              # shape: (samples, channels, trials)
    all_labels.append(ds.epoch_labels)        # shape: (trials,)

    n_trials = ds.epoch_labels.shape[0]
    run_labels.append(np.full(n_trials, i))   # start from 0

test_epochs   = np.concatenate(all_epochs, axis=2)   # shape: (768,32,1000)
test_labels = np.concatenate(all_labels, axis=0)   # shape: (1000,)
test_run_labels = np.concatenate(run_labels, axis=0)   # shape: (1000,)
test_data = {
    'data' : test_epochs,
    'labels' : test_labels,
    'run' : test_run_labels
}

In [43]:
from ErrP_decoder.modeling import leave_one_run_out_cv
posteriors = leave_one_run_out_cv(train_data, test_data, cfg)

In [45]:
from ErrP_decoder.evaluation import evaluate_classifier
evaluate_classifier(test_data['labels'], posteriors)

In [None]:
from ErrP_decoder.evaluation import plot_posteriors_by_class
plot_posteriors_by_class(test_data['labels'], posteriors, 
                         class_names=['No Feedback', 'Negative Feedback'], bins=20)

                            COMPUTE DECODER FUNCTIONS TESTING BELOW

In [29]:
from ErrP_decoder.modeling import compute_decoder

decoder = decoder = compute_decoder(training_data, training_labels, cfg)

In [12]:
from ErrP_decoder.modeling import get_cca_spatialfilter

spatial_filter = get_cca_spatialfilter(training_data, training_labels, 
                                       n_components=cfg['spatial_filter']['n_comp']) # shape (32,2)



In [13]:

from ErrP_decoder.modeling import apply_spatial_filter
filtered_data = apply_spatial_filter(training_data, spatial_filter) # shape: (samples, n_components, trials)

In [14]:
from ErrP_decoder.modeling import resample_epochs
resampled = resample_epochs(filtered_data,cfg) # shape (20, 3, 306)


In [26]:
n_samples, n_comp, n_trials = resampled.shape
X = resampled.reshape(n_samples * n_comp, n_trials).T # shape: (samples, features) e.g. (306,60)
y = training_labels # shape: (samples, ) e.g. (306,)


In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis()
lda.fit(X, y)

print("LDA coefficients shape:", lda.coef_.shape)       # (n_classes-1, n_features)
print("LDA intercepts shape:", lda.intercept_.shape)