In [1]:
import pandas as pd
import numpy as np
from sklearn import metrics
from functools import partial

TL_THRESH = .406
patient_data_path = '/home/tstrebel/repos/umich-mads-capstone-project/assets/rsna-patient-details.csv'

In [2]:
df_patients = pd.read_csv(patient_data_path, index_col='index')

df_patients['age_group'] = pd.cut(df_patients.patient_age, 
                                  [0, 1, 5, 12, 18, 44, 64, 79, np.inf], 
                                  labels=['Infant 1 year-old', 
                                          'preschool (2-5)', 
                                          'Child (6-12)', 
                                          'Adolescent (13-18)', 
                                          'Adult (19-44)', 
                                          'Middle age (45-64)', 
                                          'Aged (65-79)', 'Aged 80+'])
df_patients['view_position'] = df_patients['view_position'].map({'AP':'Anterior/Posterior', 'PA':'Posterior/Anterior'})

df_patients['densenet_prediction'] = (df_patients['rsna_densenet_proba'] >= TL_THRESH).astype(int)

df_patients.head()

Unnamed: 0_level_0,patient_id,patient_age,patient_sex,view_position,class,target,split,rsna_baseline_proba,rsna_densenet_proba,age_group,densenet_prediction
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,0004cfab-14fd-4e49-80ba-63a80b6bddd6,51,F,Posterior/Anterior,No Lung Opacity / Not Normal,0,train,0.199968,0.272411,Middle age (45-64),0
1,000924cf-0f8d-42bd-9158-1af53881a557,19,F,Anterior/Posterior,Normal,0,train,0.156917,0.031587,Adult (19-44),0
2,000db696-cf54-4385-b10b-6b16fbb3f985,25,F,Anterior/Posterior,Lung Opacity,1,train,0.688685,0.698994,Adult (19-44),1
3,000fe35a-2649-43d4-b027-e67796d412e0,40,M,Anterior/Posterior,Lung Opacity,1,train,0.943956,0.984925,Adult (19-44),1
4,001031d9-f904-4a23-b3e5-2c088acd19c6,57,M,Posterior/Anterior,Lung Opacity,1,train,0.374111,0.407364,Middle age (45-64),1


In [3]:
pd.set_option('display.precision', 2)

df_test = df_patients[df_patients['split'] == 'test']

def get_metrics_by_group(data, group):
    n = data[group].value_counts().rename('n-patients')
    n_neg = data[data.target == 0][group].value_counts().rename('n-negative')
    n_pos = data[data.target == 1][group].value_counts().rename('n-positive')
    n = n.to_frame().join([n_neg, n_pos])
    
    results = []
    
    def get_metrics(grp, metric_fn):
        if len(grp) > 0:
            return metric_fn(grp['target'], grp['densenet_prediction'])
        else:
            return 0
    
    for metric_name, metric_fn in [('Precision', partial(metrics.precision_score, zero_division=0)), 
                                   ('Recall', metrics.recall_score), 
                                   ('F1', metrics.f1_score), 
                                   ('Accuracy', metrics.accuracy_score)]:
        metric_result = (data.groupby(group)
                  .apply(get_metrics, metric_fn)).rename(metric_name)
        results.append(metric_result)
        
    def get_average_precision_score(grp):
        if len(grp) > 0:
            return metrics.average_precision_score(grp['target'], grp['rsna_densenet_proba'])
        else:
            return 0
        
    avg_precision = (data.groupby(group)
              .apply(get_average_precision_score).rename('Average Precision'))
    results.append(avg_precision)
    
    return n.join(results).sort_index()

def get_f1_adjusted_metrics_by_group(data, group):
    n = data[group].value_counts().rename('n-patients')
    n_neg = data[data.target == 0][group].value_counts().rename('n-negative')
    n_pos = data[data.target == 1][group].value_counts().rename('n-positive')
    n = n.to_frame().join([n_neg, n_pos])
    
    results = []
    
    def get_f1_adj_trhes(grp):
        if len(grp) > 0:
            precision, recall, pr_thresholds = metrics.precision_recall_curve(grp['target'], grp['rsna_densenet_proba'])
            f1_scores = (2 * precision * recall) / ((precision + recall) + 1e-8)

            ix = np.argmax(f1_scores)
            best_thresh = pr_thresholds[ix]
            return best_thresh
        else:
            return 0
        
    adj_thresh = data.groupby(group).apply(get_f1_adj_trhes).rename('Adjusted Threshold')
    adj_thresh_dict = adj_thresh.to_dict()
    
    def get_metrics(grp, metric_fn):
        if len(grp) > 0:
            best_thresh = adj_thresh_dict[grp.name]
            
            adj_prediction = (grp['rsna_densenet_proba'] >= best_thresh).astype(int)
            return metric_fn(grp['target'], adj_prediction)
        else:
            return 0
    
    for metric_name, metric_fn in [('Precision', partial(metrics.precision_score, zero_division=0)), 
                                   ('Recall', metrics.recall_score), 
                                   ('F1', metrics.f1_score),
                                   ('Accuracy', metrics.accuracy_score)]:
        metric_result = (data.groupby(group)
                  .apply(get_metrics, metric_fn)).rename(metric_name)
        results.append(metric_result)
        
    def get_average_precision_score(grp):
        if len(grp) > 0:
            
            return metrics.average_precision_score(grp['target'], grp['rsna_densenet_proba'])
        else:
            return 0
        
    avg_precision = (data.groupby(group)
              .apply(get_average_precision_score).rename('Average Precision'))
    results.append(avg_precision)
    results.append(adj_thresh)
    
    return n.join(results).sort_index()

