In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
from scipy.io import loadmat
import numpy as np

## Take a look at the dataset

In [3]:
mat_file = loadmat("./cho/s01.mat", squeeze_me=True, struct_as_record=False)

In [4]:
mat_file["eeg"]._fieldnames

['noise',
 'rest',
 'srate',
 'movement_left',
 'movement_right',
 'movement_event',
 'n_movement_trials',
 'imagery_left',
 'imagery_right',
 'n_imagery_trials',
 'frame',
 'imagery_event',
 'comment',
 'subject',
 'bad_trial_indices',
 'psenloc',
 'senloc']

#### rest: 
resting state with eyes-open condition. resting state was recorded for 60 seconds.

In [5]:
sig = mat_file["eeg"].rest
sig.shape, sig.shape[1] / mat_file["eeg"].srate

((68, 34048), 66.5)

#### noise:
- eye blinking, 5 seconds × 2
- eyeball movement up/down, 5 seconds × 2
- eyeball movement left/right, 5 seconds × 2
- jaw clenching, 5 seconds × 2
- head movement left/right, 5 seconds × 2

In [6]:
sig = mat_file["eeg"].noise[0]
sig.shape, sig.shape[1] / mat_file["eeg"].srate

((68, 5120), 10.0)

#### imagery left: 
100 or 120 trials of left hand MI

In [7]:
sig = mat_file["eeg"].imagery_left
sig.shape, sig.shape[1] / mat_file["eeg"].srate

((68, 358400), 700.0)

#### imagery right: 
100 or 120 trials of right hand MI

In [8]:
sig = mat_file["eeg"].imagery_right
sig.shape, sig.shape[1] / mat_file["eeg"].srate

((68, 358400), 700.0)

#### senloc & psenloc: 
3D sensor locations \
sensor location projected to unit sphere

In [9]:
mat_file["eeg"].senloc.shape, mat_file["eeg"].psenloc.shape

((64, 3), (64, 3))

### Getting trials

In [10]:
print(f"{12:0=3}")

012


In [13]:
ind = np.where(mat_file["eeg"].imagery_event == 1)[0]
assert len(ind) == mat_file["eeg"].n_imagery_trials

# np.array(ind).inse
trial_ranges = [ind[i + 1] - ind[i] for i in range(len(ind) - 1)]
print(len(trial_ranges))
print(set(trial_ranges))

trials = np.vstack(
    [
        np.expand_dims(mat_file["eeg"].imagery_left[:, ind[i] : ind[i + 1]], axis=0)
        for i in range(0, len(ind) - 1)
    ]
)
trials.shape

99
{3584}


(99, 68, 3584)

In [14]:
np.vstack([np.full((10, 2), 2), np.zeros((14, 2))])

