In [14]:
import os
import warnings
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
#os.environ['CUDA_LAUNCH_BLOCKING'] = str(1)
#os.environ["TORCH_USE_CUDA_DSA"]= str(0)
warnings.filterwarnings('ignore') 


In [15]:
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import math
from collections import OrderedDict
import random
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset
import sys
import torch
import numpy as np
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mne
from sklearn.preprocessing import StandardScaler
import logging
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
import statistics
import torch.optim.lr_scheduler as lr_scheduler
from scipy.special import softmax

In [16]:
logging.getLogger('mne').setLevel(logging.WARNING)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import import_ipynb
#from Model import net
from Dataloader2 import EEGDataset

In [17]:
def Accuracy(y_pred, y, train_count):
    max_values, _ = torch.max(y_pred, dim=1, keepdim=True)
    mask = y_pred == max_values
    y_pred = mask.int()
    correct_num = torch.sum(torch.all(torch.eq(y, y_pred), dim=1)).item()
    accuracy = correct_num / train_count
    return accuracy * 100

In [18]:
def test(model, test_path, test_class, verbose=True):
    

    x_test = mne.read_epochs(test_path, preload=False).get_data(picks='eeg');
    normals = []
    scaler = StandardScaler()
    for idx in range(len(x_test)):
        normals.append(scaler.fit_transform(x_test[idx]))
    normals = torch.tensor(normals).cuda().float()
    result = torch.argmax(model(normals), axis=1)
    unique_elements, counts = torch.unique(result, return_counts=True)

    votes = np.zeros([4])
    for i in range(len(unique_elements)):
        votes[unique_elements[i]] = counts[i]


    if(verbose):
        print(f"Test Accuracy: {(votes[test_class] / result.shape[0]) * 100}")
    return votes


In [19]:
class net(nn.Module):
    def __init__(self, T, C, input_size, hidden_size, num_layers, spatial_num, dropout, pool):
        super(net, self).__init__()
        
        self.T = T
        self.C = C
        self.spatial_num = spatial_num
        self.dropout = dropout
        self.pool = pool

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cell_count = self.T // self.input_size

        self.fcn_in = (spatial_num * self.hidden_size)

        self._lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)

        self.lstm = nn.ModuleList([self._lstm for i in range(self.C)])

        self.cnn_block = nn.Sequential(nn.Conv2d(1, self.spatial_num, (self.C, 1)),
                                       nn.BatchNorm2d(self.spatial_num),
                                       nn.ELU(),



                                       nn.Dropout(self.dropout))

        
        self.fcn = nn.Sequential(nn.Linear(self.fcn_in, 128), 
                                 nn.ReLU(),
                                 nn.Linear(128, 16),
                                 nn.Dropout(self.dropout),
                                 nn.ReLU(),
                                 nn.Linear(16, 4))

        #self.fcn = nn.Linear(self.fcn_in, 4)
        self.results = nn.Softmax(dim=1)
    def forward(self, x):
        x = x.reshape(-1, 1, 19, 3000)
        self.N = x.shape[0]
        x = x.reshape(self.N, self.C, self.cell_count, self.input_size)
        _x = None

        for index, cell in enumerate(self.lstm):
            cell_out, _ = cell(x[:, index, :, :], None)
            last_layer_out = cell_out[:, -1, :]
            
            last_layer_out = last_layer_out.unsqueeze(0)
            if _x is None:
                _x = last_layer_out
            else:
                _x = torch.cat((_x, last_layer_out), dim=0)
            

        x = _x.permute(1, 0, 2).unsqueeze(1)

        x = self.cnn_block(x)


        x = x.reshape(self.N, -1)

        x = self.fcn(x)
        x = self.results(x)

        return x

In [20]:
MNE_Data = EEGDataset(root_dir=r"C:\Users\admin\Desktop\MNE Data")
#MNE_Data = EEGDataset(root_dir=r"D:\TEST MNE")
test_path = r"C:\Users\admin\Desktop\TEST\amirifateme.fif"
test_class = 0

In [21]:
batch_size = 1
train_dataloader = DataLoader(MNE_Data, batch_size=batch_size, shuffle=True)

