In [1]:
from ica_benchmark.scoring import mutual_information, coherence, correntropy, apply_pairwise, apply_pairwise_parallel, SCORING_FN_DICT
from ica_benchmark.processing.ica import get_ica_transformers
import time
from ica_benchmark.io.load import join_gdfs_to_numpy, load_subjects_data, load_subject_data
from ica_benchmark.processing.label import get_annotations
from mne import find_events, events_from_annotations
from mne.viz import plot_events
from mne.io import read_raw_gdf
import numpy as np


In [2]:
from pathlib import Path

root = Path("/home/paulo/Documents/datasets/BCI_Comp_IV_2a/gdf/")

subjects =  {
    "A01": [
        root / "A01T.gdf",
#         root / "A01E.gdf",
    ]
}


In [3]:
data = load_subjects_data(root, subjects)


In [5]:
arr = data["A01"]["gdf"]._data.T
labels = data["A01"]["labels"]

In [6]:
from mne.time_frequency import psd_multitaper, tfr_array_multitaper, psd_array_multitaper
import numpy as np
import matplotlib.pyplot as plt
from ica_benchmark.processing.ica import create_gdf_obj

In [7]:
# arr = np.expand_dims(arr.T, axis=0)
# arr.shape

In [8]:
# res = tfr_array_multitaper(arr, 250, np.linspace(1, 12, 12), output="power", n_cycles=3, decim=10)
# res = res.squeeze()

In [9]:
# plt.figure(figsize=(30, 5))
# plt.imshow(res[0, :, :])

In [10]:
# plt.plot(arr[0, 0, :])

In [104]:
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch
from statistics import mode

DEFAULT_FREQUENCIES = np.linspace(3, 30, 10)

DEFAULT_TRF_KWARGS = dict(
    sfreq=250.0,
    freqs=DEFAULT_FREQUENCIES, 
#     n_cycles=7.0,
    n_cycles=3.0,
    zero_mean=True,
    time_bandwidth=4,
#     use_fft=True,
    decim=1,
    output='power',
    n_jobs=1,
)

DEFAULT_PSD_KWARGS = dict(
    sfreq=250.0,
    fmin=0,
    fmax=np.inf,
    bandwidth=None,
    verbose=0,
)

def tfr_multitaper(
    arr,
    epochs_mode=False,
    feature_format=None,
    cut_size=None,
    **mne_kwargs
):
    
    if not mne_kwargs:
        mne_kwargs = DEFAULT_TRF_KWARGS
    
    # arr is (n_times, n_channels)
    if not epochs_mode:
        assert arr.ndim == 2, "The input array must be of shape (n_times, n_channels)"
        # to (n_channels, n_times)
        arr = np.expand_dims(arr.T, axis=0)
    else:
        assert arr.ndim == 3, "The input array must be of shape (n_epochs, n_times, n_channels)"
        # to (n_epochs, n_channels, n_times)
        arr = arr.transpose(0, 2, 1)
    
    # input (n_epochs, n_channels, n_times)
    tfr_psd = tfr_array_multitaper(
        arr,
        **mne_kwargs
    )
    # output = (n_epochs, n_chans, n_freqs, n_times)
    
    if feature_format is None:
        return tfr_psd
    
    n_epochs, n_chans, n_freqs, n_times = tfr_psd.shape
    cut_size = n_times if cut_size is None else cut_size
    if feature_format:
        #(n_epochs, n_chans, n_freqs, n_times) -> (size, features)
        tfr_psd = tfr_psd\
            .transpose(0, 3, 1, 2)\
            .reshape(n_epochs, n_times, n_chans * n_freqs)[:, -n_times:, :]
            
    else:
        tfr_psd = tfr_psd.transpose(0, 3, 1, 2)[:, :, :, -n_times:]
    
        
    tfr_psd = tfr_psd.squeeze() if n_epochs == 1 else tfr_psd
    
    return tfr_psd

def psd_multitaper(
    arr,
    **mne_kwargs
    ):

    if not mne_kwargs:
        mne_kwargs = DEFAULT_PSD_KWARGS

    psd, freqs = psd_array_multitaper(
        arr.T,
        **mne_kwargs
    )
    psd = np.expand_dims(psd.T, axis=0)
    
    return psd, freqs

def with_default(value, default):
    return value if value is not None else default