array([[2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [2., 2.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]])

In [15]:
np.hstack([np.zeros(10), np.zeros(3)])

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

## Torch Dataset

In [16]:
import os
from sklearn.model_selection import RepeatedStratifiedKFold
import json
from tqdm.notebook import tqdm

ModuleNotFoundError: No module named 'sklearn'

In [None]:
class ChoDataset:
    def __init__(
        self,
        data_path: str,
        patients=list(range(52)),
        paradigm="imagery",  # ToDo: implement 'real' movement
        transforms=None,
        data: dict = None,
    ):
        self.data_path = data_path
        self.patients = patients
        self.paradigm = paradigm
        self.transforms = transforms

        if data is not None:
            self.signals = data["signals"]
            self.labels = data["labels"]
            return

        self._load_dataset()

    # def get_sampling_rate(self):
    #     return 250

    # def get_resampling_rate(self):
    #     return 256

    def _load_dataset(self):
        trials_list = []
        labels_list = []
        for i in self.patients:
            mat_file = self._load_patient(i)
            patiant_trials, patient_labels = self._extract_imagery(mat_file)
            trials_list.append(patiant_trials)
            labels_list.append(patient_labels)

        self.signals = np.vstack(trials_list)
        self.labels = np.hstack(labels_list)

    def _extract_imagery(self, mat_file):
        ind = np.where(mat_file["eeg"].imagery_event == 1)[0]
        assert len(ind) == mat_file["eeg"].n_imagery_trials

        trials_left = np.vstack(
            [
                np.expand_dims(
                    mat_file["eeg"].imagery_left[:, ind[i] : ind[i + 1]], axis=0
                )
                for i in range(0, len(ind) - 1)
            ]
        )
        trials_right = np.vstack(
            [
                np.expand_dims(
                    mat_file["eeg"].imagery_right[:, ind[i] : ind[i + 1]], axis=0
                )
                for i in range(0, len(ind) - 1)
            ]
        )
        trials = np.vstack([trials_left, trials_right])
        labels = np.hstack(
            [np.full(len(trials_left), 1), np.full(len(trials_right), 2)]
        )

        return trials, labels

    def _load_patient(self, idx):
        return loadmat(
            f"{self.data_path}/s{idx:0=2}.mat", squeeze_me=True, struct_as_record=False
        )

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        wav, label = self.signals[idx], self.labels[idx]

        if self.transforms is not None:
            wav, label = self.transforms(wav, label)

        return wav, label

    def subset(self, indices):
        return self.__class__(
            data_path=self.data_path,
            patients=self.patients,
            paradigm=self.paradigm,
            transforms=self.transforms,
            data={"signals": self.signals[indices], "labels": self.labels[indices]},
        )

    def _repeated_kfold_splits(self):
        n_splits = 5
        n_repeats = 10
        persist_path = f"./CHO_dataset_{n_splits}_splits_{n_repeats}_repeats.json"

        if os.path.exists(persist_path):
            with open(persist_path, "r") as f:
                splits = json.load(f)
        else:
            cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats)
            splits = {
                f"split_{i}": {"train": split[0].tolist(), "test": split[1].tolist()}
                for i, split in enumerate(list(cv.split(self.signals, self.labels)))
            }
            with open(persist_path, "w") as f:
                json.dump(splits, f)

        self.splits = splits
        for key, split in splits.items():
            yield key, split["train"], split["test"]

    def get_train_test_subsets(self, with_key=False):
        for key, train, val in self._repeated_kfold_splits():
            if with_key:
                yield key, self.subset(train), self.subset(val)
            else:
                yield self.subset(train), self.subset(val)

    @staticmethod
    def dict_to_2d_wave(dict_signals):
        return np.vstack([wav for wav in dict_signals.values()])

    @staticmethod
    def collate_fn(batch):
        imgs = torch.vstack([item[0] for item in batch])

        trgts = {}
        sample_item_label = batch[0][1]
        for label_key in sample_item_label.keys():
            if isinstance(sample_item_label[label_key], dict):
                trgts[label_key] = {
                    key: torch.vstack(
                        [item[1][label_key][key].squeeze() for item in batch]
                    )
                    for key in sample_item_label[label_key].keys()
                }
            else:
                trgts[label_key] = torch.vstack(
                    [item[1][label_key] for item in batch]
                ).squeeze()

        return [imgs, trgts]

In [None]:
ds = ChoDataset(data_path="./cho/", patients=[1])

In [None]:
for key, train, test in ds.get_train_test_subsets(with_key=True):
    print(key, len(train), len(test))
    idx = np.random.randint(0, len(train))
    print(train[idx][0].shape, train[idx][1])
    # train and evaluate model here

### Data Utils

In [None]:
from typing import Dict, List, Tuple, Any, Union, Callable
from torch import Tensor
import torch
import numpy as np
import scipy
from torch.utils.data import DataLoader
from lightning.fabric.utilities.apply_func import apply_to_collection

from functools import partial

In [None]:
class ToTensor:
    def __init__(self, device):
        if isinstance(device, str):
            device = torch.device(device)
        self.device = device

    def __call__(self, data, label):
        data = apply_to_collection(
            data,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: torch.from_numpy(a),
        )
        label = apply_to_collection(
            label,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: torch.tensor(a, dtype=torch.float64),
        )

        return data, label


class ToNumpy:
    def __call__(self, data, label):
        data = apply_to_collection(
            data,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: a.cpu().detach().numpy(),
        )
        label = apply_to_collection(
            label,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: a.cpu().detach().numpy(),
        )

        return data, label


class DictToTensor:
    def __call__(self, data: Dict[str, Tensor], label):
        # The output shape [batch, channel, signal]
        return (
            torch.permute(
                torch.vstack(list(map(lambda a: a.unsqueeze(0), data.values()))),
                (1, 0, 2),
            ),
            label,
        )


class DictToArray:
    def __call__(self, data, label):
        # The output shape [batch, channel, signal]
        return (
            np.transpose(
                np.vstack(
                    list(map(lambda a: np.expand_dims(a, axis=0), data.values()))
                ),
                (1, 0, 2),
            ),
            label,
        )


