### Making Spectrograms  
  
Spectrograms will be made from the EEG data rather than using the provided spectrograms and these will be used in a neural network.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import fastparquet, pyarrow
import mne
from mne.decoding import Scaler
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score
from utils import *
import h5py
from scipy.signal import spectrogram

In [2]:
mne.set_log_level('WARNING')

In [3]:
df = pd.read_csv('by_patient.csv')

In [4]:
random_indexes = pd.read_csv('overall_randoms.csv')['idx']

In [5]:
other_df = activity_df(df, 'Other', 'expert_consensus')
seizure_df = activity_df(df, 'Seizure', 'expert_consensus')
gpd_df = activity_df(df, 'GPD', 'expert_consensus')
lpd_df = activity_df(df, 'LPD', 'expert_consensus')
grda_df = activity_df(df, 'GRDA', 'expert_consensus')
lrda_df = activity_df(df, 'LRDA', 'expert_consensus')
activity_df_list = [other_df, seizure_df, gpd_df, lpd_df, grda_df, lrda_df]

In [6]:
#other_indexes = get_indexes(other_df, 2000)
#seizure_indexes = get_indexes(seizure_df, 2000)
#gpd_indexes = get_indexes(gpd_df, 2000)
#lpd_indexes = get_indexes(lpd_df, 2000)
#grda_indexes = get_indexes(grda_df, 2000)
#lrda_indexes = get_indexes(lrda_df, 2000)

In [7]:
#activity_indexes = [other_indexes, seizure_indexes, gpd_indexes, lpd_indexes,
#                   grda_indexes, lrda_indexes]

In [8]:
indexes = pd.read_csv('activity_indexes.csv')
activity_types = indexes.columns

In [9]:
y = get_yvals(2000)['activity']

### Spectrograms  
  
I will be using the guide found at the link below.  
  
