### This is the notebook that help use evaluate the performance from the result we are getting including precision, sensitivity, etc.

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from PIL import Image
import os
import copy
import pickle
from sklearn.svm import SVC
import sklearn.metrics as metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import confusion_matrix


In [2]:
with open('../models/SVM/svm_20230412-163721.pickle','rb') as f:
  clf = pickle.load(f)

with open('../models/SVM/svm_feature_20230412-163721.pickle','rb') as f:
  feature_dict = pickle.load(f)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [3]:
train_features = feature_dict['train_features']
train_labels = feature_dict['train_labels']
test_features = feature_dict['test_features']
test_labels = feature_dict['test_labels']

svm = SVC(kernel='linear') # use sigmoid, rbf, poly will result in prediction number of 0
svm.fit(train_features, train_labels)

test_predictions = svm.predict(test_features)
accuracy = accuracy_score(test_labels, test_predictions)
balanced_accuracy = balanced_accuracy_score(test_labels, test_predictions)
print(accuracy)
print(balanced_accuracy)
confusion_matrix(test_labels, test_predictions)

0.4110429447852761
0.36804177005789906


array([[22, 28,  2],
       [33, 39,  8],
       [ 8, 17,  6]])

In [7]:
num_params = len(svm.coef_[0]) + 1

print("Number of trainable parameters for linear SVM:", num_params)

Number of trainable parameters for linear SVM: 1961


In [10]:
def get_stats(predictions, labels, num_classes):
  recalls = []

  TP = np.zeros(3)
  TN = np.zeros(3)
  FP = np.zeros(3)
  FN = np.zeros(3)

  for i in range(num_classes):
    y_true = labels == i
    y_pred = predictions == i

    TN[i], FP[i], FN[i], TP[i] = confusion_matrix(y_true, y_pred).ravel()

  return TP, TN, FP, FN

def get_eval(TP, TN, FP, FN, num_classes):
  macro_avg_precision = 0
  macro_avg_sensitivity = 0
  macro_avg_f1 = 0
  macro_avg_specification = 0
  
  
  for i in range(num_classes):
    macro_avg_precision += TP[i] / (TP[i] + FP[i])
    macro_avg_sensitivity += TP[i] / (TP[i] + FN[i])
    macro_avg_specification += TN[i] / (FP[i] + TN[i])

  macro_avg_precision /= num_classes
  macro_avg_sensitivity /= num_classes
  macro_avg_specification /= num_classes

  macro_avg_f1 = 2 * (macro_avg_precision * macro_avg_sensitivity) / (macro_avg_precision + macro_avg_sensitivity)

  return macro_avg_precision, macro_avg_sensitivity, macro_avg_specification, macro_avg_f1

In [11]:
TP, TN, FP, FN = get_stats(test_predictions, test_labels, 3)
macro_avg_precision, macro_avg_sensitivity, macro_avg_specification, macro_avg_f1 = get_eval(TP, TN, FP, FN, 3)
print('macro_avg_precision: ', macro_avg_precision)
print('macro_avg_sensitivity/recall: ', macro_avg_sensitivity)
print('macro_avg_specificity: ', macro_avg_specification)
print('macro_avg_f1: ', macro_avg_f1)

macro_avg_precision:  0.39616402116402116
macro_avg_sensitivity/recall:  0.36804177005789906
macro_avg_specificity:  0.6709014600580865
macro_avg_f1:  0.38158545579542874


In [6]:
resnet18 = torch.load('../models/ResNet18/restnet18_20230418-204835.pt')
resnet_log = open('../models/ResNet18/restnet18_20230418-204835.log', "r")

with open('../models/ResNet18/restnet18_predictlabel_20230418-204835.pickle','rb') as f:
    resnet_prediction = pickle.load(f)

In [7]:
test_predictions = np.asarray(resnet_prediction['prediction'])
test_labels = np.asarray(resnet_prediction['label'])

accuracy = accuracy_score(test_labels, test_predictions)
balanced_accuracy = balanced_accuracy_score(test_labels, test_predictions)
TP, TN, FP, FN = get_stats(test_predictions, test_labels, 3)
macro_avg_precision, macro_avg_sensitivity, macro_avg_specification, macro_avg_f1 = get_eval(TP, TN, FP, FN, 3)
print('Test Accuracy: ', accuracy)
print('Test Balanced Accuracy: ', balanced_accuracy)
print('macro_avg_precision: ', macro_avg_precision)
print('macro_avg_sensitivity/recall: ', macro_avg_sensitivity)
print('macro_avg_specificity: ', macro_avg_specification)
print('macro_avg_f1: ', macro_avg_f1)

Test Accuracy:  0.48466257668711654
Test Balanced Accuracy:  0.5075682382133996
macro_avg_precision:  0.49445465060896776
macro_avg_sensitivity/recall:  0.5075682382133996
macro_avg_specificity:  0.7453532393291429
macro_avg_f1:  0.5009256349043463
