# Imports

In [262]:
!export WANDB_API_KEY=440f1ee11bc00f4104c20e09df4fd8a0f51d1924

In [239]:
%load_ext lab_black

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


In [280]:
from typing import Dict, Tuple, Union, List, Callable, Any

import os
import random
from tqdm import tqdm

import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn import model_selection
from sklearn import svm
from imblearn.ensemble import BalancedRandomForestClassifier
from xgboost import XGBClassifier
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, CategoricalNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score


import torch
from torch.utils.data import Dataset

import functools
from functools import partial
import itertools

from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor

import mne
import pywt


import wandb

try:
    from dask.distributed import Client, progress
    import dask.bag as db

    # Make it True
    using_dask = False
except:
    using_dask = False

# /Data

In [281]:
# /Data


class KlinikDataset(Dataset):
    def __init__(
        self,
        eeg_electrode_positions: Dict[str, Tuple[int, int]],
        data_path: str,
        meta_data=None,
        length=1,
        transforms=None,
    ):

        self.eeg_electrode_positions = eeg_electrode_positions
        self.data_path = data_path

        if meta_data is None:
            self.meta_data = pd.read_csv(os.path.join(self.data_path, "meta_data.csv"))
        else:
            self.meta_data = meta_data

        df_label_1 = self.meta_data[self.meta_data.label.eq(1)]
        df_label_0 = self.meta_data[self.meta_data.label.eq(0)]

        self.meta_data = shuffle(
            pd.concat(
                (
                    df_label_0.sample(n=len(df_label_1), replace=False),
                    df_label_1,
                ),
                axis=0,
            )
        )

        self.meta_data.reset_index(drop=True, inplace=True)

        self.labels = self.meta_data["label"]

        self.length = length
        self.transforms = transforms

    def __repr__(self):
        return "KlinikDataset"

    def get_class_distribution(self):
        return self.meta_data["label"].value_counts().to_list()

    def __len__(self):
        return len(self.meta_data)

    def __getitem__(self, idx: int) -> Union[dict, torch.Tensor]:

        meta_data = self.meta_data.iloc[idx]

        # shape -> (num_channels, n_times)
        # eeg_data = np.load(
        #     os.path.join(self.data_path, os.path.split(meta_data["file_pathe"])[1])
        # )

        eeg_data = np.load(os.path.join("/", meta_data["file_pathe"]))

        # get the first 21 channels
        eeg_data = eeg_data[:21, :]

        label = int(meta_data["label"])

        wav = {
            key: np.expand_dims(eeg_data[i], axis=0)
            for i, key in enumerate(self.eeg_electrode_positions.keys())
        }

        if self.transforms is not None:
            wav, label = self.transforms(wav, label)

        return wav, label

    def get_class(self, index):
        return self.meta_data.iloc[index]["label"]

    def subset(self, indices):
        return self.__class__(
            eeg_electrode_positions=self.eeg_electrode_positions,
            data_path=self.data_path,
            meta_data=self.meta_data.iloc[indices],
            length=self.length,
            transforms=self.transforms,
        )

    @staticmethod
    def collate_fn(batch):
        imgs = {
            key: torch.vstack([item[0][key].unsqueeze(0) for item in batch])
            for key in batch[0][0].keys()
        }
        trgts = torch.vstack([item[1] for item in batch]).squeeze()

        return [imgs, trgts]

## /audio_features

In [247]:
class ToTensor(object):
    def __init__(self, device):
        self.device = device

    def __call__(self, data, label):
        data = apply_to_collection(
            data,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: torch.tensor(a).float(),
        )
        label = apply_to_collection(
            label,
            dtype=(np.ndarray, int, float, np.int64),
            function=lambda a: torch.tensor(a).float(),
        )

        return data, label


class LabelToDict(object):
    def __call__(self, data, label):
        return data, {"label": label}

    def __repr__(self):
        return "LabelToDict"


