In [1]:
import numpy as np
from medpy import metric


def assert_shape(test, reference):

    assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
        test.shape, reference.shape)


class ConfusionMatrix:

    def __init__(self, test=None, reference=None):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.reference_empty = None
        self.reference_full = None
        self.test_empty = None
        self.test_full = None
        self.set_reference(reference)
        self.set_test(test)

    def set_test(self, test):

        self.test = test
        self.reset()

    def set_reference(self, reference):

        self.reference = reference
        self.reset()

    def reset(self):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.test_empty = None
        self.test_full = None
        self.reference_empty = None
        self.reference_full = None

    def compute(self):

        if self.test is None or self.reference is None:
            raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")

        assert_shape(self.test, self.reference)

        self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
        self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
        self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
        self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
        self.size = int(np.prod(self.reference.shape, dtype=np.int64))
        self.test_empty = not np.any(self.test)
        self.test_full = np.all(self.test)
        self.reference_empty = not np.any(self.reference)
        self.reference_full = np.all(self.reference)

    def get_matrix(self):

        for entry in (self.tp, self.fp, self.tn, self.fn):
            if entry is None:
                self.compute()
                break

        return self.tp, self.fp, self.tn, self.fn

    def get_size(self):

        if self.size is None:
            self.compute()
        return self.size

    def get_existence(self):

        for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
            if case is None:
                self.compute()
                break

        return self.test_empty, self.test_full, self.reference_empty, self.reference_full


def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """2TP / (2TP + FP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty and reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    return float(2 * tp / (2 * tp + fp + fn))


def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty and reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    return float(tp / (tp + fp + fn))


def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FP)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0
        
    return float(tp / (tp + fp))


def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    return float(tp / (tp + fn))


def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FN)"""

    return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)


def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FP)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    return float(tn / (tn + fp))


def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
    """(TP + TN) / (TP + FP + FN + TN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return float((tp + tn) / (tp + fp + tn + fn))


def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):
    """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""

    precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)
    recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)

    return (1 + beta*beta) * precision_ * recall_ /\
        ((beta*beta * precision_) + recall_)


def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FP / (FP + TN)"""

    return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)


def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FN / (TN + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    return float(fn / (fn + tn))


def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FN / (TP + FN)"""

    return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)


def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FP)"""

    return specificity(test, reference, confusion_matrix, nan_for_nonexisting)


def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FP / (TP + FP)"""

    return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)


