In [None]:
# Generic imports
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score, precision_score, confusion_matrix
from tqdm import tqdm

In [None]:
# Custom imports
from src.diagnosis_tools import (
    mark_hypoxemic_episodes,
    mark_abnormal_cxr,
    mark_cxr_within_48h_of_post_vent_hypoxemia,
    mark_note_within_7d,
    mark_notes_with_ml,
    text_match_risk_factors,
    diagnose_or_exclude_encounters,
    flag_echos
)
import src.plots as plots

In [None]:
# set plotting params
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
plt.style.reload_library()
rcparams = plots.stdrcparams1()
mpl.rcParams.update(rcparams)

In [None]:
basedir = Path("..")
training_location = basedir / 'Analysis_Data' / 'train_ML'
path = basedir / 'Analysis_Data' / 'MIMIC_III' / 'labeled_subset'
figure_path = basedir / "Figures"
raw_path = basedir / 'Raw_Data' / 'MIMIC_III' / 'labeled_subset'
feihong_path = basedir / 'for_Curt_MIMIC_III'

In [None]:
pf = pd.read_csv(path / "pf_ratio.csv")
pf['pf_ratio_timestamp'] = pd.to_datetime(pf['pf_ratio_timestamp'])
pf['vent_start_timestamp'] = pd.to_datetime(pf['vent_start_timestamp'])

try:
    peep = pd.read_csv(path / "peep.csv")
    peep['peep_timestamp'] = pd.to_datetime(peep['peep_timestamp'])
    
except FileNotFoundError:
    peep = None
    print("This dataset doesn't seem to have peep separately specified.")

cxr = pd.read_csv(path / "cxr.csv")
cxr['cxr_timestamp'] = pd.to_datetime(cxr['cxr_timestamp'])

notes = pd.read_csv(path / "attending_notes.csv")
notes['notes_timestamp'] = pd.to_datetime(notes['notes_timestamp'])

echo = pd.read_csv(path / "echo_reports.csv")
echo['echo_timestamp'] = pd.to_datetime(echo['echo_timestamp'])

bnp = pd.read_csv(path / "bnp.csv")
bnp['bnp_timestamp'] = pd.to_datetime(bnp['bnp_timestamp'])

final_numbers = pd.read_excel(raw_path / "ARDS_Criteria_Curt.xlsx")

final_numbers.rename(
    columns={
        'HADM_ID': 'encounter_id',
        'FINAL ARDS CURT (1=YES)': 'clinician_diagnosed'
        },
    inplace=True
    )
        
encounters_by_Curt = final_numbers[[
    'encounter_id',
    'clinician_diagnosed']] \
        .drop_duplicates() \
        .groupby('encounter_id')['clinician_diagnosed'] \
        .sum().to_frame().reset_index()

In [None]:
pf_encounters = pf['encounter_id'].drop_duplicates()

if peep is not None:
    peep_encounters = peep['encounter_id'].drop_duplicates()
else:
    peep_encounters = None
    
cxr_encounters = cxr['encounter_id'].drop_duplicates()
notes_encounters = notes['encounter_id'].drop_duplicates()
echo_encounters = echo['encounter_id'].drop_duplicates()
bnp_encounters = bnp['encounter_id'].drop_duplicates()

if peep is None:
    total_encounters = pd.merge(pf_encounters, cxr_encounters, how='outer').drop_duplicates()
    print(f"Patient encounters with PF ratios or CXRs: {len(total_encounters)}")
else:
    total_encounters = pd.merge(pf_encounters, peep_encounters, how='outer').drop_duplicates()
    total_encounters = pd.merge(total_encounters, cxr_encounters, how='outer').drop_duplicates()
    print(f"Patient encounters with PF ratios, PEEP, or CXRs: {len(total_encounters)}")
    
total_encounters = pd.merge(total_encounters, notes_encounters, how='outer').drop_duplicates()
total_encounters = pd.merge(total_encounters, echo_encounters, how='outer').drop_duplicates()
total_encounters = pd.merge(total_encounters, bnp_encounters, how='outer').drop_duplicates()

In [None]:
# These will be dictionaries whose keys will become the column names for the flags
# and the lists will be the regex patterns to search for

# (?i) is to inactivate case-sensitivity
# (?:) is to indicate that contents inside a parenthesis shouldn't be read as a "capturing group"
# Default behavior of () is to consider it a capturing group
echo_prefix = {'lvef': ['(?i)lv\s+ejection\s+fraction',
                        '(?i)left\s+ventricular\s+ejection\s+fraction',
                        '(?i)lvef',
                        '(?i)left\s+ventricular\s+ef',
                        '(?i)lvef\s+is',
                        '(?i)left\s+ventricle\s+ejection\s+fraction\s+is',
                        '(?i)lv\s+ejection\s+fraction\s+is'],
               
               # Match "cardiopulmonary bypass" ensuring at least one whitespace character between those words
              'cp_bypass': ['(?i)cardiopulmonary\s+bypass'],
              
              'la_dimension': ['(?i)la\s+diameter',
                               '(?i)la\s+dimension'],

              'la_volume_index': ['(?i)la\s+volume',
                                  '(?i)LA\s+Vol\s+BP\s+A/L\s+Index'],
              
              'lv_hypertrophy': ['(?i)(?<!borderline\s)(?:left\s+ventricular|lv|lv\s+concentric)\s*hypertrophy',
                                 '(?i)(?<!borderline\s)LVH'],
              
              'diastolic_dysfunction': ['(?i)(grade\s*ii)',
                                        '(?i)(grade\s*iii)']}

