In [1]:
import os
import warnings
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
#os.environ['CUDA_LAUNCH_BLOCKING'] = str(1)
#os.environ["TORCH_USE_CUDA_DSA"]= str(0)
warnings.filterwarnings('ignore') 


In [2]:
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import math
from collections import OrderedDict
import random
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset
import sys
import torch
import numpy as np
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mne
from sklearn.preprocessing import StandardScaler
import logging
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
import statistics
import torch.optim.lr_scheduler as lr_scheduler
from scipy.special import softmax
from sklearn.model_selection import StratifiedKFold

In [3]:
logging.getLogger('mne').setLevel(logging.WARNING)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import import_ipynb
#from Model import net
from Dataloader2 import EEGDataset

importing Jupyter notebook from Dataloader2.ipynb


In [4]:
def Accuracy(y_pred, y, train_count):
    max_values, _ = torch.max(y_pred, dim=1, keepdim=True)
    mask = y_pred == max_values
    y_pred = mask.int()
    correct_num = torch.sum(torch.all(torch.eq(y, y_pred), dim=1)).item()
    accuracy = correct_num / train_count
    return accuracy * 100

In [5]:
def test(model, test_path, test_class, verbose=True):
    

    x_test = mne.read_epochs(test_path, preload=False).get_data(picks='eeg');
    normals = []
    scaler = StandardScaler()
    for idx in range(len(x_test)):
        normals.append(scaler.fit_transform(x_test[idx]))
    normals = torch.tensor(normals).cuda().float()
    out = model(normals)
    result = torch.argmax(out, axis=1)
    unique_elements, counts = torch.unique(result, return_counts=True)

    votes = np.zeros([4])
    for i in range(len(unique_elements)):
        votes[unique_elements[i]] = counts[i]


    if(verbose):
        print(f"Test Accuracy: {(votes[test_class] / result.shape[0]) * 100}")
    return votes, out


In [6]:
class net(nn.Module):
    def __init__(self, T, C, input_size, hidden_size, num_layers, spatial_num, dropout, pool):
        super(net, self).__init__()
        
        self.T = T
        self.C = C
        self.spatial_num = spatial_num
        self.dropout = dropout
        self.pool = pool

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cell_count = self.T // self.input_size

        self.fcn_in = (spatial_num * self.hidden_size)

        self._lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)

        self.lstm = nn.ModuleList([self._lstm for i in range(self.C)])

        self.cnn_block = nn.Sequential(nn.Conv2d(1, self.spatial_num, (self.C, 1)),
                                       nn.BatchNorm2d(self.spatial_num),
                                       nn.ELU(),



                                       nn.Dropout(self.dropout))

        
        self.fcn = nn.Sequential(nn.Linear(self.fcn_in, 128), 
                                 nn.ReLU(),
                                 nn.Linear(128, 16),
                                 nn.Dropout(self.dropout),
                                 nn.ReLU(),
                                 nn.Linear(16, 4))

        #self.fcn = nn.Linear(self.fcn_in, 4)
        self.results = nn.Softmax(dim=1)
    def forward(self, x):
        x = x.reshape(-1, 1, 19, 3000)
        self.N = x.shape[0]
        x = x.reshape(self.N, self.C, self.cell_count, self.input_size)
        _x = None

        for index, cell in enumerate(self.lstm):
            cell_out, _ = cell(x[:, index, :, :], None)
            last_layer_out = cell_out[:, -1, :]
            
            last_layer_out = last_layer_out.unsqueeze(0)
            if _x is None:
                _x = last_layer_out
            else:
                _x = torch.cat((_x, last_layer_out), dim=0)
            

        x = _x.permute(1, 0, 2).unsqueeze(1)

        x = self.cnn_block(x)


        x = x.reshape(self.N, -1)

        x = self.fcn(x)
        x = self.results(x)

        return x

