# Figures 3(b, c)

Requires the respective script to be run first. See ```grande_experiment_slurm.sh``` for details.

In [1]:
import os
os.chdir("../")

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import json
from src import utils
from IPython.display import clear_output
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import ConnectionPatch
import matplotlib.ticker as mtick
import numpy as np

sns.set_theme(context='paper', style='ticks', font_scale=1)

In [4]:
name="grande"
width_pt = 397
palette = sns.color_palette('husl', 5)

In [5]:
files = glob.glob("./outputs/experiments/{name}_*.json".format(name=name))

dicts = []
for fl_id, fl in enumerate(files):
    clear_output(wait=True)
    print('Reading file ' + str(fl_id+1)+'/'+str(len(files)))
    with open(fl,"r") as f:
        js = json.load(f)
        
        dict = {}
        dict['k'] = js['k']
        dict['horizon'] = js['horizon']
        dict['pid'] = js['pid']
        dict['anchor_runtime'] = js['anchor_runtime']
        dict['astar_runtime'] = js['astar_runtime']
        dict['ebf'] = js['ebf']
        dict['reward'] = js['reward']
        dict['cf_reward'] = js['cf_reward']

        dict['states'] = js['states']
        dict['actions'] = js['actions']
        dict['cf_states'] = js['cf_states']
        dict['cf_actions'] = js['cf_actions']
        
        dicts.append(dict)

raw_df = pd.DataFrame(dicts)

Reading file 15992/15992


In [294]:
# get scaler data
with open("./data/processed/feature_normalization.json", 'r') as f:
    scaling = json.load(f)
    sofa_min = scaling['min']['SOFA']
    sofa_max = scaling['max']['SOFA']

def compute_sofa(row):
    states = np.array(row['states'])
    cf_states = np.array(row['cf_states'])
    states[:,-1] = (states[:,-1] + 0.5) * (sofa_max - sofa_min) + sofa_min
    cf_states[:,-1] = (cf_states[:,-1] + 0.5) * (sofa_max - sofa_min) + sofa_min
    #truncate to sofa_min, sofa_max
    states[:,-1] = np.clip(states[:,-1], sofa_min, sofa_max)
    cf_states[:,-1] = np.clip(cf_states[:,-1], sofa_min, sofa_max)
    
    return np.sum(states[:,-1]), np.sum(cf_states[:,-1])

In [295]:
input_df = raw_df

# iterate over dataframe indices
for i in input_df.index:
    input_df.loc[i, 'total_sofa'], input_df.loc[i, 'cf_total_sofa'] = compute_sofa(input_df.loc[i])

input_df['improvement'] = (input_df['total_sofa'] - input_df['cf_total_sofa'])/input_df['total_sofa'] * 100

In [296]:
utils.latexify() # Computer Modern, with TeX

fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_width, fig_height))

# Draw Axis 1
sns.histplot(data=input_df, x="improvement", bins=30, color=palette[3], ax=ax)
sns.despine(ax=ax)
ax.set_xlabel('Counterfactual improvement')
ax.set_ylabel('Number of patients')
ax.set_xlim(xmin=0, xmax=40)
# ax.set_xticks([0, 10, 20, 30, 40, 50, 60])
ax.xaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))

# find the median
median = input_df['improvement'].median()
print('Median: ' + str(median))
ax.axvline(median, color='black', linestyle='--', linewidth=1)

# Create an inset axes based on a threshold
thres=15
right_lim = 37

axins = inset_axes(ax, width="54%", height="54%", loc='upper right')  # loc=2 is 'upper left'
zoomed_df = input_df[input_df['improvement'] > thres]
print('Patients above the threshold: ' + str(len(zoomed_df)))
sns.histplot(data=zoomed_df, x="improvement", bins=10, color=palette[4], ax=axins)
sns.despine(ax=axins)
axins.set_xlim(xmin=thres, xmax=right_lim)

# set (percent) ticks for axins
xticks = [20, 30]
xticklabels = [str(x)+"\%" for x in xticks]
axins.set_xticks(xticks)
axins.set_xticklabels(xticklabels)
axins.set_ylim(0, 100)  # adjust y limit if necessary

