In [None]:
import sys
import random

import mne
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
import math
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score
import importlib
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
%matplotlib inline

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from modules.BciDataHandler import BciDataHandler

data_handler = BciDataHandler()
data_handler.instantiate_dataset()

In [None]:
# ------------------------------ bci competition dataset ------------------------------
all_subject_epochs = mne.concatenate_epochs(list(data_handler.subjects_epochs.values()))
all_labels = all_subject_epochs.events[:, -1] - 2

epochs = data_handler.subjects_epochs[1]
labels = np.array(data_handler.subjects_labels[1]) - 1

In [None]:
# -------------------------------- ufjf dataset --------------------------------------
from modules.EdfHandler import EdfHandler

epochs, labels = EdfHandler.getAllData(["C:\\Users\\davi2\Desktop\\bci\\datasets_ufjf\\bci\\001.edf"])
epochs = epochs[0]
labels = np.array(labels[0])
labels[labels == 6] = 0

In [None]:
labels

In [None]:
import mne
from mne import Epochs, pick_types, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci


# #############################################################################
# # Set parameters and read data

# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1., 4.
event_id = dict(handsOrLeft=2, feetOrRight=3)

def get_physionet_data(subject, runs):

    raw_fnames = eegbci.load_data(subject, runs)
    raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
    eegbci.standardize(raw)  # set channel names
    montage = make_standard_montage('standard_1005')
    raw.set_montage(montage)

    # Apply band-pass filter
    raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge')

    events, _ = events_from_annotations(raw)

    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

    # Read epochs (train will be done only between 1 and 2s)
    # Testing will be done with a running classifier
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)

    epochs_data = epochs.copy().crop(tmin=1., tmax=2.)

    labels = epochs.events[:, -1] - 2

    return epochs_data, labels

In [None]:
#----------------------- physionet dataset -------------------------------------
# [6, 10, 14] hands vs feet
#[4, 8, 12] left vs right hand
# X_hf, y_hf = get_physionet_data(subject=1, runs=[6, 10, 14])
# X_lr, y_lr = get_physionet_data(subject=1, runs=[4, 8, 12])
#
# epochs = mne.concatenate_epochs([X_hf, X_lr])
# labels = np.concatenate([y_hf, y_lr+2])

In [None]:
X = torch.tensor(epochs.get_data()).to(dtype=torch.float32, device=device)
y = torch.tensor(labels).to(dtype=torch.long, device=device)

In [None]:
from torch import nn
from Models.SubModules import ViewConv, DepthWiseConv2d, SeparableConv2d, Unsqueeze, PositionalEncoding, ToTransformer, \
    MaxNormLayer
import torch.nn.functional as F


class EEGNET(nn.Module):
    def __init__(
            self,
            n_channels,
            n_times,
            n_classes,
            kernel_length=64,
            F1=8,
            D=2,
            F2=16,
            signal_size=32,
            pool1_stride=4,
            pool2_stride=8,
            dropout_rate=0.5,
            norm_rate=0.25,
            transformer_ffd=516,
    ):
        super().__init__()
        print('model instantianting...')
        self.net = nn.Sequential(
            ViewConv(),
            nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False, padding='same'),
            nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False),
            DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False),
            nn.BatchNorm2d(num_features=F1*D, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride),
            nn.Dropout(dropout_rate),
            SeparableConv2d(in_channels=F1*D, kernel_size=(1, 16), out_channels=F2, bias=False),
            nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride),
            nn.Dropout(dropout_rate),
            nn.Flatten(),
            nn.Linear(in_features=F2 * ((((n_times - pool1_stride) // pool1_stride + 1) - pool2_stride) // pool2_stride + 1), out_features=signal_size, bias=False),
            nn.ELU(),
            Unsqueeze(),
            PositionalEncoding(d_model=signal_size, dropout=0.1),
            ToTransformer(),
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=signal_size, dim_feedforward=transformer_ffd, nhead=4, batch_first=True),
                num_layers=2,
            ),
            # ViewConv(),
            # nn.Conv2d(in_channels=signal_size, out_channels=1, kernel_size=(1, signal_size), bias=False, padding='same'),
            # nn.Flatten(),

            nn.AvgPool1d(kernel_size=4, stride=signal_size),
            nn.Dropout(dropout_rate),
            nn.Flatten(),
            nn.BatchNorm1d(num_features=signal_size, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            MaxNormLayer(in_features=signal_size, out_features=n_classes, max_norm=norm_rate),
            nn.Softmax(dim=1),
        )

    def forward(self, x, targets):
        out_values = {}
        out = x

        for layer in self.net.children():
            out = layer(out)
            out_values[layer.__class__.__name__] = out.clone()

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(out, targets)

        return out, loss, out_values

In [None]:
# from Models.ImprovedTransformer import EEGNET

#MOABB
# model = EEGNET(n_channels=len(epochs.picks), n_times=1251, n_classes=len(data_handler.selected_events))

#UFJF
model = EEGNET(n_channels=len(epochs.picks), n_times=X.shape[-1], n_classes=4)

model = model.to(device=device)

In [None]:
seed = 1337
splits = 5
lr=1e-3

skf = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
from modules.TrainTester import TrainerTester

ud = []

#main-trianing-loop
for train_index, test_index in skf.split(X[:-100], y[:-100]):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    TrainerTester.train_loop(model, optimizer, X_train, y_train, lr, ud)
    out_values = TrainerTester.test_loop(model, X_test, y_test)
    break

In [None]:
out_values = TrainerTester.test_loop(model, X[-100:], y[-100:])
accuracy = TrainerTester.test_and_show(model, X[-100:], y[-100:]).tolist()

plt.plot(accuracy)

In [None]:
accuracy = []
for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    out_values = TrainerTester.test_loop(model, X_test, y_test)
    accuracy.append(TrainerTester.test_and_show(model, X_test[:-10], y_test[:-10]).tolist())
    break

plt.imshow(accuracy, cmap="Blues")

In [None]:
#params evaluation
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, np.prod(param.size()))

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('total: ', params)

In [None]:
plt.figure(figsize=(40, 10))
legends = []
i = 0
for name, p in model.named_parameters():
    plt.plot([ud[j][i] for j in range(len(ud))])
    legends.append(name)
    i += 1
plt.plot([0, len(ud)], [-3, -3], 'k') # these ratios should be ~1e-3, indicate on plot
plt.legend(legends);

In [None]:
# visualize histograms
plt.figure(figsize=(40, 30)) # width and height of the plot
legends = []
for name, values in out_values.items(): # note: exclude the output value
    if name != "Softmax":
        t = values
        print(f'{name}: mean {t.mean()}, std {t.std()}')
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].detach(), hy.detach())
        legends.append(f'layer ({name}')
plt.legend(legends)
plt.title('gradient distribution')


In [None]:
import re
# visualize histograms
plt.figure(figsize=(40, 20)) # width and height of the plot
legends = []
for name, params in model.named_parameters():
    if not re.search('bias', name):
        t = params.grad
        print(f'layer {name}: weight {tuple(params.shape)} | mean {t.mean()} | std {t.std()} | grad:data ratio { t.std() / params.std()}')
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].detach(), hy.detach())
        legends.append(f'{name} {tuple(params.shape)}')
plt.legend(legends)
plt.title('weights gradient distribution');

In [None]:
model.load_state_dict(torch.load('model_states/model_states.txt'))

In [None]:
# torch.save(model.state_dict(), 'model_states/model_states.txt')