class WindowTransformer():

    def __init__(
        self,
        feature_transform_fn,
        label_transform_fn=mode,
        window_size=250,
        stride=125,
        iterator_mode=False,
        ):

        self.feature_transform_fn = feature_transform_fn
        self.label_transform_fn = label_transform_fn
        self.window_size = window_size
        self.stride = stride
        self.iterator_mode = iterator_mode

    def transform(self, x, y=None, start=None, end=None):
        
        size = len(x)
        start, end = with_default(start, 0), with_default(end, size)

        if y is not None:
            assert len(x) == len(y), "X and Y must have same sizes"

        if self.iterator_mode:
            return self._transform_iter(x, y=y, start=start, end=end)
        else:
            return self._transform_list(x, y=y, start=start, end=end)

    def _transform_list(self, x, y=None, start=None, end=None):
        
        with_y = y is not None

        output_x, output_y = list(), list()

        for step in range(start, end, self.stride):

            if step + self.window_size > end:
                break

            item_x = self.feature_transform_fn(x[step : step + self.window_size])
            if with_y:
                item_y = self.label_transform_fn(y[step : step + self.window_size])
            else:
                item_y = None

            return_items = item_x if with_y else (item_x, item_y)

            output_x.append(item_x)
            output_y.append(item_y)
        
        output_x = np.concatenate(output_x, axis=0)

        if with_y:
            output_y = np.array(output_y)

        return (output_x, output_y) if with_y else output_x

    def _transform_iter(self, x, y=None, start=None, end=None):
        
        with_y = y is not None
        
        if y is not None:
            assert len(x) == len(y), "X and Y must have same sizes"

        output_x, output_y = list(), list()

        for step in range(start, end, self.stride):

            if step + self.window_size > end:
                break

            item_x = self.feature_transform_fn(x[step : step + self.window_size])
            if with_y:
                item_y = self.label_transform_fn(y[step : step + self.window_size])
            else:
                item_y = None

            return_items = (item_x, item_y) if with_y else item_x

            yield return_items


class IterDataset(IterableDataset):
    def __init__(self, x, y, transformer_instance):
        super(IterDataset).__init__()
        self.transformer_instance = transformer_instance
        self.start = 0
        assert len(x) == len(y), "Lengths must be equal"
        self.end = len(y)
        self.x, self.y = x, y
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start, iter_end = self.start, self.end
        else:  # in a worker process
            # split workload
            per_worker = int(np.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        
        for x, y in self.transformer_instance.transform(self.x, self.y, start=iter_start, end=iter_end):
            yield x, y

            
class MapDataset(Dataset):
    
    def __init__(self, x, y, transformer_instance):
        super(MapDataset).__init__()
        self.transformer_instance = transformer_instance
        assert len(x) == len(y), "Lengths must be equal"
        self.x, self.y = self.transformer_instance.transform(self.x, self.y)
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        return self.x[i], self.i[i]

class WindowTransformerDataset(Dataset):

    def __init__(
        self,
        X,
        Y,
        feature_transform_fn,
        label_transform_fn=mode,
        window_size=500,
        stride=250,
        start=None,
        end=None
        ):
        super(WindowTransformerDataset).__init__()

        self.feature_transform_fn = feature_transform_fn
        self.label_transform_fn = label_transform_fn
        self.window_size = window_size
        self.stride = stride
        
        assert len(X) == len(Y), "X and Y must have same sizes."
        assert len(X) >= window_size, "Window size must be smaller than the array size."
        self.X, self.Y = X, Y
        
        self.start = with_default(start, 0)
        self.end = with_default(end, len(X))
        
    def __len__(self):
        return (len(self.X) - self.window_size) // self.stride - 2
    
    def __getitem__(self, i):
        
        idx = i * self.stride

        x = self.feature_transform_fn(self.X[idx : idx + self.window_size])
        y = self.label_transform_fn(self.Y[idx : idx + self.window_size])

        return x, y

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]
        

In [141]:
def label_transform(x):
    l = x[:, [3, 4, 5, 6, 7]]
    l[:, -1] = (l[:, :4].max(axis=1) == 1).astype(np.uint32)
    label = mode(l.argmax(1))
    return label

