In [None]:
import pandas as pd
import numpy as np
import mne
import re
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import Dataset, DataLoader

In [None]:
DATA_PATH = './data/event_data.csv'
MFF_DIR = '/home/ajays/Desktop/WBI-data/'
STIM_CHANNEL_NAMES = ['201' + str(i) for i in range(10)]
EEG_CHANNEL_NAMES = ['E'+ str(i) for i in range(1,33)]

In [None]:
class EEGDataset(Dataset):
    def __init__(self):
        self.df = pd.read_csv(DATA_PATH)
        self.current_file = None
        self.current_raw = None
        self.current_events = None
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        fname = self.df.iloc[idx]['fname']
        if self.current_file != fname or self.current_file is None:
            self.current_file = fname
            self.current_raw = mne.io.read_raw_egi(MFF_DIR + fname,verbose=False,preload=True).pick_channels(
                STIM_CHANNEL_NAMES+EEG_CHANNEL_NAMES)
            self.current_events = mne.find_events(self.current_raw,verbose=False)
        s_time = self.df.iloc[idx]['s_time']/1000.0
        epoch = mne.Epochs(
            self.current_raw.copy().pick_types(eeg=True),
            self.current_events[idx].reshape(1,-1),tmin=s_time,tmax=s_time+2.0-1.0/self.current_raw.info['sfreq'],
            baseline=None,verbose=False)
        # TODO: apply filter to eeg data
        X = epoch.get_data()
        y = self.df.iloc[idx]['label']
        
        return X, y

In [None]:
eeg_dataset = EEGDataset()

In [None]:
"""
Issues:
- a few data values are empty
- there are only 200 examples in the dataset
"""

In [None]:
# build model
class EEGEncoder(nn.Module):
    def __init__(self, n_i):
        super(EEGEncoder,self).__init__();
        self.lstm1 = nn.LSTM(n_i, n_i, 1, bidirectional=False)
        self.linear2 = nn.Linear(n_i, n_i)
        self.linear3 = nn.Linear(n_i,n_i)
        self.linear4 = nn.Linear(n_i,10)
        
    def forward(self, x, hidden):
        x, (hn,cn) = self.lstm1(x,hidden)
        enc = self.linear2(hn)
        x = F.relu(enc)
        x = self.linear3(x)
        x = self.linear4(x)
        x = torch.sigmoid(x)
        
        return x