In [None]:
import time
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from os.path import join
from src.components.eval import eval_env
from src.components.train import train_env
from src.components.tune import tune_hyper_param
from src.simglucose.env import env_creator
from src.simglucose.env import register_simglucose_env
from src.files.utills import pickle_obj
from src.simglucose.rewards import no_negativityV2, tan_reward

ray.init(log_to_driver=False)

# Define Configs

In [None]:


total_workers = 10
num_envs_per_worker = 1
algo = "PPO"
env_name = "Simglucose-v0"
register_simglucose_env(env_name)
env_configs = dict(reward_fun=tan_reward, patient_type='adult')
config = (
    PPOConfig()
    .environment(env_name, env_config=env_configs)
    .training(gamma=0.995, num_sgd_iter=3, sgd_minibatch_size=200, clip_param=0.1, lr=1e-4, train_batch_size=1000,
              entropy_coeff=1e-6)
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .rollouts(num_rollout_workers=total_workers, num_envs_per_worker=num_envs_per_worker)
    .framework("torch")
    .training(
        model={"fcnet_hiddens": [8, 8], "vf_share_layers": False, "use_lstm": True,
               "lstm_cell_size": 8, "max_seq_len": 100,"lstm_use_prev_action":True})
    .evaluation(evaluation_num_workers=1)
)
log_dir = "tmp/pipeline_logs"

# Tune Hyper-parameters

The hyper-parameters to be optimized can be defined in the config using tune API. For example in this case *use_lstm* is a hyper-parameter with values *[True, False]*. The *tune_hyper_param*  searches for the optimal parameter values and returns the best config.

In [None]:
tune_results = tune_hyper_param(
    algo=algo,
    config=config,
    log_dir=log_dir,
    iterations=1,
    name="simglucose_tuning"
)

best_config = tune_results.get_best_result(metric="episode_reward_mean", mode="max").config

# Train RL Agent

The model is trained here using the best config from the tune step. The best training checkpoint is then chosen for evaluation

In [None]:
train_results = train_env(
    algo=algo,
    config=best_config,
    log_dir=log_dir,
    iterations=20000,
    stop_reward_mean=1000,
    name="simglucose_solver",
    checkpoint_frequency=5
)
best_checkpoint = train_results.get_best_result(metric="episode_reward_mean", mode="max").best_checkpoints[0]
best_checkpoint_path = best_checkpoint[0]._local_path