In [10]:
import torch
from torch.nn import functional as F
import torch.optim as optim
from torchvision.utils import save_image
import torch.utils.data as data_utils
import torch.nn as nn
import numpy as np
import torchaudio
import mne

In [160]:
class STFT_CNN(nn.Module):
    def __init__(self, channel_dim, class_num ):

        
        super(STFT_CNN, self).__init__()
        
        self.specgram = torchaudio.transforms.Spectrogram(normalized = True, n_fft = 128, win_length = 128, hop_length = 16)
        # Input shape: N,C,H, W
        self.conv1 = nn.Conv2d(in_channels = 2, out_channels = 24, kernel_size = 12, padding = 2)
        self.conv2 = nn.Conv2d(in_channels = 24, out_channels = 48, kernel_size = 4, padding = 2)
        self.conv3 = nn.Conv2d(in_channels = 48, out_channels = 48, kernel_size = 4)
        self.batchnorm1 = nn.BatchNorm2d(num_features = 24)
        self.batchnorm2 = nn.BatchNorm2d(num_features = 48)
        self.batchnorm3 = nn.BatchNorm2d(num_features = 48)
        self.relu = nn.ReLU()
        self.pooling = nn.MaxPool2d(2)
        self.fc = nn.Linear(4032, class_num)
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(0.3)
        

        self.conv_block1 = nn.Sequential(self.conv1, self.batchnorm1, self.relu, self.pooling)
        self.conv_block2 = nn.Sequential(self.conv2, self.batchnorm2, self.relu, self.pooling)
        self.conv_block3 = nn.Sequential(self.conv3, self.batchnorm3, self.relu, self.pooling, self.dropout)
        
    def forward(self, x):
        spec = self.specgram(x).float()
        h1 = self.conv_block1(spec)
        h2 = self.conv_block2(h1)
        h3 = self.conv_block3(h2)
        h = h3.flatten()
        print(h.shape)
        out = self.softmax(self.fc(h))
        
        return out
    

##Input shape:    
        
class EEGNet(nn.Module):
    def __init__(self, timepoints, class_num):
        # Input shape: N,C,T,1
        
        super(EEGNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        self.elu = nn.ELU()
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        self.dropout = nn.Dropout(0.25)
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        self.softmax = nn.Softmax()
        
        self.fc1 = nn.Linear(4*2*7, class_num)
        
        conv_block1 = nn.Sequential(self.conv1, self.elu, self.batchnorm1, self.dropout)
        conv_block2 = nn.Sequential(self.padding1, self.conv2, self.elu, self.batchnorm2, self.dropout, self.pooling2)
        conv_block3 = nn.Sequential(self.padding2, self.conv3, self.elu, self.batchnorm3, self.dropout, self.pooling3)
        
        

    def forward(self, x):
        h1 = conv_block1(x)
        h1 = h1.permute(0, 3, 1, 2)
        h2 = conv_block2(h1)
        h3 = conv_block3(h2)
        out = self.softmax(self.fc1(h3))
        return out
        
        
        

Here we call sample eeg file using mne and check if the model runs

In [13]:
subject_data = mne.io.read_raw_edf('./A03T.gdf', preload=True)
event = mne.events_from_annotations(subject_data)

Extracting EDF parameters from /Users/jin/Desktop/20S/Hackathon/A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  subject_data = mne.io.read_raw_edf('./A03T.gdf', preload=True)
  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  subject_data = mne.io.read_raw_edf('./A03T.gdf', preload=True)


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']


In [32]:
tz = subject_data['EEG-Cz'][1]
t3 = subject_data['EEG-C3'][1]
signalz = subject_data['EEG-Cz'][0].reshape(-1)
signal3 = subject_data['EEG-C3'][0].reshape(-1)

In [33]:
segment1_cz = [tz[:250*5],signalz[:250*5]]
segment2_cz = [tz[250*5:250*10],signalz[250*5:250*10]]
segment1_c3 = [t3[:250*5],signal3[:250*5]]
segment2_c3 = [t3[250*5:250*10],signal3[250*5:250*10]]

In [154]:
stft_model = STFT_CNN(channel_dim = 2, class_num = 2)

In [155]:
out = stft_model(feed_stft)

torch.Size([4032])




In [161]:
EEGnet_model = EEGNet(1250,2)