[**How To Make Spectrogram from EEG**](https://www.kaggle.com/code/cdeotte/how-to-make-spectrogram-from-eeg#Create-Spectrograms-with-Librosa)

In [10]:
names = ['LL','LP','RP','RR']

chains = [['Fp1','F7','T3','T5','O1'],
          ['Fp1','F3','C3','P3','O1'],
          ['Fp2','F8','T4','T6','O2'],
          ['Fp2','F4','C4','P4','O2']]

In [11]:
raw = load_preprocess(df, 0, 1, None, bandpass = True, notch = False, reref = True)

In [10]:
pd.DataFrame(raw.get_data(), index = raw.ch_names).transpose()

Unnamed: 0,Fp1,F3,C3,P3,F7,T3,T5,O1,Fz,Cz,Pz,Fp2,F4,C4,P4,F8,T4,T6,O2,EKG
0,6.098923e-17,-6.467780e-16,4.634451e-16,3.663006e-16,-2.581999e-16,1.165004e-16,3.940561e-16,2.830338e-16,-5.843279e-16,1.650726e-16,1.858893e-16,-1.471776e-16,-2.859555e-16,-8.399714e-18,1.511948e-16,-3.692222e-16,-2.165665e-16,2.552783e-16,7.486701e-17,-0.158840
1,-2.959683e-01,-9.579126e-02,1.826220e-01,-3.186127e-02,-9.071970e-02,-1.548448e-02,8.900903e-02,-1.479428e-02,-1.013828e-03,2.702369e-01,1.466023e-01,-2.878139e-01,5.164124e-02,1.635308e-02,4.239297e-03,7.051307e-03,-2.392405e-03,2.222939e-02,4.585478e-02,-0.093472
2,2.484796e-01,-6.886530e-02,-2.180245e-01,-1.533436e-01,1.291295e-01,-6.561683e-02,-2.552704e-02,-1.279294e-01,2.107432e-01,-1.611151e-01,-1.125558e-01,2.895415e-01,1.985302e-01,-1.294119e-01,-1.241651e-01,1.394871e-01,-5.330312e-02,-7.436669e-02,9.831327e-02,-0.402844
3,1.892762e-01,-6.175045e-02,-2.193049e-01,-5.970358e-02,1.484292e-02,1.642986e-02,2.404561e-02,-4.006557e-02,9.141178e-02,-8.545747e-02,-2.030920e-02,2.116358e-01,1.356586e-01,-1.675403e-01,-1.888674e-02,9.272253e-02,-8.818191e-02,-4.212377e-02,2.730053e-02,-0.577511
4,-3.798687e-01,-3.777700e-02,2.466656e-01,5.541902e-02,-2.322632e-01,8.935194e-02,1.092474e-01,2.141336e-02,-1.050048e-01,2.469289e-01,1.964642e-01,-4.492224e-01,6.911065e-02,1.071180e-01,2.249179e-02,-5.951915e-02,-4.766585e-02,7.847039e-02,6.863985e-02,-0.364735
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,-1.734111e-01,1.946201e+00,1.591217e+00,7.110342e-01,1.448383e+00,1.497626e+00,6.446492e-01,-2.939670e-02,8.873308e-01,7.365595e-02,1.577686e-01,-9.512450e-01,-1.852169e+00,-1.674549e+00,-6.974350e-01,-8.940025e-01,-1.160876e+00,-1.086074e+00,-4.387085e-01,0.577938
9996,2.940051e-01,1.574844e+00,1.207531e+00,6.844730e-01,1.495002e+00,1.138033e+00,5.298751e-01,2.379473e-02,6.097629e-01,-1.609777e-01,-3.096827e-02,-3.582997e-01,-1.800818e+00,-1.617724e+00,-5.950602e-01,-7.266458e-01,-1.049188e+00,-8.490244e-01,-3.686168e-01,0.187583
9997,-1.607435e-01,1.112967e+00,1.230602e+00,7.359271e-01,1.032293e+00,9.345460e-01,5.290299e-01,5.222980e-02,2.745289e-01,1.775637e-01,5.592745e-02,-6.599492e-01,-1.757246e+00,-1.129226e+00,-3.598756e-01,-7.171344e-01,-6.747651e-01,-4.975365e-01,-1.791389e-01,0.112159
9998,-6.045716e-01,4.557472e-01,9.274497e-01,4.729508e-01,2.949424e-01,5.907964e-01,3.933558e-01,5.994048e-02,-7.934264e-02,4.215874e-01,2.845534e-01,-8.672668e-01,-1.050253e+00,-3.358490e-01,-2.999604e-02,-5.160999e-01,-2.677085e-01,-1.197667e-01,-3.046964e-02,0.166941


In [27]:
raw = load_preprocess(df, 0, 1, None, bandpass = True, notch = False, reref = True)
sub_eeg = pd.DataFrame(raw.get_data(), index = raw.ch_names).transpose()

full_montage = []
for i in range(4):
    electrodes = chains[i]
    signal_diffs = []
    for j in range(4):
        x = sub_eeg[electrodes[j]].values - sub_eeg[electrodes[j + 1]].values
        signal_diffs.append(x)
    full_montage.append(signal_diffs)

In [28]:
np.asarray(full_montage).shape

(4, 4, 10000)

In [9]:
from scipy.signal import spectrogram

In [30]:
f, t, Sxx = spectrogram(full_montage[0][0], fs = 200, nperseg = 256, noverlap = 128,
                        scaling = 'density', mode = 'psd')

In [32]:
t.shape

(77,)

In [34]:

# Spectrograms params 
#fs = 200 
#window_length = 256     # or, this is number per segment (nperseg)
#window_overlap = 128    # or, n overlap 

# resulting spectrogram is going to be: window_length // 2 + 1 = 129 

#---------------------------------------

# Generate Spectrograms 

#f, t, Sxx = spectrogram( segments..., fs = fs, nperseg = window_length, noverlap = window_overlap)
#F = len(f)     # this is going to be 129 or so 
#T_s = len(t)   # number of windows = ( # of samples - window_overlap) / (window_length - window_overlap)  

F_trim = 128   # you can trim everything down to 128 to keep data even 
T_s = 77

spectrograms = np.zeros((4, 4, 128, 77), dtype = np.float32)

# loop over each segment
for i in range(4): 
    for j in range(4):
        f, t, Sxx = spectrogram(full_montage[i][j], fs = 200, nperseg = 256, 
                                noverlap = 128, scaling = 'density', mode = 'psd'
        )
        
        # now just trim or pad frequency axis to length F_trim
        Sxx = Sxx[:F_trim, :]
        
        # do you want to convert to decibel scale 
        epsilon = 1e-10
        Sxx_db = 10.0 * np.log10(Sxx + epsilon)
        
        # store these results back in the preallocated spectrograms array 
        spectrograms[i][j] = Sxx_db

In [25]:
def reformat_eeg(data, row, highpass, lowpass, bandpass, notch, reref):
    raw = load_preprocess(data, row, highpass, lowpass, bandpass, notch, reref)
    return pd.DataFrame(raw.get_data(), index = raw.ch_names).transpose()

In [26]:
def get_diffs(data, row, highpass, lowpass, bandpass, notch, reref, chains):
    sub_eeg = reformat_eeg(data, row, highpass, lowpass, bandpass, notch, reref)
    full_montage = []
    for i in range(4):
        electrodes = chains[i]
        signal_diffs = []
        for j in range(4):
            x = sub_eeg[electrodes[j]].values - sub_eeg[electrodes[j + 1]].values
            signal_diffs.append(x)
        full_montage.append(signal_diffs)
    return np.asarray(full_montage)

In [27]:
def get_spectrograms(data, row, highpass, lowpass, bandpass, notch, reref, chains):
    signal_diffs = get_diffs(data, row, highpass, lowpass, bandpass, notch, reref, chains)
    spectrograms = np.zeros((4, 4, 128, 77), dtype = np.float32)

    for i in range(4): 
        for j in range(4):
            f, t, Sxx = spectrogram(signal_diffs[i][j], fs = 200, nperseg = 256, 
                                    noverlap = 128, scaling = 'density', mode = 'psd'
            )

            Sxx = Sxx[:128, :]

            epsilon = 1e-10
            Sxx_db = 10.0 * np.log10(Sxx + epsilon)

            spectrograms[i][j] = Sxx_db
        
    return spectrograms

In [28]:
def avg_spectrograms(data, row, highpass, lowpass, bandpass, notch, reref, chains):
    spectrograms = get_spectrograms(data, row, highpass, lowpass, bandpass, notch, reref, chains)
    average_specs = []
    for i in range(4):
        sum_spec = np.zeros((128, 77), dtype = np.float32)
        for j in range(4):
            sum_spec += spectrograms[i][j]
        avg_spec = sum_spec / 4
        average_specs.append(avg_spec)
    return np.asarray(average_specs)

In [45]:
def no_averaging(data, row, highpass, lowpass, bandpass, notch, reref, chains):
    specs = get_spectrograms(data, row, highpass, lowpass, bandpass, notch, reref, chains)
    without_region = []
    for i in range(4):
        for j in range(4):
            without_region.append(specs[i][j])
    return np.asarray(without_region)

In [20]:
a1 = avg_spectrograms(df, 0, 1, None, bandpass = True, notch = False, reref = True, chains = chains)
a2 = avg_spectrograms(df, 1, 1, None, bandpass = True, notch = False, reref = True, chains = chains)

In [58]:
alist = [a1, a2]

In [59]:
np.asarray(alist).shape

(2, 4, 128, 77)

In [21]:
spec_list = []
for i in random_indexes:
    specgram = avg_spectrograms(df, i, 1, None, bandpass = True, notch = False,
                               reref = True, chains = chains)
    spec_list.append(specgram)
spec_list = np.asarray(spec_list)

In [22]:
spec_list.shape

(12000, 4, 128, 77)

In [24]:
spec0 = spec_list[0].copy()

### Storing Spectrograms

In [43]:
f = h5py.File('eeg_spectrograms/spec_{}.hdf5'.format(random_indexes[0]), 'w')
f.create_dataset('spectrogram', data = spec_list[0], dtype = 'float32')

<HDF5 dataset "spectrogram": shape (4, 128, 77), type "<f4">

In [46]:
f['spectrogram'][:].shape

(4, 128, 77)

In [32]:
f = h5py.File('eeg_spectrograms/hdf5_spec{}.hdf5'.format(random_indexes[0]), 'r')

In [35]:
f['spectrogram'][:].shape

(4, 128, 77)

In [48]:
for i in range(spec_list.shape[0]):
    f = h5py.File('eeg_spectrograms/spec_{}.hdf5'.format(random_indexes[i]), 'w')
    f.create_dataset('spectrogram', data = spec_list[i], dtype = 'float32')

### Next  
  
The above code should generate the spectrograms for the electrode chains specified. I need to write code that will loop over the sub EEGs I want spectrograms for, generate those spectrograms, and store those spectrograms in a folder to be worked with later. One question is whether or not I should store each separately? Each sub EEG has four spectrograms. They are currently RP, LP, RL, LL (not in that order). I can make four folders, one for each electrode chain. Or I could store each electrode chain's spectrogram together and just have an EEG spectrograms folder. I don't know which approach is more likely to add difficulty to the process of implementing a CNN with this data.

In [15]:
for i in range(5):
    f = h5py.File('eeg_spectrograms/spec_{}.hdf5'.format(random_indexes[i]), 'r')
    print(f['spectrogram'][:].shape)

(4, 128, 77)
(4, 128, 77)
(4, 128, 77)
(4, 128, 77)
(4, 128, 77)


### Convolutional Neural Network  
  
The idea for this is to have a CNN read over spectrograms like images and then classify them as representing one of six different categories of brain activity. To build my spectrogram dataset, I used the double banana bipolar montage (I'll link to resources below), calculated the signal differences within each electrode chain, calculated the spectrogram for each signal difference (four differences per electrode chain; four electrode chains without the central electrodes included), and then averaged the spectrograms in each electrode chain to get one spectrogram representing each electrode chain in the montage.  
  
This appears to be how the spectrograms which were provided with the EEGs were generated. Spectrograms are separated into four regions: Right Parasagittal, Left Parasagittal, Right Lateral, and Left Lateral. These are the chains of the double banana bipolar montage.  
  
The calculation to generate the spectrograms may be different. I used scipy's spectrogram function. One example on Kaggle used the librosa library. This at least seems to be how this organization of the spectrograms was arrived at, though. I won't be attempting other versions here due to time constraints, but there are many different montages I could use. I could also add in the central electrode chain so that Fz, Cz, and Pz aren't left out. I could also use the 16 spectrograms with four per electrode chain instead of using the average spectrogram for each electrode chain.  
  
I will attempt to make adjustments to the CNN. I can try different kernel sizes, different batch sizes, and even add a third convolutional layer. The main thing is to get initial results to use for comparison with my machine learning attempts and then to discuss what I might do next with more time.  
  
The kernel size establishes the dimensions of the kernel. The kernel is a square window which slides over the image (or spectrogram here) and calculates feature data from that. So a kernel size of three will produce a sliding window that is 3x3.  
  
Padding adds a boundary of pixels around the kernel. These are typically set to 0.  
  
The sliding window (the filter) generates a feature map which is used for the classification. In order to reduce dimensionality and guard against overfitting, pooling layers are used. Like with the kernel, the pooling filter is given a size (like 2x2) and it reads over the feature map. In every spot, it will output a single representation of the data within it. Max pooling takes the maximum value within the filter. Average pooling takes the average of the values within the filter.

In [11]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader 

class SpectrogramDataset(Dataset):
        
    def __init__(self, X, y):
        self.X = X 
        self.y = y
        
    def __len__(self):
        return len(self.y)
      
    def __getitem__(self, idx):
        x = self.X[idx]
        label = self.y[idx]
        
        return torch.tensor(x, dtype = torch.float32), torch.tensor(label, dtype = torch.long)
        
class EEGCNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()   
        
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size = 3, padding = 1) 
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)    
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 3, padding = 1) 
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)   
        
        flattened_size = 32 * (128 // 4) * (77 // 4) 
        
        self.fc1 = nn.Linear(flattened_size, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes) 
        
    def forward(self, x):
        
        x = self.conv1(x)      
        x = self.relu1(x)
        x = self.pool1(x)     
        
        x = self.conv2(x)    
        x = self.relu2(x)
        x = self.pool2(x)     
        
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        
        logits = self.fc2(x) 
        return logits

In [23]:
input_size = 9856
num_classes = 6
learning_rate = 0.001
num_epochs = 50

In [13]:
X = np.load('eeg_specs.npy')
y = np.load('eeg_labels.npy')

In [14]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y_encoded = le.fit_transform(y)

In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size = 0.1, random_state = 42)

