In [3]:
import os
import pandas as pd
import numpy as np
import scipy
import glob
import math
import matplotlib.pyplot as plt
from operator import *

import mne
from mne.preprocessing import ICA

from scipy.signal import butter, lfilter

import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from IPython.display import clear_output 


def butter_bandpass(lowcut, highcut, fs, order=6):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')

def butter_bandpass_filter(data, lowcut, highcut, fs, order=6):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y
def calcsnr(prefilter, filtered):
    filteredsum = 0
    denomsum = 0
    for i in range(len(prefilter)):
        filteredsum = filteredsum+filtered[i]*filtered[i]
        denomsum = denomsum+ (filtered[i]-prefilter[i])*(filtered[i]-prefilter[i])
    return 10*math.log10(filteredsum/denomsum)


class EEGDataset(Dataset):
    def __init__(self, eeglist, labels, transform=None, target_transform=None):
        self.labels = torch.from_numpy(np.array(labels))
        self.labels = self.labels.to(torch.float32)
        self.eeglist = eeglist
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        
        label = self.labels[idx]
        raw = self.eeglist[idx]
        eeg = torch.from_numpy(raw.get_data())
        eeg = eeg.to(torch.float32)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return eeg, label

bandwidth_split_data = {'alpha': [], 'beta': [], 'delta': [], 'gamma': [],'alpha_processed': [], 'beta_processed': [], 'delta_processed': [], 'gamma_processed': []}
labels = []
readtsv = pd.read_csv('participants.tsv', sep = '\t')

def splitbandwidth(band, data):
    return data.filter(l_freq = band[0], h_freq=band[1])

channel_names = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz']
for i in range(1,89):
    curdir = os.getcwd()
    rawpath = curdir+ "\Raw Data\sub-%s\eeg"%(str(i).zfill(3))
    file = glob.glob(os.path.join(rawpath, '*.set'))
    raw = mne.io.read_raw_eeglab(file[0], preload = True)

    bandwidth_split_data['delta'].append(splitbandwidth([0.5, 4], raw.copy()))
    bandwidth_split_data['alpha'].append(splitbandwidth([8, 12], raw.copy()))
    bandwidth_split_data['beta'].append(splitbandwidth([13, 30], raw.copy()))
    bandwidth_split_data['gamma'].append(splitbandwidth([30, 80], raw.copy()))
    clear_output()

    label = readtsv['Group'][i-1]
    if label == 'A':
        labels.append([0,0,1])
    elif label == 'F':
        labels.append([0,1,0])
    elif label == 'C':
        labels.append([1,0,0])

In [4]:
from mne_icalabel import label_components
exclusion = ['line noise', 'heartbeat', 'eye blink']

def manual_process(raw, plotting = True):
    freq_low = 0.5
    freq_high = 45
    iirparams = dict(order = 4, ftype = 'butter')

    raw.filter(freq_low, freq_high, method = 'iir', iir_params = iirparams)

    # Create ICA object and fit it to the data
    ica = ICA(n_components=19, random_state=97, verbose = False)
    ica.fit(raw)

    # Plot ICA components to identify artifacts
    icalabels = label_components(raw, ica, method = 'iclabel')
    if plotting == True:
        ica.plot_components()
        picks = list(range(0,18))
        ica.plot_properties(raw, picks=picks)
        print(icalabels)
    ica.exclude = []
    for i, label in enumerate(icalabels):
        if label in exclusion:
            ica.exclude.append(i)
    ica.apply(raw)
            
    return raw

In [5]:
from IPython.display import clear_output 
for n in range (len(bandwidth_split_data['alpha'])):
    alpha = bandwidth_split_data['alpha'][n].copy().load_data()
    beta = bandwidth_split_data['beta'][n].copy().load_data()
    gamma = bandwidth_split_data['gamma'][n].copy().load_data()
    delta = bandwidth_split_data['delta'][n].copy().load_data()
    bandwidth_split_data['alpha_processed'].append(manual_process(alpha, plotting = False))
    bandwidth_split_data['beta_processed'].append(manual_process(beta, plotting = False))
    bandwidth_split_data['gamma_processed'].append(manual_process(gamma, plotting = False))
    bandwidth_split_data['delta_processed'].append(manual_process(delta, plotting = False))
    clear_output()

In [6]:
import gc
gc.collect()
totals = [485.5, 276.5, 402]
alphatrainlist, alphavallist, alphatestlist, alphatrainlabels, alphavallabels, alphatestlabels = [],[],[],[],[],[]
betatrainlist, betavallist, betatestlist, betatrainlabels, betavallabels, betatestlabels = [],[],[],[],[],[]
gammatrainlist, gammavallist, gammatestlist, gammatrainlabels, gammavallabels, gammatestlabels = [],[],[],[],[],[]
deltatrainlist, deltavallist, deltatestlist, deltatrainlabels, deltavallabels, deltatestlabels = [],[],[],[],[],[]
trainlist, vallist, testlist, trainlabels, vallabels, testlabels = [],[],[],[],[],[]

