In [None]:
!export CUBLAS_WORKSPACE_CONFIG=:4096:8

In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2
%load_ext lab_black

In [None]:
import numpy as np
from sklearn.metrics import (
    roc_auc_score,
    precision_score,
    recall_score,
    accuracy_score,
    cohen_kappa_score,
)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import random

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

In [None]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.T = 120

        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 22), padding=0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)

        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 11))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 2)

        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 2))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))

        # FC Layer
        # NOTE: This dimension will depend on the number of timestamps per sample in your data.
        # I have 120 timepoints.
        self.fc1 = nn.Linear(4 * 4 * 99, 4)

    def forward(self, x):
        # Layer 1
        x = F.elu(self.conv1(x))
        x = self.batchnorm1(x)
        if self.training:
            x = F.dropout(x, 0.25)
        x = x.permute(0, 3, 1, 2)

        # Layer 2
        x = self.padding1(x)
        x = F.elu(self.conv2(x))
        x = self.batchnorm2(x)
        if self.training:
            x = F.dropout(x, 0.5)
        x = self.pooling2(x)

        # Layer 3
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        if self.training:
            x = F.dropout(x, 0.5)
        x = self.pooling3(x)

        # FC Layer
        # x = x.view(-1, 4*2*7)
        x = x.reshape((-1, 4 * 4 * 99))
        # x = torch.sigmoid(self.fc1(x))

        x = F.softmax(self.fc1(x), dim=1)
        return x


net = EEGNet().cuda(0)
# print(net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0))))
criterion = nn.CrossEntropyLoss()  # nn.BCELoss()
optimizer = optim.Adam(net.parameters(), weight_decay=1)
# lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr=0.001, steps_per_epoch=1, epochs=200
# )

In [None]:
class DepthwiseConv2d(torch.nn.Conv2d):
    def __init__(
        self,
        in_channels,
        depth_multiplier=1,
        kernel_size=3,
        stride=1,
        padding=0,
        dilation=1,
        bias=True,
        padding_mode="zeros",
    ):
        out_channels = in_channels * depth_multiplier
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias,
            padding_mode=padding_mode,
        )


class SeparableConv2D(nn.Module):
    """https://github.com/seungjunlee96/Depthwise-Separable-Convolution_Pytorch/blob/master/DepthwiseSeparableConvolution/DepthwiseSeparableConvolution.py"""

    def __init__(
        self,
        in_channels,
        out_channels,
        depth_multiplier=1,
        kernel_size=3,
        padding="valid",
        bias=False,
    ):
        super(SeparableConv2D, self).__init__()
        self.depthwise = DepthwiseConv2d(
            in_channels, depth_multiplier, kernel_size, padding=padding, bias=bias
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class EEGNetv2(nn.Module):
    def __init__(
        self,
        num_classes=4,
        channels=22,
        dropout_rate=0.5,
        kernel_length=32,
        F1=8,
        D=2,
        F2=16,
    ):
        super(EEGNetv2, self).__init__()
        self.dropout_rate = dropout_rate

        # Layer 1
        self.conv1 = nn.Conv2d(1, F1, (kernel_length, 1), padding="same", bias=False)
        self.batchnorm1 = nn.BatchNorm2d(F1, False)
        self.dwconv2 = DepthwiseConv2d(
            in_channels=F1,
            depth_multiplier=D,
            kernel_size=(1, channels),  ## (1, num-channels)
            stride=1,
            padding="valid",
            bias=False,
        )

        self.batchnorm2 = nn.BatchNorm2d(2 * F1, False)
        # act elu
        self.pooling1 = nn.AvgPool2d((4, 1))
        # dropout

        # Layer 2
        self.sepconv2 = SeparableConv2D(2 * F1, F2, 1, (16, 1), padding="same")
        self.batchnorm3 = nn.BatchNorm2d(F2, False)
        # elu
        self.pooling2 = nn.AvgPool2d((8, 1))
        # dropout

        # FC Layer
        self.fc1 = nn.Linear(16 * 24 * 1, 4)

    def _forward_emb(self, x):
        # Layer 1
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.dwconv2(x)
        x = self.batchnorm2(x)
        x = F.elu(x)
        x = self.pooling1(x)
        if self.training:
            x = F.dropout(x, self.dropout_rate)

        x = self.sepconv2(x)
        x = self.batchnorm3(x)
        x = F.elu(x)
        x = self.pooling2(x)
        if self.training:
            x = F.dropout(x, self.dropout_rate)

        # FC Layer
        x = x.reshape((-1, 16 * 24 * 1))
        return x

    def forward(self, x):
        x = self._forward_emb(x)
        x = F.softmax(self.fc1(x), dim=1)
        return x


net = EEGNetv2(
    num_classes=4,
    channels=22,
    dropout_rate=0.1,
    kernel_length=32,
    F1=8,
    D=2,
    F2=16,
).cuda(0)
# print(net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0))))
criterion = nn.CrossEntropyLoss()  # nn.BCELoss()
optimizer = optim.Adam(net.parameters(), weight_decay=0.2)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.001, steps_per_epoch=1, epochs=200
)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.001)

