In [None]:
import pandas as pd
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import json
from IPython.display import display, Markdown
from scipy.stats import pearsonr, spearmanr, ks_2samp, mannwhitneyu, ttest_ind

import sys
sys.path.append('../')

from utilities import data
from evalutils.roc import get_bootstrapped_roc_ci_curves
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

## directory where results are
EXPERIMENT_DIR = f"W:/experiments/lung-malignancy-fairness-shaurya"
CHANSEY_NLST_PREDS = f"{EXPERIMENT_DIR}/nlst"

TEAMS_DIR = "C:/Users/shaur/OneDrive - Radboudumc/Documents - Master - Shaurya Gaur/General/Malignancy-Estimation Results"
LOCAL_NLST_PREDS = f"{TEAMS_DIR}/nlst" ## Comment out if not using Teams backup (aka Chansey is up :)
NLST_PREDS = LOCAL_NLST_PREDS

In [None]:
nlst_preds_nodule = pd.read_csv(f"{NLST_PREDS}/nlst_demov4_allmodels_cal.csv")
nlst_preds_nodule.info()

In [None]:
with open(f'{NLST_PREDS}/nlst_demo_v4_cols.json') as json_data:
    nlst_democols = json.load(json_data)
    json_data.close()

nlst_democols['num'].pop('nodule')

In [None]:

nlst_democols

In [None]:
nlst_preds = data.prep_nlst_preds(nlst_preds_nodule, scanlevel=True, tijmen=False, sybil=True)
nlst_preds.info()

In [None]:
nlst_policy_thresholds = pd.read_csv(f"{NLST_PREDS}/policy-thresholds-{len(nlst_preds)}.csv", index_col=0)
nlst_policy_thresholds

In [None]:
THRESHOLD = 'Brock'

In [None]:
sybil_worse_df = pd.read_csv(f"{NLST_PREDS}/sybil_worse.csv")
sybil_worse_df

Sybil's worse columns are from our [scan level AUC results](./plot_nlst_scanlevel.ipynb).

