In [1]:
import warnings

import ray
from ray.rllib.utils import PolynomialSchedule

from src.components.train import train_env
from src.simglucose.env import register_simglucose_env

warnings.filterwarnings('ignore')
ray.init(log_to_driver=False)

  if (distutils.version.LooseVersion(tf.__version__) <
2023-05-20 18:18:57,529	INFO worker.py:1538 -- Started a local Ray instance.


0,1
Python version:,3.10.6
Ray version:,2.2.0


# Define Configs

In [2]:
from src.simglucose.rewards import tan_reward, uniform_reward_with_risk

schedule_timesteps = 2000000
pl_sch = PolynomialSchedule(schedule_timesteps=2000000, initial_p=1e-3, final_p=1e-4, framework="torch", power=1)
lr_schedule = list(map(lambda t: [t, pl_sch.value(t)], range(0, schedule_timesteps, 2000)))
entropy_pl_sch = PolynomialSchedule(schedule_timesteps=2000000, initial_p=1e-3, final_p=1e-6, framework="torch",
                                    power=3)
ent_schedule = list(map(lambda t: [t, entropy_pl_sch.value(t)], range(0, schedule_timesteps, 2000)))
total_workers = 10
num_envs_per_worker = 1

env_name = "Simglucose-v0"
register_simglucose_env(env_name)
env_configs = dict(reward_fun=uniform_reward_with_risk, patient_name='adult#004')

In [3]:
from ray.rllib.algorithms.ppo import PPOConfig

lstm_model = {"fcnet_hiddens": [32, 32, 32], "vf_share_layers": False, "use_lstm": True,
              "lstm_cell_size": 32, "max_seq_len": 100}
algo = "PPO"
config = (
    PPOConfig()
    .environment(env_name, env_config=env_configs)
    .training(gamma=0.996, num_sgd_iter=3, sgd_minibatch_size=400, clip_param=0.1, lr=1e-3,
              train_batch_size=4000,
              entropy_coeff=1e-3, entropy_coeff_schedule=ent_schedule, lr_schedule=lr_schedule, vf_clip_param=10000)
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .rollouts(num_rollout_workers=total_workers, num_envs_per_worker=num_envs_per_worker, enable_connectors=True, observation_filter='MeanStdFilter', batch_mode='complete_episodes')
    .framework("torch")
    .training(
        model=lstm_model)
    .evaluation(evaluation_num_workers=1)
)


In [4]:
# from ray.rllib.algorithms.dreamer import DreamerConfig
#
# config = (DreamerConfig().
#           training(gamma=0.996)
#
#           .resources(num_gpus=1, num_cpus_per_worker=1)
#           .rollouts(num_rollout_workers=total_workers, num_envs_per_worker=num_envs_per_worker)
#           .framework("torch")
#           .training(
#     model=lstm_model)
#           .evaluation(evaluation_num_workers=1)
#           )

In [5]:
log_dir = "tmp/pipeline_logs"

# Train RL Agent

The model is trained here using the best config from the tune step. The best training checkpoint is then chosen for evaluation

In [None]:
train_results = train_env(
    algo=algo,
    config=config,
    log_dir=log_dir,
    iterations=20000,
    stop_reward_mean=1000,
    name="simglucose_solver",
    checkpoint_frequency=5
)
best_checkpoint = train_results.get_best_result(metric="episode_reward_mean", mode="max").best_checkpoints[0]
best_checkpoint_path = best_checkpoint[0]._local_path



Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_Simglucose-v0_06a8c_00000,4083,"{'num_env_steps_sampled': 4083, 'num_env_steps_trained': 4083, 'num_agent_steps_sampled': 4083, 'num_agent_steps_trained': 4083}",{},2023-05-20_18-19-26,False,16.9419,{},2.12176,1.19943,0.28526,241,241,a15cf10c789845628eb6cabd75b42f14,hamza-Legion-5-15ACH6H,"{'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'cur_kl_coeff': 0.20000000000000007, 'cur_lr': 0.0010000000000000002, 'total_loss': 0.13348971574256818, 'policy_loss': 0.006307779659982771, 'vf_loss': 0.12796844815214475, 'vf_explained_var': 0.3390824536482493, 'kl': 0.002509225061900736, 'entropy': 1.28835479815801, 'entropy_coeff': 0.0010000000000000002}, 'model': {}, 'num_grad_updates_lifetime': 15.5, 'diff_num_grad_updates_vs_sampler_policy': 14.5}}, 'num_env_steps_sampled': 4083, 'num_env_steps_trained': 4083, 'num_agent_steps_sampled': 4083, 'num_agent_steps_trained': 4083}",1,192.168.0.185,4083,4083,4083,4083,4083,4083,0,10,0,0,4083,"{'cpu_util_percent': 64.35882352941178, 'ram_util_percent': 88.56470588235294}",148076,{},{},{},"{'mean_raw_obs_processing_ms': 1.765510754995826, 'mean_inference_ms': 1.5396314379706169, 'mean_action_processing_ms': 0.193541997606311, 'mean_env_wait_ms': 20.45460580531374, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 2.121761197872997, 'episode_reward_min': 0.28526041490423226, 'episode_reward_mean': 1.1994283189441668, 'episode_len_mean': 16.941908713692946, 'episode_media': {}, 'episodes_this_iter': 241, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [1.1741908309575861, 0.9552588631412302, 1.17318428527117, 0.5706309526268465, 2.029138092285152, 1.0288239223642648, 0.28526041490423226, 1.3350373213846123, 1.1388257390405185, 1.9185263480208132, 1.1294051198752733, 1.1464528109216399, 1.1261667244372564, 1.2624735641780171, 0.7760345646965001, 0.9955120638030092, 1.131488924721655, 1.2173219127776345, 1.1948084091608633, 1.0750744934495664, 1.4418459340256358, 1.3180885325583285, 1.2723730326621179, 0.5231667618841931, 1.186095517490072, 1.144982574899019, 1.1569992954027346, 1.2357566752025675, 0.8593451330771202, 1.1351594392824278, 1.2182389620670593, 1.2052194251378803, 1.123057897580513, 1.2975822448302488, 1.2466449924031322, 1.1896734179158144, 1.091957084519267, 1.293905747934952, 1.334772233997966, 1.7461239381099112, 0.997309261927987, 1.5919556920883111, 1.3237225591054116, 1.2196429714773274, 0.9971827734421893, 1.1568526165787603, 1.4937335715687827, 1.0988662235840592, 1.2466971996453005, 1.0285787060503533, 1.345891417706436, 1.0226161960282323, 1.3036625872955143, 1.427091980884922, 1.188057870521537, 1.577457089725934, 1.4162245619556835, 0.9267233699378592, 1.287576876270868, 1.1019397976935486, 1.2249938699700118, 2.121761197872997, 1.2288324971465339, 1.2246106789887747, 1.2108307006175905, 1.280536783025668, 0.9944633339774639, 1.098469961798034, 1.2004442984314274, 0.9013071504386119, 0.9929687026166253, 0.8802713839140954, 1.2373464024707375, 1.2443027878713147, 1.2197501708399376, 1.1672444770926482, 1.1705599552994266, 0.5183953889473666, 1.2330512516311924, 1.3857389453695594, 1.1588460881090856, 1.196519262072873, 1.3443602195362117, 1.3518623399742162, 1.1567087568872882, 1.1166744343727741, 1.1421651868423566, 1.0687393874483524, 1.4025592342823947, 1.0889075674573647, 1.21096486488215, 1.004961692243685, 1.2768690734366896, 1.3231255075915986, 1.3008718896571343, 1.2087410972524122, 1.2165563045409133, 1.2405724488297136, 1.0433528242008048, 1.0292617677362592, 1.4847936848783443, 1.3586792258915152, 1.0943811584763488, 1.0198124036796494, 1.4040735079632038, 1.2460725346223938, 1.1388490917162046, 1.9033377738758048, 1.1326419756547008, 1.3182319453373903, 1.3686589827367244, 1.0858991165219962, 1.1839678641610811, 1.1778736115601514, 1.3911553888617694, 0.925679621367488, 1.2045514844185183, 0.6950081324109619, 1.2896725424239974, 1.2162238238414151, 1.1130326258059349, 1.4665039060757599, 1.5107867086832156, 0.9914190607066584, 1.4105063262810147, 1.0949243233340298, 1.066274779127319, 1.2570077860883089, 0.8791408367905108, 1.4506421200603006, 0.8355160048862038, 1.211829065411082, 1.3301487604702225, 1.2183759243120547, 1.1655471940444957, 0.9790180031506471, 1.4712144857644915, 1.2994977582121094, 1.1094771675817514, 1.1408017113097326, 1.007397417617331, 1.3837178186473433, 1.1886381448168841, 0.8670602646036119, 1.0881946458246483, 1.1624679040688308, 1.203546263339217, 1.1317490730637108, 1.0647170562940156, 0.9777582349459334, 1.2412149101236485, 1.0556071786066719, 1.1284630873791652, 1.0099582782713117, 1.1402523758461123, 1.4595385081647716, 1.2926937043763511, 0.9932518986935278, 1.1785745153296667, 0.8826128375438732, 1.5761406012957175, 1.1809551818958022, 1.3997439359902781, 1.2533515010058223, 1.2114319867156267, 1.1740421710332287, 1.0646843857944441, 1.1955753313220858, 1.1582826575792167, 1.4468646467050763, 1.0020223888364117, 1.226666734257881, 1.3482859808492254, 1.0181707969784373, 0.9777159821829783, 1.127547810339693, 1.3525806042329536, 1.4246551810581563, 1.0544734688043185, 1.1739009959301898, 1.2272840581309603, 1.1982718564917556, 1.0619183224793878, 1.2710356488371335, 1.3212603762334123, 0.9619208448828266, 1.0987432148109038, 1.1709930635824481, 1.2817240487504242, 0.4159930950146319, 1.3792170417643284, 0.7688732620720624, 1.1190389243703283, 1.003196906579913, 1.4935124297979805, 1.0774511420067079, 0.9405677168142115, 1.250237385610897, 1.177403768888536, 1.15983800018185, 1.4318819611414493, 1.1452508726032673, 0.9229938643224408, 1.3307590883709455, 1.3923474718622006, 1.517812355967812, 1.513004180842246, 1.1372798698345676, 1.0936747682812586, 1.6588860453858258, 1.0334364959847364, 1.2316399378146445, 1.5822780874411784, 1.5411616245065656, 1.2563011231713932, 1.1617350698805975, 1.0703179849622475, 1.0601812931606525, 1.4154835269740331, 1.1284234920533456, 1.0381009097869336, 1.1409184152793455, 1.1370062064710866, 1.2054589415594856, 1.0115703800821885, 1.3704363857388995, 1.2310971390485816, 1.5587349762087794, 1.3546492411794318, 1.2028718256395103, 1.2300635570268947, 1.3019332975230336, 1.385120269732615, 1.3059249139238882, 1.1457977637462449, 1.2609684423402805, 1.5038330377977125, 1.124576483714786, 1.7478390522730578, 1.2562322502964334, 1.2251388672910537], 'episode_lengths': [16, 16, 16, 16, 29, 16, 16, 16, 18, 27, 17, 15, 16, 17, 18, 15, 15, 16, 16, 16, 16, 18, 15, 15, 17, 16, 17, 16, 15, 16, 19, 18, 16, 18, 16, 16, 15, 17, 17, 29, 15, 22, 19, 16, 15, 16, 20, 16, 17, 17, 17, 17, 16, 18, 18, 19, 17, 16, 18, 17, 16, 29, 15, 17, 16, 17, 15, 17, 17, 16, 16, 15, 16, 16, 18, 16, 15, 15, 17, 18, 18, 18, 16, 18, 16, 15, 16, 16, 16, 15, 19, 15, 17, 17, 18, 16, 17, 16, 15, 16, 16, 17, 16, 15, 17, 17, 23, 25, 17, 16, 18, 17, 17, 15, 19, 17, 17, 18, 17, 17, 16, 17, 19, 16, 19, 16, 15, 17, 28, 20, 16, 17, 17, 16, 16, 17, 17, 16, 16, 17, 16, 18, 18, 16, 16, 16, 15, 16, 15, 16, 16, 15, 17, 16, 16, 16, 16, 16, 15, 25, 21, 18, 16, 15, 16, 17, 15, 17, 16, 20, 15, 16, 17, 15, 15, 17, 16, 17, 17, 17, 17, 18, 16, 17, 17, 15, 15, 15, 17, 17, 18, 15, 17, 16, 17, 16, 15, 16, 18, 16, 16, 16, 15, 17, 18, 18, 20, 17, 17, 20, 16, 16, 17, 17, 17, 17, 15, 16, 17, 15, 16, 17, 15, 17, 15, 17, 17, 19, 18, 17, 17, 17, 15, 19, 15, 17, 18, 16, 20, 16, 16]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 1.765510754995826, 'mean_inference_ms': 1.5396314379706169, 'mean_action_processing_ms': 0.193541997606311, 'mean_env_wait_ms': 20.45460580531374, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",11.4144,11.4144,11.4144,"{'training_iteration_time_ms': 11394.644, 'load_time_ms': 15.98, 'load_throughput': 255506.799, 'learn_time_ms': 306.226, 'learn_throughput': 13333.289, 'synch_weights_time_ms': 3.629}",1684599566,0,4083,1,06a8c_00000,11.8756
