In [1]:
import pandas as pd
import numpy as np
import mne
import re
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchsummary import summary
import torch.autograd as autograd

In [2]:
DATA_PATH = '../data/event_data.csv'
MFF_DIR = '../data/'
STIM_CHANNEL_NAMES = ['101' + str(i) for i in range(10)]
EEG_CHANNEL_NAMES = ['E'+ str(i) for i in range(1,33)]

EVENT_LENGTHS = [160] + [200]*9
NUMPY_X_FNAME, NUMPY_Y_FNAME = MFF_DIR + 'X_small.npy',MFF_DIR + 'y_small.npy'

In [3]:
######### NO NEED TO RUN ###################

In [4]:
class EEGDataset(Dataset):
    def __init__(self):
        self.df = pd.read_csv(DATA_PATH)
        self.current_file = None
        self.current_raw = None
        self.current_events = None
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        fname = self.df.iloc[idx]['fname']
        if self.current_file != fname or self.current_file is None:
            self.current_file = fname
            self.current_raw = mne.io.read_raw_egi(MFF_DIR + fname,verbose=False,preload=True).pick_channels(
                STIM_CHANNEL_NAMES+EEG_CHANNEL_NAMES)
            self.current_events = mne.find_events(self.current_raw,verbose=False)
        s_time = self.df.iloc[idx]['s_time']/1000.0
        if idx < 160:
            fake_idx = idx
        else:
            fake_idx = (idx - 160)%200 
        epoch = mne.Epochs(
            self.current_raw.copy().pick_types(eeg=True),
            self.current_events[fake_idx].reshape(1,-1),tmin=s_time,tmax=s_time+4.0-1.0/self.current_raw.info['sfreq'],
            baseline=None,verbose=False,preload=True)
        try:
            X = epoch.load_data().filter(l_freq = 0, h_freq = 30).resample(100).get_data()
        except ValueError:
            X = np.array([]).reshape(0,32,400)
        y = self.df.iloc[idx]['label']
        
        return X, y

In [5]:
eeg_dataset = EEGDataset()

In [6]:
X = []
y = []
i = 0
for xi,yi in eeg_dataset:
    if i % 100 == 0:
        print(i)
    if xi.shape == (1,32,400):
        X.append(xi)
        y.append(yi)
    i += 1

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900


In [7]:
X_np = np.concatenate(X, axis=0)
y_np = np.array(y)
print(X_np.shape,y_np.shape)

(967, 32, 400) (967,)


In [8]:
np.save(NUMPY_X_FNAME,X_np)
np.save(NUMPY_Y_FNAME,y_np)

In [9]:
#####################################

In [6]:
X_np = np.load(NUMPY_X_FNAME)
y_np = np.load(NUMPY_Y_FNAME)

In [8]:
X_torch = torch.Tensor(X_np)
y_torch = torch.Tensor(y_np)
dset = TensorDataset(X_torch[:10],y_torch[:10]) # create your datset
dloader = DataLoader(dset,shuffle=True,batch_size=32) # create your dataloader

In [25]:
print(X_torch.shape,y_torch.shape)
print(X_torch[0].shape)

torch.Size([967, 32, 400]) torch.Size([967])
torch.Size([32, 400])


In [43]:
class MultiLSTM(nn.Module):
    def __init__(self,n_i):
        super(MultiLSTM,self).__init__()
        self.lstms  = [None]*n_i
        for i in range(n_i):
            self.lstms[i] = nn.LSTM(1, n_i, 1, bidirectional=False)
        
    def forward(self,x):
        print(x.shape)
        y = []
        # (N,n_i,400)
        for i,lstm in enumerate(self.lstms):
            x_ = x[:,i,:]
            print(x_.shape)
            y.append(lstm(x[:,i,:])[1][0])
        y = torch.stack(y, dim=1)
        
        return y

In [44]:
mlt = MultiLSTM(32)

In [36]:
# build model
class EEGEncoder(nn.Module):
    def __init__(self, n_i):
        super(EEGEncoder,self).__init__();
        self.n_i = n_i
        self.multilstm = MultiLSTM(n_i)
        #self.lstm1 = nn.LSTM(1, n_i, 1, bidirectional=False)
        
        self.linear2 = nn.Linear(32*n_i, n_i)
        self.linear3 = nn.Linear(n_i,n_i)
        self.linear4 = nn.Linear(n_i,10)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        print(x.shape)
        x = self.multilstm(x)
        enc = self.linear2(x)
        x = F.relu(enc)
        x = self.linear3(x)
        x = self.linear4(x)
        x = self.softmax(x)
        
        return x

In [None]:
# Testing dimensions
eeg_classifier = EEGEncoder(32)
inp,l = torch.randn(400,1,32),torch.randn(1)
out = eeg_classifier(inp)
print(out.shape, inp.shape)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = opt.Adam(eeg_classifier.parameters(),lr=0.001)

In [None]:
losses = []
for epoch in range(5):
    running_loss = 0.0
    
    for i, data in enumerate(dloader,0):
        inp,lab = data
        optimizer.zero_grad()
        
        out = eeg_classifier(inp)
        out = out.reshape(-1,10)
                        
        loss = criterion(out.float(),lab.long())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 1 == 0:
            print('[{}/{}]: loss = {}'.format(epoch+1,i+1,running_loss))
            losses.append(running_loss)
            running_loss = 0.0
            
print('Finished training')

In [22]:
torch.save(eeg_classifier.state_dict(), '../ckpt/save1.pth')

In [43]:
for i, data in enumerate(dloader,0):
        inp,lab = data
        out = eeg_classifier(inp);
        print(out.argmax(dim=2),lab)

tensor([[6]]) tensor([5.])
tensor([[6]]) tensor([6.])
tensor([[6]]) tensor([0.])
tensor([[6]]) tensor([9.])
tensor([[6]]) tensor([7.])
tensor([[6]]) tensor([6.])
tensor([[6]]) tensor([9.])
tensor([[6]]) tensor([2.])
tensor([[6]]) tensor([2.])
tensor([[6]]) tensor([0.])