def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FN)"""

    return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)


def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TP + FP"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tp + fp


def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TN + FN"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tn + fn


def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TP + FN"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tp + fn


def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TN + FP"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tn + fp


def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.hd(test, reference, voxel_spacing, connectivity)


def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.hd95(test, reference, voxel_spacing, connectivity)


def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.asd(test, reference, voxel_spacing, connectivity)


def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.assd(test, reference, voxel_spacing, connectivity)


ALL_METRICS = {
    "False Positive Rate": false_positive_rate,
    "Dice": dice,
    "Jaccard": jaccard,
    "Hausdorff Distance": hausdorff_distance,
    "Hausdorff Distance 95": hausdorff_distance_95,
    "Precision": precision,
    "Recall": recall,
    "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,
    "Avg. Surface Distance": avg_surface_distance,
    "Accuracy": accuracy,
    "False Omission Rate": false_omission_rate,
    "Negative Predictive Value": negative_predictive_value,
    "False Negative Rate": false_negative_rate,
    "True Negative Rate": true_negative_rate,
    "False Discovery Rate": false_discovery_rate,
    "Total Positives Test": total_positives_test,
    "Total Negatives Test": total_negatives_test,
    "Total Positives Reference": total_positives_reference,
    "total Negatives Reference": total_negatives_reference
}


In [2]:
import timeit
import torch
import argparse
import skimage, os
from glob import glob
import numpy as np
from skimage.io import imread
import nibabel as nib
import csv

BASE_IMG_PATH=os.path.join('/','home','asma','Documents','GPU','final_results01','Task02_Heart','Task02_Heart')
gt=sorted(glob(os.path.join(BASE_IMG_PATH,'gt_3dsrn3','*.nii')))
out=sorted(glob(os.path.join(BASE_IMG_PATH,'resnnUNet_seg_output3','*.nii')))

f_metrics = [
        "False Positive Rate",
        "Dice",
        "Jaccard",
        "Hausdorff Distance",
        "Hausdorff Distance 95",
        "Precision",
        "Recall",
        "Avg. Symmetric Surface Distance",
        "Avg. Surface Distance",
        "Accuracy",
        "False Omission Rate",
        "Negative Predictive Value",
        "False Negative Rate",
        "True Negative Rate",
        "False Discovery Rate",
        "Total Positives Test",
        "Total Negatives Test",
        "Total Positives Reference",
        "total Negatives Reference"]
with open(os.path.join(BASE_IMG_PATH,'my_metrics_3DSRNet_test3.csv'), 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(f_metrics)

    metrics = []

    for i in range(len(gt)):
        test = nib.load(out[i]).get_fdata()
        reference = nib.load(gt[i]).get_fdata()
        test = np.atleast_1d(test.astype(np.bool))
        reference = np.atleast_1d(reference.astype(np.bool))
        confusion_matrix = ConfusionMatrix(test, reference)
        tp, fp, tn, fn = confusion_matrix.get_matrix()
    
        my_metrics = {
            false_positive_rate(test, reference,confusion_matrix),
            dice(test, reference,confusion_matrix),
            jaccard(test, reference,confusion_matrix),
            hausdorff_distance(test, reference,confusion_matrix),
            hausdorff_distance_95(test, reference,confusion_matrix),
            precision(test, reference,confusion_matrix),
            recall(test, reference,confusion_matrix),
            avg_surface_distance_symmetric(test, reference,confusion_matrix),
            avg_surface_distance(test, reference,confusion_matrix),
            accuracy(test, reference,confusion_matrix),
            false_omission_rate(test, reference,confusion_matrix),
            negative_predictive_value(test, reference,confusion_matrix),
            false_negative_rate(test, reference,confusion_matrix),
            true_negative_rate(test, reference,confusion_matrix),
            false_discovery_rate(test, reference,confusion_matrix),
            total_positives_test(test, reference,confusion_matrix),
            total_negatives_test(test, reference,confusion_matrix),
            total_positives_reference(test, reference,confusion_matrix),
            total_negatives_reference(test, reference,confusion_matrix)
            }
        writer.writerow(my_metrics)
#        writer.writerow({tp,fp,tn,fn})
        prec = float(tp / (tp + fn))
        rec = float(tp / (tp + fp))
        writer.writerow({prec,rec})
#        print(metric.binary.hd(test, reference))
        
#    metrics = np.hstack((metrics,my_metrics))

#    print(confusion_matrix.get_matrix())
#    tp, fp, tn, fn = confusion_matrix.get_matrix()
#    print(float(tp / (tp + fn)))
        print(metric.hd95(test,reference))


7.0
7.0
14.798648586948742
3.605551275463989
5.830951894845301
4.242640687119285


In [4]:
tp = ((test != 0) * (reference != 0)).sum()
tn = ((test == 0) * (reference == 0)).sum()
fp = ((test != 0) * (reference == 0)).sum()
fn = ((test == 0) * (reference != 0)).sum()

print(tp)
print(tn)
print(fp)
print(fn)

22253
11128392
601
10354


In [9]:
#my_metrics = {
#        "False Positive Rate": false_positive_rate(test, reference),
#        "Dice": dice(test, reference),
#        "Jaccard": jaccard(test, reference),
#        "Hausdorff Distance": hausdorff_distance(test, reference),
#        "Hausdorff Distance 95": hausdorff_distance_95(test, reference),
#        "Precision": precision(test, reference),
#        "Recall": recall(test, reference),
#        "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric(test, reference),
#        "Avg. Surface Distance": avg_surface_distance(test, reference),
#        "Accuracy": accuracy(test, reference),
#        "False Omission Rate": false_omission_rate(test, reference),
#        "Negative Predictive Value": negative_predictive_value(test, reference),
#        "False Negative Rate": false_negative_rate(test, reference),
#        "True Negative Rate": true_negative_rate(test, reference),
#        "False Discovery Rate": false_discovery_rate(test, reference),
#        "Total Positives Test": total_positives_test(test, reference),
#        "Total Negatives Test": total_negatives_test(test, reference),
#        "Total Positives Reference": total_positives_reference(test, reference),
#        "total Negatives Reference": total_negatives_reference(test, reference)
#    } 
#my_metrics = np.array(my_metrics)
#metrics = np.array([
#        "False Positive Rate","Dice",
#        "Jaccard",
#        "Hausdorff Distance",
#        "Hausdorff Distance 95",
#        "Precision",
#        "Recall",
#        "Avg. Symmetric Surface Distance",
#        "Avg. Surface Distance",
#        "Accuracy",
#        "False Omission Rate",
#        "Negative Predictive Value",
#        "False Negative Rate",
#        "True Negative Rate",
#        "False Discovery Rate",
#        "Total Positives Test",
#        "Total Negatives Test",
#        "Total Positives Reference",
#        "total Negatives Reference"])    

#f_metrics = np.hstack((my_metrics,metrics))

print(metrics)


[{0.002147686882388311, 0.6909669804847254, 0.5278453409376834, 0.534355033856212, 0.0, 2.598223623397627, 2.9179019091403826, 7.0, 0.9978009765625, 0.9999429994773807, 0.022558726460881418, 11.445523142259598, 47111, 10214245, 0.46564496614378803, 5.7000522619249556e-05, 0.9978523131176117, 25755, 10192889}
 {0.0021754115045494737, 0.7383928221604044, 0.5852795031055901, 0.5858131631299734, 0.0, 3.0114992784993113, 3.3880742187589323, 7.0, 0.9978264973958333, 0.9999952005640647, 0.0015540564405043256, 48256, 9167744, 9187687, 15.297058540778355, 28313, 0.9978245884954505, 4.799435935383885e-06, 0.41418683687002655}
 {0.0030679614539380617, 0.5901727931243004, 0.4186135650142367, 0.4225391035206986, 0.0, 5.588831503392134, 0.99688935546875, 7.783421062628765, 0.9999500280834317, 0.02171131206278787, 0.5774608964793013, 54279, 10216556, 14.798648586948742, 113.81124724736128, 23444, 0.9969320385460619, 4.997191656830184e-05, 10185721}
 {0.0009734641830030366, 0.8037828421090056, 0.67193

In [4]:
from tabulate import tabulate

print(tabulate(metrics))

-----------  --------  --------  --------  --------  --------  --------  -------  --------  ---------  -----------  ------------  ---------------  ---------------  ------------  ---------------  -----------  ---------------  ---------------
0.00214769   0.690967  0.527845  0.534355  0.977441  2.59822   2.9179    7        0.997801  0.999943   0.0225587       11.4455    47111                1.02142e+07      0.465645      5.70005e-05  0.997852     25755                1.01929e+07
0.00217541   0.738393  0.58528   0.585813  0.998446  3.0115    3.38807   7        0.997826  0.999995   0.00155406   48256             9.16774e+06      9.18769e+06     15.2971    28313            0.997825         4.79944e-06      0.414187
0.00306796   0.590173  0.418614  0.422539  0.978289  5.58883   0.996889  7.78342  0.99995   0.0217113  0.577461     54279             1.02166e+07     14.7986         113.811     23444            0.996932         4.99719e-05      1.01857e+07
0.000973464  0.803783  0.671937  3.6055

In [5]:
import csv

f_metrics = [
        "False Positive Rate","Dice",
        "Jaccard",
        "Hausdorff Distance",
        "Hausdorff Distance 95",
        "Precision",
        "Recall",
        "Avg. Symmetric Surface Distance",
        "Avg. Surface Distance",
        "Accuracy",
        "False Omission Rate",
        "Negative Predictive Value",
        "False Negative Rate",
        "True Negative Rate",
        "False Discovery Rate",
        "Total Positives Test",
        "Total Negatives Test",
        "Total Positives Reference",
        "total Negatives Reference"]
with open(os.path.join(BASE_IMG_PATH,'my_metrics_3DSRNet_test1.csv'), 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(f_metrics)
    writer.writerows(metrics)

# Display Confusion matrix

In [24]:
import pandas as pd
import numpy as np
import altair as alt
critical_value_dict = {70:1.04, 75:1.15, 80:1.28, 85:1.44, 90:1.64 , 95:1.96 , 98:2.33 , 99:2.58}

def odds_ratio(a, b, c, d):
    if a==0 or np.isnan(a) or b==0 or np.isnan(b) or c==0 or np.isnan(c) or d==0 or np.isnan(d):
        a = 0.5 if np.isnan(a) else a + 0.5
        b = 0.5 if np.isnan(b) else b + 0.5
        c = 0.5 if np.isnan(c) else c + 0.5
        d = 0.5 if np.isnan(d) else d + 0.5

    return (a*d)/(b*c)

def odds_ratio_lower_ci(OR, a, b, c, d, confidence_level):
    if a==0 or np.isnan(a) or b==0 or np.isnan(b) or c==0 or np.isnan(c) or d==0 or np.isnan(d):
        a = 0.5 if np.isnan(a) else a + 0.5
        b = 0.5 if np.isnan(b) else b + 0.5
        c = 0.5 if np.isnan(c) else c + 0.5
        d = 0.5 if np.isnan(d) else d + 0.5

    return np.exp(np.log(OR) - critical_value_dict[confidence_level]*np.sqrt(1/a + 1/b + 1/c + 1/d))

def odds_ratio_upper_ci(OR, a, b, c, d, confidence_level):
    if a==0 or np.isnan(a) or b==0 or np.isnan(b) or c==0 or np.isnan(c) or d==0 or np.isnan(d):
        a = 0.5 if np.isnan(a) else a + 0.5
        b = 0.5 if np.isnan(b) else b + 0.5
        c = 0.5 if np.isnan(c) else c + 0.5
        d = 0.5 if np.isnan(d) else d + 0.5

    return np.exp(np.log(OR) + critical_value_dict[confidence_level]*np.sqrt(1/a + 1/b + 1/c + 1/d))

def confusion_matrix_data(Yy, Yn, Ny, Nn):
    CM = pd.DataFrame({'label':['Yy','Yn','Ny','Nn', 
                                'y|Y','n|Y','n|N','y|N',
                                'Y|y','N|y','N|n','Y|n',
                                'Y','N','y','n',
                                'Y*','N*','y*','n*',
                                'OR_lci90','OR_lci95','OR_lci99','OR','OR_uci90','OR_uci95','OR_uci99', '1',
                                'ACC','ACC-','F1','F1-'], 
                       'value':[Yy,  Yn,  Ny,  Nn,   
                                0 if Yy+Yn==0 else Yy/(Yy+Yn), 
                                0 if Yy+Yn==0 else Yn/(Yy+Yn), 
                                0 if Ny+Nn==0 else Nn/(Ny+Nn), 
                                0 if Ny+Nn==0 else Ny/(Ny+Nn),
                                0 if Yy+Ny==0 else Yy/(Yy+Ny), 
                                0 if Yy+Ny==0 else Ny/(Yy+Ny), 
                                0 if Yn+Nn==0 else Nn/(Yn+Nn), 
                                0 if Yn+Nn==0 else Yn/(Yn+Nn),
                                Yy+Yn, Ny+Nn, Yy+Ny, Yn+Nn, 
                                (Yy+Yn)/(Yy+Yn+Ny+Nn), (Ny+Nn)/(Yy+Yn+Ny+Nn), 
                                (Yy+Ny)/(Yy+Yn+Ny+Nn), (Yn+Nn)/(Yy+Yn+Ny+Nn),
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 90), 
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 95), 
                                odds_ratio_lower_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 99), 
                                odds_ratio(Yy, Yn, Ny, Nn), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 90), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 95), 
                                odds_ratio_upper_ci(odds_ratio(Yy, Yn, Ny, Nn), Yy, Yn, Ny, Nn, 99), 
                                1,
                                (Yy+Nn)/(Yy+Yn+Ny+Nn), (Yn+Ny)/(Yy+Yn+Ny+Nn),
                                0 if Yy==0 or Yy+Yn==0 or Yy+Ny==0 else 2 * ((Yy/(Yy+Yn)) * (Yy/(Yy+Ny))) / ((Yy/(Yy+Yn)) + (Yy/(Yy+Ny))),
                                1 if Yy==0 or Yy+Yn==0 or Yy+Ny==0 else 1 - (2 * ((Yy/(Yy+Yn)) * (Yy/(Yy+Ny))) / ((Yy/(Yy+Yn)) + (Yy/(Yy+Ny))))
                               ]})


    colours = alt.Scale(domain=['Yy','Yn','Ny','Nn', 
                                'y|Y','n|Y','n|N','y|N',
                                'Y|y','N|y','N|n','Y|n',
                                'Y','N','y','n',
                                'Y*','N*',
                                'y*','n*',
                                'OR_lci90','OR_lci95','OR_lci99','OR','OR_uci90','OR_uci95','OR_uci99', '1',
                                'ACC','ACC-','F1','F1-'], 
                        range =['snow', 'snow','snow', 'snow',
                                'forestgreen','palegreen','powderblue','cadetblue',
                                'forestgreen','cadetblue','powderblue','palegreen',
                                'goldenrod','gold','goldenrod','gold',
                                'goldenrod','gold',
                                'goldenrod','gold',
                                'dodgerblue','deepskyblue','lightskyblue','blue',
                                'dodgerblue','deepskyblue','lightskyblue','darkorange',
                                'goldenrod','gold','goldenrod','gold'
                               ])
    return CM, colours


Confusion Matrix Chart

In [25]:
def cf_v_bar(CM, colours, label_list, sort_order, w_factor, h_factor, sf):
    bar = alt.Chart(CM.loc[CM['label'].isin(label_list)]).mark_bar(size=w_factor*sf).encode(
        y=alt.Y('sum(value)', stack='normalize', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort=sort_order),
        tooltip=['value']
    ).properties(width=w_factor*sf, height=h_factor*sf) 
    
    return bar

def cf_h_bar(CM, colours, label_list, sort_order, w_factor, h_factor, sf):
    bar = alt.Chart(CM.loc[CM['label'].isin(label_list)]).mark_bar(size=h_factor*sf).encode(
        x=alt.X('sum(value)', stack='normalize', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort=sort_order),
        tooltip=['value']
    ).properties(width=w_factor*sf, height=h_factor*sf) 
    
    return bar


def cf_text(CM, label, format, font_size, w_factor, dy_factor, sf):
    text = alt.Chart(CM.loc[CM['label']==label]).mark_text(fontSize=font_size, color='black').encode(
        text=alt.Text('sum(value)', format=format)
    ).properties(width=w_factor*sf, height=w_factor*sf) 

    return text


def confusion_matrix_chart(Yy, Yn, Ny, Nn):
    
    # Scaling factor
    sf = 15  
    
    
    # Derive chart data
    CM, colours = confusion_matrix_data(Yy, Yn, Ny, Nn)
    
    
    # FIRST ROW

    text_Yy = cf_text(CM, label='Yy', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    bar_Y = cf_v_bar(CM, colours,
                     label_list=['n|Y','y|Y'], sort_order='descending', 
                     w_factor=2, h_factor=10, sf=sf)
    
    text_Yn = cf_text(CM, label='Yn', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    # SECOND ROW
    
    bar_y = cf_h_bar(CM, colours,
                     label_list=['Y|y','N|y'], sort_order='ascending', 
                     w_factor=10, h_factor=2, sf=sf)
    
    bar_a = cf_v_bar(CM, colours,
                     label_list=['ACC','ACC-'], sort_order='ascending', 
                     w_factor=2, h_factor=2, sf=sf)
    
    bar_n = cf_h_bar(CM, colours,
                     label_list=['N|n','Y|n'], sort_order='ascending', 
                     w_factor=10, h_factor=2, sf=sf)
    
    # THIRD ROW
    
    text_Ny = cf_text(CM, label='Ny', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    bar_N = cf_v_bar(CM, colours,
                     label_list=['n|N','y|N'], sort_order='descending', 
                     w_factor=2, h_factor=10, sf=sf)
    
    text_Nn = cf_text(CM, label='Nn', format='.0f', font_size=36, 
                      w_factor=10, dy_factor=5, sf=sf)

    
    # FRAMING BARS
    
    # Left bar
    bar_L = cf_v_bar(CM, colours,
                     label_list=['Y*','N*'], sort_order='ascending', 
                     w_factor=2, h_factor=25, sf=sf)
    
    # Top left corner bar
    bar_0 = cf_v_bar(CM, colours,
                     label_list=['F1','F1-'], sort_order='ascending', 
                     w_factor=2, h_factor=2, sf=sf)
    
    # Top bar
    bar_T = cf_h_bar(CM, colours,
                     label_list=['y*','n*'], sort_order='descending', 
                     w_factor=25, h_factor=2, sf=sf)
    
    # Top right corner text
    text_R = cf_text(CM, label='OR', format='.1f', font_size=12, w_factor=2, dy_factor=1, sf=sf)

    # Right bar
    bar_R = alt.Chart(CM.loc[
        CM['label'].isin(['1','OR_lci90','OR_lci95','OR_lci99','OR','OR_uci90','OR_uci95','OR_uci99'])]
                     ).mark_circle(opacity=0.8, stroke='black', strokeWidth=1, size=10*sf).encode(
        y=alt.Y('value', title=None, axis=None),
        color=alt.Color('label', scale = colours, legend=None),
        order=alt.Order('label', sort='descending'),
        tooltip=['value']
    ).properties(width=2*sf, height=25*sf) 


    # BUILD COMBINED CHART
    
    return (bar_0 | bar_T | text_R) & ( 
        bar_L | ( ( (text_Yy) | bar_Y | text_Yn) & (bar_y | bar_a | bar_n) & (text_Ny | bar_N | text_Nn) ) 
     | bar_R )


In [26]:
Yy, Yn, Ny, Nn = confusion_matrix.get_matrix()
confusion_matrix_chart(Yy, Yn, Ny, Nn)