In [None]:
class MBEEGNetv2(nn.Module):
    def __init__(
        self,
    ):
        super(MBEEGNetv2, self).__init__()
        self.mbeegnet1 = EEGNetv2(
            num_classes=4,
            channels=22,
            dropout_rate=0,
            kernel_length=16,
            F1=4,
            D=2,
            F2=16,
        )
        self.mbeegnet2 = EEGNetv2(
            num_classes=4,
            channels=22,
            dropout_rate=0.1,
            kernel_length=32,
            F1=8,
            D=2,
            F2=16,
        )
        self.mbeegnet3 = EEGNetv2(
            num_classes=4,
            channels=22,
            dropout_rate=0.2,
            kernel_length=64,
            F1=16,
            D=2,
            F2=16,
        )

        # FC Layer
        self.fc1 = nn.Linear(16 * 24 * 3, 4)

    def forward(self, x):
        x = torch.cat(
            (
                self.mbeegnet1._forward_emb(x),
                self.mbeegnet2._forward_emb(x),
                self.mbeegnet3._forward_emb(x),
            ),
            dim=1,
        )
        x = F.softmax(self.fc1(x), dim=1)
        return x


net = MBEEGNetv2().cuda(0)
# print(net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0))))
criterion = nn.CrossEntropyLoss()  # nn.BCELoss()
optimizer = optim.Adam(net.parameters(), weight_decay=0.5)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.001, steps_per_epoch=1, epochs=200
)

In [None]:
def evaluate(model, dl, params=["acc"]):
    results = []
    predicted = []
    Y = []

    for batch in dl:
        inputs, labels = batch
        inputs = torch.permute(
            torch.vstack(list(map(lambda a: a.unsqueeze(0), inputs.values()))),
            (1, 2, 3, 0),
        )
        # wrap them in Variable
        inputs, labels = inputs.cuda(0), labels.type(torch.LongTensor).cuda(0)

        pred = model(inputs)

        predicted.append(pred.cpu().detach())
        Y.append(labels.cpu())

    predicted = torch.cat(predicted, 0)
    Y = torch.cat(Y, 0)

    loss = criterion(predicted, Y)

    predicted = predicted.numpy()
    Y = Y.numpy()

    for param in params:
        if param == "acc":
            results.append(accuracy_score(Y, np.argmax(predicted, axis=1)))
        if param == "auc":
            results.append(roc_auc_score(Y, predicted, multi_class="ovr"))
        if param == "kappa":
            results.append(cohen_kappa_score(Y, np.argmax(predicted, axis=1)))
        if param == "recall":
            results.append(
                recall_score(Y, np.argmax(predicted, axis=1), average="micro")
            )
        if param == "precision":
            results.append(
                precision_score(Y, np.argmax(predicted, axis=1), average="micro")
            )
        if param == "fmeasure":
            precision = precision_score(
                Y, np.argmax(predicted, axis=1), average="micro"
            )
            recall = recall_score(Y, np.argmax(predicted, axis=1), average="micro")
            results.append(2 * precision * recall / (precision + recall))

    results.append(loss)
    return results

In [None]:
import sys

sys.path.append("../../")

from pase_eeg.lit_modules.simple_classifier_lit import EEGBCIIV2aDataLit

In [None]:
data_lit = EEGBCIIV2aDataLit(
    data_path="/data/BCI_Competition_IV/",
    channels_config="../../configs/eeg_recording_standard/international_10_20_22.py",
    train_patients=[],
    test_patients=[9],
    batch_size=32,
    leave_one_out=False,
    transforms=[
        {
            "class_path": "pase_eeg.data.transforms.ToTensor",
            "init_args": {"device": "cpu"},
        },
        # {
        #     "class_path": "pase_eeg.data.transforms.ZNorm",
        #     "init_args": {"stats": "./bci_comp_iv2a_stats.pkl", "mode": "mean-std"},
        # },
    ],
)
data_lit.setup()

#### Run

In [None]:
history = {
    "train_loss": [],
    "test_loss": [],
    "train_acc": [],
    "test_acc": [],
    "lr": [],
    "train_kappa": [],
    "test_kappa": [],
}

In [None]:
batch_size = 32