class ZNorm(object):
    def __init__(
        self,
        stats: str,
        mode: str = "min-max",
        max_clip_val: int = 0,
        min_clip_val: int = None,
    ):
        self.stats_name = stats
        self.mode = mode
        self.min_clip_val = min_clip_val if min_clip_val is not None else -max_clip_val
        self.max_clip_val = max_clip_val
        with open(stats, "rb") as stats_f:
            self.stats = pickle.load(stats_f)

    def __call__(self, pkg: Tuple[Dict[str, Tensor], List[int]], target: Any):
        for k, st in self.stats.items():
            if k in pkg:
                if self.mode == "min-max":
                    minx = st["min"].unsqueeze(0).to(pkg[k].device)
                    maxx = st["max"].unsqueeze(0).to(pkg[k].device)
                    pkg[k] = (pkg[k] - minx) / (maxx - minx)
                if self.mode == "mean-std":
                    mean = st["mean"].unsqueeze(0).to(pkg[k].device)
                    std = st["std"].unsqueeze(0).to(pkg[k].device)
                    pkg[k] = (pkg[k] - mean) / std
                if self.max_clip_val > 0 or self.min_clip_val is not None:
                    pkg[k] = torch.clip(
                        pkg[k], min=self.min_clip_val, max=self.max_clip_val
                    )
            else:
                raise ValueError(f"couldn't find stats key {k} in package")
        return pkg, target


class Compose(object):
    def __init__(self, transforms: List[Callable]) -> None:
        self.transforms = transforms

    def __call__(self, data: Any, target: Any):
        for t in self.transforms:
            data, target = t(data, target)
        return data, target

    def __repr__(self):
        return "Compose(" + f"-".join(str(t) for t in self.transforms) + ")"


class FlattenChannel(object):
    def __call__(self, data: Any, target: Any):
        for i, worker in enumerate(target.keys()):

            if worker == "label":
                pass
            else:
                flatten_worker = []
                for j, ch in enumerate(target[worker].keys()):
                    flatten_worker.append(target[worker][ch])

                target[worker] = np.array(flatten_worker)
                # print(f"target flattened : {target[worker].shape}")

        return data, target

    def __repr__(self):
        return f"FlattenChannel"


class ConcatenateWorker(object):
    def __init__(self, transforms: List[Callable]) -> None:
        self.transforms = transforms

    def __call__(self, data: Any, target: Any):

        concatenate_worker = []

        for i, worker in enumerate(target.keys()):
            if worker == "label":
                pass
            else:
                # print(f"worker : {target[worker].shape}")
                concatenate_worker.append(target[worker])

        target["concat"] = np.concatenate(concatenate_worker, axis=-1)

        return data, target

    def __repr__(self):
        return f"ConcatenateWorker({'-'.join(str(t) for t in self.transforms)})"


class WTE(object):
    def __init__(self, level=4, wavelet="db1", name="wte"):
        self.level = level
        self.wavelet = wavelet
        self.name = name

    def __call__(
        self,
        data: Dict[str, np.ndarray],
        label: Dict[str, Union[Any, Dict[str, np.ndarray]]],
    ):

        label[self.name] = apply_to_collection(
            data,
            dtype=np.ndarray,
            function=partial(
                self.wavelet_transform_energy, level=self.level, wavelet=self.wavelet
            ),
        )

        return data, label

    def __repr__(self):
        attrs = "(level={}, wavelet={})".format(self.level, self.wavelet)
        return self.__class__.__name__ + attrs

    @staticmethod
    def wavelet_transform_energy(signal: np.ndarray, level: int, wavelet: str = "db1"):
        """calculates wavelet transform energy of a 1d signal

        Parameters
        ----------
        signal : numpy.ndarray
            raw signal. (eg. audio signal)
        level : int
            wavelet transform maximum level
        wavelet : str, optional
            wavelet type. one of the type available in :code:`pywt.wavelist()`,
             by default "db1"

        Returns
        -------
        numpy.ndarray
            The energy vector with shape of (level + 1,)

        Notes
        -----
        The WT energy can be calculated in different ways. Here we have implemented
        the method proposed in [1] equation (1) and [2] equation (2):

        .. math::
            \tilde{\mathbf{E}}_{\mathbf{V}_{j}} =
            \frac{\sum_{n} (\mathbf{w}_{j,n})^2}{\sum_{j=1}^{J_{max}} \sum_{n} (\mathbf{w}_{j,n})^2}

        Where :math:`\mathbf{w}_{j,n}` are the coefficients generated by DWT at the
        jth decomposition level.

        .. [1] K. Qian et al., “A bag of wavelet features for snore sound classification,”
           Ann. Biomed. Eng., vol. 47, no. 4, pp. 1000–1011, 2019.
        .. [2] Qian, K., C. Janott, Z. Zhang, C. Heiser, and B. Schuller.
           Wavelet features for classification of VOTE snore sounds.
           In: Proceedings of ICASSP, Shanghai, China, 2016, pp.221–225.
        """
        wt = pywt.wavedec(data=signal, wavelet=wavelet, mode="symmetric", level=level)

        ps_wt = np.array([np.sum(np.power(wt_j, 2)) for wt_j in wt])
        energy_vector = ps_wt / np.sum(ps_wt)

        return energy_vector


