In [1]:
import os
import warnings
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
#os.environ['CUDA_LAUNCH_BLOCKING'] = str(1)
#os.environ["TORCH_USE_CUDA_DSA"]= str(0)
warnings.filterwarnings('ignore') 


In [2]:
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import math
from collections import OrderedDict
import random
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset
import sys
import torch
import numpy as np
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mne
from sklearn.preprocessing import StandardScaler
import logging
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
import statistics
import torch.optim.lr_scheduler as lr_scheduler
from scipy.special import softmax
from sklearn.model_selection import StratifiedKFold

In [3]:
logging.getLogger('mne').setLevel(logging.WARNING)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import import_ipynb
#from Model import net
from Dataloader2 import EEGDataset

importing Jupyter notebook from Dataloader2.ipynb


In [4]:
def Accuracy(y_pred, y, train_count):
    max_values, _ = torch.max(y_pred, dim=1, keepdim=True)
    mask = y_pred == max_values
    y_pred = mask.int()
    correct_num = torch.sum(torch.all(torch.eq(y, y_pred), dim=1)).item()
    accuracy = correct_num / train_count
    return accuracy * 100

In [5]:
def test(model, test_path, test_class, verbose=True):
    

    x_test = mne.read_epochs(test_path, preload=False).get_data(picks='eeg');
    normals = []
    scaler = StandardScaler()
    for idx in range(len(x_test)):
        normals.append(scaler.fit_transform(x_test[idx]))
    normals = torch.tensor(normals).cuda().float()
    result = torch.argmax(model(normals), axis=1)
    unique_elements, counts = torch.unique(result, return_counts=True)

    votes = np.zeros([4])
    for i in range(len(unique_elements)):
        votes[unique_elements[i]] = counts[i]


    if(verbose):
        print(f"Test Accuracy: {(votes[test_class] / result.shape[0]) * 100}")
    return votes


In [6]:
class net(nn.Module):
    def __init__(self, T, C, input_size, hidden_size, num_layers, spatial_num, dropout, pool):
        super(net, self).__init__()
        
        self.T = T
        self.C = C
        self.spatial_num = spatial_num
        self.dropout = dropout
        self.pool = pool

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cell_count = self.T // self.input_size

        self.fcn_in = (spatial_num * self.hidden_size)

        self._lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)

        self.lstm = nn.ModuleList([self._lstm for i in range(self.C)])

        self.cnn_block = nn.Sequential(nn.Conv2d(1, self.spatial_num, (self.C, 1)),
                                       nn.BatchNorm2d(self.spatial_num),
                                       nn.ELU(),



                                       nn.Dropout(self.dropout))

        
        self.fcn = nn.Sequential(nn.Linear(self.fcn_in, 128), 
                                 nn.ReLU(),
                                 nn.Linear(128, 16),
                                 nn.Dropout(self.dropout),
                                 nn.ReLU(),
                                 nn.Linear(16, 4))

        #self.fcn = nn.Linear(self.fcn_in, 4)
        self.results = nn.Softmax(dim=1)
    def forward(self, x):
        x = x.reshape(-1, 1, 19, 3000)
        self.N = x.shape[0]
        x = x.reshape(self.N, self.C, self.cell_count, self.input_size)
        _x = None

        for index, cell in enumerate(self.lstm):
            cell_out, _ = cell(x[:, index, :, :], None)
            last_layer_out = cell_out[:, -1, :]
            
            last_layer_out = last_layer_out.unsqueeze(0)
            if _x is None:
                _x = last_layer_out
            else:
                _x = torch.cat((_x, last_layer_out), dim=0)
            

        x = _x.permute(1, 0, 2).unsqueeze(1)

        x = self.cnn_block(x)


        x = x.reshape(self.N, -1)

        x = self.fcn(x)
        x = self.results(x)

        return x

In [7]:
MNE_Data = EEGDataset(root_dir=r"C:\Users\admin\Desktop\MNE Data")
labels = MNE_Data.labels
#MNE_Data = EEGDataset(root_dir=r"D:\TEST MNE")
test_path = r"C:\Users\admin\Desktop\TEST\amirifateme.fif"
test_class = 0

In [8]:
train_dataloader = DataLoader(MNE_Data, batch_size=1, shuffle=True)

In [9]:
model = net(T = 3000, C = 19, input_size = 300, hidden_size = 100, num_layers=1, spatial_num= 300, dropout=0.5, pool=1).to(device)