In [22]:
model = net(T = 3000, C = 19, input_size = 300, hidden_size = 100, num_layers=1, spatial_num= 30, dropout=0.5, pool=1).to(device)

In [23]:
criterion = nn.CrossEntropyLoss(weight = torch.Tensor([5.3125, 2.8333, 3.5417, 5.6667]).cuda())

optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.Adagrad(model.parameters(), lr=0.001)


scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
epochs = 80

In [24]:
test_log = []
log = []
for epoch in trange(epochs):
        model.train()
        running_loss = 0.0
        correct_num = 0
        
        

        for index, data in enumerate(train_dataloader):
            
            x, y = data
            y = y.to(torch.float64)
            x = x.reshape(-1, 1, 19, 3000).float()
            x = x[torch.randperm(x.shape[0])]
            y = F.one_hot(torch.tensor(torch.tensor([y.item()]).to(torch.int64)), num_classes=4).expand(x.shape[0], -1).float()
            
            train_count = x.shape[0]
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()
            log.append(Accuracy(y_pred, y, train_count))
            torch.cuda.empty_cache()
        
        votes = test(model, test_path, test_class)
        test_log.append(votes)
        print(f"mean train accuracy across epoch: {statistics.mean(log)}")
        
        #print(f"Train Accuracy:{correct_num / train_count}")
        """
        """

  1%|▏         | 1/80 [00:46<1:01:30, 46.72s/it]

Test Accuracy: 29.346092503987244
mean train accuracy across epoch: 29.777560828869706


  2%|▎         | 2/80 [01:34<1:01:25, 47.25s/it]

Test Accuracy: 23.923444976076556
mean train accuracy across epoch: 28.29943812061339


  4%|▍         | 3/80 [02:21<1:00:46, 47.36s/it]

Test Accuracy: 38.75598086124402
mean train accuracy across epoch: 28.364200325229696


  5%|▌         | 4/80 [03:11<1:01:06, 48.24s/it]

Test Accuracy: 23.923444976076556
mean train accuracy across epoch: 28.53071286655868


  6%|▋         | 5/80 [04:19<1:09:13, 55.38s/it]

Test Accuracy: 21.371610845295056
mean train accuracy across epoch: 28.820400253078013


  8%|▊         | 6/80 [05:07<1:05:21, 53.00s/it]

Test Accuracy: 56.45933014354066
mean train accuracy across epoch: 29.42675819513106


  9%|▉         | 7/80 [05:54<1:01:58, 50.94s/it]

Test Accuracy: 19.457735247208934
mean train accuracy across epoch: 29.903306562151194


 10%|█         | 8/80 [06:41<59:38, 49.70s/it]  

Test Accuracy: 30.303030303030305
mean train accuracy across epoch: 30.774760677836678


 11%|█▏        | 9/80 [07:28<57:48, 48.86s/it]

Test Accuracy: 24.720893141945773
mean train accuracy across epoch: 31.663214606278537


 12%|█▎        | 10/80 [08:15<56:15, 48.22s/it]

Test Accuracy: 13.237639553429027
mean train accuracy across epoch: 32.31637047306193


 14%|█▍        | 11/80 [09:01<54:52, 47.72s/it]

Test Accuracy: 13.716108452950559
mean train accuracy across epoch: 33.25371741424269


 15%|█▌        | 12/80 [09:48<53:37, 47.31s/it]

Test Accuracy: 17.384370015948964
mean train accuracy across epoch: 34.23390466551838


 16%|█▋        | 13/80 [10:45<56:10, 50.31s/it]

Test Accuracy: 8.293460925039872
mean train accuracy across epoch: 35.27303946737012


 18%|█▊        | 14/80 [11:32<54:05, 49.18s/it]

Test Accuracy: 8.771929824561402
mean train accuracy across epoch: 36.430282166004744


 19%|█▉        | 15/80 [12:19<52:34, 48.53s/it]

Test Accuracy: 15.47049441786284
mean train accuracy across epoch: 37.38801851301276