labelonehot = [[0,0,1], [0,1,0], [1,0,0]]
#make sure enough of each group in each partition of data
for x, label in enumerate(labelonehot):
  count = 0
  for i in range (len(labels)):
    if labels[i] == label:
      m = 0
      # print(len(n.get_data()[0]))
      while 15000*(m+1) < len(bandwidth_split_data['alpha'][i].copy().get_data()[0]):
        if totals[x]*2-count>100:
          A = bandwidth_split_data['alpha'][i].copy()
          B = bandwidth_split_data['beta'][i].copy()
          G = bandwidth_split_data['gamma'][i].copy()
          D = bandwidth_split_data['delta'][i].copy()
          alphatrainlist.append(A.crop(m*30,(m+1)*30))
          alphatrainlabels.append(labels[i])
          
          betatrainlist.append(B.crop(m*30,(m+1)*30))
          betatrainlabels.append(labels[i])
          gammatrainlist.append(G.crop(m*30,(m+1)*30))
          gammatrainlabels.append(labels[i])
          deltatrainlist.append(D.crop(m*30,(m+1)*30))
          deltatrainlabels.append(labels[i])
          trainlist.append(A)
          trainlist.append(B)
          trainlist.append(G)
          trainlist.append(D)
          for j in range(0,4):  
            trainlabels.append(labels[i])
        if totals[x]*2-count>20:
          A = bandwidth_split_data['alpha'][i].copy()
          B = bandwidth_split_data['beta'][i].copy()
          G = bandwidth_split_data['gamma'][i].copy()
          D = bandwidth_split_data['delta'][i].copy()
          alphavallist.append(A.crop(m*30,(m+1)*30))
          alphavallabels.append(labels[i])
          betavallist.append(B.crop(m*30,(m+1)*30))
          betavallabels.append(labels[i])
          gammavallist.append(G.crop(m*30,(m+1)*30))
          gammavallabels.append(labels[i])
          deltavallist.append(D.crop(m*30,(m+1)*30))
          deltavallabels.append(labels[i])
          vallist.append(A)
          vallist.append(B)
          vallist.append(G)
          vallist.append(D)
          for j in range(0,4):  
            vallabels.append(labels[i])
        else:
          A = bandwidth_split_data['alpha'][i].copy()
          B = bandwidth_split_data['beta'][i].copy()
          G = bandwidth_split_data['gamma'][i].copy()
          D = bandwidth_split_data['delta'][i].copy()
          alphatestlist.append(A.crop(m*30,(m+1)*30))
          alphatestlabels.append(labels[i])
          betatestlist.append(B.crop(m*30,(m+1)*30))
          betatestlabels.append(labels[i])
          gammatestlist.append(G.crop(m*30,(m+1)*30))
          gammatestlabels.append(labels[i])
          deltatestlist.append(D.crop(m*30,(m+1)*30))
          deltatestlabels.append(labels[i])
          testlist.append(A)
          testlist.append(B)
          testlist.append(G)
          testlist.append(D)
          for j in range(0,4):  
            testlabels.append(labels[i])
        count = count + 1
        m = m+1
        gc.collect()

alphatrainset = EEGDataset(alphatrainlist, alphatrainlabels)
betatrainset = EEGDataset(betatrainlist, betatrainlabels)
gammatrainset = EEGDataset(gammatrainlist, gammatrainlabels)
deltatrainset = EEGDataset(deltatrainlist, deltatrainlabels)
fulltrainset = EEGDataset(trainlist, trainlabels)

alphavalset = EEGDataset(alphavallist, alphavallabels)
betavalset = EEGDataset(betavallist, betavallabels)
gammavalset = EEGDataset(gammavallist, gammavallabels)
deltavalset = EEGDataset(deltavallist, deltavallabels)
fullvalset = EEGDataset(vallist, vallabels)

alphatestset = EEGDataset(alphatestlist, alphatestlabels)
betatestset = EEGDataset(betatestlist, betatestlabels)
gammatestset = EEGDataset(gammatestlist, gammatestlabels)
deltatestset = EEGDataset(deltatestlist, deltatestlabels)
fulltestset = EEGDataset(testlist, testlabels)

alphatrainloader = DataLoader(alphatrainset, batch_size = 64, shuffle = True)
betatrainloader = DataLoader(betatrainset, batch_size = 64, shuffle = True)
gammatrainloader = DataLoader(gammatrainset, batch_size = 64, shuffle = True)
deltatrainloader = DataLoader(deltatrainset, batch_size = 64, shuffle = True)

alphavalloader = DataLoader(alphavalset, batch_size = 64, shuffle = True)
betavalloader = DataLoader(betavalset, batch_size = 64, shuffle = True)
gammavalloader = DataLoader(gammavalset, batch_size = 64, shuffle = True)
deltavalloader = DataLoader(deltavalset, batch_size = 64, shuffle = True)

alphatestloader = DataLoader(alphatestset, batch_size = 64, shuffle = True)
betatestloader = DataLoader(betatestset, batch_size = 64, shuffle = True)
gammatestloader = DataLoader(gammatestset, batch_size = 64, shuffle = True)
deltatestloader = DataLoader(deltatestset, batch_size = 64, shuffle = True)

fulltrainloader = DataLoader(fulltrainset, batch_size = 64, shuffle = True)
fullvalloader = DataLoader(fullvalset, batch_size = 64, shuffle = True)
fulltestloader = DataLoader(fulltestset, batch_size = 64, shuffle = True)

In [72]:
combinedtrainsplitlist = []
combinedvalsplitlist = []
combinedtestsplitlist = []
combinedtrainlabels = []
combinedvallabels = []
combinedtestlabels = []

class CombEEGDataset(Dataset):
    def __init__(self, eeglist, labels, transform=None, target_transform=None):
        self.labels = torch.from_numpy(np.array(labels))
        self.labels = self.labels.to(torch.float32)
        self.eeglist = eeglist
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        
        label = self.labels[idx]
        raw = self.eeglist[idx]
        eeg = torch.from_numpy(raw)
        eeg = eeg.to(torch.float32)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return eeg, label