class WPTE(object):
    def __init__(
        self,
        level=4,
        wavelet="db1",
        include_raw=True,
        name="wpte",
    ):
        self.level = level
        self.wavelet = wavelet
        self.include_raw = include_raw
        self.name = name

    def __call__(
        self,
        data: Dict[str, np.ndarray],
        label: Dict[str, Union[Any, Dict[str, np.ndarray]]],
    ):

        label[self.name] = apply_to_collection(
            data,
            dtype=np.ndarray,
            function=partial(
                self.wavelet_packet_transform_energy,
                maxlevel=self.level,
                wavelet=self.wavelet,
                include_raw=self.include_raw,
            ),
        )

        return data, label

    def __repr__(self):
        attrs = "(level={}, wavelet={})".format(self.level, self.wavelet)
        return self.__class__.__name__ + attrs

    @staticmethod
    def wavelet_packet_energy(wpt_subspace: np.ndarray):
        """calculates the energy of a single subband from subspaces given by wpt

        Parameters
        ----------
        wpt_subspace : numpy.ndarray
            coefficients calculated by WPT from the signal at subspace V
            which is the kth subband at jth level.

        Returns
        -------
        int
            The energy of WPT sub space

        Notes
        -----
        The WPT energy can be calculated in different ways. Here we have implemented
        the method proposed in [1] equation (2):

        .. math:: \tilde{\mathbf{E}}_{\mathbf{V}_{j,k}} = log(\sqrt{\frac{\sum_{n=1}^{N_{j,k}} (\mathbf{w}_{j,k,n})^2}{N_{j,k}}})

        Where :math:`\mathbf{w}_{j,k,n}` represents the coefficients calculated by
        WPT from the signal at the subspace :math:`\mathbf{V}_{j,k}`.
        :math:`N_{j,k}` is the total number of wavelet coefficients in the kth subband
        at the jth level.

        .. [1] K. Qian et al., “A bag of wavelet features for snore sound classification,”
           Ann. Biomed. Eng., vol. 47, no. 4, pp. 1000–1011, 2019.
        """
        energy = np.log(
            np.sqrt(np.sum(np.power(wpt_subspace, 2) / wpt_subspace.shape[0]))
        )
        return energy

    @staticmethod
    def wavelet_packet_transform_energy(
        signal: np.ndarray,
        maxlevel: int,
        wavelet: str = "db1",
        include_raw: bool = True,
    ):
        """calculates wavelet packet transform energy of a 1d signal

        Parameters
        ----------
        signal : numpy.ndarray
            raw signal. (eg. audio signal)
        maxlevel : int
            wavelet packet transform maximum level
        wavelet : str, optional
            wavelet type. one of the types available in :code:`pywt.wavelist()`,
             by default "db1"

        Returns
        -------
        numpy.ndarray
            The energy vector with shape of (:math:`2^{maxlevel + 1} - 1`,)
        """
        energy_vector = []

        wp = pywt.WaveletPacket(
            data=signal, wavelet=wavelet, mode="symmetric", maxlevel=maxlevel
        )
        if include_raw:
            energy = WPTE.wavelet_packet_energy(signal)
            energy_vector.append(energy)

        for row in range(1, maxlevel + 1):
            for i in [node.path for node in wp.get_level(row, "freq")]:
                energy = WPTE.wavelet_packet_energy(wp[i].data)
                energy_vector.append(energy)

        return np.array(energy_vector)