def print_metrics(data, adjust_by=None):
    n = len(data)
    n_neg = len(data[data.target == 0])
    n_pos = len(data[data.target == 1])
    
    def get_f1_adj_trhes(grp):
        if len(grp) > 0:
            precision, recall, pr_thresholds = metrics.precision_recall_curve(grp['target'], grp['rsna_densenet_proba'])
            f1_scores = (2 * precision * recall) / ((precision + recall) + 1e-8)

            ix = np.argmax(f1_scores)
            best_thresh = pr_thresholds[ix]
            return best_thresh
        else:
            return 0
    
    if adjust_by:
        adj_thresh_dict = (data
                           .groupby(adjust_by)
                           .apply(get_f1_adj_trhes)
                           .rename('Adjusted Threshold')
                           .to_dict())
        adj_thresh = data[adjust_by].astype(str).map(adj_thresh_dict)
        preds = (data['rsna_densenet_proba'] >= adj_thresh).astype(int)
    else:
        preds = data['densenet_prediction']
        
    targets = data['target']
    precision = metrics.precision_score(targets, preds, zero_division=0)
    recall = metrics.recall_score(targets, preds)
    f1 = metrics.f1_score(targets, preds)
    accuracy = metrics.accuracy_score(targets, preds)
    average_precision = metrics.average_precision_score(targets, data['rsna_densenet_proba'])
    
    print('N Samples:\t\t{:,}'.format(n))
    print('N Negative:\t\t{:,}'.format(n_neg))
    print('N Positive:\t\t{:,}'.format(n_pos))
    print()
    print('Precision:\t\t{:.2f}'.format(precision))
    print('Recall:\t\t\t{:.2f}'.format(recall))
    print('F1:\t\t\t{:.2f}'.format(f1))
    print('Accuracy:\t\t{:.2f}'.format(accuracy))
    print('Average Precision:\t{:.2f}'.format(average_precision))

In [4]:
print_metrics(df_test)

N Samples:		2,668
N Negative:		2,067
N Positive:		601

Precision:		0.61
Recall:			0.69
F1:			0.65
Accuracy:		0.83
Average Precision:	0.70


In [10]:
print_metrics(df_test, adjust_by='age_group')

N Samples:		2,668
N Negative:		2,067
N Positive:		601

Precision:		0.59
Recall:			0.75
F1:			0.66
Accuracy:		0.83
Average Precision:	0.70


In [5]:
get_metrics_by_group(df_test, group='view_position')

Unnamed: 0,n-patients,n-negative,n-positive,Precision,Recall,F1,Accuracy,Average Precision
Anterior/Posterior,1242,784,458,0.62,0.77,0.69,0.74,0.75
Posterior/Anterior,1426,1283,143,0.58,0.41,0.48,0.91,0.51


In [6]:
get_metrics_by_group(df_test, group='patient_sex')

Unnamed: 0,n-patients,n-negative,n-positive,Precision,Recall,F1,Accuracy,Average Precision
F,1148,906,242,0.56,0.68,0.62,0.82,0.68
M,1520,1161,359,0.66,0.69,0.67,0.84,0.71


In [11]:
get_f1_adjusted_metrics_by_group(df_test, group='patient_sex')

Unnamed: 0,n-patients,n-negative,n-positive,Precision,Recall,F1,Accuracy,Average Precision,Adjusted Threshold
F,1148,906,242,0.68,0.62,0.65,0.86,0.68,0.52
M,1520,1161,359,0.62,0.73,0.67,0.83,0.71,0.36


In [12]:
get_metrics_by_group(df_test, group='age_group')

Unnamed: 0,n-patients,n-negative,n-positive,Precision,Recall,F1,Accuracy,Average Precision
Infant 1 year-old,0,0,0,0.0,0.0,0.0,0.0,0.0
preschool (2-5),8,6,2,1.0,1.0,1.0,1.0,1.0
Child (6-12),43,28,15,0.81,0.87,0.84,0.88,0.84
Adolescent (13-18),82,55,27,0.74,0.74,0.74,0.83,0.86
Adult (19-44),942,730,212,0.58,0.7,0.64,0.82,0.68
Middle age (45-64),1221,962,259,0.64,0.68,0.66,0.85,0.72
Aged (65-79),351,273,78,0.56,0.63,0.59,0.81,0.6
Aged 80+,21,13,8,0.5,0.38,0.43,0.62,0.72


In [8]:
get_f1_adjusted_metrics_by_group(df_test, group='age_group')

Unnamed: 0,n-patients,n-negative,n-positive,Precision,Recall,F1,Accuracy,Average Precision,Adjusted Threshold
Infant 1 year-old,0,0,0,0.0,0.0,0.0,0.0,0.0,0.0
preschool (2-5),8,6,2,1.0,1.0,1.0,1.0,1.0,0.74
Child (6-12),43,28,15,0.81,0.87,0.84,0.88,0.84,0.56
Adolescent (13-18),82,55,27,0.83,0.74,0.78,0.87,0.86,0.54
Adult (19-44),942,730,212,0.55,0.79,0.65,0.81,0.68,0.33
Middle age (45-64),1221,962,259,0.63,0.69,0.66,0.85,0.72,0.4
Aged (65-79),351,273,78,0.5,0.82,0.62,0.77,0.6,0.24
Aged 80+,21,13,8,0.58,0.88,0.7,0.71,0.72,0.18
