In [None]:
import sys
import random

import mne
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
import math
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score
import importlib
from einops import rearrange, reduce, repeat
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
%matplotlib inline

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def unfold_output_size(n_times, patches, step):
    return (n_times - patches) / step + 1

In [None]:
from modules.BciDataHandler import BciDataHandler

data_handler = BciDataHandler()
data_handler.instantiate_dataset()

In [None]:
# ------------------------------ bci competition dataset ------------------------------
all_subject_epochs = mne.concatenate_epochs(list(data_handler.subjects_epochs.values()))
all_labels = all_subject_epochs.events[:, -1] - 1

epochs = data_handler.subjects_epochs[1]
labels = np.array(data_handler.subjects_labels[1]) - 1

# epochs = all_subject_epochs
# labels = all_labels
# labels

In [None]:
# -------------------------------- ufjf dataset --------------------------------------
# from modules.EdfHandler import EdfHandler
#
# epochs, labels = EdfHandler.getAllData(["C:\\Users\\davi2\Desktop\\bci\\datasets_ufjf\\bci\\001.edf"])
# epochs = epochs[0]
# labels = np.array(labels[0])
# labels[labels == 6] = 0

In [None]:
#----------------------- physionet dataset -------------------------------------
# import mne
# from mne import Epochs, pick_types, events_from_annotations
# from mne.channels import make_standard_montage
# from mne.io import concatenate_raws, read_raw_edf
# from mne.datasets import eegbci
#
#
# #############################################################################
# # Set parameters and read data
#
# # avoid classification of evoked responses by using epochs that start 1s after
# # cue onset.
# tmin, tmax = -1., 4.
# event_id = dict(handsOrLeft=2, feetOrRight=3)
#
# def get_physionet_data(subject, runs):
#
#     raw_fnames = eegbci.load_data(subject, runs)
#     raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
#     eegbci.standardize(raw)  # set channel names
#     montage = make_standard_montage('standard_1005')
#     raw.set_montage(montage)
#
#     # Apply band-pass filter
#     raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge')
#
#     events, _ = events_from_annotations(raw)
#
#     picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
#                        exclude='bads')
#
#     # Read epochs (train will be done only between 1 and 2s)
#     # Testing will be done with a running classifier
#     epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
#                     baseline=None, preload=True)
#
#     epochs_data = epochs.copy().crop(tmin=1., tmax=2.)
#
#     labels = epochs.events[:, -1] - 2
#
#     return epochs_data, labels
#
#
# # [6, 10, 14] hands vs feet
# # [4, 8, 12] left vs right hand
# X_hf, y_hf = get_physionet_data(subject=1, runs=[6, 10, 14])
# X_lr, y_lr = get_physionet_data(subject=1, runs=[4, 8, 12])
#
# epochs = mne.concatenate_epochs([X_hf, X_lr])
# labels = np.concatenate([y_hf, y_lr+2])

In [None]:
# def plot_psd(data, axis, label, color):
#     psds, freqs = mne.time_frequency.psd_array_multitaper(data, sfreq=sfreq,
#                                                           fmin=0.1, fmax=100)
#     psds = 10. * np.log10(psds)
#     psds_mean = psds.mean(0).mean(0)
#     axis.plot(freqs, psds_mean, color=color, label=label)
#
#
# _, ax = plt.subplots()
# plot_psd(X, ax, 'original', 'k')
# plot_psd(X_tr.numpy(), ax, 'shifted', 'r')
#
# ax.set(title='Multitaper PSD (gradiometers)', xlabel='Frequency (Hz)',
#        ylabel='Power Spectral Density (dB)')
# ax.legend()
# plt.show()

In [None]:
def exists(val):
    return val is not None

