In [None]:
import pandas as pd
from sklearn import metrics

patient_data_path = '/home/tstrebel/repos/umich-mads-capstone-project/assets/rsna-patient-details.csv'

In [None]:
df_patients = pd.read_csv(patient_data_path, index_col='index')

df_patients['age_group'] = pd.cut(df_patients.patient_age, 
                                  [0, 1, 5, 12, 18, 44, 64, 79, np.inf], 
                                  labels=['Infant 1 year-old', 
                                          'preschool (2-5)', 
                                          'Child (6-12)', 
                                          'Adolescent (13-18)', 
                                          'Adult (19-44)', 
                                          'Middle age (45-64)', 
                                          'Aged (65-79)', 'Aged 80+'])
df_patients['view_position'] = df_patients['view_position'].map({'AP':'Anterior/Posterior', 'PA':'Posterior/Anterior'})

df_patients['densenet_prediction'] = (df_patients['densenet_predict_proba'] >= TL_THRESH).astype(int)

df_patients.head()

In [None]:
pd.set_option('display.precision', 3)

df_test = df_patients[df_patients['split'] == 'test']

def get_metrics_by_group(data, group):
    n = data[group].value_counts().rename('n-patients')
    n_neg = data[data.target == 0][group].value_counts().rename('n-negative')
    n_pos = data[data.target == 1][group].value_counts().rename('n-positive')
    n = n.to_frame().join([n_neg, n_pos])
    
    results = []
    
    def get_metrics(grp, metric_fn):
        if len(grp) > 0:
            return metric_fn(grp['target'], grp['densenet_prediction'])
        else:
            return 0
    
    for metric_name, metric_fn in [('Precision', partial(metrics.precision_score, zero_division=0)), 
                                   ('Recall', metrics.recall_score), 
                                   ('F1', metrics.f1_score), 
                                   ('Accuracy', metrics.accuracy_score)]:
        metric_result = (data.groupby(group)
                  .apply(get_metrics, metric_fn)).rename(metric_name)
        results.append(metric_result)
        
    def get_average_precision_score(grp):
        if len(grp) > 0:
            return metrics.average_precision_score(grp['target'], grp['densenet_predict_proba'])
        else:
            return 0
        
    avg_precision = (data.groupby(group)
              .apply(get_average_precision_score).rename('Average Precision'))
    results.append(avg_precision)
    
    return n.join(results).sort_index()

def get_f1_adjusted_metrics_by_group(data, group):
    n = data[group].value_counts().rename('n-patients')
    n_neg = data[data.target == 0][group].value_counts().rename('n-negative')
    n_pos = data[data.target == 1][group].value_counts().rename('n-positive')
    n = n.to_frame().join([n_neg, n_pos])
    
    results = []
    
    def get_f1_adj_trhes(grp):
        if len(grp) > 0:
            precision, recall, pr_thresholds = metrics.precision_recall_curve(grp['target'], grp['densenet_predict_proba'])
            f1_scores = (2 * precision * recall) / ((precision + recall) + 1e-8)

            ix = np.argmax(f1_scores)
            best_thresh = pr_thresholds[ix]
            return best_thresh
        else:
            return 0
        
    adj_thresh = data.groupby(group).apply(get_f1_adj_trhes).rename('Adjusted Threshold')
    adj_thresh_dict = adj_thresh.to_dict()
    
    def get_metrics(grp, metric_fn):
        if len(grp) > 0:
            best_thresh = adj_thresh_dict[grp.name]
            
            adj_prediction = (grp['densenet_predict_proba'] >= best_thresh).astype(int)
            return metric_fn(grp['target'], adj_prediction)
        else:
            return 0
    
    for metric_name, metric_fn in [('Precision', partial(metrics.precision_score, zero_division=0)), 
                                   ('Recall', metrics.recall_score), 
                                   ('F1', metrics.f1_score),
                                   ('Accuracy', metrics.accuracy_score)]:
        metric_result = (data.groupby(group)
                  .apply(get_metrics, metric_fn)).rename(metric_name)
        results.append(metric_result)
        
    def get_average_precision_score(grp):
        if len(grp) > 0:
            
            return metrics.average_precision_score(grp['target'], grp['densenet_predict_proba'])
        else:
            return 0
        
    avg_precision = (data.groupby(group)
              .apply(get_average_precision_score).rename('Average Precision'))
    results.append(avg_precision)
    results.append(adj_thresh)
    
    return n.join(results).sort_index()

def print_metrics(data, adjust_by=None):
    n = len(data)
    n_neg = len(data[data.target == 0])
    n_pos = len(data[data.target == 1])
    
    def get_f1_adj_trhes(grp):
        if len(grp) > 0:
            precision, recall, pr_thresholds = metrics.precision_recall_curve(grp['target'], grp['densenet_predict_proba'])
            f1_scores = (2 * precision * recall) / ((precision + recall) + 1e-8)

            ix = np.argmax(f1_scores)
            best_thresh = pr_thresholds[ix]
            return best_thresh
        else:
            return 0
    
    if adjust_by:
        adj_thresh_dict = (data
                           .groupby(adjust_by)
                           .apply(get_f1_adj_trhes)
                           .rename('Adjusted Threshold')
                           .to_dict())
        adj_thresh = data[adjust_by].astype(str).map(adj_thresh_dict)
        preds = (data['densenet_predict_proba'] >= adj_thresh).astype(int)
    else:
        preds = data['densenet_prediction']
        
    targets = data['target']
    precision = metrics.precision_score(targets, preds)
    recall = metrics.recall_score(targets, preds)
    f1 = metrics.f1_score(targets, preds)
    accuracy = metrics.accuracy_score(targets, preds)
    
    print('N Samples:\t{:,}'.format(n))
    print('N Negative:\t{:,}'.format(n_neg))
    print('N Positive:\t{:,}'.format(n_pos))
    print()
    print('Precision:\t{:.3f}'.format(precision))
    print('Recall:\t{:.3f}'.format(recall))
    print('F1:\t\t{:.3f}'.format(f1))
    print('Accuracy:\t{:.3f}'.format(accuracy))

In [None]:
get_metrics_by_group(df_test, group='patient_sex')

In [None]:
get_f1_adjusted_metrics_by_group(df_test, group='patient_sex')

In [None]:
get_metrics_by_group(df_test, group='view_position')

In [None]:
get_f1_adjusted_metrics_by_group(df_test, group='view_position')

In [None]:
get_metrics_by_group(df_test, group='age_group')

In [None]:
get_f1_adjusted_metrics_by_group(df_test, group='age_group')

In [None]:
print_metrics(df_test)

In [None]:
print_metrics(df_test, adjust_by='age_group')