PanCan columns from [Radiopaedia](https://radiopaedia.org/articles/brock-model-for-pulmonary-nodules).

The gender differential columns are from the "What about men vs women?" section of the [training info demographic splits notebook](./nlst/nlst_traininfo.ipynb).

In [None]:
pancan_cols = ['Age', 'Gender', 'race', 'FamilyHistoryLungCa', 'Emphysema', 'Diameter [mm]', 'NoduleInUpperLung', 'PartSolid', 'NoduleCounts', 'Spiculation']
train_diff_cols = ['pipe', 'cigar', 'Married', 'wrknomask', 'wrkfarm', 'smokework', 'diaghear', 'diagasbe', 'smokelive', 'diagpneu', 'diagchro', 'diagcopd', 'PersonalCancerHist', 'cigsmok', 'cancbrea', 'pkyr', 'smokeage']
relevant_cols = list(set(list(sybil_worse_df['col']) + pancan_cols + train_diff_cols + nlst_democols['cat']['lungcanc'] + nlst_democols['cat']['nodule']))

In [None]:
relevant_cols

In [None]:
nlst_preds[['sybil_year1']].info()

In [None]:
MODEL_TO_COL = {
    "Venkadesh": "DL_cal",
    # "de Haas Combined": "Thijmen_mean_cal",
    "de Haas Local": "Thijmen_local_cal",
    "de Haas Global (hidden nodule)": "Thijmen_global_hidden",
    "de Haas Global (w/nodule)": "Thijmen_global_show_cal",
    "Sybil": "sybil_year1",
    "PanCan2b": "PanCan2b",
}

In [None]:
nlst_preds['Sybil_pred_label'] = (nlst_preds[MODEL_TO_COL['Sybil']] > nlst_policy_thresholds.loc['Sybil year 1', THRESHOLD]).astype(int).to_numpy()
nlst_preds['Sybil_PanCan_diff'] = nlst_preds[MODEL_TO_COL['Sybil']] - nlst_preds['PanCan2b']

In [None]:
false_positives = nlst_preds.query("label == 0 and Sybil_pred_label == 1")
false_negatives = nlst_preds.query("label == 1 and Sybil_pred_label == 0")

true_positives = nlst_preds.query("label == 1 and Sybil_pred_label == 1")
true_negatives = nlst_preds.query("label == 0 and Sybil_pred_label == 0")

## Difference between TP/FP/TN/FN

In [None]:
result_sets = {
    "FP": false_positives,
    "FN": false_negatives,
    "TP": true_positives,
    "TN": true_negatives, 
}

### utility code

In [None]:
def combine_col_dfs(cols=nlst_democols['cat'], df_func=pd.DataFrame, dfsets=result_sets, dispdf=False):
    splitdfs = []
    for cat in cols:
        if dispdf: display(Markdown(f"### {cat}"))
        
        for c in cols[cat]:
            df = df_func(c, dfsets)
            if dispdf: display(df)

            df['category'] = [cat] * len(df)
            df['attribute'] = [c] * len(df)
            df['value'] = df.index.values
            
            dfcols = df.columns.tolist()
            dfcols = dfcols[-3:] + dfcols[:-3]
            df = df[dfcols]
            df.reset_index(inplace=True, drop=True)
            df.sort_values(by='value', ascending=True, inplace=True)

            splitdfs.append(df)

    return pd.concat(splitdfs, axis=0, ignore_index=True)

In [None]:
def cat_dist_df(c='Gender', dfsets=result_sets):
    dfdict = {}
    for m in dfsets:
        dfdict[f"{m}_freq"] = dfsets[m][c].value_counts(normalize=False, dropna=False).astype(int)
        dfdict[f"{m}_norm"] = 100 * dfsets[m][c].value_counts(normalize=True, dropna=False).round(6)
        dfdict[f"{m}_freq"].fillna(0, inplace=True)
        dfdict[f"{m}_norm"].fillna(0, inplace=True)
    
    for i, m1 in enumerate(dfsets):
        for j, m2 in enumerate(dfsets):
            if j > i:
                # dfdict[f"diff_freq_{m1}_{m2}"] = (dfdict[f"{m1}_freq"] - dfdict[f"{m2}_freq"]).round(4)
                dfdict[f"diff_norm_{m1}_{m2}"] = (dfdict[f"{m1}_norm"] - dfdict[f"{m2}_norm"]).round(4)
    
    df = pd.DataFrame(dfdict).drop_duplicates()

    for m in dfsets:
        df[f"{m}_freq"] = df[f"{m}_freq"].fillna(0.0)
        df[f"{m}_norm"] = df[f"{m}_norm"].fillna(0.0)

    for i, m1 in enumerate(dfsets):
        for j, m2 in enumerate(dfsets):
            if j > i:
                # dfdict[f"diff_freq_{m1}_{m2}"] = (dfdict[f"{m1}_freq"] - dfdict[f"{m2}_freq"]).round(4)
                df[f"diff_norm_{m1}_{m2}"] = (df[f"{m1}_norm"] - df[f"{m2}_norm"]).round(4)    

    # df = pd.DataFrame(dfdict).drop_duplicates()
    return df

In [None]:
def num_dist_df(c='Gender', dfsets=result_sets):
    dfdict = {}
    for m in dfsets:
        dfdict[f"{m}"] = dfsets[m][c].describe(percentiles=[0.5]).round(4)
    
    for i, m1 in enumerate(dfsets):
        for j, m2 in enumerate(dfsets):
            if j > i:
                dfdict[f"diff_{m1}_{m2}"] = dfdict[f"{m1}"] - dfdict[f"{m2}"]
    
    df = pd.DataFrame(dfdict).drop_duplicates()
    # display(df)
    df.drop(index=['count', 'max', 'min', 'std'], inplace=True)
    return df

### differences

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, result_sets).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_FP_TP', ascending=False)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20))
cat_demo_splits.sort_values(by='diff_norm_FP_TP', ascending=True)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20)

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, result_sets).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=False).query('category == "nodule"'))
cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=True).query('category == "nodule"')

In [None]:
num_demo_splits = combine_col_dfs(nlst_democols['num'], num_dist_df, result_sets)
display(num_demo_splits.sort_values(by='diff_FP_FN', ascending=False)[num_demo_splits['attribute'].isin(relevant_cols)].head(20))
num_demo_splits.sort_values(by='diff_FP_FN', ascending=True)[num_demo_splits['attribute'].isin(relevant_cols)].head(20)

### Now with the top 100 scores that were different from PanCan

In [None]:
result_top_100_diff = {
    "FP": false_positives.sort_values(by=['Sybil_PanCan_diff'], ascending=False)[0:100],
    "FN": false_negatives.sort_values(by=['Sybil_PanCan_diff'], ascending=False)[0:100],
}

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, result_top_100_diff).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=False)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20))
cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=True)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20)

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, result_top_100_diff).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=False).query("category == 'nodule'"))
cat_demo_splits.sort_values(by='diff_norm_FP_FN', ascending=True).query("category == 'nodule'")