class MyViTransformerWrapper(nn.Module):
    def __init__(
            self,
            *,
            image_size,
            patch_size,
            attn_layers,
            channels,
            num_classes = None,
            dropout = 0.,
            post_emb_norm = False,
            emb_dropout = 0.
    ):
        super().__init__()
        assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        dim = attn_layers.dim
        num_patches = (image_size // patch_size)
        patch_dim = channels * patch_size

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))

        self.patch_to_embedding = nn.Sequential(
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
        self.dropout = nn.Dropout(emb_dropout)

        self.attn_layers = attn_layers
        self.norm = nn.LayerNorm(dim)
        self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()

    def forward(
            self,
            img,
            return_embeddings = False
    ):
        p = self.patch_size
        img = img.unsqueeze(-2)
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = 1, p2 = p)
        x = self.patch_to_embedding(x)
        n = x.shape[1]

        x = x + self.pos_embedding[:, :n]

        x = self.post_emb_norm(x)
        x = self.dropout(x)

        x = self.attn_layers(x)
        x = self.norm(x)

        if not exists(self.mlp_head) or return_embeddings:
            return x

        x = x.mean(dim = -2)
        return self.mlp_head(x)

In [None]:
import math

import torch
from torch import nn
import torch.nn.functional as F
import torch
from x_transformers import TransformerWrapper, Encoder, ViTransformerWrapper


class DepthWiseConv2d(nn.Module):
    def __init__(self, in_channels, kernel_size, kernels_per_layer, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels=in_channels, out_channels=in_channels * kernels_per_layer,
                                   kernel_size=kernel_size, groups=in_channels, bias=bias, padding='same')

    def forward(self, x):
        return self.depthwise(x)


class PointWiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernels_per_layer=1, bias=False):
        super().__init__()
        self.pointwise = nn.Conv2d(in_channels=in_channels * kernels_per_layer, out_channels=out_channels,
                                   kernel_size=(1, 1), bias=bias, padding="valid")

    def forward(self, x):
        return self.pointwise(x)


class MaxNormLayer(nn.Linear):
    def __init__(self, in_features, out_features, max_norm=1.0):
        super(MaxNormLayer, self).__init__(in_features=in_features, out_features=out_features)
        self.max_norm = max_norm

    def forward(self, x):
        if self.max_norm is not None:
            with torch.no_grad():
                self.weight.data = torch.renorm(
                    self.weight.data, p=2, dim=0, maxnorm=self.max_norm
                )
        return super(MaxNormLayer, self).forward(x)


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, kernels_per_layer=1, bias=False):
        super().__init__()
        self.depthwise = DepthWiseConv2d(in_channels=in_channels, kernels_per_layer=kernels_per_layer,
                                         kernel_size=kernel_size, bias=bias)
        self.pointwise = PointWiseConv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernels_per_layer=kernels_per_layer, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class ViewConv(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.view((x.shape[0], x.shape[1], 1, x.shape[2]))



class FeatureExtraction(nn.Module):
    def __init__(
            self,
            n_channels,
            kernel_length,
            F1,
            D,
            F2,
            pool1_stride,
            pool2_stride,
    ):
        super().__init__()
        self.net = nn.Sequential(
            ViewConv(),
            nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False,
                      padding='same'),
            nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False),
            DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False),
            nn.BatchNorm2d(num_features=F1 * D, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride),
            SeparableConv2d(in_channels=F1 * D, kernel_size=(1, 16), out_channels=F2, bias=False),
            nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
