In [6]:
import torch
from torch.nn import functional as F
import torch.optim as optim
from torchvision.utils import save_image
import torch.utils.data as data_utils
import torch.nn as nn
import numpy as np
import torchaudio

In [8]:



class STFT_CNN(nn.Module):
    def __init__(self, x_dim, y_dim, channel_dim, class_num ):
        # Input shape: N,C, H, W
        
        super(STFT_CNN, self).__init__()
        
        self.specgram = torchaudio.transforms.Spectrogram(normalized = True, n_fft = 128, win_length = 128, hop_length = 64)
        self.conv1 = nn.Conv2d(in_channels = channel_dim, out_channels = 32, kernel_size = 12)
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 8)
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 4)
        self.batchnorm1 = nn.BatchNorm2d(num_features = 32)
        self.batchnorm2 = nn.BatchNorm2d(num_features = 64)
        self.batchnorm3 = nn.Batchnorm2d(num_features = 64)
        self.relu = nn.ReLU()
        self.pooling = nn.MaxPool2d(2)
        self.fc = nn.Linear(1024, class_num)
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(0.3)
        
        conv_block1 = nn.Sequential(specgram, self.batchnorm1, self.relu, self.pooling)
        conv_block2 = nn.Sequential(self.conv2, self.batchnorm2, self.relu, self.pooling)
        conv_blcok3 = nn.Sequential(self.conv3, self.batchnorm3, self.relu, self.pooling, self.dropout)
        
    def forward(self, x):
        h1 = conv_block1(x)
        h2 = conv_block2(h1)
        h3 = conv_block3(h2)
        h = h3.flatten()
        out = self.softmax(self.fc(h))
        
        return out
    

##Input shape:    
        
class EEGNet(nn.Module):
    def __init__(self, timepoints, class_num):
        # Input shape: N,C,T,1
        
        super(EEGNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        self.elu = nn.ELU()
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        self.dropout = nn.Dropout(0.25)
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        self.softmax = nn.Softmax()
        
        self.fc1 = nn.Linear(4*2*7, class_num)
        
        conv_block1 = nn.Seqeuntial(self.conv1, self.elu, self.batchnorm1, self.dropout)
        conv_block2 = nn.Seqeuntial(self.padding1, self.conv2, self.elu, self.batchnorm2, self.dropout, self.pooling1)
        conv_block3 = nn.Seqeuntial(self.padding2, self.conv3, self.elu, self.batchnorm3, self.dropout, self.pooling3)
        
        

    def forward(self, x):
        h1 = conv_block1(x)
        h1 = h1.permute(0, 3, 1, 2)
        h2 = conv_block2(h1)
        h3 = conv_block3(h2)
        out = self.softmax(self.fc1(h3))
        return out
        
        
        