Computing the global statistics of the predictions when we train different models from a cross-validation approach.

In [7]:
import os
import json
import numpy as np

from imgclas.data_utils import load_image, load_class_names
from imgclas import paths, plot_utils

from imgclas import test_utils
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix

import warnings
warnings.filterwarnings("ignore") # To ignore UndefinedMetricWarning: [Recall/Precision/F-Score] is ill-defined and being set to 0.0 in labels with no [true/predicted] samples.

# User parameters to set
timestamp = ['2022-03-22_TortSp1_50ep_4Batch', '2022-03-23_TortSp2_50ep_4Batch',
            '2022-03-31_TortSp3_16ep_stop_8Batch', '2022-03-31_TortSp4_25ep_stop_8Batch',
            '2022-04-01_TortSp5_22ep_stop_8Batch']      # timestamp of the model
SPLIT_NAME = 'test'                   # dataset split to predict
MODEL_NAME = 'final_model.h5'         # model to use to make the mediction
TOP_K = 2                             # number of top classes predictions to save

accs = []
sens = []
specs = []
for TIMESTAMP in timestamp:

    # Set the timestamp
    paths.timestamp = TIMESTAMP

    # Load clas names
    class_names = load_class_names(splits_dir=paths.get_ts_splits_dir())

    # Load back the predictions
    pred_path = os.path.join(paths.get_predictions_dir(), '{}+{}+top{}.json'.format(MODEL_NAME, SPLIT_NAME, TOP_K))
    with open(pred_path) as f:
        pred_dict = json.load(f)
    
    # accuracy
    true_lab, pred_lab = np.array(pred_dict['true_lab']), np.array(pred_dict['pred_lab'])
    top1 = test_utils.topK_accuracy(true_lab, pred_lab, K=1)
    accs.append(top1)
    
    y_pred = np.array([item[0] for item in pred_lab])
    # standard confussion matrix
    TN, FP, FN, TP = confusion_matrix(true_lab, y_pred, labels=[0, 1]).ravel()
    sensitivity  = TP/(TP+FN)
    specificity  = TN/(TN+FP)
    pos_pred_val = TP/(TP+FP)
    neg_pred_val = TN/(TN+FN)
    sens.append(sensitivity)
    specs.append(specificity)

Loading class names...
Loading class names...
Loading class names...
Loading class names...
Loading class names...


In [8]:
accs

[0.825, 0.875, 0.825, 0.825, 0.825]

In [9]:
specs

[0.8, 0.8, 0.9, 0.8571428571428571, 0.7619047619047619]

In [10]:
sens

[0.85, 0.95, 0.75, 0.7894736842105263, 0.8947368421052632]