class PSD(object):
    def __init__(
        self,
        sfreq=256,
        fmin=0,
        fmax=np.inf,
        n_fft=256,
        n_overlap=128,
        n_per_seg=256,
        average="mean",
        verbose=0,
        windowed=False,
        unit="Hz",  # Can be bin
        name="psd",
    ):
        self.sfreq = sfreq
        self.fmin = fmin
        self.fmax = fmax
        self.n_fft = n_fft
        self.n_overlap = n_overlap
        self.n_per_seg = n_per_seg
        self.average = average
        self.windowed = windowed
        self.unit = unit
        self.verbose = verbose
        self.name = name

    def __call__(
        self,
        data: Dict[str, np.ndarray],
        label: Dict[str, Union[Any, Dict[str, np.ndarray]]],
    ):

        label[self.name] = apply_to_collection(
            data,
            dtype=np.ndarray,
            function=partial(
                self.power_spectral_density,
                sfreq=self.sfreq,
                fmin=self.fmin,
                fmax=self.fmax,
                n_fft=self.n_fft,
                n_overlap=self.n_overlap,
                n_per_seg=self.n_per_seg,
                average=self.average,
                verbose=self.verbose,
                windowed=self.windowed,
                unit=self.unit,
            ),
        )

        return data, label

    def __repr__(self):
        attrs = "(sfreq={}, n_fft={}, n_overlap={}, n_per_seg={})".format(
            self.sfreq, self.n_fft, self.n_overlap, self.n_per_seg
        )
        return self.__class__.__name__ + attrs

    @staticmethod
    def power_spectral_density(
        signal: np.ndarray,
        sfreq: int = 256,
        fmin: float = 0,
        fmax: float = np.inf,
        n_fft: int = 256,
        n_overlap: int = 128,
        n_per_seg: int = 256,
        average: str = "mean",
        verbose: int = 0,
        windowed: bool = False,
        unit: str = "Hz",
    ):

        if windowed:
            average = None

        data, _ = mne.time_frequency.psd_array_welch(
            signal,
            sfreq=sfreq,
            fmin=fmin,
            fmax=fmax,
            n_fft=n_fft,
            n_overlap=n_overlap,
            n_per_seg=n_per_seg,
            average=average,
            verbose=verbose,
        )

        if not windowed:
            data = (10 * np.log10(data * sfreq / n_fft)).flatten()
            # data = data.flatten()

        # The Shape of data should be [trial, window, psds]
        # Handle Window data
        if unit == "bin" and windowed:
            data = np.apply_along_axis(
                lambda x: 10 * np.log10(x * sfreq / n_fft),
                axis=2,
                arr=np.transpose(data, (0, 2, 1)),
            )

        return data

In [248]:
def prepare_transforms(worker_configs):
    transforms = [LabelToDict()]
    worker_transform = []
    # transforms = []
    for worker in worker_configs:
        name = worker["name"]
        if name == "wte":
            transforms.append(WTE())
            worker_transform.append(WTE())
        elif name == "wpte":
            transforms.append(WPTE())
            worker_transform.append(WPTE())
        elif name == "psd":
            transforms.append(PSD())
            worker_transform.append(PSD())

    transforms.append(FlattenChannel())
    transforms.append(ConcatenateWorker(worker_transform))

    # transforms.append(ToTensor(device=torch.device("cpu")))

    return Compose(transforms)

# /Scripts

In [249]:
import numpy as np
from sklearn.base import ClassifierMixin


class BatchVotingClassifier(ClassifierMixin):
    def __init__(self, model, voting="hard"):
        super().__init__()

        self.model = model
        self.voting = voting

    def fit(self, x, y):
        X = np.vstack([np.vstack(batch) for batch in x])
        Y = np.array([y[i] for i, batch in enumerate(x) for j in range(len(batch))])

        self.model.fit(X, Y)

    def predict(self, X):
        res_maj = np.zeros((len(X),))

        for i, batch in enumerate(X):
            batch = np.vstack(batch)
            # code chunk from sklearn:
            # https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09bcc2eaeba98f7e737aac2ac782f0e5f1/sklearn/ensemble/_voting.py#L340
            if self.voting == "soft":
                maj = np.argmax(self.model.predict_proba(batch).reshape(-1, 1), axis=1)

            else:  # 'hard' voting
                predictions = self.model.predict(batch)
                predictions = predictions.reshape(1, -1)
                maj = np.apply_along_axis(
                    lambda x: np.argmax(np.bincount(x)),
                    axis=1,
                    arr=predictions,
                )
            res_maj[i] = maj

        return res_maj

    def predict_proba(self, x):
        raise NotImplementedError()