# add an exclamation mark
axins.text(25, 40, r'\textbf{!}', fontsize=30, color=palette[4])

axins.set_xlabel('')
axins.set_ylabel('')

# Draw line from the point (0.1, 0) on the main plot to the bottom left corner of the inset plot
xyA = (thres, 0)  # point on main plot
xyB = (0, 0)  # point on inset plot (bottom left corner)
coordsA, coordsB = "data", "axes fraction"
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=coordsA, coordsB=coordsB, axesA=ax, axesB=axins,
                      linewidth=1, linestyle='dotted', color=sns.axes_style()['axes.edgecolor'], shrinkB=5)
ax.add_artist(con)

xyA = (right_lim, 0)  # point on main plot
xyB = (1, 0)  # point on inset plot (bottom left corner)
coordsA, coordsB = "data", "axes fraction"
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=coordsA, coordsB=coordsB, axesA=ax, axesB=axins,
                      linewidth=1, linestyle='dotted', color=sns.axes_style()['axes.edgecolor'], shrinkB=5)
ax.add_artist(con)

fig.tight_layout()
fig.savefig('figures/grande_distribution.pdf', dpi=300)

Median: 5.089303140438469
Patients above the threshold: 176


  fig.tight_layout()


In [55]:
# print ids and improvement
for improv, id in zoomed_df[['improvement', 'pid']].sort_values(by='improvement').to_numpy().tolist():
    print('Patient ' + str(int(id)) + ': ' + str(np.round(improv, decimals=2)) + '%')

Patient 27882: 15.02%
Patient 81410: 15.02%
Patient 62482: 15.04%
Patient 8782: 15.04%
Patient 81147: 15.05%
Patient 90585: 15.05%
Patient 63500: 15.06%
Patient 17800: 15.07%
Patient 78499: 15.08%
Patient 1209: 15.1%
Patient 3511: 15.13%
Patient 19962: 15.14%
Patient 72342: 15.15%
Patient 24649: 15.16%
Patient 97506: 15.16%
Patient 40913: 15.17%
Patient 9493: 15.2%
Patient 29092: 15.21%
Patient 34929: 15.24%
Patient 88974: 15.24%
Patient 61040: 15.24%
Patient 54154: 15.25%
Patient 653: 15.25%
Patient 21304: 15.26%
Patient 40506: 15.26%
Patient 87916: 15.27%
Patient 77587: 15.29%
Patient 42592: 15.3%
Patient 11103: 15.3%
Patient 35747: 15.31%
Patient 38793: 15.31%
Patient 41264: 15.32%
Patient 66156: 15.34%
Patient 70853: 15.35%
Patient 85169: 15.36%
Patient 93620: 15.36%
Patient 21825: 15.37%
Patient 12175: 15.38%
Patient 96749: 15.4%
Patient 92509: 15.4%
Patient 64617: 15.42%
Patient 76315: 15.43%
Patient 62179: 15.45%
Patient 6618: 15.48%
Patient 36610: 15.49%
Patient 99626: 15.49%
P

### Code to show the observed and counterfactual episode of a specific patient

The example in the paper is a patient with $pid=65961$.

In [None]:
# select a patient to investigate
patient_df = input_df[input_df['pid'] == 65961]

In [300]:
print('Improvement: ' + str(patient_df['improvement'].values[0]) + '%')

# read SOFA and actions
horizon = patient_df['horizon'].values[0]
states = pd.DataFrame(patient_df['states'].values[0], columns=['gender', 're_admission', 'age', 'FiO2_1', 'paO2', 'Platelets_count', 'Total_bili', 'GCS', 'MeanBP', 'Creatinine', 'output_4hourly', 'SOFA'])
states = states.loc[:,'SOFA']
actions = pd.DataFrame(patient_df['actions'].values[0], columns=['vaso', 'ivfluids', 'mechvent'])
episode = pd.concat([actions, states], axis=1)
episode['time'] = 4*episode.index

# scale SOFA, vaso and ivfluids to their real values
with open('./data/processed/action_dictionary.json', 'r') as f:
    action_dict = json.load(f)

