In [1]:
from ica_benchmark.scoring import mutual_information, coherence, correntropy, apply_pairwise, apply_pairwise_parallel, SCORING_FN_DICT
from ica_benchmark.processing.ica import get_ica_transformers
import time
from ica_benchmark.io.load import join_gdfs_to_numpy, load_subjects_data, load_subject_data
from ica_benchmark.processing.label import get_annotations
from mne import find_events, events_from_annotations
from mne.viz import plot_events
from mne.io import read_raw_gdf


In [2]:
from pathlib import Path

root = Path("C:/Users/paull/Documents/GIT/BCI_MsC/notebooks/BCI_Comp_IV_2a/BCICIV_2a_gdf")

subjects =  {
    "A01": [
        "C:\\Users\\paull\\Documents\\GIT\\BCI_MsC\\notebooks\\BCI_Comp_IV_2a\\BCICIV_2a_gdf\\A01E.gdf",
        "C:\\Users\\paull\\Documents\\GIT\\BCI_MsC\\notebooks\\BCI_Comp_IV_2a\\BCICIV_2a_gdf\\A01T.gdf"
    ]
}


In [3]:
data = load_subjects_data(root, subjects)

In [4]:
arr = data["A01"]["gdf"]._data.T
labels = data["A01"]["labels"]
arr.shape, labels.shape

((1359528, 22), (1359528, 11))

In [5]:
from mne.time_frequency import psd_multitaper, tfr_array_multitaper, psd_array_multitaper
import numpy as np
import matplotlib.pyplot as plt
from ica_benchmark.processing.ica import create_gdf_obj

In [6]:
# arr = np.expand_dims(arr.T, axis=0)
# arr.shape

In [7]:
# res = tfr_array_multitaper(arr, 250, np.linspace(1, 12, 12), output="power", n_cycles=3, decim=10)
# res = res.squeeze()

In [8]:
# plt.figure(figsize=(30, 5))
# plt.imshow(res[0, :, :])

In [9]:
# plt.plot(arr[0, 0, :])

In [15]:
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch
from statistics import mode

DEFAULT_FREQUENCIES = np.linspace(1, 30, 30)

DEFAULT_TRF_KWARGS = dict(
    sfreq=250.0,
    freqs=DEFAULT_FREQUENCIES, 
#     n_cycles=7.0,
    n_cycles=3.0,
    zero_mean=True,
    time_bandwidth=4,
#     use_fft=True,
    decim=1,
    output='power',
    n_jobs=1,
)

DEFAULT_PSD_KWARGS = dict(
    sfreq=250.0,
    fmin=0,
    fmax=np.inf,
    bandwidth=None,
    verbose=0,
)

def tfr_multitaper(
    arr,
    epochs_mode=False,
    feature_format=None,
    **mne_kwargs
):
    
    if not mne_kwargs:
        mne_kwargs = DEFAULT_TRF_KWARGS
    
    # arr is (n_times, n_channels)
    if not epochs_mode:
        assert arr.ndim == 2, "The input array must be of shape (n_times, n_channels)"
        # to (n_channels, n_times)
        arr = np.expand_dims(arr.T, axis=0)
    else:
        assert arr.ndim == 3, "The input array must be of shape (n_epochs, n_times, n_channels)"
        # to (n_epochs, n_channels, n_times)
        arr = arr.transpose(0, 2, 1)
    
    # input (n_epochs, n_channels, n_times)
    tfr_psd = tfr_array_multitaper(
        arr,
        **mne_kwargs
    )
    # output = (n_epochs, n_chans, n_freqs, n_times)
    
    if feature_format is None:
        return tfr_psd
    
    n_epochs, n_chans, n_freqs, n_times = tfr_psd.shape
    if feature_format:
        #(n_epochs, n_chans, n_freqs, n_times) -> (size, features)
        tfr_psd = tfr_psd\
            .transpose(0, 3, 1, 2)\
            .reshape(n_epochs, n_times, n_chans * n_freqs)
    else:
        tfr_psd = tfr_psd.transpose(0, 3, 1, 2)
        
    tfr_psd = tfr_psd.squeeze() if n_epochs == 1 else tfr_psd
    
    return tfr_psd

def psd_multitaper(
    arr,
    **mne_kwargs
    ):

    if not mne_kwargs:
        mne_kwargs = DEFAULT_PSD_KWARGS

    psd, freqs = psd_array_multitaper(
        arr.T,
        **mne_kwargs
    )
    psd = np.expand_dims(psd.T, axis=0)
    
    return psd, freqs

def with_default(value, default):
    return value if value is not None else default

