In [56]:
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
%matplotlib inline

In [57]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [61]:
class DepthWiseConv2d(nn.Module):
    def __init__(self, in_channels, kernel_size, kernels_per_layer, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels=in_channels, out_channels=in_channels*kernels_per_layer, kernel_size=kernel_size, groups=in_channels, bias=bias, padding='same')
    def forward(self, x):
        return self.depthwise(x)

class PointWiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernels_per_layer=1, bias=False):
        super().__init__()
        self.pointwise = nn.Conv2d(in_channels=in_channels*kernels_per_layer, out_channels=out_channels,
                                   kernel_size=(1,1), bias=bias, padding="valid")

    def forward(self, x):
        return self.pointwise(x)

class MaxNormLayer(nn.Linear):
    def __init__(self, in_features, out_features, max_norm=1.0,):
        super(MaxNormLayer, self).__init__(in_features=in_features, out_features=out_features)
        self.max_norm = max_norm

    def forward(self, x):
        if self.max_norm is not None:
            with torch.no_grad():
                self.weight.data = torch.renorm(
                    self.weight.data, p=2, dim=0, maxnorm=self.max_norm
                )
        return super(MaxNormLayer, self).forward(x)

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, kernels_per_layer=1, bias=False):
        super().__init__()
        self.depthwise = DepthWiseConv2d(in_channels=in_channels, kernels_per_layer=kernels_per_layer, kernel_size=kernel_size, bias=bias)
        self.pointwise = PointWiseConv2d(in_channels=in_channels, out_channels=out_channels, kernels_per_layer=kernels_per_layer, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

possíveis melhorias:
1 - trocar batchnorm por layernorm
2 - aplicação de transformer

In [144]:
class EEGNET(nn.Module):
    def __init__(
        self,
        n_channels,
        n_times,
        n_classes,
        kernel_length=64,
        F1=8,
        D=2,
        F2=16,
        pool1_stride=4,
        pool2_stride=8,
        dropout_rate=0.5,
        norm_rate=0.25,
    ):
        super().__init__()
        #block 1
        self.conv2d = nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False, padding='same')
        self.batchNorm = nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False)
        self.depthWise = DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False) #equivalente a convolução depth wise
        #---------------------------------------------------------------------

        #block 2
        self.batchNorm2 = nn.BatchNorm2d(num_features=F1*D, momentum=0.01, eps=0.001, track_running_stats=False)
        self.elu1 = nn.ELU()
        self.avgPool2d = nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.separableConv2d = SeparableConv2d(in_channels=F1*D, kernel_size=(1, 16), out_channels=F2, bias=False)
        self.batchNorm3 = nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False)
        self.elu2 = nn.ELU()
        self.avgPool2d_2 = nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride)
        self.dropout2 = nn.Dropout(dropout_rate)
        #---------------------------------------------------------------------

        #final block
        self.flatten = nn.Flatten()
        self.maxNormLayer = MaxNormLayer(in_features = F2 * ((((n_times - pool1_stride) // pool1_stride + 1) - pool2_stride) // pool2_stride + 1), out_features=n_classes, max_norm=norm_rate)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = x.view((1,) + x.shape) # channel, batch, time (conv_depth, C, B, T)
        out = self.conv2d(out) # (conv_depth, F1, F1//2, T)
        out = self.batchNorm(out) # (conv_depth, F1, F1//2, T)
        out = self.depthWise(out) # (conv_depth, F1*D, F1//2, T)
        out = self.batchNorm2(out) # (conv_depth, F1*D, F1//2, T)
        out = self.elu1(out) # (conv_depth, F1*D, F1//2, T)
        out = self.avgPool2d(out) # (conv_depth, F1*D, max(1, F1//(2*pool1_stride)), max(1, T//(2*pool1_stride)))
        out = self.dropout1(out) # (conv_depth, F1*D, max(1, F1//(2*pool1_stride)), max(1, T//(2*pool1_stride)))
        out = self.separableConv2d(out) # (conv_depth, F1*D, max(1, F1//(2*pool1_stride)), max(1, T//(2*pool1_stride)))
        out = self.batchNorm3(out) # (conv_depth, F1*D, max(1, F1//(2*pool1_stride)), max(1, T//(2*pool1_stride)))
        out = self.elu2(out) # (conv_depth, F1*D, max(1, F1//(2*pool1_stride)), max(1, T//(2*pool1_stride)))
        print(out.shape)
        out = self.avgPool2d_2(out) # (conv_depth, F1*D, F1//(2*pool1_stride), T//F2)
        print(out.shape)
        out = self.dropout2(out) # (conv_depth, F1*D, F1//(2*pool1_stride), T//F2)
        print(out.shape)
        out = self.flatten(out) # (conv_depth, (F1*D) * (F1//(2*4)) * (T//F2) )
        print(out.shape)
        out = self.maxNormLayer(out) # (conv_depth, n_classes)
        out = self.softmax(out) # (conv_depth, n_classes)
        return out



In [150]:
model = EEGNET(n_channels=16, n_times=256, n_classes=16)
example = torch.randn((16, 1, 256), device=device) #channel, batch, time (C, B, T)
out = model(example)
out

torch.Size([1, 16, 1, 64])
torch.Size([1, 16, 1, 8])
torch.Size([1, 16, 1, 8])
torch.Size([1, 128])


tensor([[0.0657, 0.0591, 0.0584, 0.0696, 0.0593, 0.0756, 0.0620, 0.0523, 0.0547,
         0.0734, 0.0629, 0.0599, 0.0590, 0.0571, 0.0555, 0.0755]],
       grad_fn=<SoftmaxBackward0>)