In [1]:
from ica_benchmark.io.load import OpenBMI_Dataset
from pathlib import Path
import pandas as pd
from mne.decoding import CSP
import numpy as np

In [2]:
from sklearn.base import BaseEstimator
from mne.time_frequency import psd_array_welch, psd_welch
import numpy as np
import mne

class ConcatenateChannelsPSD(BaseEstimator):
    def __init__(self):
        super(ConcatenateChannelsPSD).__init__()

    def fit(self, x, y=None):
        return self

    def transform(self, x, y=None):
        n = len(x)
        return x.reshape(n, -1)


class GetEpochsData(BaseEstimator):
    def __init__(self):
        super(GetEpochsData).__init__()

    def fit(self, x, y=None):
        return self

    def transform(self, x, y=None):
        return x.get_data()


class PSD(BaseEstimator):
    BANDS_DICT = {
    #         "delta": (1, 4),
    #         "theta": (4, 8),
    #         "mu": (8, 13),
        "mu": (8, 13),
    #         "beta": (13, 25),
        "beta": (13, 25),
    #         "gamma": (25, 40)
    }
    def __init__(self, **kwargs):
        super(PSD).__init__()
        self.kwargs = kwargs

    def set_params(**params):
        for param in params:
            assert params in ["picks", "n_fft", "n_overlap", "n_per_seg"]
        self.kwargs.update(params)

    def get_params(self, *args, **kwargs):
        return self.kwargs

    def fit(self, x, y=None):
        return self

    def transform(self, x, y=None):
        if isinstance(x, list):
            x = mne.concatenate_epochs(x)
            psds, freqs = psd_welch(x, ** self.kwargs)
        if isinstance(x, mne.Epochs):
    #             psds, freqs = psd_welch(x, ** self.kwargs)
            x = x.get_data()
        if isinstance(x, np.ndarray):
            psds, freqs = psd_array_welch(x, ** self.kwargs)
        if ("average" in self.kwargs) and (self.kwargs["average"] is None):
            psds = psds.sum(axis=3)
        self.freqs = freqs

        band_spectras = list()
        for band, (lfreq, hfreq) in self.BANDS_DICT.items():
            band_spectra = psds[:, :, (freqs >= lfreq) & (freqs < hfreq)]
            band_spectras.append(
                band_spectra.sum(axis=2, keepdims=True)
            )

        band_spectras = np.concatenate(band_spectras, axis=2)

        return band_spectras

In [3]:
dataset = OpenBMI_Dataset("/home/paulo/Documents/datasets/OpenBMI/edf")

![image.png](attachment:image.png)

In [19]:
from sklearn.pipeline import make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import cohen_kappa_score, balanced_accuracy_score
import mne
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

mne.set_log_level(False)

channels = ["FC" + str(s) for s in [5, 3, 1, 2, 4, 6]]
channels += ["C" + str(s) for s in [5, 3, 1, 2, 4, 6]]
channels += ["CP" + str(s) for s in [5, 3, 1, 2, 4, 6]]


filter_kwargs = dict(
    method="iir",
    iir_params=dict(
        order=5,
        ftype="butter"
    )
)

def run_experiment(train_epochs, test_epochs):

    train_epochs = train_epochs.copy()
    test_epochs = test_epochs.copy()
    
    train_epochs.load_data()
    test_epochs.load_data()

