In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2
%load_ext lab_black

## Workers

In [None]:
import os
import mne

%matplotlib inline
import numpy as np

In [None]:
!pip install mne-features

In [None]:
import mne_features as mnf

In [None]:
data = np.random.normal(0, 20, 256)

In [None]:
mnf.univariate.compute_kurtosis(data)

In [None]:
mnf.univariate.compute_ptp_amp(data)

In [None]:
mnf.univariate.compute_skewness(data)

## Analyze Workers

In [None]:
import sys

sys.path.insert(0, "../")

In [None]:
import torch
from torch.utils.data import DataLoader
from pase_eeg.lit_modules.pase_lit import PaseEEGBCIIV2aDataLit
from pase_eeg.data.transforms import ToTensor, ZNorm, Compose
from pase_eeg.lit_modules.utils import eeg_electrode_configs
from pase_eeg.data.transforms import ToTensor, ZNorm

In [None]:
import matplotlib.pyplot as plt
from itertools import product
import numpy as np

In [None]:
dslit = PaseEEGBCIIV2aDataLit(
    data_path="/data/BCI_Competition_IV/",
    channels_config="../configs/eeg_recording_standard/international_10_20_22.py",
    train_patients=[
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
    ],
    test_patients=[9],
    batch_size=1,
    workers_config="../configs/pase_base/workers.json",
    transforms=[
        {
            "class_path": "pase_eeg.data.transforms.ToTensor",
            "init_args": {"device": "cpu"},
        },
        # {
        #     "class_path": "pase_eeg.data.transforms.ZNorm",
        #     "init_args": {"stats": "bci_comp_iv2a_stats.pkl", "mode": "mean-std"},
        # },
    ],
)
dslit.setup()

In [None]:
dloader = dslit.train_dataloader()

In [None]:
def draw_2d_plot(shape, plotter, data, num=None):
    fig, axs = plt.subplots(*shape, figsize=(15, 15))
    for i, idx in enumerate(product(*[list(range(n)) for n in shape])):
        if num is not None and i >= num:
            break
        plotter(axs[idx[0], idx[1]], data[i])


# test
def plotter(axes, data):
    x = list(range(len(data)))
    axes.plot(x, data)

In [None]:
def extract_label(dloader, label):
    data = {}
    # run one epoch of training data to extract z-stats of minions
    for bidx, batch in enumerate(dloader, start=1):
        if bidx % 100 == 0:
            print("Bidx: {}/{}".format(bidx, len(dloader.dataset) / 1))
        signal, labels = batch
        for k, v in labels[label].items():
            if k not in data:
                data[k] = []
            data[k].append(v.squeeze())

    return data

In [None]:
def dict_stats(data):
    stats = {}

    data_cat = dict((k, torch.cat(v)) for k, v in data.items())
    print(next(iter(data_cat.values())).size())

    for k, v in data_cat.items():
        v = torch.flatten(v)
        # v = torch.clip(v, min=-clip_val , max=clip_val)
        stats[k] = {
            "mean": torch.mean(v, dim=0),
            "std": torch.std(v, dim=0),
            "min": torch.min(v, dim=0).values,
            "max": torch.max(v, dim=0).values,
        }

    return stats


def dict_timeseries_stats(data):
    stats = {}

    data_cat = dict((k, torch.stack(v)) for k, v in data.items())
    print(next(iter(data_cat.values())).size())

    for k, v in data_cat.items():
        v = torch.flatten(v)
        # v = torch.clip(v, min=-clip_val , max=clip_val)
        stats[k] = {
            "mean": torch.mean(v),
            "std": torch.std(v),
            "min": torch.min(v),
            "max": torch.max(v),
        }

    return stats

In [None]:
def error_region_plotter(axes, data):
    y, std = data
    x = np.linspace(0, len(y), len(y))

    axes.plot(x, y, "k-")
    axes.fill_between(x, y - std, y + std)

## PSD Worker

In [None]:
psd_data = extract_label(dloader, "psd")

In [None]:
stats = dict_stats(psd_data)
stats

In [None]:
data_stack = dict((k, torch.stack(v)) for k, v in psd_data.items())
print(next(iter(data_stack.values())).size())

num = len(data_stack)
shape = 2 * [int(np.ceil(np.sqrt(num)))]

plot_data = list(
    map(lambda a: (torch.mean(a, dim=0), torch.std(a, dim=0)), data_stack.values())
)
draw_2d_plot(shape, error_region_plotter, plot_data, num)

### Normalize and clip

In [None]:
ts_stats = dict_timeseries_stats(psd_data)
ts_stats

In [None]:
data_stack = dict(
    (k, ((torch.stack(v) - ts_stats[k]["mean"]) / ts_stats[k]["std"])[:, 0:123])
    for k, v in psd_data.items()
)
print(next(iter(data_stack.values())).size())

num = len(data_stack)
shape = 2 * [int(np.ceil(np.sqrt(num)))]

plot_data = list(
    map(lambda a: (torch.mean(a, dim=0), torch.std(a, dim=0)), data_stack.values())
)
draw_2d_plot(shape, error_region_plotter, plot_data, num)

In [None]:
import pickle

with open("bci_comp_iv2a_psd_stats.pkl", "wb") as stats_f:
    pickle.dump(ts_stats, stats_f)

## WTE Worker

In [None]:
wte_data = extract_label(dloader, "wte")

In [None]:
stats = dict_stats(wte_data)
stats

In [None]:
data_stack = dict((k, torch.stack(v)) for k, v in wte_data.items())
print(next(iter(data_stack.values())).size())

num = len(data_stack)
shape = 2 * [int(np.ceil(np.sqrt(num)))]

plot_data = list(map(lambda a: (torch.mean(a, dim=0), torch.std(a, dim=0)), data_stack.values()))
draw_2d_plot(shape, error_region_plotter, plot_data, num)