In [None]:
num_demo_splits = combine_col_dfs(nlst_democols['num'], num_dist_df, result_top_100_diff)
display(num_demo_splits.sort_values(by='diff_FP_FN', ascending=False)[num_demo_splits['attribute'].isin(relevant_cols)].head(20))
num_demo_splits.sort_values(by='diff_FP_FN', ascending=True)[num_demo_splits['attribute'].isin(relevant_cols)].head(20)

## True Positives

In [None]:
true_positives = nlst_preds.query("label == 1 and Sybil_pred_label == 1")
true_positives['abs_Sybil_PanCan_diff'] = true_positives['Sybil_PanCan_diff'].apply(abs)
tp_top20 = true_positives.sort_values(by=['sybil_year1','abs_Sybil_PanCan_diff',  'PanCan2b'], ascending=[False, True, False]).head(20)
tp_top20[
    ['PatientID', 'label', 'Sybil_pred_label', 'PanCan2b', 'sybil_year1', 'abs_Sybil_PanCan_diff'] + nlst_democols['cat']['nodule'] + ['Emphysema', 'Gender', 'race', 'BMI', 'Age']
]

In [None]:
tp_top20_series = tp_top20['SeriesInstanceUID'].tolist()

In [None]:
tp_top20_save = nlst_preds_nodule[nlst_preds_nodule['SeriesInstanceUID'].isin(tp_top20_series)]

In [None]:
# tp_top20_save.to_csv(f"{CHANSEY_NLST_PREDS}/sybil_tp_brock_top20.csv", index=False)
tp_top20_save.to_csv(f"{LOCAL_NLST_PREDS}/sybil_tp_brock_top20.csv", index=False)

## False Positives

In [None]:
false_positives = nlst_preds.query("label == 0 and Sybil_pred_label == 1")

In [None]:
sns.histplot(false_positives, x='Sybil_PanCan_diff', hue='Gender', common_norm=False, element='bars', kde=True, stat='density')

### Select FPs for Attention Visualization

In [None]:
fp_relevant_cols = [
    'sybil_year1', 'PanCan2b',
    'Gender', 'Age',  'race', 'BMI', 'Emphysema', 'NoduleInUpperLung', 
    'wrkasbe', 'wrkchem',  'wrkweld',
    'diagadas', 'diagcopd', 'diaghear', 'diagpneu',
    # 'diaghype', 'pkyr', 'cigar', 'pipe'
]
fp_top50 = false_positives.query('PanCan2b < 0.06').sort_values(by='Sybil_PanCan_diff', ascending=False).head(50)
display(fp_top50[fp_relevant_cols])

For further analysis. Not for getting attentions. For the above dataframe ONLY!!!!

In [None]:
# interesting_idxs = [2263, 3985, 929, 655, 2659, 2439, 1863, 726, 2459, 2676, 4627, 1832]

In [None]:
# interesting_series = [fp_top50.loc[i, 'SeriesInstanceUID'] for i in interesting_idxs]

Get top 50 series IDs.

In [None]:
fp_top50_series = fp_top50['SeriesInstanceUID'].tolist()

In [None]:
fp_top50_save = nlst_preds_nodule[nlst_preds_nodule['SeriesInstanceUID'].isin(fp_top50_series)]

In [None]:
# fp_top50_save.to_csv(f"{CHANSEY_NLST_PREDS}/sybil_fp_brock_top50.csv", index=False)
fp_top50_save.to_csv(f"{LOCAL_NLST_PREDS}/sybil_fp_brock_top50.csv", index=False)

### Gender differences

In [None]:
gender_fps = {
    "M": false_positives.query("Gender == 1"),
    "F": false_positives.query("Gender == 2"),
}

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, gender_fps).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=False)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20))
cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=True)[cat_demo_splits['attribute'].isin(relevant_cols)].head(20)

In [None]:
# num_demo_splits = combine_col_dfs(nlst_democols['num'], num_dist_df, gender_fps)
# display(num_demo_splits.sort_values(by='diff_M_F', ascending=False)[num_demo_splits['attribute'].isin(relevant_cols)].head(20))
# num_demo_splits.sort_values(by='diff_M_F', ascending=True)[num_demo_splits['attribute'].isin(relevant_cols)].head(20)

## False Negatives

In [None]:
len(false_negatives)

In [None]:
false_negatives.sort_values(by=['Sybil_PanCan_diff'], ascending=True)[[
     'PanCan2b', 'sybil_year1', 'Sybil_PanCan_diff', 
     'Age', 'Gender', 'race', 'weight', 'BMI',
     'Emphysema', 'Adenocarcinoma', 'pkyr', 'pipe', 'cigar', 
     'wrknomask', 'wrkfoun', 'wrkasbe', 'diaghype', 'diaghear'
    #  'Squamous_cell_carcinoma', 'Large_cell_carcinoma', 'diagcopd', 'NoduleInUpperLung', 'Solid'
]].head(25)