In [278]:
def fit_eval(clf, X, Y, train_idx, test_idx, dataset):
    X_train, X_test = X[train_idx], X[test_idx]
    Y_train, Y_test = Y[train_idx], Y[test_idx]

    X_train = X_train.reshape(X_train.shape[0], -1)
    X_test = X_test.reshape(X_test.shape[0], -1)
    print(f"The data shape : {X_train.shape}")

    print("*** Training the Model...")
    clf.fit(X_train, Y_train)
    print("Done.")

    # classes = list(map(str, dataset.classes))

    print("Prediction on Train data")
    preds = clf.predict(X_train)
    print("Done.")
    print(
        classification_report(
            Y_train,
            preds,
            # labels=[i for i in range(len(classes))],
            # target_names=classes,
            zero_division=0,
        )
    )
    temp_log = {}
    temp_log["train accuracy"] = accuracy_score(preds, Y_train)
    temp_log["train F1 (micro)"] = f1_score(preds, Y_train, average="micro")
    temp_log["train F1 (macro)"] = f1_score(preds, Y_train, average="macro")

    print("Prediction on Test data")
    preds = clf.predict(X_test)
    print("Done.")
    print(
        classification_report(
            Y_test,
            preds,
            # labels=[i for i in range(len(classes))],
            # target_names=classes,
            zero_division=0,
        )
    )
    temp_log["test accuracy"] = accuracy_score(preds, Y_test)
    temp_log["test F1 (micro)"] = f1_score(preds, Y_test, average="micro")
    temp_log["test F1 (macro)"] = f1_score(preds, Y_test, average="macro")

    if len(temp_log) > 0:
        wandb.log(temp_log)


# wandb.sklearn.plot_confusion_matrix(Y_test, preds, classes)
# wandb.sklearn.plot_summary_metrics(clf, X_train, Y_train, X_test, Y_test)

In [253]:
def get_splitter(splitter="", n_splits=None):
    cv = getattr(model_selection, splitter)

    if splitter == "train_test_split":

        def wrapper(x, y, group):
            print(f"X : {len(x)}")
            print(f" Y : {len(y)}")
            print(f"Cross Validation Type : {type(cv)}")

            train_idx, test_idx = cv(
                x,
                stratify=y,
                test_size=0.2,
            )

            yield train_idx, test_idx

        return wrapper

    if splitter in ["LeaveOneGroupOut"]:
        return cv().split

    if splitter in ["GroupKFold", "StratifiedGroupKFold", "StratifiedKFold"]:
        return cv(n_splits=n_splits if n_splits is not None else 4).split

    return None


def select_dataset(opts, transform=None):
    if opts["dataset"] == "KlinikDataset":
        dataset = KlinikDataset(
            eeg_electrode_positions=opts["eeg_electrode_positions"],
            data_path=opts["data_path"],
            meta_data=None,
            length=opts["length"],
            transforms=transform,
        )

    return dataset


def select_model(opts):
    if opts["model"] == "svm":
        clf = svm.SVC(decision_function_shape="ovo", probability=True)
    elif opts["model"] == "BalancedRandomForestClassifier":
        clf = BalancedRandomForestClassifier(
            n_estimators=100, max_depth=4, n_jobs=os.cpu_count()
        )
    elif opts["model"] == "XGBoost":
        params = opts.copy()
        params.pop("model", None)
        clf = XGBClassifier(use_label_encoder=False, **params)
    elif opts["model"] == "GaussianNaiveBayes":
        clf = GaussianNB()
    elif opts["model"] == "K-NearestNeighbores":
        clf = KNeighborsClassifier()
    else:
        print(f"unknown model type {opts['model']}")
        clf = None

    return clf


def extract_sample_and_transform(args):
    dataset, idx = args
    sample_w, sample_l = dataset[idx]

    return sample_w, sample_l, idx


def find_nan_idx(x):
    nan_idx = np.argwhere(np.isnan(x))
    trials = np.array(list(set([i[0] for i in nan_idx])))
    mask_trials = np.zeros(x.shape[0], dtype=bool)
    mask_trials[trials] = True

    return mask_trials


def find_inf_idx(x):
    inf_idx = np.argwhere(np.isinf(x))
    trials = np.array(list(set([i[0] for i in inf_idx])))
    print(f"{trials = }")
    mask_trials = np.zeros(x.shape[0], dtype=bool)
    mask_trials[trials] = True

    return mask_trials


def prepare_data(dataset, extract_sample_and_transform, name="data"):

    X = []
    Y = []

    if not using_dask:
        for i in tqdm(range(len(dataset))):
            sample_w, sample_l, _ = extract_sample_and_transform((dataset, i))
            X.append(sample_l["concat"])
            Y.append(sample_l["label"])

        X = np.array(X)
        Y = np.array(Y)

        # Delete Nan from data
        if np.isnan(np.sum(X)):
            mask = find_nan_idx(X)
            print(f"mask true nan: {mask[mask == True].shape}")
            X = X[~mask]
            Y = Y[~mask]

        # Delete inf from data
        if np.isinf(np.sum(X)):
            mask = find_inf_idx(X)
            print(f"mask true inf: {mask[mask == True].shape}")
            X = X[~mask]
            Y = Y[~mask]

    return X, Y


