In [1]:
import os
import warnings
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
#os.environ['CUDA_LAUNCH_BLOCKING'] = str(1)
#os.environ["TORCH_USE_CUDA_DSA"]= str(0)
warnings.filterwarnings('ignore') 


In [91]:
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import math
from collections import OrderedDict
import random
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset
import sys
import torch
import numpy as np
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mne
from sklearn.preprocessing import StandardScaler

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import import_ipynb

In [4]:
class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        self.transforms = transforms.ToTensor()
        self.x = x
        self.y = y
        
    def __getitem__(self, index):
        x = self.x[index, ...]
        y = self.y[index, ...]
        return x, y
    
    def __len__(self):
        return len(self.x)

In [5]:
def Normalize(data):
    normals = []
    scaler = StandardScaler()
    for idx in range(len(data)):
        normals.append(scaler.fit_transform(data[idx]))
    return np.array(normals)

In [6]:
PNES_data = mne.read_epochs(r"D:\MNE Data\PNES\RostamiAlireza.fif", preload=False).get_data(picks='eeg');
TLE_data = mne.read_epochs(r"D:\MNE Data\TLE\AlipoorMohamadHakim.fif", preload=False).get_data(picks='eeg');

Reading D:\MNE Data\PNES\RostamiAlireza.fif ...
    Found the data of interest:
        t =       0.00 ...    9996.67 ms
        0 CTF compensation matrices available
Not setting metadata
693 matching events found
No baseline correction applied
0 projection items activated
Loading data for 693 events and 3000 original time points ...
Reading D:\MNE Data\TLE\AlipoorMohamadHakim.fif ...
    Found the data of interest:
        t =       0.00 ...    9996.67 ms
        0 CTF compensation matrices available
Not setting metadata
650 matching events found
No baseline correction applied
0 projection items activated
Loading data for 650 events and 3000 original time points ...


In [7]:
PNES_data = torch.Tensor(Normalize(PNES_data)).cuda()
TLE_data = torch.Tensor(Normalize(TLE_data)).cuda()

In [8]:
PNES_labels = torch.Tensor(np.zeros((PNES_data.shape[0]))).cuda()
TLE_labels = torch.Tensor(np.ones((TLE_data.shape[0]))).cuda()

In [9]:
data = torch.cat((PNES_data, TLE_data), axis=0).reshape(-1, 1, 19, 3000)
labels = torch.cat((PNES_labels, TLE_labels))
labels = F.one_hot(labels.to(torch.int64))

In [10]:
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2)

In [11]:
class net(nn.Module):
    def __init__(self, T, C, input_size, hidden_size, num_layers, spatial_num, dropout, pool):
        super(net, self).__init__()
        
        self.T = T
        self.C = C
        self.spatial_num = spatial_num
        self.dropout = dropout
        self.pool = pool


        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cell_count = self.T // self.input_size

        self.fcn_in = (spatial_num * self.hidden_size)

        self._lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)

        self.lstm = nn.ModuleList([self._lstm for i in range(self.C)])

        self.cnn_block = nn.Sequential(nn.Conv2d(1, self.spatial_num, (self.C, 1)),
                                       nn.BatchNorm2d(self.spatial_num),
                                       nn.ELU(),
                                       nn.Dropout(self.dropout))

        
        self.fcn = nn.Sequential(nn.Linear(self.fcn_in, 64), 
                                 nn.ReLU(),
                                 nn.Linear(64, 16),
                                 nn.ReLU(),
                                 nn.Linear(16, 2))

        #self.fcn = nn.Linear(self.fcn_in, 4)
        self.results = nn.Softmax(dim=1)
    def forward(self, x):

        self.N = x.shape[0]
        x = x.reshape(self.N, self.C, self.cell_count, self.input_size)
        _x = None

        for index, cell in enumerate(self.lstm):
            cell_out, _ = cell(x[:, index, :, :], None)
            last_layer_out = cell_out[:, -1, :]
            
            last_layer_out = last_layer_out.unsqueeze(0)
            if _x is None:
                _x = last_layer_out
            else:
                _x = torch.cat((_x, last_layer_out), dim=0)
            

        x = _x.permute(1, 0, 2).unsqueeze(1)

        x = self.cnn_block(x)

        x = x.reshape(self.N, -1)

        x = self.fcn(x)
        x = self.results(x)

        return x

In [98]:
b = torch.Tensor(479, 3000, 19)
a = torch.rand(479, 1, 19, 3000).cuda()

model = net( T = 3000, C = 19, input_size = 60, hidden_size = 30, num_layers=4, spatial_num= 10, dropout=0.2, pool=1).to(device)
print(model(a).shape)


torch.Size([479, 2])


In [99]:
print(model)

net(
  (_lstm): LSTM(60, 30, num_layers=4, batch_first=True)
  (lstm): ModuleList(
    (0-18): 19 x LSTM(60, 30, num_layers=4, batch_first=True)
  )
  (cnn_block): Sequential(
    (0): Conv2d(1, 10, kernel_size=(19, 1), stride=(1, 1))
    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Dropout(p=0.2, inplace=False)
  )
  (fcn): Sequential(
    (0): Linear(in_features=300, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
  (results): Softmax(dim=1)
)


In [101]:
model.eval()
with torch.no_grad():
    a = torch.rand(49, 1, 19, 3000).cuda()

    out_data = model(a)
    print(out_data.shape)

torch.Size([49, 2])