In [None]:
fn_top_25_sybil = false_negatives.sort_values(by=['Sybil_PanCan_diff'], ascending=True).head(25)

In [None]:
fn_top_25_sybil.to_csv(f"{NLST_PREDS}/sybil_fn_brock_top25.csv")

### Gender differences

In [None]:
sns.histplot(false_negatives, x='Sybil_PanCan_diff', hue='Gender', common_norm=False, element='bars', kde=True, stat='density')

In [None]:
gender_fns = {
    "M": false_negatives.query("Gender == 1"),
    "F": false_negatives.query("Gender == 2"),
}

In [None]:
cat_demo_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, gender_fns).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=False).head(30))
cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=True).head(30)

In [None]:
display(cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=False).query('category == "nodule"'))
cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=True).query('category == "nodule"')

In [None]:
display(cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=False).query('attribute == "LC_stage"'))
cat_demo_splits.sort_values(by='diff_norm_M_F', ascending=True).query('attribute == "LC_stage"')

In [None]:
num_demo_splits = combine_col_dfs(nlst_democols['num'], num_dist_df, gender_fns)
display(num_demo_splits.sort_values(by='diff_M_F', ascending=False)[num_demo_splits['attribute'].isin(relevant_cols)].head(20))
num_demo_splits.sort_values(by='diff_M_F', ascending=True)[num_demo_splits['attribute'].isin(relevant_cols)].head(20)

### Racial differences

In [None]:
sns.histplot(false_negatives, x='Sybil_PanCan_diff', hue='race', common_norm=False, element='bars', kde=True)

In [None]:
false_negatives['race'].value_counts()

In [None]:
race_fns = {
    "white": false_negatives.query("race == 1"),
    "black": false_negatives.query("race == 2"),    
}

In [None]:
cat_race_splits = combine_col_dfs(nlst_democols['cat'], cat_dist_df, race_fns).query('value != 0')
display(cat_race_splits.sort_values(by='diff_norm_white_black', ascending=False).head(30))
cat_race_splits.sort_values(by='diff_norm_white_black', ascending=True).head(30)

In [None]:
display(cat_race_splits.sort_values(by='diff_norm_white_black', ascending=False).query('category == "nodule"'))
cat_race_splits.sort_values(by='diff_norm_white_black', ascending=True).query('category == "nodule"')

In [None]:
display(cat_race_splits.sort_values(by='diff_norm_white_black', ascending=False).query('attribute == "LC_stage"'))
cat_race_splits.sort_values(by='diff_norm_white_black', ascending=True).query('attribute == "LC_stage"')

In [None]:
num_race_splits = combine_col_dfs(nlst_democols['num'], num_dist_df, race_fns)
display(num_race_splits.sort_values(by='diff_white_black', ascending=False)[num_race_splits['attribute'].isin(relevant_cols)].head(20))
num_race_splits.sort_values(by='diff_white_black', ascending=True)[num_race_splits['attribute'].isin(relevant_cols)].head(20)

In [None]:
fn_black_series = false_negatives.query('race == 2')['SeriesInstanceUID'].tolist()
len(fn_black_series)

In [None]:
fn_black_sybil = nlst_preds_nodule[nlst_preds_nodule['SeriesInstanceUID'].isin(fn_black_series)]

In [None]:
fn_black_sybil.to_csv(f"{NLST_PREDS}/sybil_fn_brock_black.csv")
# fn_black_sybil.to_csv(f"{CHANSEY_NLST_PREDS}/sybil_fn_brock_black.csv")

## What do our radiologist's notes line up with?

### False Negatives

In [None]:
fn_traits = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/ernst-sybil-fn-traits.csv")

In [None]:
fn_traits2 = fn_traits.merge(false_negatives, how='left', on='SeriesInstanceUID', copy=False, suffixes=('', ''))

### False Positives

In [None]:
fp_top50 = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/sybil_fp_brock_top50.csv")

In [None]:
fp_traits = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/ernst-sybil-fp-tp-traits.csv").query('PredType == "FP"')

In [None]:
fp_traits['AttnOnAreaNotes'].fillna('', inplace=True)
fp_traits['OtherNotes'].fillna('', inplace=True)