def train(opts):
    transform = prepare_transforms(opts["workers"])
    dataset = select_dataset(opts, transform=transform)
    print(f"Dataset : {dataset}, Transformer : {transform}")

    classes = np.array([dataset.get_class(i) for i in range(len(dataset))])

    groups = np.array([dataset.get_class(i) for i in range(len(dataset))])

    X, Y = prepare_data(
        dataset,
        extract_sample_and_transform,
        # name=f"{opts['transform']}_{opts['dataset']}",
    )

    print(f"X: {X.shape}, Y : {Y.shape}")

    for train_idx, test_idx in get_splitter(
        splitter=opts["data_splitter"], n_splits=opts["n_splits"]
    )(list(range(len(X))), Y, Y):

        wandb_run = wandb.init(
            project=opts["proj"], group=opts["wb_group"], reinit=True
        )
        wandb.config.update(opts)

        print(f"Train Samples : {len(train_idx)}")
        print(f"Test Samples : {len(test_idx)}")

        clf = select_model(opts)
        print(f"Model : {clf}")
        # clf = BatchVotingClassifier(clf)
        fit_eval(clf, X, Y, train_idx, test_idx, dataset)

In [259]:
# /Config

eeg_electrode_positions = {
    "Fp1": (-1, -2),
    "Fp2": (1, -2),
    "F3": (-1, -1),
    "F4": (1, -1),
    "C3": (-1, 0),
    "C4": (1, 0),
    "P3": (-1, 1),
    "P4": (1, 1),
    "O1": (-1, 2),
    "O2": (1, 2),
    "F7": (-2, -1),
    "F8": (2, -1),
    "T3": (-2, 0),
    "T4": (2, 0),
    "T5": (-2, 1),
    "T6": (2, 1),
    "A1": (-3, 0),
    "A2": (3, 0),
    "Fz": (0, -1),
    "Cz": (0, 0),
    "Pz": (0, 1),
}

eeg_electrods_plane_shape = (5, 7)

In [270]:
conf = {
    "dataset": "KlinikDataset",
    "eeg_electrode_positions": eeg_electrode_positions,
    "data_path": "/data/",
    "length": 1,
    "n_splits": 5,
    "proj": "baseline_models_eeg",
}

In [271]:
configs = {
    "transforms": [
        {
            # "workers": [{"name": "psd"}, {"name": "wte"}],
            "workers": [{"name": "psd"}, {"name": "wte"}],
        },
        # {
        #     "transform": "PASE_orig_pt",
        #     "fe_cfg": "/usr/src/app/cfg/frontend/PASE+.cfg",
        #     "fe_ckpt": "/experiments/pase_original_chpt/FE_e199.ckpt",
        # },
        # {"transform": "scat"},
    ],
    "models": [
        {"model": "BalancedRandomForestClassifier"},
        {"model": "XGBoost", "objective": "binary:logistic"},  # "multi:softprob"
        {"model": "GaussianNaiveBayes"},
        {"model": "K-NearestNeighbores"},
        {"model": "svm"},
    ],
    "data_splitters": [{"data_splitter": "train_test_split"}],
}

In [279]:
for model, transform, splitter in itertools.product(
    configs["models"], configs["transforms"], configs["data_splitters"]
):

    temp_conf = conf.copy()
    temp_conf.update(model)
    temp_conf.update(transform)
    temp_conf.update(splitter)
    temp_conf[
        "wb_group"
    ] = f"{model['model']}_{'_'.join(worker['name'] for worker in transform['workers'])}_{splitter['data_splitter']}"

    # print(temp_conf)

    train(temp_conf)

Dataset : KlinikDataset, Transformer : Compose(LabelToDict-PSD(sfreq=256, n_fft=256, n_overlap=128, n_per_seg=256)-WTE(level=4, wavelet=db1)-FlattenChannel-ConcatenateWorker(PSD(sfreq=256, n_fft=256, n_overlap=128, n_per_seg=256)-WTE(level=4, wavelet=db1)))


  data = (10 * np.log10(data * sfreq / n_fft)).flatten()
  energy_vector = ps_wt / np.sum(ps_wt)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 848/848 [00:15<00:00, 55.53it/s]

