# Compute the exact transition matrix for the sepsis simulator v2
https://github.com/clinicalml/gumbel-max-scm/tree/854229e039b52f10257ad5460fa79d34f0452b27/sepsisSimDiabetes

In [1]:
!mkdir '../data'

In [2]:
# Apply transitions in the following order:
# abx on/off
# vent on/off
# vaso on/off
# hr fluctuate
# sbp fluctuate
# o2 fluctuate
# glu fluctuate

In [3]:
import numpy as np
from sepsisSimDiabetes.State import State
from sepsisSimDiabetes.Action import Action
from sepsisSimDiabetes.MDP import MDP
import itertools
import joblib

## Metadata

In [4]:
# state_categs
# [hr, sbp, o2, glu, abx, vent, vaso, diab]

state_variable_values = {
    'hr': [0,1,2], 
    'sbp': [0,1,2], 
    'o2': [0,1], 
    'glu': [0,1,2,3,4], 
    'abx': [0,1], 
    'vaso': [0,1], 
    'vent': [0,1], 
    'diab': [0,1],
}
state_variables = list(state_variable_values.keys())

In [5]:
nS = 720
nA = 8

## Reward Matrix (A,S,S) (S,A)

In [6]:
dummy_pol = np.ones((nS, nA)) / nA
reward_per_state = np.zeros((nS))
for s in range(nS):
    this_mdp = MDP(init_state_idx=s, policy_array=dummy_pol, p_diabetes=0)
    r = this_mdp.calculateReward()
    reward_per_state[s] = r

print((reward_per_state == 0).sum(), 'non-terminal states')
print((reward_per_state == -1).sum(), 'death states')
print((reward_per_state == 1).sum(), 'discharge states')

303 non-terminal states
416 death states
1 discharge states


In [7]:
reward_matrix_ASS = np.zeros((nA, nS*2, nS*2))
for s in range(nS):
    reward_matrix_ASS[:, :nS, s] = reward_per_state[s]
    reward_matrix_ASS[:, nS:, nS+s] = reward_per_state[s]

In [8]:
# Assign reward for the transition from death/disch
reward_matrix_absorbing_SA = np.zeros((nS*2+2, nA))
for s in range(nS):
    if reward_per_state[s] == -1:
        reward_matrix_absorbing_SA[s, :] = -1
        reward_matrix_absorbing_SA[nS+s, :] = -1
    elif reward_per_state[s] == 1:
        reward_matrix_absorbing_SA[s, :] = 1
        reward_matrix_absorbing_SA[nS+s, :] = 1

In [9]:
reward_matrix_absorbing_ASS = np.zeros((nA, nS*2+2, nS*2+2))

# Assign reward for the transition from death/disch leading to the corresponding absorbing state
reward_matrix_absorbing_ASS[..., -2] = -1
reward_matrix_absorbing_ASS[..., -1] = 1

# No reward once in aborbing state
reward_matrix_absorbing_ASS[..., -2, -2] = 0 
reward_matrix_absorbing_ASS[..., -1, -1] = 0

## Treatments

### Antibiotics

In [10]:
# abx indicator
abx_on = np.zeros((nS, nS))
abx_off = np.zeros((nS, nS))
for (hr, sbp, o2, glu, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['abx']]
):
    s0 = State(state_categs=[hr, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    abx_on[s0, s1] = 1
    abx_on[s1, s1] = 1
    abx_off[s1, s0] = 1
    abx_off[s0, s0] = 1

assert np.isclose(abx_on.sum(axis=1), 1).all()
assert np.isclose(abx_off.sum(axis=1), 1).all()

In [11]:
# abx affects hr: high->normal wp 0.5
hr_H2N_wp05 = np.zeros((nS, nS))
for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['hr']]
):
    s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    hr_H2N_wp05[s2, s2] = 0.5
    hr_H2N_wp05[s2, s1] = 0.5
    hr_H2N_wp05[s1, s1] = 1
    hr_H2N_wp05[s0, s0] = 1

assert np.isclose(hr_H2N_wp05.sum(axis=1), 1).all()