class Windowing:
    def __init__(self, n_segments: int = 5, sample_rate: float = 250.0):
        self.n_segments = n_segments
        self.sample_rate = sample_rate

    # The Output of the signal is [batch, channels, windowed, band_filtered, signal]
    def __call__(self, data: Tensor, label):
        """Takes as input a signal tensor of shape [batch, channels, band_filtered, signal]
        and outputs a signal tensor of shape [batch, channels, windowed, band_filtered, signal]
        """
        start, end = 0, data.size()[-1]
        step = int((end - start) / self.n_segments)
        windows = np.arange(start, end - step, step=step)

        if len(windows) == 0:
            data = data.unsqueeze(dim=2)
            return data, label

        windowed_data = torch.permute(
            torch.stack(
                [data[:, :, :, window : (window + step)] for window in windows], dim=0
            ),
            (1, 2, 0, 3, 4),
        )

        return windowed_data, label


class Filtering:
    def __init__(self, N: int, rs: float, Wns: List[float], bandwidth, fs: float):
        self.N = N
        self.rs = rs
        self.Wns = Wns / (fs / 2)  # Normalize the signals
        self.bandwidth = bandwidth / (fs / 2)  # Normalize the signals
        self.fs = fs

    # The Output of the signal is [batch, channels, band_filtered, signal]
    def __call__(self, data, label):
        filtered_data = []

        for wn in self.Wns:
            b, a = scipy.signal.cheby2(
                N=self.N,
                rs=self.rs,
                Wn=[wn, wn + self.bandwidth],
                btype="bandpass",
                fs=self.fs,
            )
            filtered_data.append(scipy.signal.filtfilt(b, a, data, axis=-1))

        filtered_data = torch.permute(torch.Tensor(filtered_data), (1, 2, 0, 3))

        return filtered_data, label


class ExpandDim(object):
    def __init__(self, dim):
        self.dim = dim

    def __call__(self, data, label):
        return data.unsqueeze_(self.dim), label


class LabelToDict:
    def __call__(self, data, label):
        return data, {"label": label}


class ToNumpy:
    def __call__(self, data, label):
        return data.cpu().detach().numpy(), label.cpu().detach().numpy()


class Compose:
    def __init__(self, transforms: List[Callable]) -> None:
        self.transforms = transforms

    def __call__(self, data: Any, target: Any):
        for t in self.transforms:
            data, target = t(data, target)
        return data, target

    def __repr__(self):
        return "\n".join([c.__class__.__name__ for c in self.transforms])


# TODO: complete this part
from scipy.signal import cheby2, filtfilt


