In [1]:
import numpy as np
import mne

annotation_desc_dict = {
    276: "Idling EEG (eyes open)",
    277: "Idling EEG (eyes closed)",
    768: "Start of a trial",
    769: "Cue onset left (class 1)",
    770: "Cue onset right (class 2)",
    771: "Cue onset foot (class 3)",
    772: "Cue onset tongue (class 4)",
    783: "Cue unknown",
    1023: "Rejected trial",
    1072: "Eye movements",
    32766: "Start of a new run",
}

annotation_encode_dict = {
    276: 0,
    277: 1,
    768: 2,
    769: 3,
    770: 4,
    771: 5,
    772: 6,
    783: 7,
    1023: 8,
    1072: 9,
    32766: 10,
}

def get_annotations(data):
    sr = data.info["sfreq"]
    n_samples = data._raw_extras[0]["n_records"]

    onsets = np.trunc(data.annotations.onset * sr).astype(np.uint32, casting="unsafe")
    durations = np.trunc(data.annotations.duration * sr).astype(np.uint32, casting="unsafe")
    
    desc = data.annotations.description.astype(np.uint32)
    labels_codes = np.vectorize(annotation_encode_dict.get)(desc)
    
    n_codes = len(annotation_encode_dict)
    labels = np.zeros((n_samples, n_codes))
    
    for code, onset, duration in zip(labels_codes, onsets, durations):
        labels[onset:onset+duration, code] = 1
    
    return labels



In [2]:
from mne.io import read_raw_gdf
from pathlib import Path
import numpy as np

root = Path("C:/Users/paull/Documents/GIT/BCI_MsC/notebooks/BCI_Comp_IV_2a/BCICIV_2a_gdf")

dataset_folder = root
mat_files = list(dataset_folder.iterdir())

PRELOAD = False

def load_gdf_file(filepath):
    gdf_data = read_raw_gdf(filepath, preload=PRELOAD)

    chs = gdf_data.ch_names

    gdf_data = read_raw_gdf(
        filepath,
        preload=True,
        eog=["EOG-left", "EOG-central", "EOG-right"],
        exclude=[x for x in chs if "EOG" in x]
    )
    ch_names = gdf_data.ch_names
    info = parse_info(
        gdf_data._raw_extras[0]["subject_info"]
    )
    
    labels = get_annotations(gdf_data)
    
    return gdf_data, labels, ch_names, info

def parse_info(info_dict):
    cols = ['id', 'smoking', 'alcohol_abuse', 'drug_abuse', 'medication', 'weight', 'height', 'sex', 'handedness', 'age']
    parsed_info = {k:v for k, v in info_dict.items() if k in cols}
    return parsed_info
     
def load_subject_data(root, subject, mode=None):
    if mode is None:
        mode = "train"
    
    if mode == "train":
        filepath = root / f"{subject}T.gdf"
        gdf_data, labels, ch_names, info = load_gdf_file(filepath)
    elif mode == "test":
        filepath = root / f"{subject}E.gdf"
        gdf_data, labels, ch_names, info = load_gdf_file(filepath)
    elif mode == "both":
        filepath_t = root / f"{subject}T.gdf"
        filepath_e = root / f"{subject}E.gdf"
        gdf_data_t, labels_t, ch_names_t, info_t = load_gdf_file(filepath_t)
        gdf_data_e, labels_e, ch_names, info = load_gdf_file(filepath_e)
        
        
        assert np.all(ch_names_t == ch_names)
        assert np.all(info_t == info)
        
        gdf_data = gdf_data_t.copy()
        gdf_data._data = np.concatenate(
            [
                gdf_data_t._data,
                gdf_data_e._data,
            ],
            axis=1
        )
        
        labels = np.concatenate(
            [
                labels_t,
                labels_e
            ],
            axis=0
        )
    
    return gdf_data, labels, ch_names, info

def load_subjects_data(root, datasets=None, mode="train"):
    if datasets is None:
#         data_dict = {
#             "all": {
#                 filepath.name[:3]: None for filepath in root.glob("*T.gdf")
#             }
#         }
        data_dict = {
            filepath.name[:3]: {
                filepath.name[:3]: None
            } for filepath in root.glob("*T.gdf")
        }
    else:
        data_dict = {
            dataset: {
                subject_id: {} for subject_id in datasets[dataset]
            } for dataset in datasets
        }
    
    chs_ = None
    for dataset in data_dict:
        for subject_id in data_dict[dataset]:
            gdf, labels, chs, info = load_subject_data(root, subject_id, mode=mode)
            if chs_ is None:
                chs_ = chs
            else:
                assert chs_ == chs
            data_dict[dataset][subject_id] = {
                "gdf": gdf,
                "chs": chs,
                "info": info,
                "labels": labels
            }
    
    
    return data_dict



In [3]:
dataset_dict = {
    "train": ["A02", "A07", "A09", "A01"],
    "validation": ["A03", "A06"],
    "test": ["A04", "A05"],
}
all_subjects = [f"A0{i}" for i in range(10)]

dataset_dict = {
    "train": ["A02", "A07"],
    "validation": ["A03"],
    "test": ["A04"],
}

In [4]:
import mne

def get_kwargs(m, is_extended=False):
    if is_extended:
        return dict(method=m, fit_params=dict(extended=True))
    return dict(method=m)

ica_kwargs_dict = {
    "fastica": get_kwargs("fastica"),
    "infomax": get_kwargs("infomax"),
    "picard": get_kwargs("picard"),
    "ext_infomax": get_kwargs("infomax", is_extended=True),
    "ext_picard": get_kwargs("picard", is_extended=True)
}



In [5]:
from scoring import mutual_information, coherence, correntropy, apply_pairwise, apply_pairwise_parallel

In [None]:
import time

def join_gdfs(data_dict, datasets_names=None):
    new_dict = {}
    if datasets_names is None:
        datasets_names = data_dict.keys()
    
    for dataset_name in datasets_names:
        all_gdfs = []
        all_labels = []
        for subject_id in data_dict[dataset_name]:
            all_gdfs.append(data_dict[dataset_name][subject_id]["gdf"])
            all_labels.append(data_dict[dataset_name][subject_id]["labels"])

        labels = np.concatenate(all_labels, axis=0)
        gdf_base = all_gdfs[0].copy()
        for gdf in all_gdfs[1:]:
            gdf_base._data = np.concatenate(
                [
                    gdf_base._data,
                    gdf._data
                ],
                axis=1
            )
        new_dict[dataset_name] = {
            "all": {
                "gdf": gdf_base,
                "labels": labels,
                "info": None,
                "chs": gdf_base.ch_names
            }
        }
        
    return new_dict
        
N_RUNS = 3

results = {}

fn_dict = {
    "MI": mutual_information,
    "correntropy": correntropy,
    "coherence": coherence
}

n_components_list = [4, 8, 12, 16, 20, 22]



try:
    datasets
except:
    datasets = load_subjects_data(root, datasets=dataset_dict, mode="both")

score_calculated_before = {}