class WindowTransformer():

    def __init__(
        self,
        feature_transform_fn,
        label_transform_fn=mode,
        window_size=250,
        stride=125,
        iterator_mode=False,
        ):

        self.feature_transform_fn = feature_transform_fn
        self.label_transform_fn = label_transform_fn
        self.window_size = window_size
        self.stride = stride
        self.iterator_mode = iterator_mode

    def transform(self, x, y=None, start=None, end=None):
        
        size = len(x)
        start, end = with_default(start, 0), with_default(end, size)

        if y is not None:
            assert len(x) == len(y), "X and Y must have same sizes"

        if self.iterator_mode:
            return self._transform_iter(x, y=y, start=start, end=end)
        else:
            return self._transform_list(x, y=y, start=start, end=end)

    def _transform_list(self, x, y=None, start=None, end=None):
        
        with_y = y is not None

        output_x, output_y = list(), list()

        for step in range(start, end, self.stride):

            if step + self.window_size > end:
                break

            item_x = self.feature_transform_fn(x[step : step + self.window_size])
            if with_y:
                item_y = self.label_transform_fn(y[step : step + self.window_size])
            else:
                item_y = None

            return_items = item_x if with_y else (item_x, item_y)

            output_x.append(item_x)
            output_y.append(item_y)
        
        output_x = np.concatenate(output_x, axis=0)

        if with_y:
            output_y = np.array(output_y)

        return (output_x, output_y) if with_y else output_x

    def _transform_iter(self, x, y=None, start=None, end=None):
        
        with_y = y is not None
        
        if y is not None:
            assert len(x) == len(y), "X and Y must have same sizes"

        output_x, output_y = list(), list()

        for step in range(start, end, self.stride):

            if step + self.window_size > end:
                break

            item_x = self.feature_transform_fn(x[step : step + self.window_size])
            if with_y:
                item_y = self.label_transform_fn(y[step : step + self.window_size])
            else:
                item_y = None

            return_items = (item_x, item_y) if with_y else item_x

            yield return_items


class IterDataset(IterableDataset):
    def __init__(self, x, y, transformer_instance):
        super(IterDataset).__init__()
        self.transformer_instance = transformer_instance
        self.start = 0
        assert len(x) == len(y), "Lengths must be equal"
        self.end = len(y)
        self.x, self.y = x, y
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start, iter_end = self.start, self.end
        else:  # in a worker process
            # split workload
            per_worker = int(np.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        
        for x, y in self.transformer_instance.transform(self.x, self.y, start=iter_start, end=iter_end):
            yield x, y

            
class MapDataset(Dataset):
    
    def __init__(self, x, y, transformer_instance):
        super(MapDataset).__init__()
        self.transformer_instance = transformer_instance
        assert len(x) == len(y), "Lengths must be equal"
        self.x, self.y = self.transformer_instance.transform(self.x, self.y)
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        return self.x[i], self.i[i]

class WindowTransformerDataset(Dataset):

    def __init__(
        self,
        X,
        Y,
        feature_transform_fn,
        label_transform_fn=mode,
        window_size=500,
        stride=250,
        start=None,
        end=None
        ):
        super(WindowTransformerDataset).__init__()

        self.feature_transform_fn = feature_transform_fn
        self.label_transform_fn = label_transform_fn
        self.window_size = window_size
        self.stride = stride
        assert len(X) == len(Y), "X and Y must have same sizes"
        self.X, self.Y = X, Y
        
        self.start = with_default(start, 0)
        self.end = with_default(end, len(X))
        
    def __len__(self):
        return (len(self.X) - self.window_size) // self.stride        
    
    def __getitem__(self, step):
        
        x = self.feature_transform_fn(self.X[step : step + self.window_size])
        y = self.label_transform_fn(self.Y[step : step + self.window_size])

        return x, y

In [18]:
def label_transform(x):
    return x.argmax(1)

def feature_transform(x):
    kwargs = DEFAULT_TRF_KWARGS.copy()
    kwargs.update(dict(n_cycles=3))
    return tfr_multitaper(
        arr,
        epochs_mode=False,
        feature_format=None,
        **kwargs
    )

psd_fn = lambda x: psd_multitaper(x)[0]

n = 100000

# dataset = WindowTransformerDataset(arr, labels, tfr_multitaper, label_transform_fn=label_transform)
dataset = WindowTransformerDataset(arr[:n], labels[:n], feature_transform, label_transform_fn=label_transform)
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0, 
    drop_last=True,
    prefetch_factor=2,
)

In [None]:
for x, y in dataloader:
    print(x.shape)

In [14]:
arr.shape[0] // (32 * 1000)

42