for i in range(len(alphatestlist)):
    temp1 = np.array(list(alphatrainlist[i].copy().get_data())+list(betatrainlist[i].copy().get_data())+list(gammatrainlist[i].copy().get_data())+list(deltatrainlist[i].copy().get_data()))
    temp2 = np.array(list(alphavallist[i].copy().get_data())+list(betavallist[i].copy().get_data())+list(gammavallist[i].copy().get_data())+list(deltavallist[i].copy().get_data()))
    temp3 = np.array(list(alphatestlist[i].copy().get_data())+list(betatestlist[i].copy().get_data())+list(gammatestlist[i].copy().get_data())+list(deltatestlist[i].copy().get_data()))

    combinedtrainsplitlist.append(temp1)
    combinedvalsplitlist.append(temp2)   
    combinedtestsplitlist.append(temp3)
    
    combinedtrainlabels.append(alphatrainlabels+betatrainlabels+gammatrainlabels+deltatrainlabels)
    combinedvallabels.append(alphavallabels+ betavallabels+ gammavallabels+deltavallabels)
    combinedtestlabels.append(alphatestlabels+ betatestlabels+ gammatestlabels+deltatestlabels)
print(combinedtrainlabels)
combtrainset = CombEEGDataset(combinedtrainsplitlist,combinedtrainlabels)
combvalset = CombEEGDataset(combinedvalsplitlist,combinedvallabels)
combtestset = CombEEGDataset(combinedtestsplitlist,combinedtestlabels)

combtrainloader = DataLoader(combtrainset, batch_size=64, shuffle = True)
combvalloader = DataLoader(combvalset, batch_size=64, shuffle = True)
combtestloader = DataLoader(combtestset, batch_size=64, shuffle = True)