echo_suffix = {'lvef': '\D{0,20}(\d{1,3}|\d{1,2}\s*-\s*\d{1,3})-{0,1}\s*%', # Sample matches: 45%, 45 %, 45-55%, 45 - 55 %, 45- 100%, 45- %
               'cp_bypass': '(?!\s*N\/A|\s*Patient\s+was\s+not\s+placed\s+on\s+cardiopulmonary\s+bypass|\s*NA)',  # Don't match if N/A or Patient wasn't placed on CPB
               'la_dimension': '\D{0,25}(\d\.\s*\d)\s*(?:cm|centimeter)', # Sample matches: 2.7cm, 2.7 cm, 2.7   centimeter
               
                # Match anything until "ml" appears once or never, then match anything until the number of interest appears
                # followed by either ml/m or ml per square meter
               'la_volume_index': '.*?(?:ml)?.*?(\d+\.\s*\d+)\s+(?:(?=ml\/m)|(?=ml\s+per\s+square\s+meter))',
               'lv_hypertrophy': '',
               # Matches anything, either never or up to 30 characters, then an arbitrary number of white spaces,
               # as long as "diastolic dysfunction" immediately follows.
               'diastolic_dysfunction': '.{0,30}\s*?(?=diastolic\s+dysfunction)'}

In [None]:
# 5. Flag ECHO reports (we can do this now since it is independent of previous steps. Plus, it saves runtime)
echo = flag_echos(echo, echo_prefix, echo_suffix)

In [None]:
# 1. Adjudicate hypoxemia
pf, hypox_df = mark_hypoxemic_episodes(pf, peep, 'encounter_id')

In [None]:
train_data = training_location / 'cxr_whole_training_dataset.csv'

In [None]:
cxr_threshs = np.linspace(0, 1, 1001)
notes_threshs = np.linspace(0, 1, 1001)

In [None]:
results = []

