In [None]:
import json
import itertools
import os

import pandas as pd
from stable_baselines3 import DQN, PPO

from gym_sepsis.envs.sepsis_env import SepsisEnv
from environments.sepsis_env_wrapper import SepsisEnvWrapper
from policies.sb3_policy import SB3Policy
from utils.offline_dataset import OfflineRLDataset
from models.fnn_nuisance_model import FeedForwardNuisanceModel
from models.large_a_fnn_nuisance_model import LargeAFeedForwardNuisanceModel
from models.fnn_critic import FeedForwardCritic
from learners.robust_fqi_learner import RobustFQILearner
from learners.iterative_sieve_critic import IterativeSieveLearner

In [None]:
with open('configs/sepsis_config.json') as f:
    sepsis_config = json.load(f)

# Train PPO policy over large number of timesteps for behavior policy

In [None]:
base_env = SepsisEnv()
env = SepsisEnvWrapper(base_env=base_env, s_init_idx=0)

num_a = env.get_num_a()
state_dim = env.get_s_dim()

print(f'Num actions: {num_a}')
print(f'State dimension: {state_dim}')


In [None]:
# train a policy using PPO

ppo_kwargs = sepsis_config['ppo_model_kwargs']
ppo_model = PPO(
    'MlpPolicy', env,
    gamma=sepsis_config['gamma'],
    **ppo_kwargs
)
ppo_total_timesteps = sepsis_config['ppo_num_updates'] * ppo_kwargs['n_steps']
ppo_model.learn(total_timesteps=ppo_total_timesteps, progress_bar=True)

ppo_model.save(sepsis_config['ppo_model_path'])

# Train DQN over relatively small number of timesteps for evaluation policy

In [None]:
# train evaluation model with DQN over smaller number of iterations

dqn_model = DQN(
    'MlpPolicy', env,
    gamma=sepsis_config['gamma'],
    **sepsis_config['dqn_model_kwargs']
)
dqn_model.learn(total_timesteps=sepsis_config['dqn_total_timesteps'], progress_bar=True)

dqn_model.save(sepsis_config['dqn_model_path'])

# Build Offline Dataset using Behavioral (PPO) Policy

In [None]:
# build offline datasets

pi_b = SB3Policy(env, model=PPO.load(sepsis_config['ppo_model_path']))
pi_e = SB3Policy(env, model=DQN.load(sepsis_config['dqn_model_path']))
pi_e_name = sepsis_config['pi_e_name']


dataset = OfflineRLDataset()
burn_in = sepsis_config['dataset_burn_in']
num_sample = sepsis_config['dataset_num_samples']
thin = sepsis_config['dataset_thin']
dataset.sample_new_trajectory(
    env=env,
    pi=pi_b,
    burn_in=burn_in,
    num_sample=num_sample,
    thin=thin
)

test_dataset = OfflineRLDataset()
test_dataset.sample_new_trajectory(
    env=env,
    pi=pi_b,
    burn_in=burn_in,
    num_sample=num_sample,
    thin=thin
)

dataset.apply_eval_policy(pi_e_name, pi_e)
test_dataset.apply_eval_policy(pi_e_name, pi_e)

dataset.save_dataset(sepsis_config["train_dataset_path"])
test_dataset.save_dataset(sepsis_config["test_dataset_path"])