mask true nan: (27,)
trials = array([130, 518, 136, 648, 649, 779,  12, 792, 153, 666, 155, 286, 673,
       546,  41, 425, 442, 444, 714, 216,  95, 224, 104, 616, 363, 368])
mask true inf: (26,)
X: (795, 21, 134), Y : (795,)
X : 795
 Y : 795
Cross Validation Type : <class 'function'>





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Train Samples : 636
Test Samples : 159
Model : BalancedRandomForestClassifier(max_depth=4, n_jobs=6)
The data shape : (636, 2814)
*** Training the Model...
Done.
Prediction on Train data
Done.
              precision    recall  f1-score   support

           0       0.99      0.96      0.98       314
           1       0.96      0.99      0.98       322

    accuracy                           0.98       636
   macro avg       0.98      0.98      0.98       636
weighted avg       0.98      0.98      0.98       636

Prediction on Test data
Done.
              precision    recall  f1-score   support

           0       0.92      0.94      0.93        78
           1       0.94      0.93      0.93        81

    accuracy                           0.93       159
   macro avg       0.93      0.93      0.93       159
weighted avg       0.93      0.93      0.93       159

Dataset : KlinikDataset, Transformer : Compose(LabelToDict-PSD(sfreq=256, n_fft=256, n_overlap=128, n_per_seg=256)-WTE(leve

  data = (10 * np.log10(data * sfreq / n_fft)).flatten()
  energy_vector = ps_wt / np.sum(ps_wt)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 848/848 [00:15<00:00, 56.13it/s]

mask true nan: (26,)
trials = array([  1, 513, 641, 521, 147, 660, 149, 533, 285, 157, 286, 415, 672,
       551, 174, 303, 306, 437, 188, 195,  73, 721, 219,  91, 228, 363,
       626, 754, 247])
mask true inf: (29,)
X: (793, 21, 134), Y : (793,)
X : 793
 Y : 793
Cross Validation Type : <class 'function'>





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test F1 (macro),▁
test F1 (micro),▁
test accuracy,▁
train F1 (macro),▁
train F1 (micro),▁
train accuracy,▁

0,1
test F1 (macro),0.93081
test F1 (micro),0.93082
test accuracy,0.93082
train F1 (macro),0.9764
train F1 (micro),0.97642
train accuracy,0.97642


[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Train Samples : 634
Test Samples : 159
Model : XGBClassifier(base_score=None, booster=None, colsample_bylevel=None,
              colsample_bynode=None, colsample_bytree=None, data_path='/data/',
              data_splitter='train_test_split', dataset='KlinikDataset',
              eeg_electrode_positions={'A1': (-3, 0), 'A2': (3, 0),
                                       'C3': (-1, 0), 'C4': (1, 0),
                                       'Cz': (0, 0), 'F3': (-1, -1),
                                       'F4': (1, -1), 'F7': (-2, -1),
                                       'F8': (2, -1), 'Fp1': (-1, -2),
                                       'Fp2': (1, -2), 'Fz':...
              enable_categorical=False, gamma=None, gpu_id=None,
              importance_type=None, interaction_constraints=None,
              learning_rate=None, length=1, max_delta_step=None, max_depth=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              n_estimators=100, 

  data = (10 * np.log10(data * sfreq / n_fft)).flatten()
  energy_vector = ps_wt / np.sum(ps_wt)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 848/848 [00:15<00:00, 55.77it/s]

mask true nan: (29,)
trials = array([130,   5, 393, 522, 396, 151, 797, 422, 554, 813, 304, 560, 435,
        59,  61, 710, 457, 338, 471, 347, 356, 490, 116, 633])
mask true inf: (24,)
X: (795, 21, 134), Y : (795,)
X : 795
 Y : 795
Cross Validation Type : <class 'function'>





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test F1 (macro),▁
test F1 (micro),▁
test accuracy,▁
train F1 (macro),▁
train F1 (micro),▁
train accuracy,▁

0,1
test F1 (macro),0.98742
test F1 (micro),0.98742
test accuracy,0.98742
train F1 (macro),1.0
train F1 (micro),1.0
train accuracy,1.0


[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Train Samples : 636
Test Samples : 159
Model : GaussianNB()
The data shape : (636, 2814)
*** Training the Model...
Done.
Prediction on Train data
Done.
              precision    recall  f1-score   support

           0       0.76      0.93      0.83       314
           1       0.91      0.71      0.80       322

    accuracy                           0.82       636
   macro avg       0.83      0.82      0.82       636
weighted avg       0.83      0.82      0.82       636

Prediction on Test data
Done.
              precision    recall  f1-score   support

           0       0.79      0.94      0.86        78
           1       0.93      0.77      0.84        81

    accuracy                           0.85       159
   macro avg       0.86      0.85      0.85       159
weighted avg       0.86      0.85      0.85       159

Dataset : KlinikDataset, Transformer : Compose(LabelToDict-PSD(sfreq=256, n_fft=256, n_overlap=128, n_per_seg=256)-WTE(level=4, wavelet=db1)-FlattenChannel-Concaten

  data = (10 * np.log10(data * sfreq / n_fft)).flatten()
  energy_vector = ps_wt / np.sum(ps_wt)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 848/848 [00:15<00:00, 55.69it/s]

mask true nan: (26,)
trials = array([512, 388, 637, 141, 399,  19, 284, 159, 423,  48,  54, 698, 443,
       316, 324, 709, 582, 461, 208, 473, 346, 606, 631, 110, 371, 372,
       119, 509])
mask true inf: (28,)
X: (794, 21, 134), Y : (794,)
X : 794
 Y : 794
Cross Validation Type : <class 'function'>





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test F1 (macro),▁
test F1 (micro),▁
test accuracy,▁
train F1 (macro),▁
train F1 (micro),▁
train accuracy,▁

0,1
test F1 (macro),0.84833
test F1 (micro),0.84906
test accuracy,0.84906
train F1 (macro),0.81586
train F1 (micro),0.81761
train accuracy,0.81761


[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Train Samples : 635
Test Samples : 159
Model : KNeighborsClassifier()
The data shape : (635, 2814)
*** Training the Model...
Done.
Prediction on Train data
Done.
              precision    recall  f1-score   support

           0       1.00      0.93      0.97       313
           1       0.94      1.00      0.97       322

    accuracy                           0.97       635
   macro avg       0.97      0.97      0.97       635
weighted avg       0.97      0.97      0.97       635

Prediction on Test data
Done.
              precision    recall  f1-score   support

           0       1.00      0.92      0.96        78
           1       0.93      1.00      0.96        81

    accuracy                           0.96       159
   macro avg       0.97      0.96      0.96       159
weighted avg       0.96      0.96      0.96       159

Dataset : KlinikDataset, Transformer : Compose(LabelToDict-PSD(sfreq=256, n_fft=256, n_overlap=128, n_per_seg=256)-WTE(level=4, wavelet=db1)-FlattenChanne

  data = (10 * np.log10(data * sfreq / n_fft)).flatten()
  energy_vector = ps_wt / np.sum(ps_wt)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 848/848 [00:14<00:00, 56.54it/s]

mask true nan: (33,)
trials = array([260, 134, 135, 523, 273, 402,  19, 538, 411, 156,  29, 670, 287,
       794, 551, 174, 433, 693, 704, 321, 582, 466,  94, 734, 103, 752,
       242, 373, 374, 249, 507, 126])
mask true inf: (32,)
X: (783, 21, 134), Y : (783,)
X : 783
 Y : 783
Cross Validation Type : <class 'function'>





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test F1 (macro),▁
test F1 (micro),▁
test accuracy,▁
train F1 (macro),▁
train F1 (micro),▁
train accuracy,▁

0,1
test F1 (macro),0.96214
test F1 (micro),0.96226
test accuracy,0.96226
train F1 (macro),0.96686
train F1 (micro),0.96693
train accuracy,0.96693


[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Train Samples : 626
Test Samples : 157
Model : SVC(decision_function_shape='ovo', probability=True)
The data shape : (626, 2814)
*** Training the Model...
Done.
Prediction on Train data
Done.
              precision    recall  f1-score   support

           0       0.96      0.92      0.94       304
           1       0.93      0.97      0.95       322

    accuracy                           0.94       626
   macro avg       0.94      0.94      0.94       626
weighted avg       0.94      0.94      0.94       626

Prediction on Test data
Done.
              precision    recall  f1-score   support

           0       0.89      0.93      0.91        76
           1       0.94      0.89      0.91        81

    accuracy                           0.91       157
   macro avg       0.91      0.91      0.91       157
weighted avg       0.91      0.91      0.91       157

