In [4]:
import torch
import torch.nn as nn
import torchvision

class EEGNet(nn.Module):
    def __init__(self, activation = nn.ReLU()):
        super(EEGNet, self).__init__()
        
        # Layer 1 
        self.firstConv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size = (1, 51), stride = (1, 1), padding = (0, 25), bias = False),
            nn.BatchNorm2d(16, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True),
        )
        
        # Depthwise Layer  
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size = (2, 1), stride = (2, 1), groups = 16, bias = False),
            nn.BatchNorm2d(32, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True),
            activation,
            nn.AvgPool2d(kernel_size = (1, 4), stride = (1, 4), padding = 0),
            nn.Dropout(p = 0.25)
        )
        
        # Separable Layer
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size = (1, 15), stride = (1, 1), padding = (0, 7), bias = False),
            nn.BatchNorm2d(32, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True),
            activation,
            nn.AvgPool2d(kernel_size = (1, 8), stride = (1, 8), padding = 0),
            nn.Dropout(p = 0.25)
        )
        
        # in_features = 736, out_features = 2
        self.classify = nn.Sequential(nn.Linear(736, 2, bias = True))

    def forward(self, x):
        firstResults = self.firstConv(x)
        dwResults = self.depthwiseConv(firstResults)
        separableResults = self.separableConv(dwResults)
        
        """ view results """
        separableResults = separableResults.view(separableResults.shape[0],-1)

        out = self.classify(separableResults)
        return out


class DeepConvNet(nn.Module):
    def __init__(self, activation = nn.ReLU()):
        super(DeepConvNet, self).__init__()
        
        """
         conv0, input = [1, 1, C=2, T=750], filter=25, kernel_size=(1,5)
         conv1, filter=25, kernel_size=(2,1)
         conv2, filter=50, kernel_size=(1,5)
         conv3, filter=100, kernel_size=(1,5)
         conv4, filter=200, kernel_size=(1,5)
        """

        self.convnet1 = nn.Sequential(
            nn.Conv2d(1, 25, kernel_size=(1, 5)),
            nn.Conv2d(25, 25, kernel_size=(2, 1)),
            nn.BatchNorm2d(25, eps = 1e-5, momentum = 0.1),
            activation,
            nn.MaxPool2d(kernel_size = (1, 2)),
            nn.Dropout(p = 0.5)
        )
              
        Filters = [25,50,100,200]
        
        for i in range(1,len(Filters)):
            setattr(self, f'convnet{i+1}', nn.Sequential(
                nn.Conv2d(Filters[i-1], Filters[i], kernel_size = (1,5)),
                nn.BatchNorm2d(Filters[i], eps = 1e-5, momentum = 0.1),
                activation,
                nn.MaxPool2d(kernel_size = (1,2)),
                nn.Dropout(p = 0.5)
            ))

        self.classify = nn.Linear(8600, 2)

    def forward(self, x):
        results1 = self.convnet1(x)
        results2 = self.convnet2(results1)
        results3 = self.convnet3(results2)
        results4 = self.convnet4(results3)
        
        """ view results """
        results4 = results4.view(results4.shape[0], -1)
        
        out = self.classify(results4)
        return out