In [None]:
fp_traits2 = fp_traits.merge(nlst_preds, how='left', on='SeriesInstanceUID', copy=False, suffixes=('', ''))
fp_traits2[['Reader', 'AttnOnTumor', 'Zshift', 'AbnormalAttn','AttnOnAreaNotes', 'OtherNotes'] + ['PanCan2b', 'sybil_year1', 'BMI', 'Gender', 'race', 'Emphysema', 'diaghype']]

In [None]:
fp_traits['AttnOnAreaNotes'].str.contains('normal').sum()

In [None]:
fp_traits2.query('AbnormalAttn == True')[['AttnOnTumor', 'AttnOnAreaNotes', 'Gender', 'race', 'BMI', "NoduleCounts", 'wrknomask', 'wrkfarm', 'diaghear', 'diaghype', 'diagcopd', 'diagadas', 'diagpneu', 'cigar', 'pipe'] + list(sybil_worse_df['col']) + nlst_democols['cat']['other'] + nlst_democols['cat']['disease'] + nlst_democols['cat']['work']]

In [None]:
fp_abnormal = fp_traits2.query('AbnormalAttn == True and Reader == "SG"')
len(fp_abnormal)

In [None]:
gender_key = {1: 'Male', 2: 'Female'}
race_key = {1.0: 'White American', 2.0: 'Black American', 3.0: 'Asian American'}

def weight_cat(bmi):
    if bmi < 18.5: return 'underweight'
    if bmi < 25: return 'normal weight'
    if bmi < 30: return 'overweight'
    return 'obese'

def list_workhist(row):
    workhist_key = {
        'wrkasbe': 'Asbestos',
        'wrkbaki': 'Baking',
        'wrkbutc': 'Butcher',
        'wrkchem': 'Chemicals / plastics manufacturing',
        'wrkfarm': 'Farmwork',
        'wrkfire': 'Firefighting',
        'wrkfoun': 'Foundry',
        'wrkpain': 'Painting',
        'wrksand': 'Sandblasting',
        'wrkweld': 'Welding'
    }
    out_strs = []
    for col in nlst_democols['cat']['work']:
        if col in workhist_key and row[col]:
            out_strs.append(workhist_key[col])
    
    return ", ".join(out_strs)

# for i, (index, row) in enumerate(fp_abnormal.iterrows()):
#     if row['Reader'] == 'ES': continue
#     print(f"Scan {i+1}/{len(fp_abnormal)}")
#     print("Series Instance UID:", row['SeriesInstanceUID'], "\n")
#     print(f"Scan with heatmaps: \t serie_{row['SeriesInstanceUID']}.mha")
#     print(f"Scan with NO heatmaps: \t noattn_{row['SeriesInstanceUID']}.mha")

#     if row['AttnOnAreaNotes'] or row['OtherNotes']: print(f"\nNotes from Shaurya:")
#     if row['AttnOnAreaNotes']: print(f"  - Attention: {row['AttnOnAreaNotes']}")
#     if row['OtherNotes']: print(f"  - Other: {row['OtherNotes']}")
    
#     print(f"\nGender: {gender_key[row['Gender']]} \t Race: {race_key[row['race']]} \t Body Mass Index: {row['BMI']:.2f} ({weight_cat(row['BMI'])})")
    
#     print(f"\nEmphysema in scan: {'YES' if row['Emphysema'] else 'NO'} \t Previous diagnosis of hypertension: {'YES' if row['diaghype'] else 'NO'}")
    
#     print(f"\nFamily History of Lung Cancer: {'YES' if row['FamilyHistoryLungCa'] else 'NO'}")

#     print(f"\nDid patient smoke a cigar? {'YES' if row['cigar'] else 'NO'} \t Did patient smoke a pipe? {'YES' if row['pipe'] else 'NO'}")
    
#     print(f"\nDid patient work in a field high-risk for the lungs (according to NLST) without a mask? {'YES' if row['wrknomask'] else 'NO'}")

#     workhist_str = list_workhist(row)
#     if workhist_str != "": print(f"\nWork history (1 or more years) in: {list_workhist(row)}")

#     print("\n\n")

## Do our false positives and false negatives exist in Sybil's set?

In [None]:
sybil_fp     = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/ernst-sybil-fp-tp-traits.csv").query('PredType == "FP"')
sybil_tp     = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/ernst-sybil-fp-tp-traits.csv").query('PredType == "TP"')
sybil_fn     = pd.read_csv(f"{NLST_PREDS}/Sybil-Heatmap-Info/ernst-sybil-fn-traits.csv")
sybil_splits = pd.read_csv(f"{NLST_PREDS}/sybil-nlst-splitinfo.csv")

In [None]:
sybil_train_series = set(sybil_splits['SeriesInstanceUID'].tolist())