def psd_feature_transform(x, freqs=DEFAULT_FREQUENCIES, bandwidth=3):
    psd, psd_freqs = psd_multitaper(x)
    feature_vector = list()
    for freq in freqs:
        top_freq = freq + bandwidth / 2
        bot_freq = freq - bandwidth / 2
        selected_freqs = np.bitwise_and(psd_freqs >= bot_freq, psd_freqs <= top_freq)
        feature = psd[:, selected_freqs, :].mean(axis=1)
        feature_vector.append(feature)
    
    features = np.concatenate(feature_vector, axis=0).flatten()
    return features
        
def feature_transform(x):
    kwargs = DEFAULT_TRF_KWARGS.copy()
    kwargs.update(dict(n_cycles=3, n_jobs=4))
    return tfr_multitaper(
        x,
        epochs_mode=False,
        feature_format=None,
        **kwargs
    )

psd_fn = lambda x: psd_multitaper(x)[0]

n = 30000

# dataset = WindowTransformerDataset(arr, labels, tfr_multitaper, label_transform_fn=label_transform)
dataset = WindowTransformerDataset(
    arr,
    labels,
#     psd_feature_transform,
    np.max,
    label_transform_fn=label_transform,
    window_size=250,
    stride=100
)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4, 
    drop_last=True,
    prefetch_factor=2,
)


In [142]:
# print(len(dataset))
for x, y in dataset:
#     print(x.shape, y.shape)
    if y not in (0,):
        print(y)
#     plt.clf()
#     plt.figure(figsize=(30, 5))
#     p = x.numpy()[0, 0].transpose(1, 2, 0)[:, :300, :3]
#     p = (p - p.min())/ (p.max() - p.min())
#     plt.imshow(p)
#     plt.show()

3
3
3
2
2
2
1
1
1
1
1
1
2
2
2
3
3
3
1
1
1
2
2
2
3
3
3
1
1
1
1
1
1
2
2
2
1
1
1
3
3
3
3
3
3
2
2
2
3
3
3
3
3
3
1
1
1
3
3
3
3
3
3
1
1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
3
3
3
2
2
2
2
3
3
3
1
1
1
2
2
2
1
1
1
2
2
2
3
3
3
1
1
1
2
2
2
2
3
3
3
3
1
1
1
2
2
2
2
2
2
1
1
1
3
3
3
2
2
2
2
2
2
2
2
2
1
1
1
3
3
3
3
3
3
3
3
3
2
2
2
3
3
3
3
1
1
1
3
3
3
1
1
1
1
2
2
2
1
1
1
2
2
2
2
2
2
2
2
2
3
3
3
3
3
3
1
1
1
1
1
1
3
3
3
1
1
1
3
3
3
3
2
2
2
1
1
1
1
1
1
1
1
1
2
2
2
3
3
3
1
1
1
3
3
3
2
2
2
2
2
2
2
3
3
3
2
2
2
1
1
1
3
3
3
3
3
3
3
3
3
1
1
1
2
2
2
1
1
1
3
3
3
3
3
3
2
2
2
1
1
1
3
3
3
3
3
3
1
1
1
1
1
1
1
1
2
2
2
2
3
3
3
1
1
1
1
3
3
3
1
1
1
2
2
2
1
1
1
1
1
1
2
2
2
3
3
3
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
1
1
2
2
2
2
2
2
1
1
1
2
2
2
3
3
3
3
3
3
1
1
1
3
3
3
2
2
2
1
1
1
3
3
3
2
2
2
3
3
3
2
2
2
3
3
3
1
1
1
1
1
1
3
3
3
1
1
1
1
1
1
1
1
1
2
2
2
3
3
3
3
3
3
2
2
2
2
3
3
3
2
2
2
2
1
1
1
2
2
2
2
2
2
3
3
3
1
1
1
3
3
3
1
1
1
1
2
2
2
2
2
2
3
3
3
1
1
1
3
3
3
2
2
2
2
2
2
1
1
1
3
3
3
1
1
1
1
1
1
1
1
1
3
3
3
3
3
3
1
1
1
1
1
1
1
1
1
1


In [127]:
np.unique(labels[96300:96600, [3, 4, 5, 6]].argmax(axis=1), return_counts=True)

(array([1]), array([300]))

In [99]:
y

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
0 200
200 300
400 500
600 700
800 90/;.ç0

In [208]:
psd_feature_transform(arr[:1000]).shape

(220,)