[[[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1

In [69]:
import torch.nn as nn
import torch.nn.functional as F

class CNN_EEG_Classifier(nn.Module):
    def __init__(self, conv_kernel_size = 3, conv1_size = 25, conv2_size = 30, conv_stride = 1):
        super(CNN_EEG_Classifier, self).__init__()
        self.conv1_size = conv1_size
        self.conv2_size = conv2_size
        self.conv_kernel_size = conv_kernel_size
        #USE LARGER KERNEL BECAUSE IMAGES ARE HIGH RESOLUTION, FEATURES TAKE UP MORE PIXELS
        self.conv1 = nn.Conv1d(19, conv1_size, conv_kernel_size)
        #BATCH NORMALIZATION TO PREVENT VANISHING GRADIENTS
        self.bn1 = nn.BatchNorm1d(conv1_size)
        #LARGE POOLING KERNEL TO REDUCE DIMENSIONALIZATION FASTER (JUST FOUND THIS TO BE HELPFUL BY EXPERIMENTING)
        self.pool = nn.MaxPool1d(4, 4)

        self.conv2 = nn.Conv1d(conv1_size, conv2_size, conv_kernel_size)
        
        self.bn2 = nn.BatchNorm1d(conv2_size)


        self.fc1 = nn.Linear(((15001-conv_kernel_size)//4-conv_kernel_size)//4*conv2_size, 256)
        self.fc2 = nn.Linear(256, 3)


    def forward(self, x):
        batch_size = x.shape[0]
        # print(x.shape)
        outconv1 = self.pool(F.relu(self.bn1(self.conv1(x))))
        # print(outconv1.shape)
        outconv2 = self.pool(F.relu(self.bn2(self.conv2(outconv1))))
        # print(outconv2.shape)
        outconv2 = outconv2.view(batch_size, -1)
        # print(outconv2.shape)

        outfc1 = F.relu(self.fc1(outconv2))
        # print(outfc1.shape)

        outfc2 = self.fc2(outfc1)
        # print(outfc2.shape)

        out = F.softmax(outfc2, dim = 1)

        return out
class CNN_EEG_Classifier2(nn.Module):
    def __init__(self, conv_kernel_size = 3, conv1_size = 25, conv2_size = 30, conv3_size = 35, conv4_size = 40, conv5_size = 50, n = 2,conv_stride = 1):
        super(CNN_EEG_Classifier2, self).__init__()
        self.conv1_size = conv1_size
        self.conv2_size = conv2_size
        self.conv3_size = conv3_size
        self.conv4_size = conv4_size
        self.conv5_size = conv5_size
        self.n = n
        self.conv_kernel_size = conv_kernel_size
        #USE LARGER KERNEL BECAUSE IMAGES ARE HIGH RESOLUTION, FEATURES TAKE UP MORE PIXELS
        self.conv1 = nn.Conv1d(19, conv1_size, conv_kernel_size)
        #BATCH NORMALIZATION TO PREVENT VANISHING GRADIENTS
        self.bn1 = nn.BatchNorm1d(conv1_size)
        #LARGE POOLING KERNEL TO REDUCE DIMENSIONALIZATION FASTER (JUST FOUND THIS TO BE HELPFUL BY EXPERIMENTING)
        self.pool = nn.MaxPool1d(4, 4)

        self.conv2 = nn.Conv1d(conv1_size, conv2_size, conv_kernel_size)
        
        self.bn2 = nn.BatchNorm1d(conv2_size)
        self.conv3 = nn.Conv1d(conv2_size, conv3_size, conv_kernel_size)
        self.bn3 = nn.BatchNorm1d(conv3_size)
        self.conv4 = nn.Conv1d(conv3_size, conv4_size, conv_kernel_size)
        self.bn4 = nn.BatchNorm1d(conv4_size)
        self.conv5 = nn.Conv1d(conv4_size, conv5_size, conv_kernel_size)
        self.bn5 = nn.BatchNorm1d(conv5_size)
        self.linearinput = 15001
        for i in range(n):
            self.linearinput = (self.linearinput - conv_kernel_size)//4
        self.fc1 = nn.Linear(self.linearinput*conv5_size, 256)
        self.fc2 = nn.Linear(256, 3)


    def forward(self, x):
        batch_size = x.shape[0]
        outconv1 = self.pool(F.relu(self.bn1(self.conv1(x))))
        outconv2 = self.pool(F.relu(self.bn2(self.conv2(outconv1))))
        outconv3 = self.pool(F.relu(self.bn3(self.conv3(outconv2))))
        outconv4 = self.pool(F.relu(self.bn4(self.conv4(outconv3))))
        outconv5 = self.pool(F.relu(self.bn5(self.conv5(outconv4))))
        outconv5 = outconv5.view(batch_size, -1)

        outfc1 = F.relu(self.fc1(outconv5))
        # print(outfc1.shape)

        outfc2 = self.fc2(outfc1)
        # print(outfc2.shape)

        out = F.softmax(outfc2, dim = 1)

        return out
class CNN_EEG_Classifier3(nn.Module):

    def __init__(self, pooling_layers = [1,1,1,1,1],conv_kernel_size = 3, conv1_size = 25, conv2_size = 30, conv3_size = 35, conv4_size = 40, conv5_size = 50, n = 2,conv_stride = 1):
        super(CNN_EEG_Classifier3, self).__init__()
        self.conv1_size = conv1_size
        self.conv2_size = conv2_size
        self.conv3_size = conv3_size
        self.conv4_size = conv4_size
        self.conv5_size = conv5_size
        self.pooling_layers = pooling_layers
        self.n = n
        self.conv_kernel_size = conv_kernel_size
        #USE LARGER KERNEL BECAUSE IMAGES ARE HIGH RESOLUTION, FEATURES TAKE UP MORE PIXELS
        self.conv1 = nn.Conv1d(19, conv1_size, conv_kernel_size)
        #BATCH NORMALIZATION TO PREVENT VANISHING GRADIENTS
        self.bn1 = nn.BatchNorm1d(conv1_size)
        #LARGE POOLING KERNEL TO REDUCE DIMENSIONALIZATION FASTER (JUST FOUND THIS TO BE HELPFUL BY EXPERIMENTING)
        self.pool = nn.MaxPool1d(4, 4)

        self.conv2 = nn.Conv1d(conv1_size, conv2_size, conv_kernel_size)
        
        self.bn2 = nn.BatchNorm1d(conv2_size)
        self.conv3 = nn.Conv1d(conv2_size, conv3_size, conv_kernel_size)
        self.bn3 = nn.BatchNorm1d(conv3_size)
        self.conv4 = nn.Conv1d(conv3_size, conv4_size, conv_kernel_size)
        self.bn4 = nn.BatchNorm1d(conv4_size)
        self.conv5 = nn.Conv1d(conv4_size, conv5_size, conv_kernel_size)
        self.bn5 = nn.BatchNorm1d(conv5_size)
        self.linearinput = 15001
        for num in pooling_layers:
            if num == 1:
                self.linearinput = (self.linearinput - conv_kernel_size)//4
        self.fc1 = nn.Linear(self.linearinput*conv5_size, 256)
        self.fc2 = nn.Linear(256, 3)


    def forward(self, x):
        batch_size = x.shape[0]
        if self.pooling_layers[0] == 1:
            outconv1 = self.pool(F.relu(self.bn1(self.conv1(x))))
        else:
            outconv1 = F.relu(self.bn1(self.conv1(x)))
        if self.pooling_layers[1] == 1:
            outconv2 = self.pool(F.relu(self.bn2(self.conv2(outconv1))))
        else:
            outconv2 = F.relu(self.bn2(self.conv2(outconv1)))
        if self.pooling_layers[2] == 1:
            outconv3 = self.pool(F.relu(self.bn3(self.conv3(outconv2))))
        else:
            outconv3 = F.relu(self.bn3(self.conv3(outconv2)))
        if self.pooling_layers[3] == 1:
            outconv4 = self.pool(F.relu(self.bn4(self.conv4(outconv3))))
        else:
            outconv4 = F.relu(self.bn4(self.conv4(outconv3)))
        if self.pooling_layers[4] == 1:
            outconv5 = self.pool(F.relu(self.bn5(self.conv5(outconv4))))
        else:
            outconv5 = F.relu(self.bn5(self.conv5(outconv4)))
        outconv5 = outconv5.view(batch_size, -1)

        outfc1 = F.relu(self.fc1(outconv5))
        # print(outfc1.shape)

        outfc2 = self.fc2(outfc1)
        # print(outfc2.shape)

        out = F.softmax(outfc2, dim = 1)

        return out
class CNN_EEG_Classifier_5Layer(nn.Module):
    def __init__(self, dropoutsize):
        super(CNN_EEG_Classifier_5Layer, self).__init__()
        
        
        conv_size = [25, 30, 35,40,45]
        conv_stride = [2, 2, 2,2,2]
        conv_kernel_size = [4, 4, 4,4,4]
        n_layers = len(conv_size)
        
        pool_kernel = [4, 1, 1,1,4]
        pool_stride = [4, 1, 1,1,4]
        
        
        self.conv1 = nn.Conv1d(19, conv_size[0], conv_kernel_size[0], conv_stride[0])
        self.bn1 = nn.BatchNorm1d(conv_size[0])
        self.pool1 = nn.MaxPool1d(pool_kernel[0], pool_stride[0])

        self.conv2 = nn.Conv1d(conv_size[0], conv_size[1], conv_kernel_size[1], conv_stride[1])
        self.bn2 = nn.BatchNorm1d(conv_size[1])
        self.pool2 = nn.MaxPool1d(pool_kernel[1], pool_stride[1])

        self.conv3 = nn.Conv1d(conv_size[1], conv_size[2], conv_kernel_size[2], conv_stride[2])
        self.bn3 = nn.BatchNorm1d(conv_size[2])
        self.pool3 = nn.MaxPool1d(pool_kernel[2], pool_stride[2])
        
        self.conv4 = nn.Conv1d(conv_size[2], conv_size[3], conv_kernel_size[3], conv_stride[3])
        self.bn4 = nn.BatchNorm1d(conv_size[3])
        self.pool4 = nn.MaxPool1d(pool_kernel[3], pool_stride[3])

        self.conv5 = nn.Conv1d(conv_size[3], conv_size[4], conv_kernel_size[4], conv_stride[4])
        self.bn5 = nn.BatchNorm1d(conv_size[4])
        self.pool5 = nn.MaxPool1d(pool_kernel[4], pool_stride[4])
        size = 15001
        for i in range(n_layers):
            size = ((size-conv_kernel_size[i])/conv_stride[i]) + 1
            size = ((size - pool_kernel[i])//pool_stride[i]) + 1  
            # print(size)
        
        linear1_size = int(size*conv_size[n_layers-1])
        # print(linear1_size)
        self.fc1 = nn.Linear(linear1_size, 500)
        self.fc2 = nn.Linear(500, 100)
        self.dropout1 = nn.Dropout(dropoutsize)
        self.fc3 = nn.Linear(100, 3)



    def forward(self, x):
        # print(x.shape)
        batch_size = x.shape[0]
        outconv1 = self.pool1(F.relu(self.bn1(self.conv1(x))))
        # print(outconv1.shape)
        outconv2 = self.pool2(F.relu(self.bn2(self.conv2(outconv1))))
        # print(outconv2.shape)
        outconv3 = self.pool3(F.relu(self.bn3(self.conv3(outconv2))))
        # print(outconv3.shape)
        outconv4 = self.pool4(F.relu(self.bn4(self.conv4(outconv3))))
        outconv5 = self.pool5(F.relu(self.bn5(self.conv5(outconv4))))
        outconv5 = outconv5.view(batch_size, -1)

        outfc1 = self.dropout1(F.relu(self.fc1(outconv5)))

        outfc2 = (self.fc2(outfc1))

        outfc3 = F.relu(self.fc3(outfc2))
        out = F.softmax(outfc3, dim = 1)

        return out
class CNN_EEG_Classifier_3cnn_medium_kernel(nn.Module):
    def __init__(self):
        """Trying kernel size 3 with 3 CNN model"""
        super(CNN_EEG_Classifier_3cnn_medium_kernel, self).__init__()
        
        
        conv_size = [25, 30, 35]
        conv_stride = [1, 1, 1]
        conv_kernel_size = [3, 3, 3]
        n_layers = len(conv_size)
        
        pool_kernel = [1, 4, 1]
        pool_stride = [1, 4, 1]
        
        
        self.conv1 = nn.Conv1d(19, conv_size[0], conv_kernel_size[0], conv_stride[0])
        self.bn1 = nn.BatchNorm1d(conv_size[0])
        self.pool1 = nn.MaxPool1d(pool_kernel[0], pool_stride[0])

        self.conv2 = nn.Conv1d(conv_size[0], conv_size[1], conv_kernel_size[1], conv_stride[1])
        self.bn2 = nn.BatchNorm1d(conv_size[1])
        self.pool2 = nn.MaxPool1d(pool_kernel[1], pool_stride[1])

        self.conv3 = nn.Conv1d(conv_size[1], conv_size[2], conv_kernel_size[2], conv_stride[2])
        self.bn3 = nn.BatchNorm1d(conv_size[2])
        self.pool3 = nn.MaxPool1d(pool_kernel[2], pool_stride[2])
        
        size = 15001
        for i in range(n_layers):
            size = ((size-conv_kernel_size[i])/conv_stride[i]) + 1
            size = ((size - pool_kernel[i])//pool_stride[i]) + 1  
            # print(size)
        
        linear1_size = int(size*conv_size[n_layers-1])
        # print(linear1_size)
        self.fc1 = nn.Linear(linear1_size, 7500)
        self.fc2 = nn.Linear(7500, 1000)
        self.fc3 = nn.Linear(1000, 100)
        self.dropout1 = nn.Dropout(0.75)
        self.fc4 = nn.Linear(100, 3)
    def forward(self, x):
        # print(x.shape)
        batch_size = x.shape[0]
        outconv1 = self.pool1(F.relu(self.bn1(self.conv1(x))))
        # print(outconv1.shape)
        outconv2 = self.pool2(F.relu(self.bn2(self.conv2(outconv1))))
        # print(outconv2.shape)
        outconv3 = self.pool3(F.relu(self.bn3(self.conv3(outconv2))))
        # print(outconv3.shape)
        outconv3 = outconv3.view(batch_size, -1)

        outfc1 = (F.relu(self.fc1(outconv3)))

        outfc2 = (self.fc2(outfc1))

        outfc3 = self.dropout1(F.relu(self.fc3(outfc2)))
        outfc4 = self.fc4(outfc3)
        out = F.softmax(outfc4, dim = 1)

        return out
class CNN_EEG_Classifier4(nn.Module):
    def __init__(self, conv_kernel_size = 3, conv1_size = 64, conv2_size = 48, conv3_size = 32, conv4_size = 24, conv5_size = 16, n = 2,conv_stride = 1):
        super(CNN_EEG_Classifier4, self).__init__()
        self.conv1_size = conv1_size
        self.conv2_size = conv2_size
        self.conv3_size = conv3_size
        self.conv4_size = conv4_size
        self.conv5_size = conv5_size
        self.n = n
        self.conv_kernel_size = conv_kernel_size
        #USE LARGER KERNEL BECAUSE IMAGES ARE HIGH RESOLUTION, FEATURES TAKE UP MORE PIXELS
        self.conv1 = nn.Conv1d(76, conv1_size, conv_kernel_size)
        #BATCH NORMALIZATION TO PREVENT VANISHING GRADIENTS
        self.bn1 = nn.BatchNorm1d(conv1_size)
        #LARGE POOLING KERNEL TO REDUCE DIMENSIONALIZATION FASTER (JUST FOUND THIS TO BE HELPFUL BY EXPERIMENTING)
        self.pool = nn.MaxPool1d(4, 4)

        self.conv2 = nn.Conv1d(conv1_size, conv2_size, conv_kernel_size)
        
        self.bn2 = nn.BatchNorm1d(conv2_size)
        self.conv3 = nn.Conv1d(conv2_size, conv3_size, conv_kernel_size)
        self.bn3 = nn.BatchNorm1d(conv3_size)
        self.conv4 = nn.Conv1d(conv3_size, conv4_size, conv_kernel_size)
        self.bn4 = nn.BatchNorm1d(conv4_size)
        self.conv5 = nn.Conv1d(conv4_size, conv5_size, conv_kernel_size)
        self.bn5 = nn.BatchNorm1d(conv5_size)
        self.linearinput = 15001
        for i in range(n):
            self.linearinput = (self.linearinput - conv_kernel_size)//4
        self.fc1 = nn.Linear(self.linearinput*conv5_size, 256)
        self.fc2 = nn.Linear(256, 3)


    def forward(self, x):
        batch_size = x.shape[0]
        outconv1 = self.pool(F.relu(self.bn1(self.conv1(x))))
        outconv2 = self.pool(F.relu(self.bn2(self.conv2(outconv1))))
        outconv3 = self.pool(F.relu(self.bn3(self.conv3(outconv2))))
        outconv4 = self.pool(F.relu(self.bn4(self.conv4(outconv3))))
        outconv5 = self.pool(F.relu(self.bn5(self.conv5(outconv4))))
        outconv5 = outconv5.view(batch_size, -1)

        outfc1 = F.relu(self.fc1(outconv5))
        # print(outfc1.shape)

        outfc2 = self.fc2(outfc1)
        # print(outfc2.shape)

        out = F.softmax(outfc2, dim = 1)

        return out

In [25]:
import torch.optim as optim
import time

def get_accuracy(model, train=True, train_data = fulltrainloader, val_data = fullvalloader):
    if train:
        dataloader = train_data
    else:
        dataloader = val_data
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            outputs = model(imgs.float())
            predicted = torch.argmax(outputs, dim=1)
            _, labels = torch.max(labels, dim=1)
            # print(predicted, labels)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total
torch.manual_seed(5)
def trainsgd(model,batch_size=1, traindata = fulltrainloader, valdata = fullvalloader, testdata = fulltestloader,num_epochs=10, rate = 0.001, name = 'model', adaptive = None):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=rate, momentum=0.9)

    iters, losses, train_acc, val_acc = [], [], [], []

    # training
    n = 0 # the number of iterations
    for epoch in range(num_epochs):
        for imgs, labels in iter(traindata):
            
            # print(imgs.shape)
            imgs = imgs.to(torch.float32)
            out = model(imgs)             # forward pass
            loss = criterion(out, labels) # compute the total loss
            loss.backward()               # backward pass (compute parameter updates)
            if adaptive == 'linear':
                rate = rate*0.90
            if adaptive == 'step':
                if epoch%4 == 0:
                    rate = rate*0.1
            optimizer.step()              # make the updates for each parameter
            optimizer.zero_grad()         # a clean up step for PyTorch

            # save the current training information
            iters.append(n)
            losses.append(float(loss)/batch_size)             # compute *average* loss
            train_acc.append(get_accuracy(model, train=True, train_data=traindata)) # compute training accuracy
            val_acc.append(get_accuracy(model, train=False, val_data = valdata))  # compute validation accuracy
            n += 1
        print("Epoch:", epoch)
        print("Training Accuracy:", train_acc[-1])
        print("Validation Accuracy:", val_acc[-1])

    print("Iterations:", n)
    # plotting
    plt.title("Training Curve")
    plt.plot(iters, losses, label="Train")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")

    plt.savefig("traincurve1batch%depochs%d%s.png"%(batch_size, num_epochs, name))
    plt.show()
    plt.title("Training Curve")
    plt.plot(iters, train_acc, label="Train")
    plt.plot(iters, val_acc, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Training Accuracy")
    plt.legend(loc='best')

    plt.savefig("traincurve2batch%depochs%d%s.png"%(batch_size, num_epochs,name))
    plt.show()
    print("Final Training Accuracy: {}".format(train_acc[-1]))
    print("Final Validation Accuracy: {}".format(val_acc[-1]))
    testacc = get_accuracy(model,train = True, train_data = testdata)
    f = open('results.txt', "a")
    f.write('\n%s model with %d batch size and %d epochs had a final validation accuracy of %f and a validation accuracy of %f.\nThe test accuracy was: %f'%(name,batch_size, num_epochs, train_acc[-1], val_acc[-1],testacc))
    f.close()
    clear_output()

# eegclassifier1 = CNN_LSTM_EEG(conv_kernel_size = 3, conv1_size = 64, conv2_size = 64, conv_stride = 1)
# eegclassifier2 = CNN_LSTM_EEG(conv_kernel_size = 10, conv1_size = 64, conv2_size = 64, conv_stride = 1)
# eegclassifier3 = CNN_LSTM_EEG(conv_kernel_size = 3, conv1_size = 64, conv2_size = 128, conv_stride = 1)
# eegclassifier4 = CNN_LSTM_EEG(conv_kernel_size = 10, conv1_size = 64, conv2_size = 128, conv_stride = 1)
# eegclassifier5 = CNN_LSTM_EEG(conv_kernel_size = 10, conv1_size = 64, conv2_size = 128, conv_stride = 3)
# eegcnn1 = CNN_EEG_Classifier()
# eegcnn2 = CNN_EEG_Classifier(conv1_size=32, conv2_size=64)
# eegcnn7 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
# eegcnn8 = CNN_EEG_Classifier3(pooling_layers=[1,0,1,0,1],conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
# eegcnn9 = CNN_EEG_Classifier3(pooling_layers=[1,0,0,0,1],conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
# eegcnn10 = CNN_EEG_Classifier3(pooling_layers=[1,0,0,0,1],conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
# dropout04 = CNN_EEG_Classifier_5Layer(dropoutsize=0.4)
# dropout03 = CNN_EEG_Classifier_5Layer(dropoutsize=0.3)
# dropout02 = CNN_EEG_Classifier_5Layer(dropoutsize=0.2)
# dropout01 = CNN_EEG_Classifier_5Layer(dropoutsize=0.1)
# eegcnn3 = CNN_EEG_Classifier(conv_kernel_size=10)
# eegcnn4 = CNN_EEG_Classifier(conv1_size=32, conv2_size=64, conv_kernel_size=10)
# eegcnn5 = CNN_EEG_Classifier(conv1_size=20, conv2_size= 25)
# eegcnn6 = CNN_EEG_Classifier(conv1_size=20, conv2_size= 25, conv_kernel_size=10)
# train(eegclassifier1, batch_size = 127, num_epochs = 5, rate = 0.1, name = '3-64-64-1')
# train(eegclassifier2, batch_size = 127, num_epochs = 5, rate = 0.1, name = '10-64-64-1')
# train(eegclassifier3, batch_size = 127, num_epochs = 5, rate = 0.1, name = '3-64-128-1')
# train(eegclassifier4, batch_size = 127, num_epochs = 5, rate = 0.1, name = '10-64-128-1')
# train(eegclassifier5, batch_size = 127, num_epochs = 5, rate = 0.1, name = '10-64-128-3')
# train(eegcnn1, batch_size = 64, num_epochs = 100, rate = 0.1, name = 'cnn-%d-%d-%d'%(eegcnn1.conv1_size, eegcnn1.conv2_size,eegcnn1.conv_kernel_size))
# train(eegcnn2, batch_size = 64, num_epochs = 100, rate = 0.1, name = 'cnn-%d-%d-%d'%(eegcnn2.conv1_size, eegcnn2.conv2_size,eegcnn2.conv_kernel_size))
# train(eegcnn3, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'cnn-%d-%d-%d'%(eegcnn3.conv1_size, eegcnn3.conv2_size,eegcnn3.conv_kernel_size))
# train(eegcnn4, batch_size = 64, num_epochs = 20, rate = 0.1,name = 'cnn-%d-%d-%d'%(eegcnn4.conv1_size, eegcnn4.conv2_size,eegcnn4.conv_kernel_size))
# train(eegcnn5, batch_size = 64, num_epochs = 5, rate = 0.1, name = 'cnn-%d-%d-%d'%(eegcnn5.conv1_size, eegcnn5.conv2_size,eegcnn5.conv_kernel_size))
# train(eegcnn7, batch_size = 4, num_epochs = 100, rate = 0.1, name = 'cnn-%d-%d-%d-%d-%d-%d'%(eegcnn7.conv1_size, eegcnn7.conv2_size,eegcnn7.conv3_size,eegcnn7.conv4_size, eegcnn7.conv5_size,eegcnn7.conv_kernel_size))

# train(eegcnn8, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'cnn-%d-%d-%d-%d-%d-%d1'%(eegcnn8.conv1_size, eegcnn8.conv2_size,eegcnn8.conv3_size,eegcnn8.conv4_size, eegcnn8.conv5_size,eegcnn8.conv_kernel_size))
# train(eegcnn9, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'cnn-%d-%d-%d-%d-%d-%d2'%(eegcnn9.conv1_size, eegcnn9.conv2_size,eegcnn9.conv3_size,eegcnn9.conv4_size, eegcnn9.conv5_size,eegcnn9.conv_kernel_size))
# train(eegcnn10, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'cnn-%d-%d-%d-%d-%d-%d3'%(eegcnn10.conv1_size, eegcnn10.conv2_size,eegcnn10.conv3_size,eegcnn10.conv4_size, eegcnn10.conv5_size,eegcnn10.conv_kernel_size))
# train(dropout04, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'dropout04')
# train(dropout03, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'dropout03')
# train(dropout02, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'dropout02')
# train(dropout01, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'dropout01')


In [65]:
print(combtrainloader.dataset[0])

(tensor([[ 1.3976e-20,  8.8234e-07,  1.7512e-06,  ...,  3.7328e-06,
          2.3207e-06,  9.1303e-07],
        [ 1.3208e-20,  2.4013e-07,  4.7462e-07,  ...,  3.0270e-06,
          1.7110e-06,  4.3016e-07],
        [ 4.5528e-21,  1.5676e-07,  3.1509e-07,  ...,  3.9428e-06,
          3.7744e-06,  3.5657e-06],
        ...,
        [ 6.7763e-21,  3.2229e-07,  6.4441e-07,  ...,  8.5883e-06,
          9.0247e-06,  9.4355e-06],
        [-6.7763e-21,  2.6644e-07,  5.3285e-07,  ...,  5.2440e-06,
          5.8875e-06,  6.5137e-06],
        [ 2.0329e-20,  3.7695e-07,  7.5367e-07,  ...,  9.7602e-07,
          1.8033e-06,  2.6227e-06]]), tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        ...,
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]]))


In [70]:
CNN1 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
CNN2 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
CNN3 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
CNN4 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
CNN5 = CNN_EEG_Classifier2(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)
CNN6 = CNN_EEG_Classifier4(conv1_size=25, conv2_size=30, conv3_size = 35,conv4_size = 40, conv5_size = 45, n = 5)

# trainsgd(CNN1, batch_size = 64, num_epochs = 20, rate = 0.1, name = 'splitfull', adaptive = 'linear')
# trainsgd(CNN2, batch_size = 64, num_epochs = 20, rate = 0.1, traindata= alphatrainloader, valdata = alphavalloader, testdata=alphatestloader, name = 'splitalpha', adaptive = 'linear')
# trainsgd(CNN3, batch_size = 64, num_epochs = 20, rate = 0.1, traindata= betatrainloader, valdata = betavalloader, testdata=betatestloader,name = 'splitbeta', adaptive = 'linear')
# trainsgd(CNN4, batch_size = 64, num_epochs = 20, rate = 0.1, traindata= gammatrainloader, valdata = gammavalloader, testdata=gammatestloader,name = 'splitgamma', adaptive = 'linear')
# trainsgd(CNN5, batch_size = 64, num_epochs = 20, rate = 0.1, traindata= deltatrainloader, valdata = deltavalloader, testdata=deltatestloader,name = 'splitdelta', adaptive = 'linear')

trainsgd(CNN6, batch_size = 64, num_epochs = 20, rate = 0.1, traindata= combtrainloader, valdata = combvalloader, testdata=combtestloader,name = 'splitcomb', adaptive = 'linear')

# trainsgd(CNN1, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestsgd')
# trainsgd(CNN2, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumsgd')
# trainsgd(CNN3, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsamplesgd')
# trainadam(CNN4, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestadam')
# trainadam(CNN5, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumadam')
# trainadam(CNN6, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsampleadam')
# trainsgd(CNN7, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestsgdstep', adaptive = 'step')
# trainsgd(CNN8, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumsgdstep', adaptive = 'step')
# trainsgd(CNN9, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsamplesgdstep', adaptive = 'step')
# trainadam(CNN10, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestadamstep', adaptive = 'step')
# trainadam(CNN11, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumadamstep', adaptive = 'step')
# trainadam(CNN12, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsampleadamstep', adaptive = 'step')
# trainsgd(CNN13, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestsgdlinear', adaptive = 'linear')
# trainsgd(CNN14, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumsgdlinear', adpative = 'linear')
# trainsgd(CNN15, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsamplesgdlinear', adaptive = 'linear')
# trainadam(CNN16, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerbestadamlinear', adaptive = 'linear')
# trainadam(CNN17, batch_size = 64, num_epochs = 20, rate = 0.1, name = '3layermediumadamlinear', adpative = 'linear')
# trainadam(CNN18, batch_size = 64, num_epochs = 20, rate = 0.1, name = '5layerdownsampleadamlinear', adaptive = 'linear')


RuntimeError: 0D or 1D target tensor expected, multi-target not supported

In [27]:
print("Test accuracy:",get_accuracy(CNN2,train = True, train_data = alphatestloader))

Test accuracy: 12.5


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# Define your ANN model
class ANN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(19*10001, 5000)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5000, 300)
        self.fc3 = nn.Linear(300,3)

    def forward(self, x):
        x = x.view(-1, 19*10001)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


# Define hyperparameters
hidden_size = 19
learning_rate = 0.001
num_epochs = 50
batch_size = 2

# Create DataLoader objects

# Initialize the ANN model
model = ANN(30001, hidden_size, 3)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss with Logits
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in trainloader:
        # Forward pass
        inputs = inputs.to(torch.float32)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    # Print average loss for this epoch
    epoch_loss = running_loss / len(trainloader.dataset)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

# Evaluation
# Evaluation
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for imgs, labels in testloader:
        outputs = model(imgs)
        predicted = torch.argmax(outputs, dim=1)
        _, labels = torch.max(labels, dim=1)
        # print(predicted, labels)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    print(f'Accuracy on test set: {accuracy:.4f}')

