In [4]:
import os
import glob
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)

from bokeh.io import output_notebook, reset_output, show, output_file, save
from bokeh.plotting import figure
from bokeh.layouts import column, row, gridplot
from bokeh.models import ColumnDataSource, HoverTool, Legend

from bokeh.palettes import Category10

In [87]:
# Import test sets with predictions
cfr_data_root = os.path.normpath('/mnt/obi0/andreas/data/cfr')
cfr_meta_date = '200208'
log_dir = os.path.join(cfr_data_root, 'log')
model_name_list = ['200208a4c', '200208a4cresizedpad']

# Collect predictions
df_list = []
for model in model_name_list:
    file = glob.glob(os.path.join(log_dir, model, '*_chkpt_99_predicted.parquet'))[0]
    df_file = pd.read_parquet(file)
    df_file = df_file.assign(model=model)
    df_list.append(df_file)

test_df = pd.concat(df_list, ignore_index=True).reset_index(drop=True)
print(test_df.shape)
test_df.head(2)

(2088, 55)


Unnamed: 0,mrn,study,echo_study_date,reportID,days_post_cfr,subjectid,report_number,cfr_study_date,cfr_report_date,cfr,filename,dir,datetime,fileid,institution,model,manufacturer,frame_time,number_of_frames,heart_rate,deltaX,deltaY,a2c,a2c_laocc,a2c_lvocc_s,...,a4c_laocc,a4c_lvocc_s,a4c_rv,a4c_rv_laocc,a5c,apex,other,plax_far,plax_lac,plax_laz,plax_laz_ao,plax_plax,psax_avz,psax_az,psax_mv,psax_pap,rvinf,subcostal,suprasternal,max_view,mode,rate,im_array_shape,cfr_predicted,cfr_tfr
0,22898944,49017d8b76ae238a_4903a44b32eb6df53fd314396038,2008-07-16,85058,82,3665,0820906L,2008-04-25,2008-07-14,1.368222,49017d8b76ae238a_4903a44b32eb6df53fd314396038_...,/mnt/obi0/phi/echo/npyFiles/BWH/4901/49017d8b7...,2008-07-16 16:50:22,49017d8b76ae238a_4903a44b32eb6df53fd314396038_...,BWH,200208a4c,GEMS Ultrasound,16.968699,121.0,90.0,0.035971,0.035971,8.641974e-12,1.657476e-05,6.875996e-10,...,0.0234852,4.775051e-10,4.632475e-10,2.907211e-11,4.237893e-09,3.390038e-10,4.31721e-13,1.109504e-13,9.310654e-11,1.40882e-11,4.384523e-11,1.357891e-08,5.62115e-05,7.175311e-11,2.418757e-10,0.000142,2.244174e-09,2.333829e-13,7.323142e-11,a4c,test,58.9,"[152, 229, 40]",1.369096,1.368222
1,1088541,4b7a873e7521860a_4903a583583e2c4f6357119c4aa0,2016-09-12,149060,-4,1008,E3092213,2016-09-16,2016-09-12,1.592339,4b7a873e7521860a_4903a583583e2c4f6357119c4aa0_...,/mnt/obi0/phi/echo/npyFiles/BWH/4b7a/4b7a873e7...,2016-09-12 13:27:34,4b7a873e7521860a_4903a583583e2c4f6357119c4aa0_...,BWH,200208a4c,Philips Medical Systems,33.333,74.0,75.0,0.041843,0.041843,1.232015e-07,9.767393e-09,9.907799e-09,...,3.523559e-11,3.159424e-08,8.38786e-10,2.35627e-11,9.229058e-10,2.28529e-07,7.329414e-05,8.65774e-05,3.704577e-07,1.527994e-07,2.829089e-11,5.688312e-09,4.704903e-10,2.728487e-08,1.022642e-08,1e-06,4.527188e-07,1.491174e-12,4.730189e-09,a4c,test,30.0,"[251, 335, 40]",2.214785,1.592339


In [69]:
df_list = []
for m in test_df.model.unique():

    s = {'view': list(test_df[test_df.model==m].model.unique()),
         'model': [m],
         'mrns': [len(test_df[test_df.model==m].mrn.unique())],
         'studies': [len(test_df[test_df.model==m].study.unique())],
         'videos': [len(test_df[test_df.model==m].filename.unique())],
         'unique_cfr_values': [len(test_df[test_df.model==m].cfr.unique())]}
    
    df_list.append(pd.DataFrame(s))

df_stat = pd.concat(df_list, ignore_index=True).reset_index(drop=True)
print(df_stat)

                  view                model  mrns  studies  videos  unique_cfr_values
0            200208a4c            200208a4c   272      356    1044                288
1  200208a4cresizedpad  200208a4cresizedpad   272      356    1044                288


In [30]:
def style(p):
    # Title 
    p.title.align = 'center'
    p.title.text_font_size = '11pt'
    #p.title.text_font = 'serif'

    # Axis titles
    p.xaxis.axis_label_text_font_size = '11pt'
    p.xaxis.axis_label_text_font_style = 'bold'
    p.yaxis.axis_label_text_font_size = '11pt'
    p.yaxis.axis_label_text_font_style = 'bold'

    # Tick labels
    p.xaxis.major_label_text_font_size = '11pt'
    p.yaxis.major_label_text_font_size = '11pt'
    
    return p

In [97]:
def make_plot(df):
    p=figure(title='CFR predictions',
             x_axis_label = 'True cfr',
             y_axis_label = 'Predicted cfr')
             
    plt_dict = {}
    for m, model in enumerate(df.model.unique()):
        df_model = df[df.model==model]
        datasource=ColumnDataSource(df_model)
        plt_dict[model] = p.cross(source=datasource,
                                  x='cfr_tfr',
                                  y='cfr_predicted',
                                  size=5,
                                  line_color=Category10[10][m+2],
                                  legend_label=model,
                                  name=model)
        plt_dict[model].visible = True if model == '200208a4cresizedpad' else False
        
    diag = p.line([1,3], [1,3],
                  line_color='black',
                  line_width=1,
                  line_dash='dashed')
    
    p.legend.location='top_right'
    p.legend.title='model: click to hide'
    p.legend.click_policy='hide'
    p=style(p)
    
    return p

In [98]:
# Calculate correlation coefficients
model_stats_df = pd.DataFrame()
for model in test_df.model.unique():
    x = test_df[test_df.model==model].cfr_tfr
    y = test_df[test_df.model==model].cfr_predicted
    spear = spearmanr(x, y)
    model_dict = {'model': [model],
                  'n': [len(x)],
                  'spear_c': [spear.correlation],
                  'spear_p': [spear.pvalue]}
    
    model_stats_df = pd.concat([model_stats_df, pd.DataFrame(model_dict)], ignore_index=True)

In [99]:
model_stats_df

Unnamed: 0,model,n,spear_c,spear_p
0,200208a4c,1044,0.195612,1.835443e-10
1,200208a4cresizedpad,1044,0.211238,5.382797e-12


In [100]:
pred_plot = make_plot(test_df)
reset_output()
output_notebook()
show(pred_plot)

In [101]:
reset_output()
output_file(os.path.join(log_dir, '200208a4_scatter.html'), title = '200208a4')
save(pred_plot)

'/mnt/obi0/andreas/data/cfr/log/200208a4_scatter.html'