In [1]:
import sys
sys.path.append('..')

from pathlib import Path
import pickle
import scipy.stats as stats
from tqdm import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import itertools
from collections import defaultdict

from paths import *
from constants import EPISODE_LEN, RANDOM_SEED, ALPHAS

from src.esrl.esrl import compute_SOMVPRSL, sample_Vs

%matplotlib inline

In [2]:
# Read data
data = {
    'train': pd.read_csv(SEPSIS/'train.tsv', sep='\t'),
    'test': pd.read_csv(SEPSIS/'test.tsv', sep='\t')
}

# Preprocessing
max_reward = max(data['train']['reward'].max(), data['test']['reward'].max())
for df in data.values():
    
    # Get reward in the range of 0 to 1
    df['reward'] = (max_reward-df['reward'])/max_reward
    
    # Find the next state
    df.sort_values(['icustayid', 'bloc'], ascending=True)
    df['next_state'] = df['state'].shift(-1)
    df.loc[df['bloc'] == EPISODE_LEN, 'next_state'] = -1
    df['next_state'] = df['next_state'].astype(int)
    
    df.rename({'bloc': 'timestep'}, axis=1, inplace=True)
    df.drop(['icustayid', 'died_in_hosp'], inplace=True, axis=1)

# Define actions and states
ACTIONS = list(range(0, 25, 1))
STATES = list(range(0, 100, 1))
TIMESTEPS = list(range(1, 11, 1))

assert len(set(ACTIONS) - (set(data['train'].action) | set(data['test'].action))) == 0,\
    "Found unseen action in train/test set."
assert len(set(STATES) - (set(data['train'].state) | set(data['test'].state))) == 0,\
    "Found unseen state in train/test set."
assert len(set(TIMESTEPS) - (set(data['train'].timestep) | set(data['test'].timestep))) == 0,\
    "Found unseen timestep in train/test set."

print(data['train'].shape, data['test'].shape)

(95040, 5) (23770, 5)


In [3]:
# ESPRL params

# Prior probs for sampling MDPs
PRIORS = {'m0':0,'lamb0':1e+3,'alpha0':5.01,'gamma0':1}

# Number of samples to draw
N_SAMPLES = 500

In [4]:
for alpha in tqdm(ALPHAS):
    # Train policy
    mv_smu = compute_SOMVPRSL(data['train'], A_space=ACTIONS, S_space=STATES, tau=max(TIMESTEPS),
                              priors=PRIORS, K=N_SAMPLES, alpha=alpha)
    pickle.dump(mv_smu, (SEPSIS/f'esprl-results/mv_smu_{alpha}'.replace('.', '_')).open('wb'))
    
    # Evaluate on datasets
    for df_name, df in data.items():
    
        v_sam, st_state_sam = sample_Vs(df, A_space=ACTIONS, S_space=STATES, tau=max(TIMESTEPS),
                                        priors=PRIORS, n_samples=N_SAMPLES*4, mu_st=mv_smu,
                                        random_starting_state=True, seed=RANDOM_SEED)
        value_estimate = pd.DataFrame(st_state_sam.astype(int), columns=['state'])
        value_estimate[f'alpha_{alpha}'] = v_sam
        value_estimate.to_csv(SEPSIS/f'esprl-results/{df_name}_value_estimate_{alpha}.tsv'.replace('.', '_'),
                               sep='\t', index=None)

100%|████████████████████████████████████████████████████████████████████████████████| 3/3 [2:02:40<00:00, 2453.44s/it]