#     train_epochs = train_epochs.apply_baseline().pick(channels).filter(8, 30, **filter_kwargs)
#     test_epochs = test_epochs.apply_baseline().pick(channels).filter(8, 30, **filter_kwargs)
    train_epochs = train_epochs.pick(channels).filter(8, 30, **filter_kwargs)
    test_epochs = test_epochs.pick(channels).filter(8, 30, **filter_kwargs)
    
    x_train = train_epochs.get_data()
    x_test = test_epochs.get_data()
    y_train = train_epochs.events[:, 2]
    y_test = test_epochs.events[:, 2]
    
    csp = CSP(n_components=len(channels) // 2, log=True)
    csp.fit(x_train, y_train)

    x_train = csp.transform(x_train)
    x_test = csp.transform(x_test)

    lr = LinearDiscriminantAnalysis().fit(x_train, y_train)
    
    predictions = lr.predict(x_test)
    
    acc = balanced_accuracy_score(y_test, predictions)
    return acc

In [30]:
import pandas as pd
from sklearn.model_selection import KFold
from warnings import warn, filterwarnings
import numpy as np
from mne import Epochs
import mne
from ica_benchmark.io.load import OpenBMI_Dataset
from pathlib import Path
from ica_benchmark.utils.itertools import group_iterator, constrained_group_iterator
from ica_benchmark.split.split import Split, Splitter
from warnings import filterwarnings
from tqdm import tqdm

filterwarnings("ignore", category=RuntimeWarning)

openbmi_dataset_folderpath = Path('/home/paulo/Documents/datasets/OpenBMI/edf/')
dataset = OpenBMI_Dataset(openbmi_dataset_folderpath)
fold_sizes = None
splitter = Splitter(
    dataset,
    uids=dataset.list_uids()[:12],
    sessions=dataset.SESSIONS,
    runs=dataset.RUNS,
    load_kwargs=dict(
        reject=False,
        tmin=1,
        tmax=3.5
    ),
    splitter=KFold(4),
    intra_session_shuffle=True,
    fold_sizes=fold_sizes
)


results_dict = dict()

mode = "inter_subject"
print(mode.upper())
splits_iterable = splitter.yield_splits_epochs(mode=mode)
results = list()
for i, (train_split, test_split) in enumerate(splits_iterable):
    train_epochs, test_epochs = splitter.load_from_splits((train_split, test_split), fold_sizes=fold_sizes)
    acc = run_experiment(train_epochs, test_epochs)
    results.append(
        dict(
            mode=mode,
            train_uids=np.unique(train_split["uid"]),
            test_uids=np.unique(test_split["uid"]),
            accuracy=acc
        )
    )
    del train_epochs, test_epochs
results_dict[mode] = pd.DataFrame.from_records(results)
display(results_dict[mode])
    
print("Changing splitter uids to all")
splitter.uids = dataset.list_uids()

mode = "inter_session"
print(mode.upper())
results = list()
for sessions in [[2, 1], [1, 2]]:
    splitter.sessions = sessions
    splits_iterable = splitter.yield_splits_epochs(mode=mode)
    for i, (train_split, test_split) in enumerate(splits_iterable):
        print(train_split)
        train_epochs, test_epochs = splitter.load_from_splits((train_split, test_split), fold_sizes=fold_sizes)
        acc = run_experiment(train_epochs, test_epochs)
        results.append(
            dict(
                mode=mode,
                train_uid=np.unique(train_split["uid"]),
                train_session=np.unique(train_split["session"]),
                train_run=np.unique(train_split["run"]),
                test_uid=np.unique(test_split["uid"]),
                test_session=np.unique(test_split["session"]),
                test_run=np.unique(test_split["run"]),
                accuracy=acc
            )
        )
        del train_epochs, test_epochs
results_dict[mode] = pd.DataFrame.from_records(results)
display(results_dict[mode])


mode = "intra_session_inter_run"
print(mode.upper())
results = list()
for runs in [[2, 1], [1, 2]]:
    splitter.runs = runs
    splits_iterable = splitter.yield_splits_epochs(mode=mode)
    for i, (train_split, test_split) in enumerate(splits_iterable):
        train_epochs, test_epochs = splitter.load_from_splits((train_split, test_split), fold_sizes=fold_sizes)
        acc = run_experiment(train_epochs, test_epochs)
        results.append(
            dict(
                mode=mode,
                train_uid=np.unique(train_split["uid"]),
                train_session=np.unique(train_split["session"]),
                train_run=np.unique(train_split["run"]),
                test_uid=np.unique(test_split["uid"]),
                test_session=np.unique(test_split["session"]),
                test_run=np.unique(test_split["run"]),
                accuracy=acc
            )
        )
        del train_epochs, test_epochs
results_dict[mode] = pd.DataFrame.from_records(results)
display(results_dict[mode])


mode = "intra_session_intra_run"
print(mode.upper())
splits_iterable = splitter.yield_splits_epochs(mode=mode)
results = list()
for i, splits in enumerate(splits_iterable):
    split = splits[0]
    for trial_n in range(5):
        train_epochs, test_epochs = splitter.load_from_splits(splits, fold_sizes=[.75, .25])
        acc = run_experiment(train_epochs, test_epochs)
        results.append(
            dict(
                mode=mode,
                exp_number=trial_n,
                uid=np.unique(split["uid"]),
                session=np.unique(split["session"]),
                run=np.unique(split["run"]),
                accuracy=acc
            )
        )
        del train_epochs, test_epochs
results_dict[mode] = pd.DataFrame.from_records(results)
display(results_dict[mode])

mode = "intra_session_intra_run_merged"
print(mode.upper())
splits_iterable = splitter.yield_splits_epochs(mode=mode)
results = list()
for i, splits in enumerate(splits_iterable):
    split = splits[0]
    for trial_n in range(5):
        train_epochs, test_epochs = splitter.load_from_splits(splits, fold_sizes=[.75, .25])
        acc = run_experiment(train_epochs, test_epochs)
        results.append(
            dict(
                mode=mode,
                exp_number=trial_n,
                uid=np.unique(split["uid"]),
                session=np.unique(split["session"]),
                run=np.unique(split["run"]),
                accuracy=acc
            )
        )
        del train_epochs, test_epochs
results_dict[mode] = pd.DataFrame.from_records(results)
display(results_dict[mode])

Changing splitter uids to all
INTER_SESSION
Split({'uid': '25', 'session': 2, 'run': 1},{'uid': '25', 'session': 2, 'run': 2})
Split({'uid': '15', 'session': 2, 'run': 1},{'uid': '15', 'session': 2, 'run': 2})
Split({'uid': '41', 'session': 2, 'run': 1},{'uid': '41', 'session': 2, 'run': 2})
Split({'uid': '25', 'session': 1, 'run': 1},{'uid': '25', 'session': 1, 'run': 2})
Split({'uid': '15', 'session': 1, 'run': 1},{'uid': '15', 'session': 1, 'run': 2})
Split({'uid': '41', 'session': 1, 'run': 1},{'uid': '41', 'session': 1, 'run': 2})


Unnamed: 0,mode,train_uid,train_session,train_run,test_uid,test_session,test_run,accuracy
0,inter_session,[25],[2],"[1, 2]",[25],[1],"[1, 2]",0.59
1,inter_session,[15],[2],"[1, 2]",[15],[1],"[1, 2]",0.49
2,inter_session,[41],[2],"[1, 2]",[41],[1],"[1, 2]",0.585
3,inter_session,[25],[1],"[1, 2]",[25],[2],"[1, 2]",0.535
4,inter_session,[15],[1],"[1, 2]",[15],[2],"[1, 2]",0.5
5,inter_session,[41],[1],"[1, 2]",[41],[2],"[1, 2]",0.565


INTRA_SESSION_INTER_RUN


KeyboardInterrupt: 

In [31]:
for k, v in results_dict.items():
    if isinstance(v, list):
        results_dict[k] = pd.DataFrame.from_records(v)

In [32]:
results_dict["inter_session"]


Unnamed: 0,mode,train_uid,train_session,train_run,test_uid,test_session,test_run,accuracy
0,inter_session,[25],[2],"[1, 2]",[25],[1],"[1, 2]",0.59
1,inter_session,[15],[2],"[1, 2]",[15],[1],"[1, 2]",0.49
2,inter_session,[41],[2],"[1, 2]",[41],[1],"[1, 2]",0.585
3,inter_session,[25],[1],"[1, 2]",[25],[2],"[1, 2]",0.535
4,inter_session,[15],[1],"[1, 2]",[15],[2],"[1, 2]",0.5
5,inter_session,[41],[1],"[1, 2]",[41],[2],"[1, 2]",0.565


In [None]:
accs = run_experiment(.5, 3.5)
plt.hist(accs)