for cxr_thresh in tqdm(cxr_threshs):
    
    # 2a. Adjudicate bilateral infiltrates with ML
    cxr = mark_abnormal_cxr(
        cxr,
        train_data,
        train_col=['segmented_report', 'score'],
        test_label_col='curt_bl_infiltrates_(1=yes)',
        thresholding="custom",
        custom_threshold=cxr_thresh
        )
    
    # 2b. Flag CXRs that are within 48h of hypoxemia, and hypoxemia events that are post intubation
    cxr, hypox_pred_abn_cxr_48h = mark_cxr_within_48h_of_post_vent_hypoxemia(
        hypox_df,
        cxr,
        'encounter_id',
        'cxr_timestamp'
        )
    
    # 3a. Flag notes that are within 7d of hypoxemia or bilateral infiltrates (whichever is latest)
    notes = mark_note_within_7d(notes, hypox_df, hypox_pred_abn_cxr_48h, 'encounter_id', 'cxr_timestamp')
    
    for notes_thresh in tqdm(notes_threshs):
        
        # 3b. Adjudicate pneumonia with ML
        notes = mark_notes_with_ml(
            notes,
            training_location,
            train_col=['seg_pneumonia', 'pneumonia_sw'],
            test_label_col='curt_pneumonia_(1=yes)',
            thresholding="custom",
            custom_threshold=notes_thresh
            )
        
        # 3c. Adjudicate risk factors or heart failure with text matching
        notes = text_match_risk_factors(notes)
        
        # 4. Make diagnosis decisions based on above flags
        notes, diagnosed, excluded, for_objective_assessment = diagnose_or_exclude_encounters(
            notes,
            hypox_pred_abn_cxr_48h,
            'encounter_id'
            )
        
        # Encounters without evidence of risk factors or heart failure on notes enter the next stages:
        # 5a. BNP > 100 rule out
        a = bnp['bnp_value'] > 100
        encounters_with_bnp_greater_than_100 = list(bnp.loc[a, 'encounter_id'].unique())

        j = for_objective_assessment['encounter_id'].isin(encounters_with_bnp_greater_than_100)
        remaining_after_bnp = for_objective_assessment.loc[~j]
        
        # 5b. LVEF < 40% rule out
        b = echo['lvef_value'] < 40
        encounters_with_lvef_smaller_than_40 = list(echo.loc[b, 'encounter_id'].unique())

        j = remaining_after_bnp['encounter_id'].isin(encounters_with_lvef_smaller_than_40)
        remaining_after_lvef = remaining_after_bnp.loc[~j]
        
        # 5c. Cardiopulmonary bypass rule out
        cpb = echo['cp_bypass_value'].notnull()
        encounters_with_cardiopulmonary_bypass = list(echo.loc[cpb, 'encounter_id'].unique())

        j = remaining_after_lvef['encounter_id'].isin(encounters_with_cardiopulmonary_bypass)
        remaining_after_cpb = remaining_after_lvef.loc[~j]
        
        # 5d. Two out of three: (LA dim > 4 cm or LA volume index > 28 ml/m^2), LV hypertrophy, diastolic dysfunction
        la_dim = echo['la_dimension_value'] > 4
        la_vol_idx = echo['la_volume_index_value'] > 28
        echo.loc[:, 'la_enlargement_bool'] = (la_dim | la_vol_idx).astype(int)
        echo.loc[:, 'lv_hypertrophy_bool'] = echo['lv_hypertrophy_value'].notnull().astype(int)
        echo.loc[:, 'diastolic_dysfunction_bool'] = echo['diastolic_dysfunction_value'].notnull().astype(int)

        echo['additional_criteria_count'] = echo['la_enlargement_bool'] + echo['lv_hypertrophy_bool'] + echo['diastolic_dysfunction_bool']
        
        add_crit = echo['additional_criteria_count'] > 1
        encounters_with_additional_criteria = list(echo.loc[add_crit, 'encounter_id'].unique())

        j = remaining_after_cpb['encounter_id'].isin(encounters_with_additional_criteria)
        remaining_after_additional_criteria = remaining_after_cpb.loc[~j]
        
        # Collecting diagnosed encounters
        encounters_diagnosed_by_pipeline = pd.merge(
            diagnosed['encounter_id'].drop_duplicates(),
            remaining_after_additional_criteria['encounter_id'].drop_duplicates(),
            how='outer'
            ).drop_duplicates()
        
        # Creating encounters table
        encounter_summary = pd.merge(
            total_encounters,
            encounters_diagnosed_by_pipeline,
            how='outer',
            indicator=True
            )
        
        encounter_summary = encounter_summary.replace(
            to_replace={
                '_merge': {
                    "left_only": 'No',
                    "both": 'Yes'
                    }
                }
            )

        encounter_summary = encounter_summary.rename(columns={'_merge': "pipeline_diagnosed"})
                
        # Adding encounters diagnosed by Curt
        encounter_summary = pd.merge(encounter_summary, encounters_by_Curt, how='outer')
        encounter_summary = encounter_summary.replace(to_replace={'clinician_diagnosed': {0: "No", 1: "Yes"}})
        
        y_true = encounter_summary['clinician_diagnosed']
        y_pred = encounter_summary['pipeline_diagnosed']
        
        cf = confusion_matrix(y_true, y_pred)
        TN = cf[0,0]
        FP = cf[0,1]
        FN = cf[1,0]
        TP = cf[1,1]
        
        # Collect thresholds and metrics for a potential table
        results.append({
            'cxr_threshold': cxr_thresh,
            'notes_threshold': notes_thresh,
            'false_negative_rate': FN / (FN + TP),
            'false_positive_rate': FP / (FP + TN),
            'precision': precision_score(y_true, y_pred, pos_label='Yes'),
            'negative_predictive_value': TN / (TN + FN),
            'accuracy': accuracy_score(y_true, y_pred),
            'f1': f1_score(y_true, y_pred, pos_label='Yes')
            })
        

In [None]:
results_df = pd.DataFrame(results)

In [None]:
results_df["youden_j"] = (1 - results_df["false_negative_rate"]) - results_df["false_positive_rate"]
results_df.to_csv("mimic_threshold_results.csv", index=False)

In [None]:
# results_df = pd.read_csv("mimic_threshold_results.csv")

In [None]:
max_youden = results_df['youden_j'] == results_df['youden_j'].max()
max_f1 = results_df['f1'] == results_df['f1'].max()
max_accuracy = results_df['accuracy'] == results_df['accuracy'].max()

In [None]:
results_df.loc[max_accuracy].sort_values('notes_threshold', ascending=False)

In [None]:
plt.scatter(results_df['notes_threshold'], results_df['accuracy'])

In [None]:
results_df.loc[max_f1].sort_values('cxr_threshold', ascending=False)

In [None]:
plt.scatter(results_df['notes_threshold'], results_df['f1'])

In [None]:
results_df.loc[max_youden].sort_values('cxr_threshold', ascending=False)

In [None]:
plt.scatter(results_df['notes_threshold'], results_df['youden_j'])

In [None]:
# Custom display of tables for easier inspection
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.width', None)

In [None]:
results_df.loc[max_youden & max_f1 & max_accuracy].sort_values('notes_threshold', ascending=False)

In [None]:
results_df.loc[(results_df['cxr_threshold'] == 0.5) & (results_df['notes_threshold'] == 0.5)]