In [16]:
train_ds = SpectrogramDataset(X_train, y_train)
test_ds = SpectrogramDataset(X_test, y_test)

In [17]:
train_loader = DataLoader(train_ds, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_ds, batch_size = 32, shuffle = False)

In [18]:
model = EEGCNN(in_channels = 4, num_classes = num_classes)

In [19]:
from torch import optim

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

In [21]:
from tqdm import tqdm

In [24]:
for epoch in range(num_epochs):
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    for batch_index, (data, targets) in enumerate(tqdm(train_loader)):
        
        scores = model(data)
        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

Epoch [1/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:41<00:00,  8.13it/s]


Epoch [2/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:29<00:00, 11.56it/s]


Epoch [3/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.29it/s]


Epoch [4/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.64it/s]


Epoch [5/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:24<00:00, 13.91it/s]


Epoch [6/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.33it/s]


Epoch [7/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.26it/s]


Epoch [8/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.06it/s]


Epoch [9/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.07it/s]


Epoch [10/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.10it/s]


Epoch [11/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.24it/s]


Epoch [12/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:30<00:00, 11.02it/s]


Epoch [13/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.67it/s]


Epoch [14/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.03it/s]


Epoch [15/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.31it/s]


Epoch [16/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.49it/s]


Epoch [17/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.91it/s]