for ica_method in ica_kwargs_dict:
    for n_components in n_components_list:
        for run_n in range(N_RUNS):
            
            joined_dataset = join_gdfs(datasets, ["train"])

            gdf_data = joined_dataset["train"]["all"]["gdf"]
            ica_transform = mne.preprocessing.ICA(n_components=n_components, **ica_kwargs_dict[ica_method])
            ica_transform = ica_transform.fit(gdf_data)
            
            del joined_dataset

            for dataset_name in ("test", "validation", "train"):

                for subject_id in datasets[dataset_name]:
                    

                    gdf_data = datasets[dataset_name][subject_id]["gdf"]   
                    
                    data_after = ica_transform.get_sources(gdf_data).get_data().T

                    for fn_name in fn_dict:

                        print((fn_name, ica_method, dataset_name, subject_id, run_n, n_components))
                        
                        if (n_components > 5) or len(data_after) > 2e6:
                            apply_fn = apply_pairwise_parallel
                        else:
                            apply_fn = apply_pairwise
                        
                        if not (subject_id, fn_name) in score_calculated_before:
                            data_before = gdf_data.get_data().T
                            score_before = apply_pairwise_parallel(data_before, fn_dict[fn_name])
                            score_calculated_before[(subject_id, fn_name)] = score_before

                        start = time.time()
                        score_after = apply_fn(data_after, fn_dict[fn_name])
                        duration = time.time() - start
                        
                        results[(fn_name, ica_method, dataset_name, subject_id, run_n, n_components)] = {
                            "score_before": score_calculated_before[(subject_id, fn_name)],
                            "score_after": score_after,
                            "time": duration
                        }

Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A02E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A02E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A07T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A07T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 681070  =      0.000 ...  2724.280 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A07E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A07E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 673134  =      0.000 ...  2692.536 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A03E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A03E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 648774  =      0.000 ...  2595.096 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A04T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A04T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 600914  =      0.000 ...  2403.656 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A04E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from C:\Users\paull\Documents\GIT\BCI_MsC\notebooks\BCI_Comp_IV_2a\BCICIV_2a_gdf\A04E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660046  =      0.000 ...  2640.184 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(filepath, preload=PRELOAD)
  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  gdf_data = read_raw_gdf(


Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 4 components
Fitting ICA took 5.3s.
('MI', 'fastica', 'test', 'A04', 0, 4)
('correntropy', 'fastica', 'test', 'A04', 0, 4)
('coherence', 'fastica', 'test', 'A04', 0, 4)
('MI', 'fastica', 'validation', 'A03', 0, 4)
('correntropy', 'fastica', 'validation', 'A03', 0, 4)
('coherence', 'fastica', 'validation', 'A03', 0, 4)
('MI', 'fastica', 'train', 'A02', 0, 4)
('correntropy', 'fastica', 'train', 'A02', 0, 4)
('coherence', 'fastica', 'train', 'A02', 0, 4)
('MI', 'fastica', 'train', 'A07', 0, 4)
('correntropy', 'fastica', 'train', 'A07', 0, 4)
('coherence', 'fastica', 'train', 'A07', 0, 4)
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 4 components
Fitting ICA took 3.8s.
('MI', 'fastica', 'test', 'A04', 1, 4)
('correntropy', 'fastica', 'test', 'A04', 1, 4)
('coherence', 'fastica', 'test', 'A04', 1, 4)
('MI', 'fastica', 'validation', 'A0

('coherence', 'fastica', 'train', 'A02', 2, 16)
('MI', 'fastica', 'train', 'A07', 2, 16)
('correntropy', 'fastica', 'train', 'A07', 2, 16)
('coherence', 'fastica', 'train', 'A07', 2, 16)
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 20 components
Fitting ICA took 15.8s.
('MI', 'fastica', 'test', 'A04', 0, 20)
('correntropy', 'fastica', 'test', 'A04', 0, 20)
('coherence', 'fastica', 'test', 'A04', 0, 20)
('MI', 'fastica', 'validation', 'A03', 0, 20)
('correntropy', 'fastica', 'validation', 'A03', 0, 20)
('coherence', 'fastica', 'validation', 'A03', 0, 20)
('MI', 'fastica', 'train', 'A02', 0, 20)
('correntropy', 'fastica', 'train', 'A02', 0, 20)
('coherence', 'fastica', 'train', 'A02', 0, 20)
('MI', 'fastica', 'train', 'A07', 0, 20)
('correntropy', 'fastica', 'train', 'A07', 0, 20)
('coherence', 'fastica', 'train', 'A07', 0, 20)
Fitting ICA to data using 22 channels (please be patient, this may take a while)
Selecting by number: 20 

In [None]:
import pandas as pd

df = []
cols = ["scoring", "algorithm", "dataset", "subject_id", "run", "n_components", "score_before", "score_after", "time"]

for k, v in results.items():
    df.append(list(k) + list(v.values()))
pd.DataFrame(df, columns=cols).to_csv("results.csv")

In [None]:
df = pd.DataFrame(df, columns=cols)
df.groupby(["scoring", "algorithm", "dataset", "subject_id", "n_components"]).mean().query(""" (dataset == "test") """)

In [None]:
10 + 20