In [23]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from PIL import Image
import os
import copy
import pickle
from sklearn.svm import SVC
import sklearn.metrics as metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import confusion_matrix


In [5]:
with open('/content/svm_20230407-222019.pickle','rb') as f:
  clf = pickle.load(f)

with open('/content/svm_feature_20230407-222019.pickle','rb') as f:
  feature_dict = pickle.load(f)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [56]:
train_features = feature_dict['train_features']
train_labels = feature_dict['train_labels']
test_features = feature_dict['test_features']
test_labels = feature_dict['test_labels']

svm = SVC(kernel='linear') # use sigmoid, rbf, poly will result in prediction number of 0
svm.fit(train_features, train_labels)

test_predictions = svm.predict(test_features)
accuracy = accuracy_score(test_labels, test_predictions)
balanced_accuracy = balanced_accuracy_score(test_labels, test_predictions)
print(accuracy)
print(balanced_accuracy)
confusion_matrix(test_labels, test_predictions)

0.4233128834355828
0.36754549214226634


array([[21, 27,  4],
       [30, 43,  7],
       [11, 15,  5]])

In [57]:
def get_stats(predictions, labels, num_classes):
  recalls = []

  TP = np.zeros(3)
  TN = np.zeros(3)
  FP = np.zeros(3)
  FN = np.zeros(3)

  for i in range(num_classes):
    y_true = labels == i
    y_pred = predictions == i

    TN[i], FP[i], FN[i], TP[i] = confusion_matrix(y_true, y_pred).ravel()

  return TP, TN, FP, FN

def get_eval(TP, TN, FP, FN, num_classes):
  macro_avg_precision = 0
  macro_avg_sensitivity = 0
  macro_avg_f1 = 0
  macro_avg_specification = 0
  
  
  for i in range(num_classes):
    macro_avg_precision += TP[i] / (TP[i] + FP[i])
    macro_avg_sensitivity += TP[i] / (TP[i] + FN[i])
    macro_avg_specification += TN[i] / (FP[i] + TN[i])

  macro_avg_precision /= num_classes
  macro_avg_sensitivity /= num_classes
  macro_avg_specification /= num_classes

  macro_avg_f1 = 2 * (macro_avg_precision * macro_avg_sensitivity) / (macro_avg_precision + macro_avg_sensitivity)

  return macro_avg_precision, macro_avg_sensitivity, macro_avg_specification, macro_avg_f1

In [58]:
TP, TN, FP, FN = get_stats(test_predictions, test_labels, 3)
macro_avg_precision, macro_avg_sensitivity, macro_avg_specification, macro_avg_f1 = get_eval(TP, TN, FP, FN, 3)
print('macro_avg_precision: ', macro_avg_precision)
print('macro_avg_sensitivity: ', macro_avg_sensitivity)
print('macro_avg_specification: ', macro_avg_specification)
print('macro_avg_f1: ', macro_avg_f1)

macro_avg_precision:  0.3856973434535104
macro_avg_sensitivity:  0.36754549214226634
macro_avg_specification:  0.6804244003039184
macro_avg_f1:  0.376402703665835


In [7]:
resnet18 = torch.load('/content/restnet18_20230408-005105.pth')
resnet_log = open('/content/restnet18_20230408-005105.log', "r")