In [None]:
import sys
sys.path.append("../scripts")

In [None]:
import numpy as np
import torchmetrics
import Utils_00 as utils
import torch

In [None]:
# load predictions from Inference Notebook
preds = torch.load("../eval/GESPIC_preds.pt")

In [None]:
# For Validation Dataset only: Determine optimal cut-off value
thresholds = np.arange(0.0, 1.01, 0.01)
predictions = preds["preds"]

acc_list = []
for i in thresholds:
    acc_list.append(
        torchmetrics.functional.classification.multiclass_accuracy(
            torch.where(preds["preds"][:,1]>i,1,0), preds["targets"][:,1].long(), num_classes=2, average='macro'
        )
    )
cut_off = thresholds[np.argmax(acc_list)]
print(f"best accuracy is: {max(acc_list)}, at a cut-off value of {cut_off}")
specificity = torchmetrics.functional.classification.binary_specificity(
    preds["preds"][:,1], preds["targets"][:,1].long(), threshold=cut_off
)
sensitivity = torchmetrics.functional.classification.binary_recall(
    preds["preds"][:,1], preds["targets"][:,1].long(), threshold=cut_off
)
print(f"->resulting specificity: {specificity}")
print(f"->resulting sensitivity: {sensitivity}")

In [None]:
# For Test Datasets only: setting cut-off value, as determined on the validation dataset
cut_off = 0.59

## Regular Inference

In [None]:
utils.classification_report(preds["preds"], preds["targets"], threshhold=cut_off)

In [None]:
utils.metric_w_CI(preds, torchmetrics.classification.Accuracy(num_classes=2, task='multiclass', average="macro"), threshold=cut_off, metric_name="Balanced Accuracy")
utils.metric_w_CI(preds, torchmetrics.classification.Accuracy(num_classes=2, task='multiclass'), threshold=cut_off, metric_name="Raw Accuracy")

utils.metric_w_CI(preds, torchmetrics.classification.BinaryRecall(), threshold=cut_off, metric_name="Sensitivity")
utils.metric_w_CI(
    preds,
    torchmetrics.classification.BinarySpecificity(),
    threshold=cut_off,
    metric_name="Specificty",
)
utils.metric_w_CI(
    preds, torchmetrics.classification.BinaryMatthewsCorrCoef(), threshold=cut_off, metric_name="MCC"
)
utils.metric_w_CI(
    preds, torchmetrics.classification.BinaryPrecision(), threshold=cut_off, metric_name="Precision"
)
utils.metric_w_CI(
    preds,
    torchmetrics.classification.BinaryCohenKappa(),
    threshold=cut_off,
    metric_name="Cohen Kappa",
)
utils.metric_w_CI(
    preds,
    torchmetrics.classification.BinaryAUROC(),
    metric_name="AUROC",
)

In [None]:
utils.Sensitivity_vs_FPR_val(
    preds["preds"],
    preds["targets"],
    threshhold=cut_off,
    loc='lower left',
    save_path="../Sensitivity_vs_FPR.svg",
)

In [None]:
utils.class_probs_hist(
    preds["preds"],
    preds["targets"],
    threshhold=cut_off,
    save_path="../Confidence.svg",
    legend_pos="upper right",
    y_ticks=10
)

In [None]:
utils.roc_plot(
    preds["preds"],
    preds["targets"],
    save_path="../ROC.svg",
)

In [None]:
utils.confusion_matrix(
    preds["preds"],
    preds["targets"],
    threshhold=cut_off,
    save_path="../confusionmatrix.svg",
)