In [1]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, auc
from scipy import interp
import numpy as np
from itertools import cycle
from sklearn.metrics import recall_score


In [3]:
def _compute_roc_multi(y_true, y_pred, n_classes=5):
    """
    Compute ROC for multiclass
    :param y_true: true label
    :param y_pred: predict scores
    :param n_classes: num of classes
    :return: AUC of ROC, false positive rate, true positive rate
    """
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_pred[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_pred.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    return roc_auc, fpr, tpr


def _plot_roc_multi(y_true, y_pred, n_classes=5):
    """
    Plot roc for multiclass
    :param y_true: true label
    :param y_pred: predict scores
    :param n_classes: num of classes
    :return: roc
    """
    roc_auc, fpr, tpr = _compute_roc_multi(y_true, y_pred, n_classes)
    lw = 2
    # Plot all ROC curves
    plt.figure(figsize=(8, 8))
    plt.plot(fpr["micro"], tpr["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]),
             color='deeppink', linestyle=':', linewidth=4)

    plt.plot(fpr["macro"], tpr["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["macro"]),
             color='navy', linestyle=':', linewidth=4)

    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.show()

def _compute_best_point(fpr, tpr, thresholds):
    """
    Compute the nearest point to the left corner.
    :param fpr: False positive rate
    :param tpr: True positive rate
    :param thresholds: Threshold
    :return: sensitivity, specificity, threshold
    """
    dis = fpr ** 2 + (1 - tpr) ** 2
    best_index = np.argmin(dis)
    sensitivity = 1 - fpr[best_index]
    specificity = tpr[best_index]
    threshold = thresholds[best_index]
    return sensitivity, specificity, threshold

def _compute_best_threshold(y_true, y_pred):
    thresholds = []
    for i in range(np.shape(y_pred)[1]):
        fpr, tpr, thr = roc_curve(y_true[:, i], y_pred[:, i])
        sensitivity, specificity, threshold = _compute_best_point(fpr, tpr, thr)
        print("sensitivity:{} specificity:{} threshold:{}".format(sensitivity, specificity, threshold))
        thresholds.append(threshold)
    return thresholds
    
def sigmoid(array):
    return 1 / (1 + np.exp(-1 * array))


In [4]:
static_tags_name = ['nodule', 'borderline', 'spinal fusion', 'cardiac shadow', 'interstitial', 'pulmonary congestion',
                    'technical quality of image unsatisfactory', 'bronchiectasis', 'cervical vertebrae',
                    'hypoinflation', 'medical device', 'prominent', 'mass', 'breast implants', 'calcinosis',
                    'aortic aneurysm', 'aorta, thoracic', 'lower lobe', 'scattered', 'left', 'lung, hyperlucent',
                    'pneumoperitoneum', 'enlarged', 'foreign bodies', 'epicardial fat', 'reticular', 'abnormal',
                    'irregular', 'obscured', 'large', 'diaphragm', 'right', 'breast', 'pulmonary edema',
                    'hyperostosis, diffuse idiopathic skeletal', 'airspace disease', 'stents', 'mild', 'volume loss',
                    'shift', 'sulcus', 'humerus', 'lucency', 'blunted', 'osteophyte', 'blood vessels',
                    'lumbar vertebrae', 'flattened', 'tortuous', 'small', 'healed', 'hypertension, pulmonary',
                    'bone diseases, metabolic', 'trachea', 'atherosclerosis', 'mediastinum', 'coronary vessels', 'lung',
                    'chronic', 'multiple', 'ribs', 'pulmonary disease, chronic obstructive', 'apex', 'hilum',
                    'spondylosis', 'diffuse', 'paratracheal', 'pneumothorax', 'clavicle', 'retrocardiac', 'lymph nodes',
                    'bronchovascular', 'azygos lobe', 'pulmonary emphysema', 'granulomatous disease',
                    'calcified granuloma', 'normal', 'thoracic vertebrae', 'funnel chest', 'thorax', 'aorta',
                    'adipose tissue', 'anterior', 'arthritis', 'emphysema', 'fractures, bone', 'hernia, hiatal',
                    'implanted medical device', 'sutures', 'granuloma', 'pleura', 'thickening', 'cysts', 'upper lobe',
                    'middle lobe', 'pleural effusion', 'deformity', 'contrast media', 'pulmonary atelectasis',
                    'hyperdistention', 'pericardial effusion', 'spine', 'mastectomy', 'surgical instruments',
                    'nipple shadow', 'heart', 'streaky', 'blister', 'catheters, indwelling', 'bilateral', 'neck',
                    'cavitation', 'density', 'scoliosis', 'pulmonary artery', 'round', 'opacity',
                    'lung diseases, interstitial', 'sternum', 'heart ventricles', 'lingula', 'aortic valve',
                    'heart failure', 'heart atria', 'sarcoidosis', 'bullous emphysema', 'sclerosis',
                    'costophrenic angle', 'kyphosis', 'hydropneumothorax', 'consolidation', 'dislocations', 'markings',
                    'abdomen', 'tube, inserted', 'no indexing', 'pneumonectomy', 'posterior', 'patchy',
                    'diaphragmatic eventration', 'pulmonary fibrosis', 'pneumonia', 'cardiomegaly', 'focal', 'cicatrix',
                    'elevated', 'infiltrate', 'moderate', 'degenerative', 'base', 'trachea, carina', 'severe',
                    'bronchi', 'pulmonary alveoli', 'shoulder', 'cystic fibrosis']

y_true = np.load('./results/{}'.format('y_true_m-20180330-DenseNet201-BCE.pth.tar_val.npz'))['arr_0']
y_pred = np.load('./results/{}'.format('y_pred_m-20180330-DenseNet201-BCE.pth.tar_val.npz'))['arr_0']
y_pred = sigmoid(y_pred)


In [5]:
threholds = _compute_best_threshold(y_true, y_pred)


sensitivity:0.5727969348659003 specificity:0.76 threshold:0.01010719034820795
sensitivity:0.5347091932457786 specificity:0.5714285714285714 threshold:0.01978769339621067
sensitivity:0.6525735294117647 specificity:0.6666666666666666 threshold:0.0026810213457792997
sensitivity:0.8141263940520447 specificity:0.8888888888888888 threshold:0.043309565633535385
sensitivity:0.676923076923077 specificity:0.7037037037037037 threshold:0.05944222956895828
sensitivity:0.9291044776119403 specificity:0.7272727272727273 threshold:0.164999857544899
sensitivity:0.7696629213483146 specificity:0.5384615384615384 threshold:0.18877553939819336
sensitivity:0.7113970588235294 specificity:1.0 threshold:4.798668305738829e-05
sensitivity:0.49816176470588236 specificity:1.0 threshold:0.0026841468643397093
sensitivity:0.711764705882353 specificity:0.7297297297297297 threshold:0.5066084265708923
sensitivity:0.7845303867403315 specificity:0.5 threshold:0.0013272586511448026
sensitivity:0.638095238095238 specificity:

In [38]:
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=0)


y_true = np.load('./results/{}'.format('y_true_m-20180330-DenseNet201-BCE.pth.tar_val.npz'))['arr_0']
y_pred = np.load('./results/{}'.format('y_pred_m-20180330-DenseNet201-BCE.pth.tar_val.npz'))['arr_0']

def get_top_k_index(array, k):
    t_array = [(array[i], i) for i in range(len(array))]
    t_array = sorted(t_array, key=lambda x: x[0], reverse=True)
    return np.array([x[1] for x in t_array[:k]])


def match(array_true, array_pred):
    right = 0
    for i in array_pred:
        if array_true[i] == 1:
            right += 1
    return right


def get_correct_num(y_true, y_pred, k):
    right = 0
    for i in range(np.shape(y_true)[0]):
        right += match(y_true[i], get_top_k_index(y_pred[i], k))
    return right

def recall_k(right, y_true, k):
    return right / np.sum(np.sum(y_true, axis=1), axis=0)

def precision_k(right, y_true, k):
    return right / (len(y_true) * k)

def F1_k(y_true, y_pred, k):
    right = get_correct_num(y_true, y_pred, k)
    recall = recall_k(right, y_true, k)
    precision = precision_k(right, y_true, k)
    f1 = 2 * recall * precision / (recall + precision)
    print('Top {} results - Recall {} - Precision {} - F1 {}'.format(k, recall, precision, f1))

print("Size {}".format(len(y_true)))
F1_k(y_true, y_pred, 5)
F1_k(y_true, y_pred, 10)
F1_k(y_true, y_pred, 20)

Size 547
Top 5 results - Recall 0.23741258741258742 - Precision 0.24826325411334552 - F1 0.24271671134941913
Top 10 results - Recall 0.3583916083916084 - Precision 0.18738574040219377 - F1 0.24609843937575032
Top 20 results - Recall 0.5377622377622377 - Precision 0.14058500914076782 - F1 0.22289855072463766


In [None]:
for i in range(len(y_pred)):
    print("Pred: {}".format([static_tags_name[tag] for tag in get_top_k_index(y_pred[i], 10)]))
    print("Truth:{}".format([static_tags_name[tag] for tag in get_top_k_index(y_true[i], 10)]))
    print()


In [6]:
y_pred = y_pred > threholds
def mean_recall(y_true, y_pred):
    recall_array = []
    for i in range(np.shape(y_true)[1]):
        recall_array.append(recall_score(y_true[:, i], y_pred[:, i]))
    mean_recall = np.sum(recall_array) / np.shape(y_true)[1]
    return mean_recall

def overall_recall(y_true, y_pred):
    match = np.sum((y_true == 1) & (y_pred == 1))
    return match / np.sum(y_true)

def overall_precision(y_true, y_pred):
    match = np.sum((y_true == 1) & (y_pred == 1))
    return match / np.sum(y_pred)

print(mean_recall(y_true, y_pred))
print(overall_recall(y_true, y_pred))

0.5136699643142887
0.6230769230769231


In [26]:
y_test_true = np.load('./results/{}'.format('y_true_m-20180330-DenseNet201-BCE.pth.tar.npz'))['arr_0']
y_test_pred = np.load('./results/{}'.format('y_pred_m-20180330-DenseNet201-BCE.pth.tar.npz'))['arr_0']
y_test_pred = sigmoid(y_test_pred)
y_test_pred = y_test_pred > np.array(threholds)

print(mean_recall(y_test_true, y_test_pred))
print(overall_recall(y_test_true, y_test_pred))
print(overall_precision(y_test_true, y_test_pred))


0.6059518169736098
0.53713163064833
0.041899098878195304


In [27]:
def get_index(array, value):
    return [i for i, v in enumerate(array) if v == value]

for i in range(len(y_test_pred)):
    print("Pred:{}".format(get_index(y_test_pred[i], 1)))
    print("Truth:{}".format(get_index(y_test_true[i], 1)))
    print()


Pred:[10, 32, 72]
Truth:[76]

Pred:[0, 8, 10, 13, 14, 16, 26, 27, 32, 36, 41, 46, 48, 50, 52, 54, 59, 60, 64, 66, 70, 72, 77, 78, 79, 80, 81, 83, 85, 86, 89, 96, 97, 101, 102, 103, 105, 108, 112, 113, 114, 115, 118, 119, 121, 126, 128, 131, 133, 139, 140, 143, 148, 150, 154]
Truth:[76]

Pred:[0, 8, 10, 13, 14, 20, 27, 32, 46, 47, 49, 50, 59, 61, 62, 64, 70, 72, 73, 75, 78, 79, 84, 88, 89, 97, 99, 102, 104, 108, 110, 111, 114, 115, 120, 121, 125, 140, 143, 144, 154]
Truth:[76]

Pred:[0, 2, 7, 8, 10, 12, 13, 14, 16, 19, 20, 27, 31, 32, 36, 38, 44, 45, 46, 47, 48, 49, 50, 52, 54, 55, 59, 60, 61, 62, 63, 64, 66, 70, 72, 74, 75, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 96, 97, 99, 101, 102, 103, 104, 105, 108, 110, 111, 112, 113, 114, 115, 118, 119, 120, 121, 123, 125, 126, 127, 128, 131, 133, 136, 137, 139, 140, 143, 144, 148, 150, 154, 155]
Truth:[14, 16, 31, 37, 39, 53, 77, 148, 151]

Pred:[0, 8, 10, 13, 14, 16, 27, 32, 46, 50, 59, 60, 64, 70, 72, 74, 75, 77, 78, 79, 8


Pred:[0, 10, 13, 14, 20, 23, 27, 32, 46, 47, 49, 50, 60, 62, 64, 66, 70, 72, 78, 79, 81, 84, 85, 88, 89, 96, 97, 99, 102, 103, 104, 105, 108, 110, 111, 113, 114, 115, 118, 125, 133, 136, 137, 140, 143, 154]
Truth:[30, 31, 33, 37, 47, 57, 99, 109, 113]

Pred:[0, 2, 7, 8, 10, 12, 13, 14, 16, 18, 19, 20, 27, 29, 30, 31, 32, 43, 45, 46, 47, 49, 50, 52, 54, 55, 59, 60, 61, 62, 64, 66, 70, 72, 73, 74, 75, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 93, 96, 97, 99, 101, 102, 103, 104, 105, 107, 108, 109, 110, 111, 112, 113, 114, 115, 118, 120, 121, 125, 127, 128, 133, 136, 137, 140, 143, 144, 148, 154, 155]
Truth:[4, 6, 20, 31, 49, 57, 58, 84, 90, 91, 94, 95, 116, 132, 144]

Pred:[10, 13, 27, 32, 46, 47, 50, 62, 66, 72, 78, 88, 89, 97, 99, 102, 104, 113, 115, 154]
Truth:[76]

Pred:[0, 7, 10, 13, 14, 20, 23, 27, 32, 45, 46, 47, 49, 50, 59, 60, 61, 62, 64, 66, 70, 72, 74, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 91, 96, 97, 99, 102, 103, 104, 105, 108, 110, 111, 112, 113, 114, 115, 