for epoch in range(200):  # loop over the dataset multiple times
    print("\nEpoch ", epoch)

    running_loss = 0.0
    for batch in data_lit.train_dataloader():
        inputs, labels = batch
        inputs = torch.permute(
            torch.vstack(list(map(lambda a: a.unsqueeze(0), inputs.values()))),
            (1, 2, 3, 0),
        )
        # wrap them in Variable
        inputs, labels = inputs.cuda(0), labels.type(torch.LongTensor).cuda(0)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    history["lr"].append(optimizer.param_groups[0]["lr"])
    # print(optimizer.param_groups[0]["lr"])
    lr_scheduler.step()

    # Validation accuracy
    params = ["acc", "kappa", "auc", "fmeasure", "loss"]
    print(params)
    print("Training Loss ", running_loss)
    tr = evaluate(net, data_lit.train_dataloader(), params)
    print("Train - ", tr)
    ev = evaluate(net, data_lit.val_dataloader(), params)
    print("Validation - ", ev)
    history["train_loss"].append(tr[-1])
    history["train_acc"].append(tr[0])
    history["train_kappa"].append(tr[1])

    history["test_loss"].append(ev[-1])
    history["test_acc"].append(ev[0])
    history["test_kappa"].append(ev[1])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import product


def draw_2d_plot(shape, plotter, data, num=None):
    fig, axs = plt.subplots(*shape, figsize=(15, 15))
    for i, idx in enumerate(product(*[list(range(n)) for n in shape])):
        if num is not None and i >= num:
            break
        plotter(axs[idx[0], idx[1]], data[i])


def tuple_plotter(axes, data):
    name, data = data
    axes.title.set_text(name)
    x = np.array([i for i in range(len(data))])
    axes.plot(x, data)

In [None]:
max(history["test_acc"]), max(history["test_kappa"])

In [None]:
history["test_acc"][np.argmax(history["test_kappa"])], history["test_kappa"][
    np.argmax(history["test_kappa"])
]

In [None]:
shape = [2, 2]
num = 4
data = list(history.items())
draw_2d_plot(shape, tuple_plotter, data, num)

In [None]:
plt.plot([i for i in range(len(history["lr"]))], history["lr"])

In [None]:
idxs = [1, 2, 3, 4, 5, 6, 7, 8, 9]
stat = []

for i in idxs:
    print(i)
    # net = EEGNetv2(
    #     num_classes=4,
    #     channels=22,
    #     dropout_rate=0.1,
    #     kernel_length=32,
    #     F1=8,
    #     D=2,
    #     F2=16,
    # ).cuda(0)
    net = MBEEGNetv2().cuda(0)
    # print(net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0))))
    criterion = nn.CrossEntropyLoss()  # nn.BCELoss()
    optimizer = optim.Adam(net.parameters(), weight_decay=0.5)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=0.001, steps_per_epoch=1, epochs=200
    )

    data_lit = EEGBCIIV2aDataLit(
        data_path="/data/BCI_Competition_IV/",
        channels_config="../../configs/eeg_recording_standard/international_10_20_22.py",
        train_patients=[],
        test_patients=[i],
        batch_size=32,
        leave_one_out=False,
        transforms=[
            {
                "class_path": "pase_eeg.data.transforms.ToTensor",
                "init_args": {"device": "cpu"},
            },
            # {
            #     "class_path": "pase_eeg.data.transforms.ZNorm",
            #     "init_args": {"stats": "./bci_comp_iv2a_stats.pkl", "mode": "mean-std"},
            # },
        ],
    )
    data_lit.setup()

    history = {
        "train_loss": [],
        "test_loss": [],
        "train_acc": [],
        "test_acc": [],
        "lr": [],
        "train_kappa": [],
        "test_kappa": [],
    }
    for epoch in tqdm(range(200)):  # loop over the dataset multiple times
        running_loss = 0.0
        for batch in data_lit.train_dataloader():
            inputs, labels = batch
            inputs = torch.permute(
                torch.vstack(list(map(lambda a: a.unsqueeze(0), inputs.values()))),
                (1, 2, 3, 0),
            )
            # wrap them in Variable
            inputs, labels = inputs.cuda(0), labels.type(torch.LongTensor).cuda(0)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()

        history["lr"].append(optimizer.param_groups[0]["lr"])
        # print(optimizer.param_groups[0]["lr"])
        lr_scheduler.step()

        # Validation accuracy
        params = ["acc", "kappa", "auc", "fmeasure", "loss"]
        # print(params)
        # print("Training Loss ", running_loss)
        tr = evaluate(net, data_lit.train_dataloader(), params)
        # print("Train - ", tr)
        ev = evaluate(net, data_lit.val_dataloader(), params)
        # print("Validation - ", ev)
        history["train_loss"].append(tr[-1])
        history["train_acc"].append(tr[0])
        history["train_kappa"].append(tr[1])

        history["test_loss"].append(ev[-1])
        history["test_acc"].append(ev[0])
        history["test_kappa"].append(ev[1])
    stat.append(history)

In [None]:
for his in stat:
    print(max(his["test_acc"]), max(his["test_kappa"]))

In [None]:
net.conv1.weight.size()