In [12]:
# abx affects sbp: high->normal wp 0.5
sbp_H2N_wp05 = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    sbp_H2N_wp05[s2, s2] = 0.5
    sbp_H2N_wp05[s2, s1] = 0.5
    sbp_H2N_wp05[s1, s1] = 1
    sbp_H2N_wp05[s0, s0] = 1

assert np.isclose(sbp_H2N_wp05.sum(axis=1), 1).all()

In [13]:
# abx withdrawn affects hr: normal->high wp 0.1
hr_N2H_wp01 = np.zeros((nS, nS))
for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['hr']]
):
    s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[0, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[1, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_0 = State(state_categs=[2, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[0, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[1, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_1 = State(state_categs=[2, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    hr_N2H_wp01[s2, s2] = 1
    hr_N2H_wp01[s1_1, s1_1] = 0.9
    hr_N2H_wp01[s1_1, s2_1] = 0.1
    hr_N2H_wp01[s1_0, s1_0] = 1
    hr_N2H_wp01[s0, s0] = 1

assert np.isclose(hr_N2H_wp01.sum(axis=1), 1).all()

In [14]:
# abx withdrawn affects sbp: normal->high wp 0.1
sbp_N2H_wp01 = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[hr, 0, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[hr, 1, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_0 = State(state_categs=[hr, 2, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[hr, 0, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[hr, 1, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_1 = State(state_categs=[hr, 2, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    sbp_N2H_wp01[s2, s2] = 1
    sbp_N2H_wp01[s1_1, s1_1] = 0.9
    sbp_N2H_wp01[s1_1, s2_1] = 0.1
    sbp_N2H_wp01[s1_0, s1_0] = 1
    sbp_N2H_wp01[s0, s0] = 1

assert np.isclose(sbp_N2H_wp01.sum(axis=1), 1).all()

In [15]:
# antibiotics on
# hr: hi -> normal w.p. 0.5
# sbp: hi -> normal w.p. 0.5
antibiotics_on_ = hr_H2N_wp05 @ sbp_H2N_wp05 @ abx_on
antibiotics_on = np.block([[antibiotics_on_, np.zeros((nS, nS))], [np.zeros((nS, nS)), antibiotics_on_]])
assert np.isclose(antibiotics_on.sum(axis=1), 1).all()

# antibiotics off
# if antibiotics was on
# hr: normal -> hi w.p. 0.1
# sbp: normal -> hi w.p. 0.1
antibiotics_off_ = hr_N2H_wp01 @ sbp_N2H_wp01 @ abx_off
antibiotics_off = np.block([[antibiotics_off_, np.zeros((nS, nS))], [np.zeros((nS, nS)), antibiotics_off_]])
assert np.isclose(antibiotics_off.sum(axis=1), 1).all()

### Ventilation

In [16]:
# vent indicator
vent_on = np.zeros((nS, nS))
vent_off = np.zeros((nS, nS))
for (hr, sbp, o2, glu, abx, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['vent']]
):
    s0 = State(state_categs=[hr, sbp, o2, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()
    vent_on[s0, s1] = 1
    vent_on[s1, s1] = 1
    vent_off[s1, s0] = 1
    vent_off[s0, s0] = 1

assert np.isclose(vent_on.sum(axis=1), 1).all()
assert np.isclose(vent_off.sum(axis=1), 1).all()

In [17]:
# vent affects o2: low->normal wp 0.7
o2_L2N_wp07 = np.zeros((nS, nS))
for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['o2']]
):
    s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    o2_L2N_wp07[s0, s0] = 0.3
    o2_L2N_wp07[s0, s1] = 0.7
    o2_L2N_wp07[s1, s1] = 1

assert np.isclose(o2_L2N_wp07.sum(axis=1), 1).all()

In [18]:
# vent withdrawn affects o2: normal->low wp 0.1
o2_N2L_wp01 = np.zeros((nS, nS))
for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['o2']]
):
    s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()
    o2_N2L_wp01[s0_0, s0_0] = 1
    o2_N2L_wp01[s0_1, s0_1] = 1
    o2_N2L_wp01[s1_0, s1_0] = 1
    o2_N2L_wp01[s1_1, s0_1] = 0.1
    o2_N2L_wp01[s1_1, s1_1] = 0.9

assert np.isclose(o2_N2L_wp01.sum(axis=1), 1).all()

In [19]:
ventilation_on_ = o2_L2N_wp07 @ vent_on
ventilation_off_ = o2_N2L_wp01 @ vent_off
ventilation_on = np.block([[ventilation_on_, np.zeros((nS, nS))], [np.zeros((nS, nS)), ventilation_on_]])
ventilation_off = np.block([[ventilation_off_, np.zeros((nS, nS))], [np.zeros((nS, nS)), ventilation_off_]])
assert np.isclose(ventilation_on.sum(axis=1), 1).all()
assert np.isclose(ventilation_off.sum(axis=1), 1).all()

### Vasopressor

In [20]:
# vaso indicator
vaso_on = np.zeros((nS, nS))
vaso_off = np.zeros((nS, nS))
for (hr, sbp, o2, glu, abx, vent, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['vaso']]
):
    s0 = State(state_categs=[hr, sbp, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    vaso_on[s0, s1] = 1
    vaso_on[s1, s1] = 1
    vaso_off[s1, s0] = 1
    vaso_off[s0, s0] = 1

assert np.isclose(vaso_on.sum(axis=1), 1).all()
assert np.isclose(vaso_off.sum(axis=1), 1).all()

In [21]:
# vaso affects sbp (non-diabetic)
# low->normal wp 0.7, normal->high wp 0.7
sbp_L2N_N2H_wp07 = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    sbp_L2N_N2H_wp07[s0, s0] = 0.3
    sbp_L2N_N2H_wp07[s0, s1] = 0.7
    sbp_L2N_N2H_wp07[s1, s1] = 0.3
    sbp_L2N_N2H_wp07[s1, s2] = 0.7
    sbp_L2N_N2H_wp07[s2, s2] = 1

assert np.isclose(sbp_L2N_N2H_wp07.sum(axis=1), 1).all()

In [22]:
# vaso affects sbp (diabetic)
# low->normal wp 0.5, low->high wp 0.4, normal->high wp 0.9
sbp_L2N2H = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    sbp_L2N2H[s0, s0] = 0.1
    sbp_L2N2H[s0, s1] = 0.5
    sbp_L2N2H[s0, s2] = 0.4
    sbp_L2N2H[s1, s1] = 0.1
    sbp_L2N2H[s1, s2] = 0.9
    sbp_L2N2H[s2, s2] = 1

assert np.isclose(sbp_L2N2H.sum(axis=1), 1).all()

In [23]:
# vaso affects glu (diabetic)
# LL->L, L->N, N->H, H->HH wp 0.5
glu_raise_by_1 = np.zeros((nS, nS))
for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['glu']]
):
    s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    glu_raise_by_1[s0, s0] = 0.5
    glu_raise_by_1[s0, s1] = 0.5
    glu_raise_by_1[s1, s1] = 0.5
    glu_raise_by_1[s1, s2] = 0.5
    glu_raise_by_1[s2, s2] = 0.5
    glu_raise_by_1[s2, s3] = 0.5
    glu_raise_by_1[s3, s3] = 0.5
    glu_raise_by_1[s3, s4] = 0.5
    glu_raise_by_1[s4, s4] = 1

assert np.isclose(glu_raise_by_1.sum(axis=1), 1).all()

In [24]:
vasopressor_on = np.block([
    [sbp_L2N_N2H_wp07 @ vaso_on, np.zeros((nS, nS))],
    [np.zeros((nS, nS)), sbp_L2N2H @ glu_raise_by_1 @ vaso_on]
])
assert np.isclose(vasopressor_on.sum(axis=1), 1).all()

In [25]:
# vaso withdrawn affects sbp (non-diabetic)
# N->L, H->N wp 0.1
sbp_H2N2L_wp01 = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[hr, 0, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[hr, 1, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s2_0 = State(state_categs=[hr, 2, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[hr, 0, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[hr, 1, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    s2_1 = State(state_categs=[hr, 2, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    sbp_H2N2L_wp01[s0_0, s0_0] = 1
    sbp_H2N2L_wp01[s0_1, s0_1] = 1
    sbp_H2N2L_wp01[s1_0, s1_0] = 1
    sbp_H2N2L_wp01[s1_1, s1_1] = 0.9
    sbp_H2N2L_wp01[s1_1, s0_1] = 0.1
    sbp_H2N2L_wp01[s2_0, s2_0] = 1
    sbp_H2N2L_wp01[s2_1, s2_1] = 0.9
    sbp_H2N2L_wp01[s2_1, s1_1] = 0.1

assert np.isclose(sbp_H2N2L_wp01.sum(axis=1), 1).all()

In [26]:
# vaso withdrawn affects sbp (diabetic)
# N->L, H->N wp 0.05
sbp_H2N2L_wp005 = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[hr, 0, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[hr, 1, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s2_0 = State(state_categs=[hr, 2, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[hr, 0, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[hr, 1, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    s2_1 = State(state_categs=[hr, 2, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()
    sbp_H2N2L_wp005[s0_0, s0_0] = 1
    sbp_H2N2L_wp005[s0_1, s0_1] = 1
    sbp_H2N2L_wp005[s1_0, s1_0] = 1
    sbp_H2N2L_wp005[s1_1, s1_1] = 0.95
    sbp_H2N2L_wp005[s1_1, s0_1] = 0.05
    sbp_H2N2L_wp005[s2_0, s2_0] = 1
    sbp_H2N2L_wp005[s2_1, s2_1] = 0.95
    sbp_H2N2L_wp005[s2_1, s1_1] = 0.05

assert np.isclose(sbp_H2N2L_wp005.sum(axis=1), 1).all()

In [27]:
vasopressor_off = np.block([
    [sbp_H2N2L_wp01 @ vaso_off, np.zeros((nS, nS))],
    [np.zeros((nS, nS)), sbp_H2N2L_wp005 @ vaso_off],
])
assert np.isclose(vasopressor_off.sum(axis=1), 1).all()

## Fluctuate
- all (non-treatment) states fluctuate +/- 1 w.p. .1
- exception: glucose flucuates +/- 1 w.p. .3 if diabetic

### hr

In [28]:
# (abx != 1) && (s_abx != 1)
hr_fluctuate = np.zeros((nS, nS))
for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['hr']]
):
    s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[0, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[1, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_0 = State(state_categs=[2, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[0, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[1, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    s2_1 = State(state_categs=[2, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()
    hr_fluctuate[s0_0, s0_0] = 0.9
    hr_fluctuate[s0_0, s1_0] = 0.1
    hr_fluctuate[s1_0, s0_0] = 0.1
    hr_fluctuate[s1_0, s1_0] = 0.8
    hr_fluctuate[s1_0, s2_0] = 0.1
    hr_fluctuate[s2_0, s1_0] = 0.1
    hr_fluctuate[s2_0, s2_0] = 0.9
    
    hr_fluctuate[s0_1, s0_1] = 1
    hr_fluctuate[s1_1, s1_1] = 1
    hr_fluctuate[s2_1, s2_1] = 1

assert np.isclose(hr_fluctuate.sum(axis=1), 1).all()
hr_fluctuate = np.block([[hr_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), hr_fluctuate]])

### sbp

In [29]:
# vaso withdrawn affects sbp (non-diabetic)
# N->L, H->N wp 0.1
sbp_fluctuate = np.zeros((nS, nS))
for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['sbp']]
):
    s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    
    s0_00 = State(state_categs=[hr, 0, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()
    s1_00 = State(state_categs=[hr, 1, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()
    s2_00 = State(state_categs=[hr, 2, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()
    s0_01 = State(state_categs=[hr, 0, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()
    s1_01 = State(state_categs=[hr, 1, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()
    s2_01 = State(state_categs=[hr, 2, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()
    
    s0_10 = State(state_categs=[hr, 0, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()
    s1_10 = State(state_categs=[hr, 1, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()
    s2_10 = State(state_categs=[hr, 2, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()
    s0_11 = State(state_categs=[hr, 0, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()
    s1_11 = State(state_categs=[hr, 1, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()
    s2_11 = State(state_categs=[hr, 2, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()
    
    sbp_fluctuate[s0_01, s0_01] = 1
    sbp_fluctuate[s1_01, s1_01] = 1
    sbp_fluctuate[s2_01, s2_01] = 1
    
    sbp_fluctuate[s0_10, s0_10] = 1
    sbp_fluctuate[s1_10, s1_10] = 1
    sbp_fluctuate[s2_10, s2_10] = 1
    
    sbp_fluctuate[s0_11, s0_11] = 1
    sbp_fluctuate[s1_11, s1_11] = 1
    sbp_fluctuate[s2_11, s2_11] = 1
    
    sbp_fluctuate[s0_00, s0_00] = 0.9
    sbp_fluctuate[s0_00, s1_00] = 0.1
    sbp_fluctuate[s1_00, s0_00] = 0.1
    sbp_fluctuate[s1_00, s1_00] = 0.8
    sbp_fluctuate[s1_00, s2_00] = 0.1
    sbp_fluctuate[s2_00, s1_00] = 0.1
    sbp_fluctuate[s2_00, s2_00] = 0.9

assert np.isclose(sbp_fluctuate.sum(axis=1), 1).all()
sbp_fluctuate = np.block([[sbp_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), sbp_fluctuate]])

### o2

In [30]:
# vent withdrawn affects o2: normal->low wp 0.1
o2_fluctuate = np.zeros((nS, nS))
for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['o2']]
):
    s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s0_0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()
    s1_0 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()
    s0_1 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()
    s1_1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()
    o2_fluctuate[s0_0, s0_0] = 0.9
    o2_fluctuate[s0_0, s1_0] = 0.1
    o2_fluctuate[s1_0, s0_0] = 0.1
    o2_fluctuate[s1_0, s1_0] = 0.9
    
    o2_fluctuate[s0_1, s0_1] = 1
    o2_fluctuate[s1_1, s1_1] = 1

assert np.isclose(o2_fluctuate.sum(axis=1), 1).all()
o2_fluctuate = np.block([[o2_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), o2_fluctuate]])

### glu

In [31]:
# non-diabetic wp 0.1
glu_fluctuate_01 = np.zeros((nS, nS))
for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['glu']]
):
    s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    glu_fluctuate_01[s0, s0] = 0.9
    glu_fluctuate_01[s0, s1] = 0.1
    glu_fluctuate_01[s1, s0] = 0.1
    glu_fluctuate_01[s1, s1] = 0.8
    glu_fluctuate_01[s1, s2] = 0.1
    glu_fluctuate_01[s2, s1] = 0.1
    glu_fluctuate_01[s2, s2] = 0.8
    glu_fluctuate_01[s2, s3] = 0.1
    glu_fluctuate_01[s3, s2] = 0.1
    glu_fluctuate_01[s3, s3] = 0.8
    glu_fluctuate_01[s3, s4] = 0.1
    glu_fluctuate_01[s4, s3] = 0.1
    glu_fluctuate_01[s4, s4] = 0.9

assert np.isclose(glu_fluctuate_01.sum(axis=1), 1).all()

In [32]:
# diabetic wp 0.3
glu_fluctuate_03 = np.zeros((nS, nS))
for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['glu']]
):
    s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()
    glu_fluctuate_03[s0, s0] = 0.7
    glu_fluctuate_03[s0, s1] = 0.3
    glu_fluctuate_03[s1, s0] = 0.3
    glu_fluctuate_03[s1, s1] = 0.4
    glu_fluctuate_03[s1, s2] = 0.3
    glu_fluctuate_03[s2, s1] = 0.3
    glu_fluctuate_03[s2, s2] = 0.4
    glu_fluctuate_03[s2, s3] = 0.3
    glu_fluctuate_03[s3, s2] = 0.3
    glu_fluctuate_03[s3, s3] = 0.4
    glu_fluctuate_03[s3, s4] = 0.3
    glu_fluctuate_03[s4, s3] = 0.3
    glu_fluctuate_03[s4, s4] = 0.7

assert np.isclose(glu_fluctuate_03.sum(axis=1), 1).all()

In [33]:
glu_fluctuate = np.block([
    [glu_fluctuate_01, np.zeros((nS, nS))],
    [np.zeros((nS, nS)), glu_fluctuate_03],
])
assert np.isclose(glu_fluctuate.sum(axis=1), 1).all()

## Assemble Transition Matrix (A,S,S)

In [34]:
# abx, vaso, vent
transition_000 = (antibiotics_off @ ventilation_off @ vasopressor_off).T @ (hr_fluctuate @ sbp_fluctuate @ o2_fluctuate @ glu_fluctuate).T

transition_100 = (antibiotics_on @ ventilation_off @ vasopressor_off).T @ (o2_fluctuate @ glu_fluctuate).T
transition_010 = (antibiotics_off @ ventilation_on @ vasopressor_off).T @ (hr_fluctuate @ sbp_fluctuate @ glu_fluctuate).T
transition_001 = (antibiotics_off @ ventilation_off @ vasopressor_on).T @ (hr_fluctuate @ o2_fluctuate).T

transition_110 = (antibiotics_on @ ventilation_on @ vasopressor_off).T @ (glu_fluctuate).T
transition_011 = (antibiotics_off @ ventilation_on @ vasopressor_on).T @ (hr_fluctuate).T
transition_101 = (antibiotics_on @ ventilation_off @ vasopressor_on).T @ (o2_fluctuate).T

transition_111 = (antibiotics_on @ ventilation_on @ vasopressor_on).T

In [35]:
transition_matrix = np.array([
    transition_000.T,
    transition_001.T,
    transition_010.T,
    transition_011.T,
    transition_100.T,
    transition_101.T,
    transition_110.T,
    transition_111.T,
])

In [36]:
assert np.isclose(transition_matrix.sum(axis=2), 1).all()

In [37]:
transition_matrix

array([[[0.6561 , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.729  , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.729  , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        ...,
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ]],

       [[0.     , 0.     , 0.243  , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.27   , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.243  , ..., 0.     , 0.     , 0.     ],
        ...,
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ]],

       [[0.     , 0.2187 , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.2187 , 0.     , ..., 0.     , 0.     , 0.     ],
    

In [38]:
transition_matrix.shape

(8, 1440, 1440)

In [39]:
transition_matrix_absorbing = np.zeros((nA, nS*2+2, nS*2+2))
transition_matrix_absorbing[:, :nS*2, :nS*2] = transition_matrix
transition_matrix_absorbing[:, -2, -2] = 1
transition_matrix_absorbing[:, -1, -1] = 1
for s in range(nS):
    if reward_per_state[s] == -1:
        transition_matrix_absorbing[:, s, :] = 0
        transition_matrix_absorbing[:, s, -2] = 1
        transition_matrix_absorbing[:, s+nS, :] = 0
        transition_matrix_absorbing[:, s+nS, -2] = 1
    elif reward_per_state[s] == 1:
        transition_matrix_absorbing[:, s, :] = 0
        transition_matrix_absorbing[:, s, -1] = 1
        transition_matrix_absorbing[:, s+nS, :] = 0
        transition_matrix_absorbing[:, s+nS, -1] = 1

## Initial State Distribution (S,)

In [40]:
prior_initial_state = np.zeros(nS*2)

In [41]:
diab_prior = [0.8, 0.2]
hr_prior = [0.25, 0.5, 0.25]
sbp_prior = [0.25, 0.5, 0.25]
o2_prior = [0.2, 0.8]
glu_prior = [
    [0.05, 0.15, 0.6, 0.15, 0.05], # non-diabetic
    [0.01, 0.05, 0.15, 0.6, 0.19], # diabetic
]
abx, vent, vaso = (0,0,0)

In [42]:
for (hr, sbp, o2, glu, diab) in itertools.product(
    *[state_variable_values[key] for key in state_variables if key not in ['abx', 'vaso', 'vent']]
):
    s = State(state_categs=[hr, sbp, o2, glu, 0, 0, 0], diabetic_idx=diab).get_state_idx('full')
    prior_initial_state[s] = \
        diab_prior[diab] * hr_prior[hr] * sbp_prior[sbp] * o2_prior[o2] * glu_prior[diab][glu]

In [43]:
prior_initial_state.shape

(1440,)

In [44]:
prior_initial_state

array([0.0005, 0.    , 0.    , ..., 0.    , 0.    , 0.    ])

In [45]:
prior_initial_state_absorbing = np.zeros(nS*2+2)
prior_initial_state_absorbing[:nS*2] = prior_initial_state
prior_initial_state_absorbing[[*(reward_per_state != 0), *(reward_per_state != 0), True, True]] = 0 # do not start in an almost-terminal state
prior_initial_state_absorbing = prior_initial_state_absorbing / prior_initial_state_absorbing.sum() # renormalize

In [46]:
prior_initial_state_absorbing.shape

(1442,)

In [47]:
(prior_initial_state_absorbing > 0).sum()

74

In [48]:
np.where(reward_per_state == 1)

(array([376]),)

In [49]:
joblib.dump(prior_initial_state_absorbing, '../data/prior_initial_state_absorbing.joblib')

['../data/prior_initial_state_absorbing.joblib']

## Modified Initial State Distribution (S,)

In [50]:
modified_prior_initial_state = np.zeros(nS*2)

In [51]:
diab_prior = [0.8, 0.2]
hr_prior = [0.25, 0.5, 0.25]
sbp_prior = [0.25, 0.5, 0.25]
o2_prior = [0.2, 0.8]
glu_prior = [
    [0.05, 0.15, 0.6, 0.15, 0.05], # non-diabetic
    [0.01, 0.05, 0.15, 0.6, 0.19], # diabetic
]
abx_prior, vent_prior, vaso_prior = [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]

In [52]:
for (hr, sbp, o2, glu, abx, vaso, vent, diab) in itertools.product(
    *[state_variable_values[key] for key in state_variables]
):
    s = State(state_categs=[hr, sbp, o2, glu, abx, vaso, vent], diabetic_idx=diab).get_state_idx('full')
    modified_prior_initial_state[s] = \
        diab_prior[diab] * hr_prior[hr] * sbp_prior[sbp] * o2_prior[o2] * glu_prior[diab][glu] * abx_prior[abx] * vent_prior[vent] * vaso_prior[vaso]

In [53]:
modified_prior_initial_state.shape

(1440,)

In [54]:
modified_prior_initial_state

array([6.250e-05, 6.250e-05, 6.250e-05, ..., 2.375e-04, 2.375e-04,
       2.375e-04])

In [55]:
modified_prior_initial_state_absorbing = np.zeros(nS*2+2)
modified_prior_initial_state_absorbing[:nS*2] = modified_prior_initial_state
modified_prior_initial_state_absorbing[[*(reward_per_state != 0), *(reward_per_state != 0), True, True]] = 0 # do not start in an almost-terminal state
modified_prior_initial_state_absorbing = modified_prior_initial_state_absorbing / modified_prior_initial_state_absorbing.sum() # renormalize

In [56]:
modified_prior_initial_state_absorbing.shape

(1442,)

In [57]:
(modified_prior_initial_state_absorbing > 0).sum()

606

In [58]:
np.where(reward_per_state == 1)

(array([376]),)

In [59]:
joblib.dump(modified_prior_initial_state_absorbing, '../data/modified_prior_initial_state_absorbing.joblib')

['../data/modified_prior_initial_state_absorbing.joblib']

# Save

In [60]:
MDP_parameters = {
    'transition_matrix': transition_matrix,
    'transition_matrix_absorbing': transition_matrix_absorbing,
    'reward_per_state': reward_per_state,
    'reward_matrix_ASS': reward_matrix_ASS,
    'reward_matrix_absorbing_SA': reward_matrix_absorbing_SA,
    'reward_matrix_absorbing_ASS': reward_matrix_absorbing_ASS,
    'prior_initial_state': prior_initial_state,
    'prior_initial_state_absorbing': prior_initial_state_absorbing, 
}

In [61]:
joblib.dump(MDP_parameters, '../data/MDP_parameters.joblib')

['../data/MDP_parameters.joblib']