# Model Testing

## This notebook is meant to load trained models, and test their performance



In [3]:
import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from dataset_spectrogram import EEGDataset
from torch.utils.data import random_split
import neptune.new as neptune
from torchinfo import summary
from dataset_spectrogram import load_dataset
import random
import torch.utils.data as data
from datetime import datetime

In [4]:
# Load a saved model

device = 'cuda' if torch.cuda.is_available() else 'cpu' #Check for cuda 

model = torch.jit.load("../trained_models/model_05_11_2022_09_02_10",map_location=device)


### Load and test with 50-50 test dataset

In [9]:
# Load the dataset
raw_data_dir = '../data'
testNights = 8

print("\nTest set\n")
test_set = load_dataset(range(testNights), raw_data_dir, normalized = False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, drop_last = True)

# Test the model on the test set

size = len(test_loader.dataset)
num_batches = len(test_loader)
correct = 0

# Truth table variables
true_pos,true_neg, false_pos, false_neg = 0,0,0,0

with torch.no_grad():
    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)
        pred = model(X).reshape(-1).to(device) # Reshape to 1 dimension if using binary classification, otherwise keep dimensions from model output
        correct += (pred.round() == y).type(torch.float).sum().item()
        
        # Calculate the truth table
        for i, pred_val in enumerate(pred.round()):
            if y[i] == 1:
                if pred_val == 1:
                    true_pos += 1
                else:
                    false_neg += 1
            else:
                if pred_val == 1:
                    false_pos += 1
                else:
                    true_neg += 1
                    
        

correct /= size
print(f"Test accuracy: {correct}")
print(f"Test size: {size}")

# Print the truth table
print("\nTruth table\n")
print(f"True pos: {true_pos}")
print(f"True neg: {true_neg}")
print(f"False pos: {false_pos}")
print(f"False neg: {false_neg}")

print("\n\n")



Test set

../data/study_1A_mat_simple/S_01/night_1/spectrogram_bad_segments_unnormalized.npy
../data/study_1A_mat_simple/S_01/night_1/spectrogram_good_segments_unnormalized.npy
Memory usage: 15.235964 MB

Lengths:

Good data length: 7368
Bad data length: 7368
Caluculated length: 14735
../data/study_1A_mat_simple/S_01/night_2/spectrogram_bad_segments_unnormalized.npy
../data/study_1A_mat_simple/S_01/night_2/spectrogram_good_segments_unnormalized.npy
Memory usage: 15.239309 MB

Lengths:

Good data length: 4292
Bad data length: 4292
Caluculated length: 8583
../data/study_1A_mat_simple/S_01/night_3/spectrogram_bad_segments_unnormalized.npy
../data/study_1A_mat_simple/S_01/night_3/spectrogram_good_segments_unnormalized.npy
Memory usage: 15.2429 MB

Lengths:

Good data length: 6110
Bad data length: 6110
Caluculated length: 12219
../data/study_1A_mat_simple/S_01/night_4/spectrogram_bad_segments_unnormalized.npy
../data/study_1A_mat_simple/S_01/night_4/spectrogram_good_segments_unnormalized.n

### Load and test full night data