In [None]:
test_log = np.array(test_log)
test_log[np.argmax(test_log[:, 0])] / test_log[np.argmax(test_log[:, 0])].sum()


array([0.46808511, 0.39209726, 0.10030395, 0.03951368])

In [None]:
test_log

array([[  0., 192., 123.,  14.],
       [  0.,  99., 179.,  51.],
       [  0., 129., 134.,  66.],
       [ 26.,  97., 152.,  54.],
       [  4., 146., 133.,  46.],
       [109.,  87., 107.,  26.],
       [ 78., 119., 103.,  29.],
       [ 33.,  79., 113., 104.],
       [112.,  99.,  84.,  34.],
       [ 73., 113.,  57.,  86.],
       [ 58., 129.,  67.,  75.],
       [ 49., 128.,  71.,  81.],
       [104., 151.,  30.,  44.],
       [ 81., 206.,  23.,  19.],
       [ 73., 208.,  42.,   6.],
       [ 73., 143.,  14.,  99.],
       [ 43., 195.,  69.,  22.],
       [ 63., 214.,  33.,  19.],
       [ 83., 221.,  22.,   3.],
       [154., 129.,  33.,  13.],
       [104., 154.,  33.,  38.],
       [ 69., 225.,  30.,   5.],
       [104., 184.,  33.,   8.],
       [ 93., 145.,  56.,  35.],
       [ 73., 197.,  55.,   4.],
       [ 74., 147.,  81.,  27.],
       [ 26., 175., 118.,  10.],
       [ 61., 219.,  43.,   6.],
       [ 59., 162., 102.,   6.],
       [ 54., 219.,  46.,  10.],
       [ 4

In [None]:
"""
for epoch in trange(epochs):
        model.train()
        running_loss = 0.0
        correct_num = 0
        log = []
        for index, data in enumerate(train_dataloader):
            
            x, y = data
            y = y.to(torch.float64)
            print(y)
            
            x = x.reshape(-1, 1, 19, 3000).float()
            y = torch.tensor([y.item()]).to(torch.int64).expand(x.shape[0], -1).float()
            
            
            train_count = x.shape[0]
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            print(y[-1])
            print(Accuracy(y_pred, y, train_count))
            log.append(Accuracy(y_pred, y, train_count))
            torch.cuda.empty_cache()
            
        
        """
        #print(f"mean accuracy across epoch: {statistics.mean(log)}")
        
        #print(f"Train Accuracy:{correct_num / train_count}")

'\nfor epoch in trange(epochs):\n        model.train()\n        running_loss = 0.0\n        correct_num = 0\n        log = []\n        for index, data in enumerate(train_dataloader):\n            \n            x, y = data\n            y = y.to(torch.float64)\n            print(y)\n            \n            x = x.reshape(-1, 1, 19, 3000).float()\n            y = torch.tensor([y.item()]).to(torch.int64).expand(x.shape[0], -1).float()\n            \n            \n            train_count = x.shape[0]\n            x, y = x.to(device), y.to(device)\n            y_pred = model(x)\n            loss = criterion(y_pred, y)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            scheduler.step()\n            print(y[-1])\n            print(Accuracy(y_pred, y, train_count))\n            log.append(Accuracy(y_pred, y, train_count))\n            torch.cuda.empty_cache()\n            \n        \n        '

In [None]:
file_path = 'D:\model weights\model_weights.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))

<All keys matched successfully>

In [None]:
"""
y = F.one_hot(torch.randint(low=0, high=4, size=(1,)), 4).float()
yhat = torch.rand((1, 4)).float()
criterion(y, yhat)
"""

'\ny = F.one_hot(torch.randint(low=0, high=4, size=(1,)), 4).float()\nyhat = torch.rand((1, 4)).float()\ncriterion(y, yhat)\n'

In [None]:
"""
__ = (torch.round(y_pred).to(torch.int64) == y)
preds = np.all(__.cpu().numpy(), axis=1)
correct_num += np.count_nonzero(preds)
"""

'\n__ = (torch.round(y_pred).to(torch.int64) == y)\npreds = np.all(__.cpu().numpy(), axis=1)\ncorrect_num += np.count_nonzero(preds)\n'