num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters:", num_params)

Number of parameters: 4009660


In [10]:
criterion = nn.CrossEntropyLoss(weight = torch.Tensor([5.3125, 2.8333, 3.5417, 5.6667]).cuda())

optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.Adagrad(model.parameters(), lr=0.001)


scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
epochs = 80
kf = StratifiedKFold(n_splits=5)

In [11]:
for train_indices, val_indices in kf.split(MNE_Data, labels):
    train_dataset = torch.utils.data.Subset(MNE_Data, train_indices)
    val_dataset = torch.utils.data.Subset(MNE_Data, val_indices)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

    test_log = []
    log = []

    for epoch in trange(epochs):

        model.train()   
        for index, data in enumerate(train_dataloader):
                
                x, y = data
                y = y.to(torch.float64)
                x = x.reshape(-1, 1, 19, 3000).float()
                #x = x[torch.randperm(x.shape[0])]
                y = F.one_hot(torch.tensor(torch.tensor([y.item()]).to(torch.int64)), num_classes=4).expand(x.shape[0], -1).float()
                
                train_count = x.shape[0]
                x, y = x.to(device), y.to(device)
                y_pred = model(x)
                loss = criterion(y_pred, y)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #scheduler.step()
                torch.cuda.empty_cache()


        model.eval()
        val_loss = 0
        val_log = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(val_dataloader):
                        data = data.reshape(-1, 1, 19, 3000).float().cuda()
                        target = F.one_hot(torch.tensor(torch.tensor([target.item()]).to(torch.int64)), num_classes=4).expand(data.shape[0], -1).float().cuda()
                        

                        output = model(data)

                        

                        loss = criterion(output, target)
                        
                        val_loss += loss.item()

                        val_log.append(Accuracy(output, target, target.shape[0]))

                        
        print('Epoch: {}, Validation Loss: {:.4f}, Validation Accuracy: {:.2f}%'.format(
                    epoch, val_loss / len(val_dataloader), statistics.mean(val_log)))

        


    


  1%|▏         | 1/80 [00:44<58:23, 44.35s/it]

Epoch: 0, Validation Loss: 5.5194, Validation Accuracy: 17.74%


  2%|▎         | 2/80 [01:27<57:01, 43.87s/it]

Epoch: 1, Validation Loss: 5.5273, Validation Accuracy: 18.27%


  4%|▍         | 3/80 [02:11<56:17, 43.86s/it]

Epoch: 2, Validation Loss: 5.5213, Validation Accuracy: 18.08%


  5%|▌         | 4/80 [02:56<56:07, 44.31s/it]

Epoch: 3, Validation Loss: 5.5267, Validation Accuracy: 18.08%


  6%|▋         | 5/80 [03:41<55:25, 44.34s/it]

Epoch: 4, Validation Loss: 5.5205, Validation Accuracy: 19.33%


  8%|▊         | 6/80 [04:25<54:44, 44.39s/it]

Epoch: 5, Validation Loss: 5.5143, Validation Accuracy: 18.18%


  9%|▉         | 7/80 [05:10<54:10, 44.52s/it]

Epoch: 6, Validation Loss: 5.5238, Validation Accuracy: 19.40%


 10%|█         | 8/80 [05:55<53:29, 44.57s/it]

Epoch: 7, Validation Loss: 5.5088, Validation Accuracy: 24.17%


 11%|█▏        | 9/80 [06:40<52:58, 44.76s/it]

Epoch: 8, Validation Loss: 5.5265, Validation Accuracy: 22.72%


 12%|█▎        | 10/80 [07:24<52:09, 44.71s/it]

Epoch: 9, Validation Loss: 5.5118, Validation Accuracy: 28.03%


 14%|█▍        | 11/80 [08:09<51:20, 44.64s/it]

Epoch: 10, Validation Loss: 5.5300, Validation Accuracy: 27.80%


 15%|█▌        | 12/80 [08:54<50:44, 44.77s/it]

Epoch: 11, Validation Loss: 5.5207, Validation Accuracy: 28.62%


 16%|█▋        | 13/80 [09:39<50:06, 44.88s/it]

Epoch: 12, Validation Loss: 5.5364, Validation Accuracy: 26.27%


 18%|█▊        | 14/80 [10:24<49:28, 44.98s/it]