In [7]:
class CNN_LSTM(nn.Module):
    def __init__(self, num_classes, hidden_size):
        super(CNN_LSTM, self).__init__()
        
        self.num_classes = num_classes
        self.hidden_size =hidden_size
        self.dropout = 0.5
        # CNN layers
        self.conv1 = nn.Conv2d(19, 8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(8, 4, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2)
        
        # LSTM layers
        self.lstm = nn.LSTM(3000, self.hidden_size, num_layers=1, batch_first=True)
        
        # Fully connected layer
        self.fc = nn.Sequential(nn.Linear(self.hidden_size , 16),
                                 nn.ReLU(),
                                 nn.Linear(16, self.num_classes))
        

        self.results = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        x = x.reshape(-1, 19, 3000)
        self.N = x.shape[0]

        x = x.reshape(-1, 19, 1, 3000)
        # Apply CNN
        
        x = F.relu(self.conv1(x))
        #x = self.pool(x)
        xx = x
        
        x = F.relu(self.conv2(x))
        #x = self.pool(x)
        
        x = F.relu(self.conv3(x))
        # Reshape for LSTM
        
        #x = x.transpose(1, 2)
        
        # Apply LSTM
        
        x = x.reshape(-1, 1, 3000)
        
        h_n, (__, _) = self.lstm(x)

        x = h_n[:, :, :]
        
        
        # Apply fully connected layer
        x = x.reshape(self.N, -1)
        x = self.fc(x)
        x = self.results(x)
        """
        """
        return x
    
model = CNN_LSTM(num_classes=4, hidden_size=100).cuda()
model(torch.rand(7, 19, 3000).cuda()).shape

torch.Size([7, 4])

In [8]:
MNE_Data = EEGDataset(root_dir=r"C:\Users\admin\Desktop\MNE_Data")
labels = MNE_Data.labels
#MNE_Data = EEGDataset(root_dir=r"D:\TEST MNE")
test_path = r"C:\Users\admin\Desktop\TEST\amirifateme.fif"
test_class = 0

In [9]:
train_dataloader = DataLoader(MNE_Data, batch_size=1, shuffle=True)

In [10]:
config = net(T = 3000, C = 19, input_size = 3000, hidden_size = 30, num_layers=1, spatial_num= 300, dropout=0.5, pool=1).to(device)
#config = CNN_LSTM(num_classes=4, hidden_size=30).cuda()
num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters:", num_params)

Number of parameters: 1244189


In [11]:
#criterion = nn.CrossEntropyLoss(weight = torch.Tensor([5.3125, 1.8333, 3.5417, 6.6667]).cuda())
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.Adagrad(model.parameters(), lr=0.001)


scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
epochs = 60
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
model_list = []

In [12]:
for train_indices, val_indices in kf.split(MNE_Data, labels):
    train_dataset = torch.utils.data.Subset(MNE_Data, train_indices)
    val_dataset = torch.utils.data.Subset(MNE_Data, val_indices)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

    test_log = []
    log = []


    model = config
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in trange(epochs):
        
        model.train()   
        for index, data in enumerate(train_dataloader):
                
                x, y = data
                y = y.to(torch.float64)
                x = x.reshape(-1, 1, 19, 3000).float()
                x = x[torch.randperm(x.shape[0])]
                y = F.one_hot(torch.tensor(torch.tensor([y.item()]).to(torch.int64)), num_classes=4).expand(x.shape[0], -1).float()
                
                train_count = x.shape[0]
                x, y = x.to(device), y.to(device)
                y_pred = model(x)
                loss = criterion(y_pred, y)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #scheduler.step()
                torch.cuda.empty_cache()

    
        model.eval()
        val_loss = 0
        val_log = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(val_dataloader):
                        data = data.reshape(-1, 1, 19, 3000).float().cuda()
                        target = F.one_hot(torch.tensor(torch.tensor([target.item()]).to(torch.int64)), num_classes=4).expand(data.shape[0], -1).float().cuda()
                        

                        output = model(data)

                        

                        loss = criterion(output, target)
                        
                        val_loss += loss.item()

                        val_log.append(Accuracy(output, target, target.shape[0]))

        print(test(test_path=r"C:\Users\admin\Desktop\TEST\IGE\borzoimohamad.fif", model=model, test_class= 1)[0])     
        print('Epoch: {}, Validation Loss: {:.4f}, Validation Accuracy: {:.2f}%'.format(
                    epoch, val_loss / len(val_dataloader), statistics.mean(val_log)))

    model_list.append(model)
        


    


  2%|▏         | 1/60 [00:57<56:07, 57.07s/it]

Test Accuracy: 0.0
[  0.   0. 361.  10.]
Epoch: 0, Validation Loss: 1.3892, Validation Accuracy: 22.24%


  3%|▎         | 2/60 [01:53<54:56, 56.84s/it]

Test Accuracy: 0.0
[  0.   0. 364.   7.]
Epoch: 1, Validation Loss: 1.3892, Validation Accuracy: 22.03%


  5%|▌         | 3/60 [02:51<54:13, 57.09s/it]

Test Accuracy: 0.0
[  0.   0. 295.  76.]
Epoch: 2, Validation Loss: 1.3873, Validation Accuracy: 23.30%


  7%|▋         | 4/60 [03:49<53:40, 57.52s/it]

Test Accuracy: 0.0
[ 13.   0. 258. 100.]
Epoch: 3, Validation Loss: 1.3823, Validation Accuracy: 25.72%


  8%|▊         | 5/60 [04:47<52:53, 57.70s/it]

Test Accuracy: 0.0
[  8.   0. 270.  93.]
Epoch: 4, Validation Loss: 1.3800, Validation Accuracy: 25.99%


 10%|█         | 6/60 [05:45<52:07, 57.92s/it]

Test Accuracy: 0.0
[115.   0. 220.  36.]
Epoch: 5, Validation Loss: 1.3696, Validation Accuracy: 30.20%


 12%|█▏        | 7/60 [06:45<51:34, 58.39s/it]

Test Accuracy: 0.0
[ 31.   0. 303.  37.]
Epoch: 6, Validation Loss: 1.3698, Validation Accuracy: 29.77%


 13%|█▎        | 8/60 [07:44<50:46, 58.59s/it]

Test Accuracy: 0.0
[ 46.   0. 276.  49.]
Epoch: 7, Validation Loss: 1.3633, Validation Accuracy: 31.19%


 15%|█▌        | 9/60 [08:43<49:56, 58.75s/it]

Test Accuracy: 0.0
[ 27.   0. 147. 197.]
Epoch: 8, Validation Loss: 1.3642, Validation Accuracy: 32.67%


 17%|█▋        | 10/60 [09:43<49:19, 59.19s/it]

Test Accuracy: 0.0
[ 19.   0. 276.  76.]
Epoch: 9, Validation Loss: 1.3749, Validation Accuracy: 32.34%


 18%|█▊        | 11/60 [10:42<48:24, 59.27s/it]

Test Accuracy: 0.0
[ 12.   0. 274.  85.]
Epoch: 10, Validation Loss: 1.3763, Validation Accuracy: 33.28%


 20%|██        | 12/60 [11:42<47:31, 59.41s/it]

Test Accuracy: 0.0
[ 58.   0. 190. 123.]
Epoch: 11, Validation Loss: 1.3560, Validation Accuracy: 35.47%


 22%|██▏       | 13/60 [12:42<46:37, 59.53s/it]

Test Accuracy: 0.0
[ 26.   0. 170. 175.]
Epoch: 12, Validation Loss: 1.3660, Validation Accuracy: 34.85%


 23%|██▎       | 14/60 [13:47<46:50, 61.10s/it]

Test Accuracy: 0.0
[ 22.   0. 283.  66.]
Epoch: 13, Validation Loss: 1.3757, Validation Accuracy: 33.39%


 25%|██▌       | 15/60 [14:47<45:35, 60.80s/it]

Test Accuracy: 0.2695417789757413
[ 27.   1. 274.  69.]
Epoch: 14, Validation Loss: 1.3662, Validation Accuracy: 34.48%


 27%|██▋       | 16/60 [15:47<44:29, 60.66s/it]

Test Accuracy: 4.8517520215633425
[ 30.  18. 218. 105.]
Epoch: 15, Validation Loss: 1.3452, Validation Accuracy: 36.63%


 28%|██▊       | 17/60 [16:47<43:23, 60.54s/it]

Test Accuracy: 4.5822102425876015
[ 32.  17. 228.  94.]
Epoch: 16, Validation Loss: 1.3437, Validation Accuracy: 36.93%


 30%|███       | 18/60 [17:47<42:17, 60.43s/it]

Test Accuracy: 8.89487870619946
[ 26.  33. 245.  67.]
Epoch: 17, Validation Loss: 1.3456, Validation Accuracy: 36.39%


 32%|███▏      | 19/60 [18:47<41:12, 60.31s/it]

Test Accuracy: 10.781671159029651
[ 26.  40. 215.  90.]
Epoch: 18, Validation Loss: 1.3416, Validation Accuracy: 37.52%


 33%|███▎      | 20/60 [19:47<40:08, 60.22s/it]

Test Accuracy: 9.973045822102426
[ 17.  37. 245.  72.]
Epoch: 19, Validation Loss: 1.3497, Validation Accuracy: 37.02%


 35%|███▌      | 21/60 [20:49<39:18, 60.47s/it]

Test Accuracy: 16.711590296495956
[ 20.  62. 180. 109.]
Epoch: 20, Validation Loss: 1.3425, Validation Accuracy: 37.62%


 37%|███▋      | 22/60 [21:49<38:18, 60.49s/it]

Test Accuracy: 27.49326145552561
[ 29. 102. 174.  66.]
Epoch: 21, Validation Loss: 1.3315, Validation Accuracy: 38.63%


 38%|███▊      | 23/60 [22:49<37:13, 60.37s/it]

Test Accuracy: 28.30188679245283
[ 23. 105. 153.  90.]
Epoch: 22, Validation Loss: 1.3295, Validation Accuracy: 38.77%


 40%|████      | 24/60 [23:49<36:13, 60.36s/it]

Test Accuracy: 26.954177897574123
[ 29. 100. 151.  91.]
Epoch: 23, Validation Loss: 1.3334, Validation Accuracy: 38.97%


 42%|████▏     | 25/60 [24:50<35:18, 60.52s/it]

Test Accuracy: 19.67654986522911
[ 35.  73. 182.  81.]
Epoch: 24, Validation Loss: 1.3340, Validation Accuracy: 38.95%


 43%|████▎     | 26/60 [25:51<34:16, 60.50s/it]

Test Accuracy: 15.902964959568733
[ 31.  59. 195.  86.]
Epoch: 25, Validation Loss: 1.3332, Validation Accuracy: 39.21%


 45%|████▌     | 27/60 [26:52<33:23, 60.71s/it]

Test Accuracy: 15.09433962264151
[ 32.  56. 221.  62.]
Epoch: 26, Validation Loss: 1.3366, Validation Accuracy: 38.99%


 47%|████▋     | 28/60 [27:53<32:20, 60.65s/it]

Test Accuracy: 30.18867924528302
[ 33. 112. 163.  63.]
Epoch: 27, Validation Loss: 1.3273, Validation Accuracy: 39.85%


 48%|████▊     | 29/60 [28:53<31:21, 60.68s/it]

Test Accuracy: 29.380053908355798
[ 33. 109. 150.  79.]
Epoch: 28, Validation Loss: 1.3270, Validation Accuracy: 40.11%


 50%|█████     | 30/60 [29:54<30:22, 60.75s/it]

Test Accuracy: 41.77897574123989
[ 33. 155.  93.  90.]
Epoch: 29, Validation Loss: 1.3221, Validation Accuracy: 40.57%


 52%|█████▏    | 31/60 [30:54<29:16, 60.58s/it]

Test Accuracy: 34.77088948787062
[ 39. 129. 141.  62.]
Epoch: 30, Validation Loss: 1.3234, Validation Accuracy: 40.44%


 53%|█████▎    | 32/60 [31:54<28:09, 60.33s/it]

Test Accuracy: 33.9622641509434
[ 21. 126. 182.  42.]
Epoch: 31, Validation Loss: 1.3385, Validation Accuracy: 38.88%


 55%|█████▌    | 33/60 [32:53<26:56, 59.88s/it]

Test Accuracy: 39.35309973045822
[ 29. 146. 127.  69.]
Epoch: 32, Validation Loss: 1.3269, Validation Accuracy: 39.91%


 57%|█████▋    | 34/60 [33:52<25:52, 59.73s/it]

Test Accuracy: 32.34501347708895
[ 38. 120. 126.  87.]
Epoch: 33, Validation Loss: 1.3245, Validation Accuracy: 40.35%


 58%|█████▊    | 35/60 [34:51<24:48, 59.55s/it]

Test Accuracy: 27.49326145552561
[ 41. 102. 155.  73.]
Epoch: 34, Validation Loss: 1.3243, Validation Accuracy: 40.35%


 60%|██████    | 36/60 [35:51<23:48, 59.51s/it]

Test Accuracy: 26.41509433962264
[ 24.  98. 150.  99.]
Epoch: 35, Validation Loss: 1.3326, Validation Accuracy: 39.61%


 62%|██████▏   | 37/60 [36:51<22:53, 59.71s/it]

Test Accuracy: 43.66576819407008
[ 20. 162. 136.  53.]
Epoch: 36, Validation Loss: 1.3311, Validation Accuracy: 39.94%


 63%|██████▎   | 38/60 [37:49<21:42, 59.19s/it]

Test Accuracy: 35.309973045822105
[ 31. 131. 122.  87.]
Epoch: 37, Validation Loss: 1.3186, Validation Accuracy: 41.15%


 65%|██████▌   | 39/60 [38:47<20:32, 58.69s/it]

Test Accuracy: 32.34501347708895
[ 23. 120. 157.  71.]
Epoch: 38, Validation Loss: 1.3204, Validation Accuracy: 41.10%


 67%|██████▋   | 40/60 [39:44<19:28, 58.42s/it]

Test Accuracy: 33.692722371967655
[ 17. 125. 151.  78.]
Epoch: 39, Validation Loss: 1.3247, Validation Accuracy: 40.54%


 68%|██████▊   | 41/60 [40:42<18:26, 58.23s/it]

Test Accuracy: 29.110512129380055
[ 30. 108. 162.  71.]
Epoch: 40, Validation Loss: 1.3233, Validation Accuracy: 40.71%


 70%|███████   | 42/60 [41:40<17:25, 58.06s/it]

Test Accuracy: 43.66576819407008
[ 25. 162. 105.  79.]
Epoch: 41, Validation Loss: 1.3221, Validation Accuracy: 40.91%


 72%|███████▏  | 43/60 [42:37<16:22, 57.77s/it]

Test Accuracy: 31.805929919137466
[ 35. 118. 149.  69.]
Epoch: 42, Validation Loss: 1.3242, Validation Accuracy: 40.59%


 73%|███████▎  | 44/60 [43:34<15:20, 57.55s/it]

Test Accuracy: 24.528301886792452
[ 23.  91. 155. 102.]
Epoch: 43, Validation Loss: 1.3275, Validation Accuracy: 40.53%


 75%|███████▌  | 45/60 [44:31<14:22, 57.51s/it]

Test Accuracy: 21.563342318059302
[ 32.  80. 177.  82.]
Epoch: 44, Validation Loss: 1.3274, Validation Accuracy: 40.36%


 77%|███████▋  | 46/60 [45:29<13:23, 57.40s/it]

Test Accuracy: 27.762803234501348
[ 34. 103. 140.  94.]
Epoch: 45, Validation Loss: 1.3151, Validation Accuracy: 41.71%


 78%|███████▊  | 47/60 [46:26<12:25, 57.35s/it]

Test Accuracy: 34.23180592991914
[ 25. 127. 147.  72.]
Epoch: 46, Validation Loss: 1.3248, Validation Accuracy: 40.59%


 80%|████████  | 48/60 [47:23<11:26, 57.23s/it]

Test Accuracy: 33.9622641509434
[ 26. 126. 144.  75.]
Epoch: 47, Validation Loss: 1.3222, Validation Accuracy: 41.04%


 82%|████████▏ | 49/60 [48:20<10:29, 57.26s/it]

Test Accuracy: 41.23989218328841
[ 27. 153. 126.  65.]
Epoch: 48, Validation Loss: 1.3208, Validation Accuracy: 41.17%


 83%|████████▎ | 50/60 [49:17<09:32, 57.26s/it]

Test Accuracy: 37.46630727762803
[ 31. 139. 128.  73.]
Epoch: 49, Validation Loss: 1.3177, Validation Accuracy: 41.49%


 85%|████████▌ | 51/60 [50:14<08:35, 57.24s/it]

Test Accuracy: 38.00539083557952
[ 28. 141. 127.  75.]
Epoch: 50, Validation Loss: 1.3113, Validation Accuracy: 41.94%


 87%|████████▋ | 52/60 [51:11<07:37, 57.15s/it]

Test Accuracy: 43.126684636118604
[ 30. 160. 126.  55.]
Epoch: 51, Validation Loss: 1.3173, Validation Accuracy: 41.55%


 88%|████████▊ | 53/60 [52:08<06:39, 57.04s/it]

Test Accuracy: 38.274932614555254
[ 23. 142. 130.  76.]
Epoch: 52, Validation Loss: 1.3179, Validation Accuracy: 41.37%


 90%|█████████ | 54/60 [53:06<05:42, 57.15s/it]

Test Accuracy: 35.84905660377358
[ 24. 133. 145.  69.]
Epoch: 53, Validation Loss: 1.3191, Validation Accuracy: 41.34%


 92%|█████████▏| 55/60 [54:03<04:45, 57.11s/it]

Test Accuracy: 34.23180592991914
[ 28. 127. 121.  95.]
Epoch: 54, Validation Loss: 1.3152, Validation Accuracy: 41.83%


 93%|█████████▎| 56/60 [55:00<03:49, 57.26s/it]

Test Accuracy: 30.727762803234505
[ 30. 114. 157.  70.]
Epoch: 55, Validation Loss: 1.3197, Validation Accuracy: 41.29%


 95%|█████████▌| 57/60 [55:57<02:51, 57.27s/it]

Test Accuracy: 46.36118598382749
[ 27. 172. 119.  53.]
Epoch: 56, Validation Loss: 1.3138, Validation Accuracy: 41.83%


 97%|█████████▋| 58/60 [56:54<01:54, 57.03s/it]

Test Accuracy: 44.20485175202156
[ 22. 164. 128.  57.]
Epoch: 57, Validation Loss: 1.3193, Validation Accuracy: 41.40%


 98%|█████████▊| 59/60 [57:51<00:57, 57.10s/it]

Test Accuracy: 36.65768194070081
[ 22. 136. 137.  76.]
Epoch: 58, Validation Loss: 1.3153, Validation Accuracy: 41.58%


100%|██████████| 60/60 [58:48<00:00, 58.81s/it]


Test Accuracy: 28.57142857142857
[ 26. 106. 175.  64.]
Epoch: 59, Validation Loss: 1.3229, Validation Accuracy: 40.97%


  2%|▏         | 1/60 [00:57<56:33, 57.52s/it]

Test Accuracy: 43.39622641509434
[ 27. 161. 117.  66.]
Epoch: 0, Validation Loss: 0.9053, Validation Accuracy: 83.74%


  3%|▎         | 2/60 [01:54<55:14, 57.15s/it]

Test Accuracy: 30.45822102425876
[ 22. 113. 156.  80.]
Epoch: 1, Validation Loss: 0.9144, Validation Accuracy: 82.73%


  5%|▌         | 3/60 [02:52<54:29, 57.36s/it]

Test Accuracy: 38.544474393531
[ 27. 143. 136.  65.]
Epoch: 2, Validation Loss: 0.9203, Validation Accuracy: 82.14%


  7%|▋         | 4/60 [03:49<53:28, 57.30s/it]

Test Accuracy: 33.9622641509434
[ 25. 126. 161.  59.]
Epoch: 3, Validation Loss: 0.9332, Validation Accuracy: 80.79%


In [None]:
test_model = model_list[0]
test(test_path=r"C:\Users\admin\Desktop\TEST\FOCAL\zahmatbin.fif", model=test_model, test_class= 0)

In [None]:
"""
state_dict = model.state_dict()
avg_state_dict = state_dict
for key in state_dict:
    avg_state_dict[key] += state_dict[key]
    print(state_dict[key].shape)
    print('lol')
    """

"\nstate_dict = model.state_dict()\navg_state_dict = state_dict\nfor key in state_dict:\n    avg_state_dict[key] += state_dict[key]\n    print(state_dict[key].shape)\n    print('lol')\n    "

In [None]:
test_model = model_list[0]

In [None]:
test(test_path=r"C:\Users\admin\Desktop\TEST\IGE\borzoimohamad.fif", model=model, test_class= 3)

Test Accuracy: 12.703583061889251


(array([260.,   6.,   2.,  39.]),
 tensor([[9.8578e-01, 5.7734e-03, 6.2023e-05, 8.3846e-03],
         [5.5934e-01, 1.9538e-01, 1.4391e-02, 2.3089e-01],
         [9.2552e-02, 3.8519e-01, 1.2033e-01, 4.0193e-01],
         ...,
         [9.8356e-01, 6.6635e-03, 8.8821e-05, 9.6893e-03],
         [9.9876e-01, 4.1187e-04, 9.2054e-07, 8.2272e-04],
         [9.9994e-01, 1.6301e-05, 4.7086e-09, 4.6685e-05]], device='cuda:0',
        grad_fn=<SoftmaxBackward0>))

In [None]:
file_path = 'D:\model weights\model_weights.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))

<All keys matched successfully>