Epoch [18/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.21it/s]


Epoch [19/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:24<00:00, 13.82it/s]


Epoch [20/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.38it/s]


Epoch [21/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.35it/s]


Epoch [22/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:24<00:00, 13.74it/s]


Epoch [23/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.19it/s]


Epoch [24/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.02it/s]


Epoch [25/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.27it/s]


Epoch [26/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.58it/s]


Epoch [27/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.76it/s]


Epoch [28/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:24<00:00, 13.86it/s]


Epoch [29/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:22<00:00, 14.74it/s]


Epoch [30/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:22<00:00, 14.85it/s]


Epoch [31/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:22<00:00, 14.80it/s]


Epoch [32/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.40it/s]


Epoch [33/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.66it/s]


Epoch [34/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:22<00:00, 14.85it/s]


Epoch [35/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.64it/s]


Epoch [36/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:22<00:00, 14.81it/s]


Epoch [37/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.64it/s]


Epoch [38/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.59it/s]


Epoch [39/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:23<00:00, 14.18it/s]


Epoch [40/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.25it/s]


Epoch [41/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.90it/s]


Epoch [42/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.22it/s]


Epoch [43/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.22it/s]


Epoch [44/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.61it/s]


Epoch [45/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 13.00it/s]


Epoch [46/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.88it/s]


Epoch [47/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.90it/s]


Epoch [48/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:27<00:00, 12.39it/s]


Epoch [49/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:26<00:00, 12.72it/s]


Epoch [50/50]


100%|█████████████████████████████████████████████████████████████████████████████████| 338/338 [00:25<00:00, 13.16it/s]


In [25]:
def check_accuracy(loader, model):
    if loader == train_loader:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")

    num_correct = 0
    num_samples = 0
    model.eval()  

    with torch.no_grad():  
        for x, y in loader:
            scores = model(x)
            _, predictions = scores.max(1)  
            num_correct += (predictions == y).sum()  
            num_samples += predictions.size(0)  

        accuracy = float(num_correct) / float(num_samples) * 100
        print(f"Got {num_correct}/{num_samples} with accuracy {accuracy:.2f}%")
    
    model.train()  

check_accuracy(train_loader, model)
check_accuracy(test_loader, model)

Checking accuracy on training data
Got 10346/10800 with accuracy 95.80%
Checking accuracy on test data
Got 671/1200 with accuracy 55.92%