Epoch: 13, Validation Loss: 5.4431, Validation Accuracy: 32.03%


 19%|█▉        | 15/80 [11:10<48:57, 45.19s/it]

Epoch: 14, Validation Loss: 5.5144, Validation Accuracy: 30.50%


 20%|██        | 16/80 [11:55<48:10, 45.16s/it]

Epoch: 15, Validation Loss: 5.5139, Validation Accuracy: 29.96%


 21%|██▏       | 17/80 [12:41<47:37, 45.35s/it]

Epoch: 16, Validation Loss: 5.5134, Validation Accuracy: 28.88%


 22%|██▎       | 18/80 [13:26<46:56, 45.43s/it]

Epoch: 17, Validation Loss: 5.5511, Validation Accuracy: 26.51%


 24%|██▍       | 19/80 [14:12<46:08, 45.39s/it]

Epoch: 18, Validation Loss: 5.6279, Validation Accuracy: 30.20%


 25%|██▌       | 20/80 [14:57<45:25, 45.42s/it]

Epoch: 19, Validation Loss: 5.5638, Validation Accuracy: 28.20%


 26%|██▋       | 21/80 [15:42<44:37, 45.38s/it]

Epoch: 20, Validation Loss: 5.5678, Validation Accuracy: 29.18%


 28%|██▊       | 22/80 [16:28<43:50, 45.35s/it]

Epoch: 21, Validation Loss: 5.5725, Validation Accuracy: 32.16%


 29%|██▉       | 23/80 [17:14<43:11, 45.47s/it]

Epoch: 22, Validation Loss: 5.5489, Validation Accuracy: 31.91%


 30%|███       | 24/80 [17:59<42:21, 45.38s/it]

Epoch: 23, Validation Loss: 5.5423, Validation Accuracy: 31.48%


 31%|███▏      | 25/80 [18:45<41:50, 45.64s/it]

Epoch: 24, Validation Loss: 5.6251, Validation Accuracy: 31.26%


 32%|███▎      | 26/80 [19:31<41:07, 45.70s/it]

Epoch: 25, Validation Loss: 5.5264, Validation Accuracy: 32.77%


 34%|███▍      | 27/80 [20:16<40:20, 45.67s/it]

Epoch: 26, Validation Loss: 5.6226, Validation Accuracy: 31.42%


 35%|███▌      | 28/80 [21:02<39:27, 45.53s/it]

Epoch: 27, Validation Loss: 5.6095, Validation Accuracy: 33.81%


 36%|███▋      | 29/80 [21:47<38:39, 45.48s/it]

Epoch: 28, Validation Loss: 5.4866, Validation Accuracy: 33.87%


 38%|███▊      | 30/80 [22:33<37:56, 45.53s/it]

Epoch: 29, Validation Loss: 5.5684, Validation Accuracy: 33.86%


 39%|███▉      | 31/80 [23:18<37:09, 45.51s/it]

Epoch: 30, Validation Loss: 5.5871, Validation Accuracy: 31.47%


 40%|████      | 32/80 [24:04<36:26, 45.55s/it]

Epoch: 31, Validation Loss: 5.7402, Validation Accuracy: 30.76%


 41%|████▏     | 33/80 [24:50<35:47, 45.70s/it]

Epoch: 32, Validation Loss: 5.6862, Validation Accuracy: 31.13%


 42%|████▎     | 34/80 [25:35<35:01, 45.69s/it]

Epoch: 33, Validation Loss: 5.6624, Validation Accuracy: 32.10%


In [None]:
model.eval()
val_loss = 0
correct = 0
total = 0
epoch = 2
val_log = []
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(val_dataloader):
                data = data.reshape(-1, 1, 19, 3000).float().cuda()
                target = F.one_hot(torch.tensor(torch.tensor([target.item()]).to(torch.int64)), num_classes=4).expand(data.shape[0], -1).float().cuda()
                

                output = model(data)

                

                loss = criterion(output, target)
                
                val_loss += loss.item()

                val_log.append(Accuracy(output, target, target.shape[0]))

                """
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
                """
print('Epoch: {}, Validation Loss: {:.4f}, Validation Accuracy: {:.2f}%'.format(
            epoch, val_loss / len(val_dataloader), statistics.mean(val_log)))

Epoch: 2, Validation Loss: 5.9305, Validation Accuracy: 35.27%