from einops.layers.torch import Rearrange, Reduce

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size, signal_size):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=(1, patch_size), stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((signal_size // patch_size) + 1, emb_size))


    def forward(self, x):

        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

In [None]:
class EEGNET(nn.Module):
    def __init__(
            self,
            n_times,
            n_classes,
            n_channels,
            patches_size=9,
            embedding_dim=32,
            transformer_ffd=128,
            transformer_layers=2,
            transformer_heads=2,
            dropout_rate=0.1,
            max_norm=0.25,
            F1=8,
            F2=16,
            D=2,
            pool1_stride=4,
            pool2_stride=8,
            kernel_length=4,
            n_hidden=32,
    ):
        super().__init__()

        self.feature_extraction_output = F2 * ((((n_times - pool1_stride) // pool1_stride + 1) - pool2_stride) // pool2_stride + 1)

        self.feature_extraction = nn.Sequential(
            FeatureExtraction(n_channels=n_channels, kernel_length=kernel_length, F1=F1, D=D, F2=F2, pool1_stride=pool1_stride, pool2_stride=pool2_stride),
            nn.Linear(in_features=self.feature_extraction_output, out_features=n_hidden),
            nn.ELU(),
            nn.Dropout(dropout_rate)
        )

        self.transformer = nn.Sequential(
            MyViTransformerWrapper(
                image_size = n_times,
                patch_size = patches_size,
                channels=n_channels,
                attn_layers = Encoder(
                    dim = transformer_ffd,
                    depth = transformer_layers,
                    heads = transformer_heads
                )
            ),
            nn.Linear(in_features=transformer_ffd, out_features=n_hidden),
            nn.ELU(),
            nn.Dropout(dropout_rate),
        )

        self.head = nn.Sequential(
            nn.Linear(in_features=n_hidden*2, out_features=n_classes),
        )

    def forward(self, x, targets):
        out_values = {}
        feature_extraction_result = self.feature_extraction(x)
        transformer_result = self.transformer(x)
        logits = self.head(torch.cat([
            feature_extraction_result,
            transformer_result
        ], dim=-1))

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits, targets)

        return logits, loss, out_values

In [None]:
#data augmentation
from braindecode.augmentation import FrequencyShift
from braindecode.augmentation import GaussianNoise

sfreq = epochs.info['sfreq']

freq_shift = FrequencyShift(
    probability=0.5,  # defines the probability of actually modifying the input
    sfreq=sfreq,
    max_delta_freq=2.  # the frequency shifts are sampled now between -2 and 2 Hz
)

gauss_noise = GaussianNoise(
    probability=0.5,
    std=0.01
)

transforms = {
    'freq': freq_shift,
    'gauss': gauss_noise
}

In [None]:
data = epochs.get_data()
X = torch.tensor(data).to(dtype=torch.float32, device=device)
y = torch.tensor(labels).to(dtype=torch.long, device=device)

In [None]:
model = EEGNET(n_times=X.shape[-1], n_channels=len(epochs.picks), n_classes=len(set(labels)))
model = model.to(device=device)

In [None]:
seed = 1330
splits = 5
lr=3e-3

skf = StratifiedKFold(n_splits=splits, random_state=seed, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.3)

In [1211]:
from modules.TrainTester import TrainerTester

ud = []

#main-trianing-loop
for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index],
    y_train, y_test = y[train_index], y[test_index]

    TrainerTester.train_loop(model, optimizer, X_train, y_train, X_test, y_test, lr, ud, batch_size=16, iterations=2000)

    out_values = TrainerTester.test_loop(model, X_test, y_test)
    break

KeyboardInterrupt: 

In [None]:
#params evaluation
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, np.prod(param.size()))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('total: ', params)

In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [None]:
meta_step_size = 0.25

meta_iters = 1000

eval_interval = 1
train_shots = 40
eval_shots = 4
classes = len(set(labels))

batch_size = 1
#obs: total shots = classes * shots

n_times=X.shape[-1]
n_channels=len(epochs.picks)

seed = 1330
splits = 5
lr=1e-3

skf = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True)

train_index, test_index = skf.split(X, y).__next__()
X_train, X_test = X[train_index], X[test_index],
y_train, y_test = y[train_index], y[test_index]

In [None]:
class Dataset:
    def __init__(self, training):
        split = "train" if training else "test"

        if split:
            X_dataset, y_dataset = X_test, y_test
        else:
            X_dataset, y_dataset = X_train, y_train

        self.data = {}

        for value, label in zip(X_dataset, y_dataset):
            if label not in self.data:
                self.data[label] = []
            self.data[label].append(value)
        self.labels = list(self.data.keys())

    def get_mini_dataset(self, shots, num_classes, split=False):
        temp_labels = torch.zeros((num_classes * shots))
        temp_X = torch.zeros((num_classes * shots, n_channels, n_times))
        if split:
            test_labels = torch.zeros((num_classes * eval_shots))
            test_X = torch.zeros((num_classes * eval_shots, n_channels, n_times))

        # Get a random subset of labels from the entire label set.
        label_subset = random.choices(self.labels, k=num_classes)
        for class_idx, class_obj in enumerate(label_subset):
            # Use enumerated index value as a temporary label for mini-batch in
            # few shot learning.
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
            # If creating a split dataset for testing, select an extra sample from each
            # label to create the test dataset.
            if split:
                test_labels[class_idx] = class_idx
                X_to_split = torch.stack(random.choices(self.data[label_subset[class_idx]], k=shots + 1))
                test_X[class_idx] = X_to_split[-1]
                temp_X[class_idx * shots : (class_idx + 1) * shots] = X_to_split[:-1]
            else:
                # For each index in the randomly selected label_subset, sample the
                # necessary number of images.
                temp_X[class_idx * shots : (class_idx + 1) * shots] = \
                    torch.stack(random.choices(self.data[label_subset[class_idx]], k=shots))

        temp_X, temp_labels = unison_shuffled_copies(temp_X, temp_labels)
        temp_X, temp_labels = torch.stack(temp_X.chunk(batch_size)), torch.stack(temp_labels.chunk(batch_size))
        dataset = zip(temp_X, temp_labels)

        if split:
            test_X, test_labels = unison_shuffled_copies( test_X, test_labels)
            return dataset, test_X, test_labels
        return dataset

train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

In [None]:
training = []
testing = []
for meta_iter in range(meta_iters):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    frac_done = meta_iter / meta_iters

    cur_meta_step_size = (1 - frac_done) * meta_step_size

    old_vars = model.state_dict()

    mini_dataset = train_dataset.get_mini_dataset(
        train_shots, classes
    )

    for X_values, y_labels in mini_dataset:
        y_labels = y_labels.to(dtype=torch.long)
        preds, loss, out_values = model(X_values, y_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    new_vars = model.state_dict()

    for key, var in new_vars.items():
        new_vars[key] = old_vars[key] + ((new_vars[key] - old_vars[key]) * 0.1)

    model.load_state_dict(new_vars)

    # Evaluation loop
    if meta_iter % eval_interval == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # print("test dataset reset!\n")
            # Sample a mini dataset from the full dataset.
            train_set, test_X, test_labels = dataset.get_mini_dataset(
                eval_shots, classes, split=True
            )
            old_vars = model.state_dict()

            for X_values, y_labels in train_set:
                y_labels = y_labels.to(dtype=torch.long)

                preds, test_loss, out_values = model(X_values, y_labels)

                optimizer.zero_grad()
                test_loss.backward()
                optimizer.step()

            test_labels = test_labels.to(dtype=torch.long)
            test_preds, test_loss, test_out_values = model(test_X, test_labels)

            accuracy = (test_preds.argmax(1) == test_labels).type(torch.float32).sum().item() / test_labels.shape[0]
            accuracies.append(accuracy)

            model.load_state_dict(old_vars)

        training.append(accuracies[0])
        testing.append(accuracies[1])

        if meta_iter % 100 == 0:
            print(f"batch {meta_iter}: train={accuracies[0]} test={accuracies[1]}")

In [None]:
# First, some preprocessing to smooth the training and testing arrays for display.
window_length = 100
train_s = np.r_[
    training[window_length - 1 : 0 : -1], training, training[-1:-window_length:-1]
]
test_s = np.r_[
    testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
]
w = np.hamming(window_length)
train_y = np.convolve(w / w.sum(), train_s, mode="valid")
test_y = np.convolve(w / w.sum(), test_s, mode="valid")

# Display the training accuracies.
x = np.arange(0, len(test_y), 1)
plt.plot(x, test_y, x, train_y)
plt.legend(["test", "train"])
plt.grid()

In [None]:
# model.load_state_dict(torch.load('model_states/test_model_states.txt'))
# torch.save(model.state_dict(), 'model_states/test_model_states.txt')