In [1]:
import numpy as np
from pathlib import Path
import os
import pandas as pd

root = Path('/hpc/group/kamaleswaranlab/EmoryDataset/Images/chest_xrays')
embedding_path = root / 'BioMedCLIP_embeddings'
supertable_path = root / 'matched_supertables_with_images'
supertable_template = "_timing_corrected.pickle"

In [76]:
supertables = list(supertable_path.glob("*" + supertable_template))
len(supertables)

48194

In [2]:
import pickle

with open(root / 'supertable_stats.pickle', 'rb') as f:
    stats = pickle.load(f)

In [78]:
stats

{'age_min': np.int64(18),
 'age_max': np.int64(121),
 'age_nan': 0,
 'cci9_min': np.float64(0.0),
 'cci9_max': np.int64(14),
 'cci9_nan': 0,
 'cci10_min': np.float64(0.0),
 'cci10_max': np.int64(81),
 'cci10_nan': 2879,
 'vent_rate_set_min': np.float64(-16.0),
 'vent_rate_set_max': np.float64(647.0),
 'vent_rate_set_nan': 35569,
 'vent_tidal_rate_set_min': np.float64(-480.0),
 'vent_tidal_rate_set_max': np.float64(5800.0),
 'vent_tidal_rate_set_nan': 37282,
 'vent_tidal_rate_exhaled_min': np.float64(0.0),
 'vent_tidal_rate_exhaled_max': np.float64(7647.0),
 'vent_tidal_rate_exhaled_nan': 33722,
 'peep_min': np.float64(0.0),
 'peep_max': np.float64(96.0),
 'peep_nan': 35086,
 'Oxygen_Flow_Rate_min': np.float64(-2.0),
 'Oxygen_Flow_Rate_max': np.float64(99.0),
 'Oxygen_Flow_Rate_nan': 6623,
 'covid': 0,
 'procedure': 5293,
 'gender': Counter({'Male': 25685, 'Female': 22509}),
 'race': Counter({'Caucasian or White': 24465,
          'African American  or Black': 19429,
          'Unknown,

In [79]:
selected_columns = [
    #Vitals
    'temperature',
    'daily_weight_kg',
    'sbp', 'dbp',
    'pulse', 'unassisted_resp_rate',
    'spo2', 'end_tidal_co2',
    'Oxygen_Flow_Rate',  'best_map',
    #'pulse_pressure', removing because linear combination sbp - dbp
    'pf_sp', 'pf_pa',

    
    #Labs
    'anion_gap', 'base_excess',
    'bicarb_(hco3)', 'blood_urea_nitrogen_(bun)',
    'calcium','calcium_adjusted','calcium_ionized', 'chloride',
    'creatinine','glucose','magnesium','phosphorus',
    'potassium','sodium','hematocrit','hemoglobin',         
    'platelets',
    'white_blood_cell_count',
    'alanine_aminotransferase_(alt)',
    'albumin',
    'alkaline_phosphatase',
    'ammonia',
    'aspartate_aminotransferase_(ast)',
    'bilirubin_direct',
    'bilirubin_total',
    'fibrinogen',
    'inr',
    'lactate_dehydrogenase',
    'partial_prothrombin_time_(ptt)',
    'protein',
    'lipase',
    'troponin',
    'fio2',
    #'partial_pressure_of_carbon_dioxide_(paco2)',
    #'partial_pressure_of_oxygen_(pao2)', Removing these two because correlated with pf_pa, pf_sp
    'ph',
    'saturation_of_oxygen_(sao2)',  'n_to_l',
    
    'gcs_total_score',

    #Vasopressors
    'norepinephrine',
    'epinephrine',
    'dobutamine',
    'dopamine',
    'phenylephrine',
    'vasopressin',

    #Other data  
    'covid',  #Removing vent_status because same as on_vent
    'icu', #Flag
    'procedure', #Flag 
    'on_pressors',
    'on_dobutamine',
    'on_dialysis',
    'elapsed_icu',
    'elapsed_hosp',
    'on_vent',
    'age',
    'gender',
    'race',
    'ethnicity',
    'cci9',
    'cci10',
    'infection',
    'sepsis',
    
    'bed_type', #Ward/ICU string
    'icu_type', #ICU type string
    

    'vent_mode',
    'vent_rate_set',
    'vent_tidal_rate_exhaled',
    'peep', #'vent_fio2' somehow less accurate than fio2
     'cxr_timing',
     'cxr_timing_approx_flag'
]

imputation_strategy = {
    #Vitals
    'temperature' : [37, 'FeedForward', 36, 38],  
    'daily_weight_kg': [None, 'FeedForward', 60, 90],
    'sbp_selected': [90, 'FeedForward', 90, 130], 
    'dbp_selected': [60, 'FeedForward', 65, 75],
    'pulse': [75, 'FeedForward', 60, 90], 
    'unassisted_resp_rate': [17, 'FeedForward', 10, 24],
    'spo2': [98, 'FeedForward', 95, 100], 
    'end_tidal_co2': [None, None, 35, 45],
    'Oxygen_Flow_Rate': [None, 'FeedForward6', 0, 70, 0, 70],  
    'best_map': [70, 'FeedForward', 65, 75],
    'pf_sp': [None, 'FeedForward24', 300, 500], 
    'pf_pa': [None, 'FeedForward24', 400, 500],

    
    #Labs
    'anion_gap' : [None, 'FeedForward24', 8, 12], 
    'base_excess': [None, 'FeedForward24', -2, 2],
    'bicarb_(hco3)': [None, 'FeedForward24', 22, 27], 
    'blood_urea_nitrogen_(bun)': [None, 'FeedForward24', 6, 20],
    'calcium' : [None, 'FeedForward24', 8.5, 10.5], #'calcium_adjusted',
    'calcium_ionized': [None, 'FeedForward24', 1, 1.3], 
    'chloride': [None, 'FeedForward24', 96, 106],
    'creatinine': [None, 'FeedForward24', 0.5, 1.3],
    'glucose': [None, 'FeedForward24', 60, 200],
    'magnesium': [None, 'FeedForward24', 1.5, 2.5],
    'phosphorus': [None, 'FeedForward24', 2.5, 4.5],
    'potassium': [None, 'FeedForward24', 3.5, 4.5],
    'sodium': [None, 'FeedForward24', 135, 145],
    'hematocrit': [None, 'FeedForward24', 35, 45],
    'hemoglobin': [None, 'FeedForward24', 12, 17],         
    'platelets': [None, 'FeedForward24', 150, 450],
    'white_blood_cell_count': [None, 'FeedForward24', 4, 11],
    'alanine_aminotransferase_(alt)': [None, 'FeedForward24', 4, 36],
    'albumin': [None, 'FeedForward24', 3.4, 5.4],
    'alkaline_phosphatase': [None, 'FeedForward24', 44, 147],
    'ammonia': [None, 'FeedForward24', 15, 45],
    'aspartate_aminotransferase_(ast)': [None, 'FeedForward24', 8, 33],
    'bilirubin_direct': [None, 'FeedForward24', 0.1, 0.4],
    'bilirubin_total': [None, 'FeedForward24', 0.2, 1.2],
    'fibrinogen': [None, 'FeedForward24', 200, 400],
    'inr': [None, 'FeedForward24', 0.8, 1.3],
    'lactate_dehydrogenase': [None, 'FeedForward24', 105, 350],
    'partial_prothrombin_time_(ptt)': [None, 'FeedForward24', 25, 35],
    'protein': [None, 'FeedForward24', 6, 8.3],
    'lipase': [None, 'FeedForward24', 1, 160],
    'troponin': [None, 'FeedForward24', 0, 0.03],
    'fio2': [None, 'FeedForward24', 21, 40],
    #'partial_pressure_of_carbon_dioxide_(paco2)',
    #'partial_pressure_of_oxygen_(pao2)', Removing these two because correlated with pf_pa, pf_sp
    'ph': [None, 'FeedForward24', 7.35, 7.45],
    'saturation_of_oxygen_(sao2)': [None, 'FeedForward24', 95, 100],  
    'n_to_l' : [None, 'FeedForward24', 0, 100],
    
    'gcs_total_score': [15, 'FeedForward24', 1, 15],

    #Vasopressors
    #'norepinephrine',
    #'epinephrine',
    #'dobutamine',
    #'dopamine',
    #'phenylephrine',
    #'vasopressin',

    #Other data  
    'covid':[0, 0, 0, 1],  #Removing vent_status because same as on_vent
    'icu':[0, 0, 0, 1], #Flag
    'procedure': [0, 0, 0, 1], #Flag 
    'on_pressors':[0, 0, 0, 1],
    'on_dobutamine': [0, 0, 0, 1],
    'on_dialysis': [0, 0, 0, 1],
    'elapsed_icu': [0, None, None, None],
    'elapsed_hosp': [0, None, None, None],
    'on_vent': [0, 0, 0, 1],
    'age': [None, None, 18, 100],
    'gender': [None, None, 'Onehot'],
    'race': [None, None, 'Onehot'],
    ##'ethnicity': [None, None, 'Onehot'],
    'cci9': [None, None, 0, 100],
    'cci10': [None, None, 0, 100],
    'infection': [0,0, 0, 1],
    'sepsis': [0,0, 0, 1],
    
    'bed_type': [None, None, 'Onehot'], #Ward/ICU string
    'icu_type': [None, None, 'Onehot'], #ICU type string
    

    ##'vent_mode': [None, 'FeedForwardVO', 'Onehot'],
    'vent_rate_set': [None, 'FeedforwardVO',0, 100, 0, 600],
    'vent_tidal_rate_exhaled': [None, 'FeedForwardVO', 0, 500],
    'peep': [None, 'FeedForwardVO', 5, 15], #'vent_fio2' somehow less accurate than fio2
    #'cxr_timing',
    # 'cxr_timing_approx_flag'
    
}

In [80]:
def process_icu_type(df):
    icu_types_in_df = df.icu_type.unique()
    
    if 'sicu BEFORE 1/18/2018; cticu ON OR AFTER 1/18/2018' in icu_types_in_df:
        indices = df.loc[df.icu_type == 'sicu BEFORE 1/18/2018; cticu ON OR AFTER 1/18/2018'].index
        if indices[0] > pd.to_datetime('1/18/2018'): 
            df.loc[indices, 'icu_type'] = 'cticu'
        else:
            df.loc[indices, 'icu_type'] = 'sicu'
    if 'cticu BEFORE 1/18/2018; micu ON OR AFTER 1/18/2018' in icu_types_in_df:
        indices = df.loc[df.icu_type == 'cticu BEFORE 1/18/2018; micu ON OR AFTER 1/18/2018'].index 
        if indices[0] > pd.to_datetime('1/18/2018'): 
            df.loc[indices, 'icu_type'] = 'micu'
        else:
            df.loc[indices, 'icu_type'] = 'cticu'
    
    if 'sicu BEFORE 1/18/2018' in icu_types_in_df:
        indices = df.loc[df.icu_type == 'sicu BEFORE 1/18/2018'].index 
        if indices[0] > pd.to_datetime('1/18/2018'): 
            df.loc[indices, 'icu_type'] = np.nan
        else:
            df.loc[indices, 'icu_type'] = 'sicu'
    return df['icu_type'].values

def process_race(df):
    if df.race.values[0] in ['Multiple', 'Unknown, Unavailable or Unreported']:
        return [np.nan]*len(df)

def process_bedtype(df):
    other_indices = df.loc[df.bed_type == 'other']
    if len(other_indices) > 0:
        df.loc[other_indices, 'bed_type'] = np.nan
    return df['bed_type'].values




In [81]:
def bp_selector(row):
    """
    Selects "line" or "cuff" measurements for blood pressure.  
    """
    if row[['sbp_line','dbp_line']].notnull().all() and (row['sbp_line'] - row['dbp_line']) > 15:
        return row['sbp_line'], row['dbp_line']
    elif row[['sbp_cuff','dbp_cuff']].notnull().all() and (row['sbp_cuff'] - row['dbp_cuff']) > 15 :
        return row['sbp_cuff'], row['dbp_cuff']
    else:
        return np.nan, np.nan       

In [82]:
def process_dataframes(df, imputation_strategy, stats_dict=None):
    """
    Process dataframes according to the imputation strategy.
    
    Args:
        df: pandas DataFrame to process
        imputation_strategy: dictionary containing processing rules
        stats_dict: dictionary containing stats for one-hot encoding (optional)
    """
    
    def onehot_encode(df, column_name, final_stats):
        nan_mask = df[column_name].isna()
        df[column_name] = pd.Categorical(df[column_name], categories = list(final_stats[column_name].keys()))
        dummies = pd.get_dummies(df[column_name], prefix=column_name, columns=list(final_stats[column_name].keys()))
        if nan_mask.any():
            dummies.loc[nan_mask] = 0
        return pd.concat([df, dummies], axis=1)

    # First, select only columns that are in imputation_strategy
    columns_to_process = [col for col in df.columns if col in imputation_strategy]
    df_processed = df[columns_to_process].copy()
    
    # First pass: Handle initial missing values and value bounds
    for col in columns_to_process:
        strategy = imputation_strategy[col]
        # Handle 5th and 6th elements if they exist (value bounds)
        if len(strategy) >= 6:
            min_bound, max_bound = strategy[4], strategy[5]
            mask = (df_processed[col] < min_bound) | (df_processed[col] > max_bound)
            df_processed.loc[mask, col] = np.nan
        
        # Handle initial missing values
        initial_value = strategy[0]
        if initial_value is not None:
            # If all values are missing, fill entire column
            if df_processed[col].isna().all():
                df_processed[col] = initial_value
            else:
                # Find first non-null value and fill everything before it
                first_valid = df_processed[col].first_valid_index()
                if first_valid is not None:
                    df_processed.loc[:first_valid, col] = initial_value
    
    # Second pass: Handle imputation strategies
    for col in columns_to_process:
        strategy = imputation_strategy[col]
        impute_method = strategy[1]
        
        if impute_method == 'FeedForward':
            df_processed[col] = df_processed[col].ffill()
        
        elif impute_method == 'FeedForward24':
            df_processed[col] = df_processed[col].ffill(limit=24)
        
        elif impute_method == 'FeedForward6':
            df_processed[col] = df_processed[col].ffill(limit=6)
        
        elif impute_method == 'FeedForwardVO':
            temp_series = df_processed[col].copy()
    
            # Get indices where we have valid values
            valid_indices = temp_series.dropna().index

            for valid_idx in valid_indices:
                # Only process if we're in a vent-on period
                if df_processed.loc[valid_idx, 'on_vent'] == 1:
                    # Find next vent off or next valid value after this point
                    next_indices = df_processed.index[df_processed.index > valid_idx]
                    if len(next_indices) > 0:
                        next_off = next_indices[
                            (df_processed.loc[next_indices, 'on_vent'] == 0) |
                            (~temp_series.loc[next_indices].isna())
                        ].min() if len(next_indices) > 0 else len(df_processed)

                        # Forward fill from this valid value until next off/valid
                        temp_series.loc[valid_idx:next_off] = temp_series.loc[valid_idx:next_off].ffill()

            df_processed[col] = temp_series
        
        elif impute_method == 0:
            df_processed[col] = df_processed[col].fillna(0)
    
    # Third pass: Handle normalization and one-hot encoding
    for col in columns_to_process:
        strategy = imputation_strategy[col]
        
        if strategy[2] == 'Onehot':
            if stats_dict is not None:
                df_processed = onehot_encode(df_processed, col, stats_dict)
        elif strategy[2] is not None and strategy[3] is not None:
            # Min-max normalization
            min_val, max_val = strategy[2], strategy[3]
            try:
                df_processed[col] = (df_processed[col] - min_val) / (max_val - min_val)
            except:
                print(col, min_val, max_val)
    return df_processed

### Add Imaging

In [15]:
from scipy.interpolate import interp1d

embedding_path = root / 'BioMedCLIP_embeddings'

supertable_template = "_timing_corrected_processed.pickle"
supertables = list(supertable_path.glob("*" + supertable_template))
len(supertables)

48194

In [4]:
df = pd.read_pickle(supertables[0])
df

Unnamed: 0,temperature,daily_weight_kg,pulse,unassisted_resp_rate,spo2,end_tidal_co2,anion_gap,base_excess,bicarb_(hco3),blood_urea_nitrogen_(bun),...,bed_type_or,bed_type_obs,icu_type_cticu,icu_type_micu,icu_type_neuro,icu_type_ccu,icu_type_msicu,cxr_timing,cxr_timing_approx_flag,encounter_id
2015-04-16 19:30:00,0.5,,0.500000,0.500000,0.6,,,,,,...,0,0,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-16 20:30:00,0.5,,0.500000,0.500000,0.6,,0.5,,0.6,0.285714,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-16 21:30:00,0.5,,1.500000,0.428571,0.8,,0.5,,0.6,0.285714,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-16 22:30:00,0.5,,-0.300000,0.428571,0.6,-3.3,0.5,,0.6,0.285714,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-16 23:30:00,0.5,,-0.300000,0.428571,0.6,,0.5,,0.6,0.285714,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-04-21 04:30:00,0.2,1.22,0.000000,0.714286,0.4,,,,,,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-21 05:30:00,0.2,1.22,0.000000,0.714286,0.4,,,,,,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-21 06:30:00,-0.1,1.22,-0.233333,0.714286,0.4,,,,,,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...
2015-04-21 07:30:00,-0.1,1.22,-0.233333,0.714286,0.4,,,,,,...,False,False,0,0,0,0,0,,,7c9832ae9005681cb88a87816d63c70060b1a97e2a922a...


In [5]:
object_columns = [column for column in df.columns if not pd.api.types.is_numeric_dtype(df[column])]
object_columns = [column for column in object_columns if column not in ['cxr_timing', 'encounter_id']]

In [6]:
for column in object_columns:
    df[column] = df[column].astype('float64')

In [7]:
dtypes = [df[column].dtype for column in df.columns]

In [8]:
list(zip(df.columns, dtypes))

[('temperature', dtype('float64')),
 ('daily_weight_kg', dtype('float64')),
 ('pulse', dtype('float64')),
 ('unassisted_resp_rate', dtype('float64')),
 ('spo2', dtype('float64')),
 ('end_tidal_co2', dtype('float64')),
 ('anion_gap', dtype('float64')),
 ('base_excess', dtype('float64')),
 ('bicarb_(hco3)', dtype('float64')),
 ('blood_urea_nitrogen_(bun)', dtype('float64')),
 ('calcium', dtype('float64')),
 ('calcium_ionized', dtype('float64')),
 ('chloride', dtype('float64')),
 ('creatinine', dtype('float64')),
 ('glucose', dtype('float64')),
 ('magnesium', dtype('float64')),
 ('phosphorus', dtype('float64')),
 ('potassium', dtype('float64')),
 ('sodium', dtype('float64')),
 ('hematocrit', dtype('float64')),
 ('hemoglobin', dtype('float64')),
 ('platelets', dtype('float64')),
 ('white_blood_cell_count', dtype('float64')),
 ('alanine_aminotransferase_(alt)', dtype('float64')),
 ('albumin', dtype('float64')),
 ('alkaline_phosphatase', dtype('float64')),
 ('ammonia', dtype('float64')),
 ('

In [9]:
df['cxr_trajectory_endpoints'] = [0]*len(df)

In [10]:
df.loc[df['cxr_timing'].first_valid_index(): df['cxr_timing'].last_valid_index(), 'cxr_trajectory_endpoints'] = 1

In [11]:
df['cxr_timing_ffill'] = df['cxr_timing'].ffill()

In [12]:
interpolated_embeddings = np.zeros((len(df), 512))

In [13]:
recorded_cxr_index = np.where(df.cxr_timing.values != None)[0]

In [14]:
for idx in recorded_cxr_index:
    path = embedding_path / (df.iloc[idx]['cxr_timing'] + '.npy')
    if path.exists():
        embedding = np.load(path)
        interpolated_embeddings[idx, :] = embedding
    else:
        print("No such file: ", path)

In [17]:
for i in range(len(recorded_cxr_index)-1):
    start_idx = recorded_cxr_index[i]
    end_idx = recorded_cxr_index[i+1]
    v1 = interpolated_embeddings[start_idx]
    v2 = interpolated_embeddings[end_idx]
    n_steps = end_idx - start_idx
    weights = np.arange(1, n_steps) / n_steps
    interpolated_embeddings[start_idx+1:end_idx] = v1 + weights.reshape(-1, 1) * (v2 - v1)

array([[ -1.88171816,  -3.88081431, -24.46614456, ...,  -0.17376935,
         -3.70242667,  -0.52481449],
       [ -1.90837222,  -3.81346272, -24.40264664, ...,  -0.16588222,
         -3.65594695,  -0.51140139],
       [ -1.93502628,  -3.74611113, -24.33914871, ...,  -0.15799509,
         -3.60946723,  -0.4979883 ],
       ...,
       [ -3.18776717,  -0.58058639, -21.35474625, ...,   0.21270004,
         -1.42492042,   0.13242707],
       [ -3.21442123,  -0.5132348 , -21.29124832, ...,   0.22058717,
         -1.37844071,   0.14584017],
       [ -3.24107529,  -0.44588321, -21.2277504 , ...,   0.2284743 ,
         -1.33196099,   0.15925326]])