def cheby_bandpass_filter(signal, attenuation, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = cheby2(order, rs=attenuation, Wn=[low, high], btype="band")
    y = filtfilt(b, a, signal, axis=-1)
    # print("filtered shape ", y.shape)
    return y


def cheby_bandpass_one_subject(
    X, attenuation, lowcut, highcut, fs, interval=None, verbose=True
):
    temp_epoch_EEG = X.copy()
    # print(f"data shape : {temp_epoch_EEG.shape}")

    if interval is not None:
        startband = np.arange(lowcut, highcut, step=interval)

        bands = []
        for start in startband:
            # This will be new key inside the EEG_filtered
            band = "{:02d}_{:02d}".format(start, start + interval)

            if verbose:
                print("Filtering through {} Hz band".format(band))
            # Bandpass filtering
            bands.append(
                cheby_bandpass_filter(
                    temp_epoch_EEG, attenuation, start, start + interval, fs
                )
            )

        return np.vstack(bands)

    else:
        # This will be new key inside the EEG_filtered
        band = "{:02d}_{:02d}".format(lowcut, highcut)

        return cheby_bandpass_filter(temp_epoch_EEG, attenuation, lowcut, highcut, fs)


from functools import partial


class BandPass:
    def __init__(self, attenuation, lowcut, highcut, fs, interval=None):
        self.attenuation = attenuation
        self.lowcut = lowcut
        self.highcut = highcut
        self.fs = fs
        self.interval = interval

        self.bandpass_func = partial(
            cheby_bandpass_one_subject,
            attenuation=self.attenuation,
            lowcut=self.lowcut,
            highcut=self.highcut,
            fs=self.fs,
            interval=self.interval,
            verbose=False,
        )

    # The Output of the signal is [batch, channels, band_filtered, signal]
    def __call__(self, data, label):
        filtered_data = data = apply_to_collection(
            data,
            dtype=(np.ndarray, int, float, np.int64, Tensor),
            function=self.bandpass_func,
        )

        filtered_data = np.expand_dims(filtered_data.transpose(1, 0, 2), axis=0)

        return filtered_data, label


class Whitening:
    def __init__(self, data_loader, whitening_method="PCA"):
        self.ds = data_loader
        self.method = whitening_method

        self.W = self._generate_whitening_transformation(self.ds, self.method)

    def _generate_whitening_transformation(self, data_loader, whitening_method="PCA"):
        """extract whitening transformation from data

        Parameters
        ----------
        data_loader : torch.dataloader
            pytorch data loader
        whitening_method : str
            one of following values
            "PCA" for PCA whitening
            "ZCA for ZCA whitening

        Returns
        -------
        torch.Tensor
            whitening transformation matrix
        """
        # get data
        signal = []
        for sig, lbl in data_loader:
            signal.append(sig)
        signal = torch.vstack(signal)

        # zero center
        x = signal.squeeze()
        sig = x.permute(0, 2, 1)
        x = torch.mean(sig, axis=1)
        x_mean = torch.mean(x, axis=0)

        x = sig - x_mean
        x_zero_centered = x.permute(0, 2, 1)

        # Calculate whitening matrix
        x_cov = self._calc_cov(x_zero_centered)

        lda, V = torch.linalg.eig(x_cov)
        lda, V = lda.real, V.real
        if "PCA":
            whitening_mat = torch.sqrt(torch.inverse(torch.diag(lda))) @ V.T
        elif "ZCA":
            whitening_mat = V @ torch.sqrt(torch.inverse(torch.diag(lda))) @ V.T

        return whitening_mat

    def _calc_cov(self, EEG_data):
        cov = []
        for i in range(EEG_data.size()[0]):
            cov.append(
                EEG_data[i] @ EEG_data[i].T / torch.trace(EEG_data[i] @ EEG_data[i].T)
            )

        cov = torch.mean(torch.stack(cov), 0)

        return cov

    def __call__(self, data: Tensor, label):
        whitened_data = self.W @ data
        return whitened_data, label

## Model

In [None]:
import sys

sys.path.insert(0, "../../")

In [None]:
import torch
import numpy as np
import torch.nn as nn
from cspnn.csp_nn import CSP, CSPNN

In [None]:
import torch.optim as optim

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import (
    roc_auc_score,
    precision_score,
    recall_score,
    accuracy_score,
    cohen_kappa_score,
)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

# from tqdm import tqdm
from tqdm.notebook import tqdm
import random

In [None]:
class SeparableConv2D(nn.Module):
    """https://github.com/seungjunlee96/Depthwise-Separable-Convolution_Pytorch/blob/master/DepthwiseSeparableConvolution/DepthwiseSeparableConvolution.py"""

    def __init__(
        self,
        in_channels,
        out_channels,
        depth_multiplier=1,
        kernel_size=3,
        padding="valid",
        bias=False,
    ):
        super(SeparableConv2D, self).__init__()
        self.depthwise = DepthwiseConv2d(
            in_channels, depth_multiplier, kernel_size, padding=padding, bias=bias
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class DepthwiseConv2d(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        depth_multiplier=1,
        kernel_size=3,
        stride=1,
        padding=0,
        dilation=1,
        bias=True,
        padding_mode="zeros",
    ):
        out_channels = in_channels * depth_multiplier
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias,
            padding_mode=padding_mode,
        )


class EEGNetv2(nn.Module):
    """
    Note: Use this class carefully. It is specificaly adapted with BCI Comp. IV 2a
        (22 channels, sample rate of 250, 3 seconds). So it's not a general class.
        it is developed only for experiment purposes.
    Implemented based on keras/tensorflow implementation found here:
        https://github.com/vlawhern/arl-eegmodels
    Paper:
        https://iopscience.iop.org/article/10.1088/1741-2552/aace8c
    """

    def __init__(
        self,
        num_classes=4,
        channels=22,
        dropout_rate=0.5,
        kernel_length=64,
        F1=8,
        D=2,
        F2=16,
    ):
        super(EEGNetv2, self).__init__()
        self.dropout_rate = dropout_rate

        # Layer 1
        self.conv1 = nn.Conv2d(
            num_classes, F1, (1, kernel_length), padding="valid", bias=False
        )
        self.batchnorm1 = nn.BatchNorm2d(F1, False)
        self.dwconv2 = DepthwiseConv2d(
            in_channels=F1,
            depth_multiplier=D,
            kernel_size=(channels, 1),
            stride=1,
            padding="valid",
            bias=False,
        )

        self.batchnorm2 = nn.BatchNorm2d(2 * F1, False)
        # act elu
        self.pooling1 = nn.AvgPool2d((1, 4))
        # dropout

        # Layer 2
        self.sepconv2 = SeparableConv2D(2 * F1, F2, 1, (1, 16), padding="same")
        self.batchnorm3 = nn.BatchNorm2d(F2, False)
        # elu
        self.pooling2 = nn.AvgPool2d((1, 8))
        # dropout

        # FC Layer
        self.fc1 = nn.Linear(16 * 111, num_classes)

    def _forward_emb(self, x, device=None):
        # Layer 1
        # print(f"{x.size()}")
        x = self.conv1(x)
        # print(f"{x.size()}")
        x = self.batchnorm1(x)

        x = self.dwconv2(x)
        # print(f"{x.size()}")
        x = self.batchnorm2(x)
        x = F.elu(x)
        x = self.pooling1(x)
        if self.training:
            x = F.dropout(x, self.dropout_rate)

        x = self.sepconv2(x)
        # print(f"{x.size()}")
        x = self.batchnorm3(x)
        x = F.elu(x)
        x = x.squeeze()
        x = self.pooling2(x)
        # print(f"{x.size()}")
        if self.training:
            x = F.dropout(x, self.dropout_rate)

        # FC Layer
        x = x.reshape((-1, 16 * 111))
        return x

    def forward(self, x, device=None):
        x = self._forward_emb(x)
        x = F.softmax(self.fc1(x), dim=1)
        return x


class CSPNNCls(nn.Module):
    def __init__(
        self,
        num_channels: int,
        num_features: int = None,
        num_bands: int = None,
        num_windows: int = 1,
        num_labels: int = None,
        csp_pow: bool = True,
        signal_len: int = None,
        mode: str = "constant",
        dropout_rate=0.5,
    ):
        super(CSPNNCls, self).__init__()
        self.num_channels = num_channels
        self.num_features = num_channels if num_features is None else num_features
        self.num_bands = num_bands
        self.num_windows = num_windows
        self.num_labels = num_labels
        self.csp_pow = csp_pow
        if not self.csp_pow:
            self.signal_len = signal_len
        self.mode = mode

        self.csp_nn = CSPNN(
            num_channels=num_channels,
            num_features=num_features,
            num_bands=num_bands,
            num_windows=num_windows,
            num_labels=num_labels,
            csp_pow=csp_pow,
            mode=self.mode,
        )

        if self.csp_pow:
            csp_feature_size = (
                self.num_bands * self.num_windows * self.num_labels * self.num_features
            )
        else:
            csp_feature_size = (
                self.num_bands
                * self.num_windows
                * self.num_labels
                * self.num_features
                * 51
            )

        self.eegnet = EEGNetv2(
            num_classes=self.num_labels,
            channels=self.num_channels,
            dropout_rate=dropout_rate,
            kernel_length=32,
            F1=8,
            D=2,
            F2=16,
        )

    def forward(self, x):
        csp = self.csp_nn(x)

        # features = csp.reshape(
        #     (csp.size()[0] * csp.size()[1], self.num_channels, -1)  # batch * labels
        # )
        print(csp.size())

        # features = csp.reshape(
        #     (csp.size()[0], self.num_channels, csp.size()[1], -1)  # batch * labels
        # )
        # features = features.permute(0, 2, 1, 3)

        features = csp

        x = self.eegnet(features)

        if self.training:
            return x, csp
        return x, csp

In [None]:
net = CSPNNCls(
    num_channels=68,
    num_features=68,
    num_bands=1,
    num_windows=1,
    num_labels=2,
    csp_pow=False,
    signal_len=3584,
    mode="csp",
)

### Test Model

In [None]:
transforms = [
    ToTensor(device="cpu"),
    ExpandDim(dim=0),
    ExpandDim(dim=2),
    ExpandDim(dim=2),
    LabelToDict(),
]
compose = Compose(transforms=transforms)

ds = ChoDataset(
    data_path="./cho/",
    patients=[1],
    transforms=compose,
)

In [None]:
ds[0][0].shape

In [None]:
for key, train_dataset, test_dataset in ds.get_train_test_subsets(with_key=True):

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=ChoDataset.collate_fn,
        num_workers=1,
    )
    transforms = [
        ToTensor(device="cpu"),  # "cuda"),
        ExpandDim(dim=0),
        Whitening(train_dataloader, whitening_method="ZCA"),
        ExpandDim(dim=2),
        ExpandDim(dim=2),
        LabelToDict(),
    ]
    train_dataset.transforms = Compose(transforms=transforms)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=ChoDataset.collate_fn,
        num_workers=1,  # os.cpu_count(),
    )
    for signals, labels in train_dataloader:
        preds, csp = net(signals)
        print(preds.size(), labels["label"].size())
        break