# iterate over rows 
for index, row in episode.iterrows():
    # create a tuple of the first 3 columns
    action = (row['vaso'], row['ivfluids'], row['mechvent'])
    if action != (42, 42, 42):
        # get the action from the dictionary
        episode.loc[index, 'vaso'] = action_dict[str(action)]['vaso']
        episode.loc[index, 'ivfluids'] = action_dict[str(action)]['ivfluids']
        episode.loc[index, 'mechvent'] = action_dict[str(action)]['mechvent']
        # add 0.5 to mechvent
        episode.loc[index, 'mechvent'] += 0.5
    # scale the SOFA column
    episode.loc[index, 'SOFA'] = max(0, (row['SOFA'] + 0.5) * (sofa_max - sofa_min) + sofa_min)

cf_states = pd.DataFrame(patient_df['cf_states'].values[0], columns=['gender', 're_admission', 'age', 'FiO2_1', 'paO2', 'Platelets_count', 'Total_bili', 'GCS', 'MeanBP', 'Creatinine', 'output_4hourly', 'SOFA'])
cf_states = cf_states.loc[:,'SOFA']
cf_actions = pd.DataFrame(patient_df['cf_actions'].values[0], columns=['vaso', 'ivfluids', 'mechvent'])
# add a row of 42s to the actions
cf_actions.loc[horizon-1] = [42, 42, 42]
cf_episode = pd.concat([cf_actions, cf_states], axis=1)
cf_episode['time'] = 4*cf_episode.index

# iterate over rows 
for index, row in cf_episode.iterrows():
    # create a tuple of the first 3 columns
    action = (row['vaso'], row['ivfluids'], row['mechvent'])
    if action != (42, 42, 42):
        # get the action from the dictionary
        cf_episode.loc[index, 'vaso'] = action_dict[str(action)]['vaso']
        cf_episode.loc[index, 'ivfluids'] = action_dict[str(action)]['ivfluids']
        cf_episode.loc[index, 'mechvent'] = action_dict[str(action)]['mechvent']
        # add 0.5 to mechvent
        cf_episode.loc[index, 'mechvent'] += 0.5
    # scale the SOFA column
    cf_episode.loc[index, 'SOFA'] = max(0,(row['SOFA'] + 0.5) * (sofa_max - sofa_min) + sofa_min)

Improvement: 19.907944625929783%


In [301]:
print(episode)
print(cf_episode)

      vaso  ivfluids  mechvent  SOFA  time
0    0.000       0.0       0.0   7.0     0
1    0.000       0.0       0.0   6.0     4
2    0.000       0.0       0.0   7.0     8
3    0.225     850.0       1.0   8.0    12
4    0.000     850.0       1.0  13.0    16
5    0.000     850.0       1.0  11.0    20
6    0.000     850.0       1.0  15.0    24
7    0.000     850.0       0.0   9.0    28
8    0.000      30.0       0.0   6.0    32
9    0.000       0.0       0.0   6.0    36
10   0.000       0.0       0.0   5.0    40
11  42.000      42.0      42.0   2.0    44
      vaso  ivfluids  mechvent       SOFA  time
0    0.000       0.0       0.0   7.000000     0
1    0.000       0.0       0.0   6.000001     4
2    0.000     850.0       0.0   7.000000     8
3    0.788       0.0       1.0   7.509409    12
4    0.788       0.0       1.0  13.150763    16
5    0.000     850.0       1.0   9.673411    20
6    0.000     850.0       1.0  12.363342    24
7    0.000     850.0       0.0   5.761981    28
8    0.00

In [303]:
utils.latexify() # Computer Modern, with TeX

fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_width, fig_height))

sns.lineplot(data=cf_episode, x='time', y='SOFA', linestyle='--', markers='*', color=palette[4], ax=ax)
sns.lineplot(data=episode, x='time', y='SOFA', markers='*', color='black', ax=ax)
sns.despine(ax=ax)
ax.set_xlabel('Time (hours)')
ax.set_ylabel('SOFA score')
ax.set_xlim(xmin=0)
ax.set_ylim(ymin=0)

fig.tight_layout()
fig.savefig('figures/patient.pdf', dpi=300)