In [None]:
import json
import itertools
import os

import pandas as pd
from stable_baselines3 import DQN, PPO

from gym_sepsis.envs.sepsis_env import SepsisEnv
from environments.sepsis_env_wrapper import SepsisEnvWrapper
from policies.sb3_policy import SB3Policy
from utils.offline_dataset import OfflineRLDataset
from models.large_a_fnn_nuisance_model import LargeAFeedForwardNuisanceModel
from models.fnn_critic import FeedForwardCritic
from learners.robust_fqi_learner import RobustFQILearner
from learners.iterative_sieve_critic import IterativeSieveLearner

In [None]:
with open('configs/sepsis_config.json') as f:
    sepsis_config = json.load(f)

In [None]:
def single_evaluation_run(config, rep_i, adversarial_lambda, device=None):
    print(f'doing evaluation for lambda={adversarial_lambda}, rep={rep_i}')

    base_env = SepsisEnv()
    env = SepsisEnvWrapper(base_env=base_env, s_init_idx=0)

    s_dim = env.get_s_dim()
    num_a = env.get_num_a()
    gamma = config['gamma']
    default_batch_size = config['default_batch_size']

    pi_e = SB3Policy(env, model=DQN.load(config['dqn_model_path']))
    pi_e_name = config['pi_e_name']

    s_init, a_init = env.get_s_a_init(pi_e)
    if device is not None:
        s_init = s_init.to(device)
        a_init = a_init.to(device)

    model_config = config['model_config']
    model = LargeAFeedForwardNuisanceModel(
        s_dim=s_dim,
        num_a=num_a,
        gamma=gamma,
        config=model_config
    )
    critic_class = FeedForwardCritic
    critic_config = config['critic_config']
    critic_kwargs = {
        's_dim': s_dim,
        'num_a': num_a,
        'config': critic_config
    }

    train_dataset = OfflineRLDataset.load_dataset(config['train_dataset_path'])
    test_dataset = OfflineRLDataset.load_dataset(config['test_dataset_path'])
    if device is not None:
        train_dataset.to(device)
        test_dataset.to(device)

    # first train q/beta
    print('  -- training q')
    q_learner = RobustFQILearner(
        nuisance_model=model, gamma=gamma, use_dual_cvar=True,
        adversarial_lambda=adversarial_lambda,
    )
    q_learner_kwargs = config['q_learner_kwargs']
    q_learner.train(
        dataset=train_dataset, pi_e_name=pi_e_name, verbose=False,
        device=device, **q_learner_kwargs,
    )
    model.freeze_embeds()

    # second train eta
    print('  -- training eta')
    eta_learner = IterativeSieveLearner(
        nuisance_model=model, gamma=gamma, use_dual_cvar=True,
        adversarial_lambda=adversarial_lambda,
        train_q_beta=False, train_eta=True, train_w=False, debug_beta=False,
    )
    eta_learner_kwargs = config['eta_learner_kwargs']
    eta_learner.train(
        dataset=train_dataset, pi_e_name=pi_e_name, verbose=False,
        device=device, init_basis_func=env.bias_basis_func,
        num_init_basis=1, critic_class=critic_class, s_init=s_init,
        critic_kwargs=critic_kwargs, **eta_learner_kwargs,
    )

    # third train w
    print('  -- training w')
    w_learner = IterativeSieveLearner(
        nuisance_model=model, gamma=gamma, use_dual_cvar=True,
        adversarial_lambda=adversarial_lambda,
        train_q_beta=False, train_eta=False, train_w=True, debug_beta=False,
    )
    w_learner_kwargs = config['w_learner_kwargs']
    w_learner.train(
        dataset=train_dataset, pi_e_name=pi_e_name, verbose=False,
        device=device, init_basis_func=env.bias_basis_func,
        num_init_basis=1, critic_class=critic_class, s_init=s_init,
        critic_kwargs=critic_kwargs, **w_learner_kwargs,
    )

    model_path_base = config['base_model_path']
    model_name = 'sepsis_model'
    model_name += f'_lambda={adversarial_lambda}'
    model_name += f'_rep={rep_i}'
    model_path = os.path.join(model_path_base, model_name)
    model.save_model(model_path)

    ## evaluate model using 3 policy value estimators

    dl_test = test_dataset.get_batch_loader(batch_size=default_batch_size)

    q_pv = model.estimate_policy_val_q(
        s_init=s_init, a_init=a_init, gamma=gamma
    )
    w_pv = model.estimate_policy_val_w(
        dl=dl_test, pi_e_name=pi_e_name,
    )
    w_pv_norm = model.estimate_policy_val_w(
        dl=dl_test, pi_e_name=pi_e_name, normalize=True,
    )
    dr_pv = model.estimate_policy_val_dr(
        s_init=s_init, a_init=a_init, pi_e_name=pi_e_name, dl=dl_test,
        adversarial_lambda=adversarial_lambda, gamma=gamma, dual_cvar=True,
    )
    dr_pv_norm = model.estimate_policy_val_dr(
        s_init=s_init, a_init=a_init, pi_e_name=pi_e_name, dl=dl_test,
        adversarial_lambda=adversarial_lambda, gamma=gamma, dual_cvar=True,
        normalize=True,
    )
    pv_results = {
        'q': q_pv, 'w': w_pv, 'w_norm': w_pv_norm,
        'dr': dr_pv, 'dr_norm': dr_pv_norm, 
    }
    results = []
    for key, val in pv_results.items():
        row = {
            'rep_i': rep_i,
            'lambda': adversarial_lambda,
            'est_policy_value': val,
            'estimator': key,
        }
        results.append(row)
    return results


In [None]:
# iterate over values of Lambda and repetition index, and run experiment
restart_range = sepsis_config['num_restart_range']
lambda_range = sepsis_config['adversarial_lambda_values']
results_list = []

for rep_i, adversarial_lambda in itertools.product(restart_range, lambda_range):
    next_results = single_evaluation_run(
        config=sepsis_config,
        rep_i=rep_i,
        adversarial_lambda=adversarial_lambda,
        device=sepsis_config['device'],
    )
    results_list.extend(next_results)
    results_df = pd.DataFrame(results_list)
    display(results_df)
    results_df.to_csv('sepsis_results.csv')

In [None]:
# now, build plot of results

for q in (0.2, 0.5, 0.8):

    keep_rows = (results_df['est_policy_value'] > 0) \
                & (results_df['est_policy_value'] < 1)
    plot_df = results_df[keep_rows]\
                    .set_index(['rep_i', 'lambda', 'estimator'])\
                    .groupby(['lambda', 'estimator'])\
                    .quantile(q)\
                    .reset_index()\
                    .pivot(index='lambda', values='est_policy_value', columns='estimator')\
                    .loc[:, ['q', 'w', 'dr']]

    print(f'{q} quantile')
    display(plot_df)
    print('')