### Training RL Policies using L5Kit Closed-Loop Environment

This notebook describes how to train RL policies for self-driving using our gym-compatible closed-loop environment.

We will be using [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) algorithm as our reinforcement learning algorithm, as it not only demonstrates remarkable performance but it is also empirically easy to tune.

The PPO implementation in this notebook is based on [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) framework, a popular framework for training RL policies. Note that our environment is also compatible with [RLlib](https://docs.ray.io/en/latest/rllib.html), another popular frameworks for the same.

ref: 
([rllib] Best workflow to train, save, and test agent #9123
)[https://github.com/ray-project/ray/issues/9123]

In [1]:
import os
os.environ["L5KIT_DATA_FOLDER"] = '/DATA/l5kit/l5kit-dataset'
# os.environ["TUNE_RESULT_DIR"] =  '/DATA/l5kit/rllib_tb_logs'

In [2]:
import gym

# from stable_baselines3 import PPO
# from stable_baselines3.common.callbacks import CheckpointCallback
# from stable_baselines3.common.env_util import make_vec_env
# from stable_baselines3.common.utils import get_linear_fn
# from stable_baselines3.common.vec_env import SubprocVecEnv

from l5kit.configs import load_config_data
# from l5kit.environment.feature_extractor import CustomFeatureExtractor
# from l5kit.environment.callbacks import L5KitEvalCallback
from l5kit.environment.envs.l5_env import SimulationConfigGym, GymStepOutput, L5Env

from l5kit.visualization.visualizer.zarr_utils import episode_out_to_visualizer_scene_gym_cle
from l5kit.visualization.visualizer.visualizer import visualize
from bokeh.io import output_notebook, show
from l5kit.environment.gym_metric_set import L2DisplacementYawMetricSet, CLEMetricSet
from prettytable import PrettyTable
import datetime
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import torch.nn as nn
import numpy as np
import gym
from typing import Dict
import numpy as np
import ray
import pytz


## Init ray and env

In [4]:
# Dataset is assumed to be on the folder specified
# in the L5KIT_DATA_FOLDER environment variable
from l5kit.configs import load_config_data

# get environment config
env_config_path = '/DATA/l5kit/configs/gym_config84.yaml'
cfg = load_config_data(env_config_path)


In [5]:
newYorkTz = pytz.timezone("Asia/Ho_Chi_Minh") 
date = datetime.datetime.now(newYorkTz).strftime("%d-%m-%Y_%H-%M-%S")
ray_result_logdir = '/DATA/l5kit/ray_results/' + date

ray.init(num_cpus=64, ignore_reinit_error=True, log_to_driver=False)

2023-01-01 08:30:23,356	INFO worker.py:1538 -- Started a local Ray instance.


0,1
Python version:,3.8.10
Ray version:,2.2.0


## Customize my model
SAC: https://github.com/ray-project/ray/blob/dfb9689701361cfd18f383e0a3edeed6baf81abb/rllib/agents/sac/sac_torch_model.py

In [6]:
class GNCNN(SACTorchModel):
    """
    Simple Convolution agent that calculates the required linear output layer
    """

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        # raise ValueError(obs_space.shape)
        self._num_objects = obs_space.shape[2] # num_of_channels of input, size x size x channels
        self._num_actions = num_outputs
        self._feature_dim = model_config["custom_model_config"]['feature_dim']

        # linear_flatten = np.prod(obs_space.shape[:2])*64

        self.network = nn.Sequential(
            nn.Conv2d(self._num_objects, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            nn.GroupNorm(4, 64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            nn.GroupNorm(2, 32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(in_features=1568, out_features=self._feature_dim),
            # layer_init(nn.Conv2d(self._num_objects, 32, 3, padding=1)),
            # nn.ReLU(),
            # layer_init(nn.Conv2d(32, 64, 3, padding=1)),
            # nn.ReLU(),
            # nn.Flatten(),
            # layer_init(nn.Linear(linear_flatten, 1024)),
            # nn.ReLU(),
            # layer_init(nn.Linear(1024, 512)),
            # nn.ReLU(),
        )

        self._actor_head = nn.Sequential(
            # layer_init(nn.Linear(512, 256), std=0.01),
            # nn.ReLU(),
            # layer_init(nn.Linear(256, self._num_actions), std=0.01)
            nn.Linear(self._feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, self._num_actions),
        )

        self._critic_head = nn.Sequential(
            # layer_init(nn.Linear(512, 1), std=0.01)
            nn.Linear(self._feature_dim, 1),
        )

    def forward(self, input_dict, state, seq_lens):
        obs_transformed = input_dict['obs'].permute(0, 3, 1, 2) # 32 x 112 x 112 x 7 [B, size, size, channels]
        network_output = self.network(obs_transformed)
        value = self._critic_head(network_output)
        self._value = value.reshape(-1)
        logits = self._actor_head(network_output)
        return logits, state

    def value_function(self):
        return self._value
    def get_policy_output(self, , input_dict, state, seq_lens):
        return forward()
    def 

SyntaxError: invalid syntax (3054087535.py, line 63)

In [None]:
from ray.rllib.models import ModelCatalog
ModelCatalog.register_custom_model(
        "GN_CNN_torch_model", GNCNN
    )

## Define Training and Evaluation Environments

**Training**: We will be training the PPO policy on episodes of length 32 time-steps. We will have 4 sub-processes (training environments) that will help to parallelize and speeden up episode rollouts. The *SimConfig* dataclass will define the parameters of the episode rollout: like length of episode rollout, whether to use log-replayed agents or simulated agents etc.

**Evaluation**: We will evaluate the performance of the PPO policy on the *entire* scene (~248 time-steps).

## Customize gym env

In [6]:
class L5EnvWrapper(gym.Wrapper):
    def __init__(self, env, raster_size = 112, n_channels = 7):
        super().__init__(env)
        self.env = env
        self.n_channels = n_channels
        self.raster_size = raster_size
        obs_shape = (self.raster_size, self.raster_size, self.n_channels)
        self.observation_space =gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32)

    def step(self, action:  np.ndarray) -> GymStepOutput:
        # return GymStepOutput(obs, reward["total"], done, info)
        output =  self.env.step(action)
        onlyImageState = output.obs['image'].reshape(self.raster_size, self.raster_size, self.n_channels)
        return GymStepOutput(onlyImageState, output.reward, output.done, output.info)

    def reset(self) -> Dict[str, np.ndarray]:
        return self.env.reset()['image'].reshape(self.raster_size, self.raster_size, self.n_channels)

In [7]:
from ray import tune
train_eps_length = 32
train_sim_cfg = SimulationConfigGym()
train_sim_cfg.num_simulation_steps = train_eps_length + 1
# Register , how your env should be constructed (always with 5, or you can take values from the `config` EnvContext object):
env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, 'sim_cfg': train_sim_cfg}

tune.register_env("L5-CLE-V0", lambda config: L5Env(**env_kwargs))
tune.register_env("L5-CLE-V1", lambda config: L5EnvWrapper(env = L5Env(**env_kwargs), \
                                                           raster_size= cfg['raster_params']['raster_size'][0], \
                                                           n_channels = 7))

## Train

import numpy as np

import ray
from ray import air, tune
from ray.air import session
from ray.air.integrations.wandb import setup_wandb
from ray.air.integrations.wandb import WandbLoggerCallback
os.environ['WANDB_NOTEBOOK_NAME'] = '/DATA/rllib_ppo_policy_training.ipynb'!wandb login 083592c84134c040dcca598c644c348d32540a08
import wandb ## ref

Resume stop tune: https://docs.ray.io/en/latest/tune/tutorials/tune-stopping.html

tune.Tuner analysis: https://docs.ray.io/en/latest/rllib/rllib-training.html#basic-python-api

get best result, load from dir: https://docs.ray.io/en/master/tune/examples/tune_analyze_results.html#trial-level-analysis-working-with-an-individual-result

In [8]:
!wandb login 083592c84134c040dcca598c644c348d32540a08

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
import numpy as np

import ray
from ray import air, tune
from ray.air import session
from ray.air.integrations.wandb import setup_wandb
from ray.air.integrations.wandb import WandbLoggerCallback
os.environ['WANDB_NOTEBOOK_NAME'] = '/DATA/rllib_sac_policy_training.ipynb'

In [11]:
import wandb
wandb.init(project="l5kit2", reinit = True)

VBox(children=(Label(value='0.007 MB of 0.012 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.566350…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666868578334591, max=1.0)…

In [None]:
import ray
from ray import air, tune
train_envs = 4
ray_result_logdir = '/DATA/l5kit/ray_results/' + date

config_param_space = {
    "env": "L5-CLE-V1",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 63,
    "num_envs_per_worker": train_envs,
    'q_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'policy_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'tau': 0.005,
    'target_network_update_freq': 5,
    'replay_buffer_config':{
        'type': 'PrioritizedReplayBuffer',
        'capacity': int(4e5),
    },
    'num_steps_sampled_before_learning_starts': 1000,
    
    'target_entropy': 'auto',
#     "model": {
#         "custom_model": "GN_CNN_torch_model",
#         "custom_model_config": {'feature_dim':128},
#     },
    '_disable_preprocessor_api': True,
     "eager_tracing": True,
     "restart_failed_sub_environments": True,
    # 'train_batch_size': 4000,
    # 'sgd_minibatch_size': 256,
    # 'num_sgd_iter': 16,
    'store_buffer_in_checkpoints' : True,
    'seed': 42,
    'batch_mode': 'truncate_episodes',
    "rollout_fragment_length": 32,
    'gamma': 0.8,
}

result_grid = tune.Tuner(
    "SAC",
    run_config=air.RunConfig(
        stop={"episode_reward_mean": 0, 'timesteps_total': int(6e6)},
        local_dir=ray_result_logdir,
        checkpoint_config=air.CheckpointConfig(num_to_keep=2, checkpoint_frequency = 10, checkpoint_score_attribute = 'episode_reward_mean')
        ),
    param_space=config_param_space).fit()

In [None]:
import ray
from ray import air, tune
train_envs = 4

hcmTz = pytz.timezone("Asia/Ho_Chi_Minh") 
date = datetime.datetime.now(hcmTz).strftime("%d-%m-%Y_%H-%M-%S")
ray_result_logdir = '/DATA/l5kit/ray_results/' + date

lr = 3e-3
lr_start = 3e-4
lr_end = 3e-5
config_param_space = {
    "env": "L5-CLE-V1",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 63,
    "num_envs_per_worker": train_envs,
    'q_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'policy_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'tau': 0.005,
    'target_network_update_freq': 1,
    'replay_buffer_config':{
        'type': 'MultiAgentPrioritizedReplayBuffer',
        'capacity': int(1e5),
        "worker_side_prioritization": True,
    },
    'num_steps_sampled_before_learning_starts': 8000,
    
    'target_entropy': 'auto',
#     "model": {
#         "custom_model": "GN_CNN_torch_model",
#         "custom_model_config": {'feature_dim':128},
#     },
    '_disable_preprocessor_api': True,
     "eager_tracing": True,
     "restart_failed_sub_environments": True,
 
    # 'train_batch_size': 4000,
    # 'sgd_minibatch_size': 256,
    # 'num_sgd_iter': 16,
    # 'store_buffer_in_checkpoints' : False,
    'seed': 42,
    'batch_mode': 'truncate_episodes',
    "rollout_fragment_length": 1,
    'train_batch_size': 2048,
    'training_intensity' : 32, # (4x 'natural' value = 8)
    'gamma': 0.8,
    'twin_q' : True,
    "lr": 3e-4,
    "min_sample_timesteps_per_iteration": 8000,
}

result_grid = tune.Tuner(
    "SAC",
    run_config=air.RunConfig(
        stop={"episode_reward_mean": 0, 'timesteps_total': int(4e6)},
        local_dir=ray_result_logdir,
        checkpoint_config=air.CheckpointConfig(num_to_keep=2, checkpoint_frequency = 10, checkpoint_score_attribute = 'episode_reward_mean'),
        callbacks=[WandbLoggerCallback(project="l5kit2", 
						save_checkpoints=True),],
        ),
        
    param_space=config_param_space).fit()

2023-01-01 08:53:50,174	INFO wandb.py:250 -- Already logged into W&B.


0,1
Current time:,2023-01-01 09:44:56
Running for:,00:51:06.62
Memory:,231.2/503.2 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
SAC_L5-CLE-V1_cf7bb_00000,RUNNING,172.17.0.2:1654720,10,2813.26,80640,-266.552,-58.2638,-395.789,31


Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
SAC_L5-CLE-V1_cf7bb_00000,80640,"{'num_env_steps_sampled': 80640, 'num_env_steps_trained': 2367488, 'num_agent_steps_sampled': 80640, 'num_agent_steps_trained': 2367488, 'last_target_update_ts': 80640, 'num_target_updates': 289}",{},2023-01-01_09-41-41,False,31,{},-58.2638,-266.552,-395.789,252,2520,728aa68f10b24a71807ff882bf21f6b9,38d1f3f62321,"{'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'actor_loss': 8.812812805175781, 'critic_loss': 0.39401447772979736, 'alpha_loss': -1.7508076429367065, 'alpha_value': 0.7067752, 'log_alpha_value': -0.34704265, 'target_entropy': -3.0, 'policy_t': 0.022778328508138657, 'mean_q': -8.996729850769043, 'max_q': 0.1875818520784378, 'min_q': -19.785879135131836}, 'model': {}, 'num_grad_updates_lifetime': 1156.0, 'diff_num_grad_updates_vs_sampler_policy': 1155.0, 'td_error': array([6.3309903, 3.6437593, 4.47351 , ..., 4.5698147, 3.6771932,  1.9372492], dtype=float32), 'mean_td_error': 3.2561662197113037}}, 'num_env_steps_sampled': 80640, 'num_env_steps_trained': 2367488, 'num_agent_steps_sampled': 80640, 'num_agent_steps_trained': 2367488, 'last_target_update_ts': 80640, 'num_target_updates': 289}",10,172.17.0.2,80640,2367488,80640,8064,2367488,262144,0,63,0,0,262144,"{'cpu_util_percent': 21.818246445497632, 'ram_util_percent': 45.19928909952607}",1654720,{},{},{},"{'mean_raw_obs_processing_ms': 543.2517422664438, 'mean_inference_ms': 38.52429465642051, 'mean_action_processing_ms': 2.763415269503187, 'mean_env_wait_ms': 403.98261404789645, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -58.263843804597855, 'episode_reward_min': -395.78902223706245, 'episode_reward_mean': -266.55183226853194, 'episode_len_mean': 31.0, 'episode_media': {}, 'episodes_this_iter': 252, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-239.0102435052395, -315.46955117583275, -350.2809314131737, -295.3525523394346, -375.7609887123108, -125.42248171940446, -213.39028072357178, -321.204388320446, -165.30160097777843, -278.51806247234344, -368.7421975284815, -222.80448591709137, -281.59126146137714, -100.46899828314781, -360.35176506638527, -324.17164608836174, -282.2768241763115, -358.5813331156969, -318.5083861798048, -273.05490189790726, -374.263945132494, -345.19716399908066, -329.61641654372215, -91.25386509299278, -317.0819436609745, -141.67730418592691, -313.07272773236036, -283.61031448841095, -384.7334594428539, -182.61282911151648, -315.39646857976913, -281.41736520826817, -269.46059215068817, -291.2918329387903, -149.78548197448254, -381.0407643914223, -278.61631241440773, -318.82015654444695, -393.1304204761982, -346.84480476379395, -206.43473315238953, -271.129971280694, -324.2219986617565, -122.4961024671793, -149.06780755519867, -318.75388012453914, -372.715628772974, -168.17121502757072, -394.420886695385, -297.1803601384163, -378.12008449435234, -263.08119531720877, -138.7313081510365, -130.89524812996387, -331.40234012156725, -341.44916197657585, -356.0900799408555, -258.9323396682739, -329.97140258550644, -213.98328721523285, -76.17018977552652, -199.0896539390087, -116.36399615556002, -256.76849437877536, -195.1551467180252, -278.98062989115715, -373.14412117004395, -162.08038310706615, -267.6527770459652, -395.78902223706245, -280.88155114650726, -212.79353946447372, -299.28727918863297, -276.9038734138012, -157.7410768866539, -239.55959382653236, -128.08969803154469, -157.7236249744892, -83.28353877365589, -375.7882164120674, -123.53195330128074, -315.46474266052246, -282.3662536740303, -307.5886910408735, -248.6706149019301, -379.85153594613075, -149.24158400297165, -341.6842777132988, -243.37295603752136, -136.38045568019152, -139.9097843170166, -252.9941124022007, -372.57703068852425, -172.69577984511852, -201.23314698413014, -362.9781885743141, -252.24430656433105, -339.5008824765682, -386.95836544036865, -180.5984003841877, -387.8370141983032, -338.3813439011574, -320.6602288633585, -182.11181201040745, -331.15034152567387, -295.0036338567734, -351.50891917943954, -370.22955053299665, -315.38987512886524, -219.2766069471836, -281.8167742192745, -337.45818120241165, -136.9562622308731, -326.5631651878357, -129.88796013593674, -372.9880223274231, -257.91548904776573, -346.78891310095787, -275.48679903149605, -276.64143876731396, -68.60699209570885, -262.44351311028004, -320.73733642697334, -317.57978264428675, -153.9463251233101, -360.0428886190057, -355.87650751695037, -318.85867639631033, -381.90275936294347, -173.61199155449867, -254.27757105231285, -294.07805059850216, -348.5094857811928, -332.40808129683137, -302.2544457241893, -224.95612782239914, -368.79041600227356, -308.72165640443563, -81.85456553474069, -373.8572616279125, -352.3852998614311, -291.7810097038746, -209.45745495706797, -298.30812910199165, -288.91846653819084, -355.6318387687206, -175.4664726704359, -160.70283073186874, -311.34211841225624, -130.31787486374378, -381.113315731287, -214.31183449178934, -177.55279672145844, -337.85355463624, -329.1560100913048, -248.05376362800598, -329.61016638576984, -347.2066205702722, -101.01703269779682, -365.98493924736977, -126.35004635620862, -310.025057464838, -291.4803636074066, -149.51297974959016, -324.3074985444546, -334.34783523902297, -324.5137799978256, -189.60005459189415, -278.78084920346737, -352.72923892736435, -378.7834988832474, -348.7612207531929, -259.84096771478653, -368.79233261942863, -192.33007282018661, -300.38981211185455, -154.30786885879934, -58.263843804597855, -93.19370300322771, -293.36644001305103, -304.796572946012, -168.984716668725, -318.72418519854546, -370.6341763138771, -307.6780381947756, -321.5807832479477, -337.34231155738235, -328.14320838451385, -77.33276599645615, -73.4064095467329, -179.20058937370777, -205.91496774554253, -329.2398442327976, -116.89023643359542, -379.51175037026405, -201.87782264314592, -385.4600729942322, -276.5918136537075, -352.695658326149, -304.96964110434055, -288.5717467814684, -310.58537343144417, -340.9859722480178, -310.1423180401325, -125.64173893630505, -323.93349212408066, -316.2093194723129, -140.2889180481434, -71.23376521468163, -185.49103969335556, -268.1483285985887, -333.3573638498783, -381.0383552014828, -366.2847504019737, -141.01020368933678, -326.9266217201948, -389.55089950561523, -129.7415716946125, -119.29032766819, -286.13683247566223, -311.6091835796833, -308.49939363077283, -138.90350463986397, -293.912226960063, -337.5189410299063, -315.37532091140747, -189.5664700344205, -363.56122364103794, -279.73857159912586, -65.6487163901329, -329.10058203339577, -358.5928415954113, -248.3689722418785, -341.20086854696274, -178.77139765024185, -291.5753686353564, -134.45915368944407, -78.27500683069229, -336.74289134144783, -326.5278294980526, -172.49989058077335, -251.43231546878815, -165.35790884494781, -86.14193984866142, -344.15941359102726, -299.9199951440096, -292.821906670928, -185.8763953447342, -318.7057000398636, -232.87917717546225, -344.14256888628006, -310.89024114608765], 'episode_lengths': [31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 543.2517422664438, 'mean_inference_ms': 38.52429465642051, 'mean_action_processing_ms': 2.763415269503187, 'mean_env_wait_ms': 403.98261404789645, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",2813.26,306.02,2813.26,"{'training_iteration_time_ms': 9504.646, 'load_time_ms': 133.039, 'load_throughput': 15393.991, 'learn_time_ms': 164.86, 'learn_throughput': 12422.695, 'synch_weights_time_ms': 598.243}",1672566101,0,80640,10,cf7bb_00000,53.376


[34m[1mwandb[0m: Adding directory to artifact (/DATA/l5kit/ray_results/01-01-2023_15-53-49/SAC/SAC_L5-CLE-V1_cf7bb_00000_0_2023-01-01_08-53-50/checkpoint_000010)... Done. 0.3s


In [38]:
import ray
from ray import air, tune
train_envs = 4

hcmTz = pytz.timezone("Asia/Ho_Chi_Minh") 
date = datetime.datetime.now(hcmTz).strftime("%d-%m-%Y_%H-%M-%S")
ray_result_logdir = '/DATA/l5kit/ray_results/' + date

lr = 3e-3
lr_start = 3e-4
lr_end = 3e-5
config_param_space = {
    "env": "L5-CLE-V1",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 63,
    "num_envs_per_worker": train_envs,
    'q_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'policy_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'tau': 0.005,
    'target_network_update_freq': 1,
    'replay_buffer_config':{
        'type': 'MultiAgentPrioritizedReplayBuffer',
        'capacity': int(1e5),
        "worker_side_prioritization": True,
    },
    'num_steps_sampled_before_learning_starts': 8000,
    
    'target_entropy': 'auto',
#     "model": {
#         "custom_model": "GN_CNN_torch_model",
#         "custom_model_config": {'feature_dim':128},
#     },
    '_disable_preprocessor_api': True,
     "eager_tracing": True,
     "restart_failed_sub_environments": True,
 
    # 'train_batch_size': 4000,
    # 'sgd_minibatch_size': 256,
    # 'num_sgd_iter': 16,
    # 'store_buffer_in_checkpoints' : False,
    'seed': 42,
    'batch_mode': 'truncate_episodes',
    "rollout_fragment_length": 1,
    'train_batch_size': 2048,
    # 'training_intensity' : 1000,
    'gamma': 0.8,
    'twin_q' : True,
    "lr": 3e-4,
    "min_sample_timesteps_per_iteration": 8000,
}

result_grid = tune.Tuner(
    "SAC",
    run_config=air.RunConfig(
        stop={"episode_reward_mean": 0, 'timesteps_total': int(2e6)},
        local_dir=ray_result_logdir,
        checkpoint_config=air.CheckpointConfig(num_to_keep=2, checkpoint_frequency = 10, checkpoint_score_attribute = 'episode_reward_mean')
        ),
    param_space=config_param_space).fit()

0,1
Current time:,2022-12-29 09:10:12
Running for:,08:22:49.98
Memory:,323.5/503.2 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
SAC_L5-CLE-V1_5af7a_00000,TERMINATED,172.17.0.2:2100464,249,30097.2,2007936,-48.418,-4.97356,-342.104,31


Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
SAC_L5-CLE-V1_5af7a_00000,2007936,"{'num_env_steps_sampled': 2007936, 'num_env_steps_trained': 16254976, 'num_agent_steps_sampled': 2007936, 'num_agent_steps_trained': 16254976, 'last_target_update_ts': 2007936, 'num_target_updates': 7937}",{},2022-12-29_09-10-12,True,31,{},-4.97356,-48.418,-342.104,252,64764,6a1b193b49884b9783bc68502a9fba4d,38d1f3f62321,"{'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'actor_loss': 10.66728687286377, 'critic_loss': 0.2743040919303894, 'alpha_loss': -8.968500137329102, 'alpha_value': 0.097858936, 'log_alpha_value': -2.3242283, 'target_entropy': -3.0, 'policy_t': 0.002813179511576891, 'mean_q': -10.056241989135742, 'max_q': -0.7284332513809204, 'min_q': -71.80377197265625}, 'model': {}, 'num_grad_updates_lifetime': 7937.0, 'diff_num_grad_updates_vs_sampler_policy': 7936.0, 'td_error': array([3.714016 , 0.5761843, 1.6428685, ..., 3.9594688, 2.1804185,  0.9676223], dtype=float32), 'mean_td_error': 2.5138769149780273}}, 'num_env_steps_sampled': 2007936, 'num_env_steps_trained': 16254976, 'num_agent_steps_sampled': 2007936, 'num_agent_steps_trained': 16254976, 'last_target_update_ts': 2007936, 'num_target_updates': 7937}",249,172.17.0.2,2007936,16254976,2007936,8064,16254976,65536,0,63,0,0,65536,"{'cpu_util_percent': 38.95, 'ram_util_percent': 64.37439024390243}",2100464,{},{},{},"{'mean_raw_obs_processing_ms': 557.5587640255825, 'mean_inference_ms': 47.593003390485485, 'mean_action_processing_ms': 2.9690678893112814, 'mean_env_wait_ms': 426.11047655216566, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -4.973555717617273, 'episode_reward_min': -342.10410130023956, 'episode_reward_mean': -48.418040810326126, 'episode_len_mean': 31.0, 'episode_media': {}, 'episodes_this_iter': 252, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-7.4621262066066265, -41.47999408096075, -30.397633604705334, -49.24600240588188, -29.92940052598715, -72.64650725573301, -26.87973716855049, -11.681799755431712, -18.12396379560232, -19.913102217018604, -23.069430757313967, -11.442539423704147, -99.51352718472481, -22.244776694104075, -27.77426766604185, -98.61479258537292, -49.127030685544014, -22.231638342142105, -199.97455298900604, -76.25392800569534, -18.71334876306355, -42.81597018241882, -73.56289330124855, -7.541620695963502, -55.609064519405365, -41.371235713362694, -22.585967406630516, -12.339263431727886, -51.4718153513968, -80.66785857826471, -58.32361352443695, -87.72112779319286, -28.258968710899353, -12.798266559839249, -18.643772268667817, -42.05495078116655, -10.752920123748481, -7.985721813514829, -83.72157561779022, -9.54317206516862, -13.384278029203415, -68.8031562268734, -10.638589896261692, -119.62784695625305, -132.9736880660057, -96.00003948807716, -11.209963822737336, -10.219531135633588, -8.989193486981094, -8.054704411886632, -14.753779930993915, -11.431751469150186, -19.78944766148925, -63.47205105796456, -136.97895389050245, -12.408421391621232, -48.95064224302769, -101.75829920172691, -20.026732839643955, -6.55037109926343, -16.132280349731445, -25.484977897256613, -42.59231808036566, -33.8322284668684, -9.60510440915823, -8.720755230635405, -51.93710348010063, -10.440021116286516, -64.29593467526138, -64.36974440515041, -114.32647287100554, -66.07944963127375, -50.41681718826294, -40.00171493180096, -16.200324084609747, -29.237795993685722, -153.15099585056305, -168.20879462361336, -25.784470826387405, -78.44883043318987, -91.56494441628456, -166.42671021819115, -26.73681242763996, -80.4181410074234, -12.422316025942564, -101.45837111771107, -120.86603850126266, -51.918699741363525, -10.306001126766205, -10.808054637163877, -11.039475049823523, -10.298748157918453, -52.092963591217995, -10.005635634995997, -8.237097959034145, -103.76691523194313, -24.40737035870552, -13.600290849804878, -159.21457397937775, -81.14093877747655, -27.153657734394073, -24.96376897767186, -34.3931125625968, -22.96482305508107, -75.07949529588223, -15.354894362390041, -81.35158762335777, -60.97712790966034, -14.311687611043453, -33.83127050474286, -90.24751383066177, -8.340447791735642, -25.292251905426383, -9.46522231400013, -49.51268584281206, -65.61743873357773, -26.8861066326499, -7.616639997810125, -9.61411420069635, -17.074677165597677, -33.59195667505264, -127.78772127628326, -26.960867166519165, -11.141022946685553, -21.494155287742615, -8.329457223415375, -134.2124132886529, -207.68313336372375, -12.8949363976717, -53.89726731926203, -10.092645371332765, -60.15844654291868, -24.786295540630817, -135.5761418044567, -18.681159734725952, -41.64393958076835, -8.743692738935351, -61.73739293217659, -157.37170243263245, -36.76923814415932, -8.601557003334165, -133.0081479549408, -65.10541343688965, -7.6507846750319, -9.336956173181534, -58.92861983180046, -34.97503907978535, -14.337392661720514, -32.35617619752884, -83.36293375492096, -88.2966719083488, -11.025825656950474, -30.715365447103977, -342.10410130023956, -56.04067733883858, -82.09118078276515, -12.04058371623978, -16.880190078169107, -24.79690684378147, -56.83008548617363, -37.77252276428044, -51.7009187489748, -96.63202089071274, -8.88059919513762, -69.18232207000256, -10.123821537941694, -23.197871148586273, -61.34557843208313, -15.173034584149718, -10.205897832289338, -36.05493459291756, -25.538627788424492, -45.104613564908504, -8.038095336407423, -116.4073496311903, -23.685832545161247, -11.050906583666801, -11.857267241925001, -27.28685772791505, -30.16975475475192, -102.22034421097487, -56.96879315376282, -49.29720479995012, -10.223536057397723, -9.4553512185812, -7.503548703156412, -19.415623735636473, -54.286733977496624, -15.285702427849174, -86.9964390695095, -8.883664108812809, -84.89092037081718, -36.3870615363121, -7.033709031995386, -23.797568812966347, -123.78697862941772, -135.91445318609476, -47.04819901660085, -13.954086169600487, -13.610693532973528, -8.311992602422833, -9.553745400160551, -10.068777902051806, -60.71382158994675, -55.95641040802002, -82.21351046860218, -9.413188518024981, -39.387228824198246, -45.233738109469414, -24.85464497283101, -91.2011684179306, -127.55325359106064, -94.68607670068741, -154.2859691977501, -8.40143384411931, -21.433973629027605, -82.24952015280724, -112.48282658308744, -84.93823811411858, -75.58485485613346, -16.214268926531076, -34.300846844911575, -44.14082069694996, -27.32290068268776, -11.797557841986418, -62.649516731500626, -22.95103609189391, -99.95410460233688, -34.06992940232158, -81.63254825770855, -4.973555717617273, -49.15727382898331, -114.69648578763008, -18.213496141135693, -94.72749266214669, -43.1529616266489, -29.413159757852554, -16.90747994184494, -48.11836316809058, -15.246637467294931, -9.59222056902945, -44.99032598733902, -55.53623207658529, -12.991382762789726, -206.02356186509132, -55.89227467775345, -11.870953384786844, -52.908058792352676, -60.95068273693323, -33.86251822859049, -10.461833715438843, -53.137717263773084], 'episode_lengths': [31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 557.5587640255825, 'mean_inference_ms': 47.593003390485485, 'mean_action_processing_ms': 2.9690678893112814, 'mean_env_wait_ms': 426.11047655216566, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",30097.2,119.19,30097.2,"{'training_iteration_time_ms': 3943.931, 'load_time_ms': 185.149, 'load_throughput': 11061.376, 'learn_time_ms': 210.285, 'learn_throughput': 9739.143, 'synch_weights_time_ms': 458.456}",1672305012,0,2007936,249,5af7a_00000,56.4943


2022-12-29 09:10:12,594	INFO tune.py:762 -- Total run time: 30170.23 seconds (30169.98 seconds for the tuning loop).


In [43]:
config_param_space = {
    "env": "L5-CLE-V1",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 63,
    "num_envs_per_worker": train_envs,
    'q_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'policy_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'tau': 0.005,
    'target_network_update_freq': 1,
    'replay_buffer_config':{
        'type': 'MultiAgentPrioritizedReplayBuffer',
        'capacity': int(1e5),
        "worker_side_prioritization": True,
    },
    'num_steps_sampled_before_learning_starts': 8000,
    
    'target_entropy': 'auto',
#     "model": {
#         "custom_model": "GN_CNN_torch_model",
#         "custom_model_config": {'feature_dim':128},
#     },
    '_disable_preprocessor_api': True,
     "eager_tracing": True,
     "restart_failed_sub_environments": True,
 
    # 'train_batch_size': 4000,
    # 'sgd_minibatch_size': 256,
    # 'num_sgd_iter': 16,
    # 'store_buffer_in_checkpoints' : False,
    'seed': 42,
    'batch_mode': 'truncate_episodes',
    "rollout_fragment_length": 1,
    'train_batch_size': 2048,
    # 'training_intensity' : 1000,
    'gamma': 0.8,
    'twin_q' : True,
    "lr": 3e-4,
    "min_sample_timesteps_per_iteration": 8000,
}

In [45]:
# config_param_space['stop']['timesteps_total'] = 3e-5
path_to_trained_agent_checkpoint = 'l5kit/ray_results/29-12-2022_07-47-22/SAC/SAC_L5-CLE-V1_5af7a_00000_0_2022-12-29_00-47-23/checkpoint_000249'
from ray.rllib.algorithms.sac import SAC
ray.tune.run(SAC, config=config_param_space, restore=path_to_trained_agent_checkpoint)

0,1
Current time:,2022-12-29 15:00:53
Running for:,00:04:32.61
Memory:,270.1/503.2 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
SAC_L5-CLE-V1_f4f69_00000,RUNNING,172.17.0.2:137387,250,30255,2016000,-137.817,-15.617,-371.548,31


Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
SAC_L5-CLE-V1_f4f69_00000,2016000,"{'num_env_steps_sampled': 2016000, 'num_env_steps_trained': 16320512, 'num_agent_steps_sampled': 2016000, 'num_agent_steps_trained': 16320512, 'last_target_update_ts': 2016000, 'num_target_updates': 7969}",{},2022-12-29_14-59-58,False,31,{},-15.617,-137.817,-371.548,252,65016,6a1b193b49884b9783bc68502a9fba4d,38d1f3f62321,"{'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'actor_loss': 4.579573631286621, 'critic_loss': 0.9118623733520508, 'alpha_loss': -9.704286575317383, 'alpha_value': 0.09693226, 'log_alpha_value': -2.3337429, 'target_entropy': -3.0, 'policy_t': -0.08073669672012329, 'mean_q': -3.388627767562866, 'max_q': -0.7259072065353394, 'min_q': -15.54305648803711}, 'model': {}, 'num_grad_updates_lifetime': 32.0, 'diff_num_grad_updates_vs_sampler_policy': 31.0, 'td_error': array([1.2998202, 8.56468 , 1.5235349, ..., 2.8340416, 3.6547613,  1.2337224], dtype=float32), 'mean_td_error': 4.344608306884766}}, 'num_env_steps_sampled': 2016000, 'num_env_steps_trained': 16320512, 'num_agent_steps_sampled': 2016000, 'num_agent_steps_trained': 16320512, 'last_target_update_ts': 2016000, 'num_target_updates': 7969}",1,172.17.0.2,2016000,16320512,2016000,8064,16320512,65536,0,63,0,0,65536,"{'cpu_util_percent': 60.70904977375565, 'ram_util_percent': 42.32488687782805}",137387,{},{},{},"{'mean_raw_obs_processing_ms': 691.7444644724308, 'mean_inference_ms': 41.91511855095612, 'mean_action_processing_ms': 3.431972257789366, 'mean_env_wait_ms': 584.0683852668678, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -15.617014065384865, 'episode_reward_min': -371.54844295978546, 'episode_reward_mean': -137.81730142700079, 'episode_len_mean': 31.0, 'episode_media': {}, 'episodes_this_iter': 252, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-299.97706861048937, -51.82022622227669, -212.13150823116302, -244.13909198343754, -197.46283899247646, -31.607199588790536, -96.4726347476244, -15.617014065384865, -273.5172915905714, -194.24902094714344, -129.61998998373747, -65.69232455454767, -19.393883276730776, -111.7408323250711, -59.09626119583845, -90.9191816598177, -199.88144850730896, -299.2814950942993, -190.75740267336369, -130.1176925431937, -179.24520199000835, -194.28794536739588, -143.29140626639128, -101.1886000931263, -86.3916173428297, -29.176321119070053, -35.31472562253475, -223.68560089170933, -74.8477296680212, -54.759615938179195, -307.39633852243423, -240.48044553399086, -268.46005898714066, -223.39069479703903, -220.1556041240692, -50.030099771916866, -80.1811374425888, -28.63940942659974, -222.0825450643897, -63.31155708618462, -110.8121508359909, -40.13294491916895, -269.29217824339867, -186.04005989432335, -249.58123522996902, -218.90602142363787, -47.48398283869028, -55.637234918773174, -52.47745594102889, -111.15736995637417, -211.80621452629566, -43.390553280711174, -177.45844892412424, -52.30602751299739, -214.2268007695675, -161.82819437980652, -192.73150944709778, -48.21109406277537, -227.2928847670555, -103.89885324984789, -30.138481315225363, -111.42090831696987, -158.19647543132305, -182.80938386917114, -251.64277704060078, -137.9156246036291, -140.02911753207445, -98.01596263051033, -134.044519148767, -111.15653241798282, -71.10387574322522, -225.6438679099083, -212.01682150363922, -121.28016351722181, -75.78648805245757, -153.77385725080967, -53.40755498409271, -168.57436906546354, -39.32255797833204, -70.8449877682142, -32.380650751292706, -246.21190083026886, -43.56784425675869, -24.78948475793004, -133.95375065878034, -185.20826063677669, -75.8833856601268, -247.84678468108177, -165.07357239723206, -332.57243210077286, -172.52162861824036, -167.6060889363289, -46.20060350000858, -270.73676973581314, -101.12249380350113, -50.965405613183975, -371.54844295978546, -25.443911347538233, -161.73639491945505, -43.21673648804426, -34.5110229998827, -198.44270712137222, -76.74124410748482, -88.40089039504528, -337.2680090069771, -147.06688775122166, -67.2424168586731, -49.01937335729599, -171.42071080207825, -146.56798292323947, -173.93097941577435, -31.48320733755827, -94.1955065280199, -23.609794601798058, -96.3097610771656, -236.95774947106838, -97.59911411628127, -72.61437848210335, -180.87512730807066, -129.8502264805138, -143.43173751235008, -93.29107657074928, -128.40934970974922, -271.45442797709256, -192.11385215818882, -163.12449376285076, -314.89704218506813, -289.9468065202236, -116.46002896130085, -150.59363002888858, -130.35035149008036, -256.71769541502, -93.11557127535343, -250.47537092864513, -164.41879165172577, -71.24733397364616, -139.42685773968697, -202.00427490472794, -104.31808914244175, -210.9410725682974, -194.30858905240893, -89.75386482477188, -67.03705172333866, -160.9048422574997, -215.08678121119738, -264.1829191148281, -152.43290308117867, -72.9229040145874, -276.17744693160057, -196.89465822279453, -96.54025993961841, -68.37835693359375, -21.348181121982634, -31.92421041801572, -123.40187205374241, -90.50316984020174, -160.30389396846294, -167.4033802896738, -71.28940810263157, -32.464180406183004, -252.95062878355384, -196.7002201974392, -210.17943699657917, -56.82348436117172, -138.51692804694176, -91.46426229923964, -254.60068672895432, -143.89270337671041, -71.78226737631485, -46.60442142933607, -46.624403320252895, -70.64460425078869, -120.41957493126392, -86.16039713518694, -142.7494396492839, -267.28318813443184, -82.15140162222087, -156.07647028565407, -186.9594401270151, -63.831169694662094, -150.02693651244044, -191.60844258964062, -84.98108759522438, -203.62129264511168, -43.943675100803375, -55.66991034895182, -126.72345610335469, -283.9452374931425, -80.56798090040684, -133.2169153690338, -55.0657594576478, -88.30508452653885, -53.43094354867935, -64.081110522151, -207.8457684367895, -56.04468463920057, -230.81134974956512, -59.9131864765659, -72.7744899392128, -237.27829033136368, -121.15408451855183, -126.59941921383142, -78.40274704247713, -56.59736351296306, -137.6304468885064, -57.37811692850664, -208.78320770338178, -56.395890951156616, -191.32112161070108, -262.88904443383217, -30.63711839914322, -58.60731145367026, -42.97315649688244, -43.484818050055765, -103.17349380254745, -39.4192226678133, -65.10727168619633, -301.94181221723557, -207.65956818684936, -272.50115633010864, -64.54533440619707, -160.38112823665142, -274.56690204143524, -107.91420528292656, -279.497002184391, -186.86970533430576, -66.16285207122564, -155.11157170310616, -109.60744011029601, -177.88816595077515, -144.8501342087984, -183.32364154607058, -35.32881237566471, -98.80114330351353, -132.91177412495017, -226.34882310405374, -144.25793534331024, -156.4375006519258, -92.1831620708108, -125.86796532571316, -126.13164404407144, -148.4325664229691, -79.81684523820877, -168.8368287011981, -218.18234498798847, -117.6562134847045, -296.0146112330258, -47.064086839556694, -172.08175939507782, -191.52176237478852, -53.177558897063136, -104.96447165310383], 'episode_lengths': [31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 691.7444644724308, 'mean_inference_ms': 41.91511855095612, 'mean_action_processing_ms': 3.431972257789366, 'mean_env_wait_ms': 584.0683852668678, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",157.784,157.784,30255,"{'training_iteration_time_ms': 7780.57, 'load_time_ms': 205.295, 'load_throughput': 9975.888, 'learn_time_ms': 239.504, 'learn_throughput': 8551.009, 'synch_weights_time_ms': 2729.168}",1672325998,0,2016000,250,f4f69_00000,54.8865


2022-12-29 15:00:54,331	ERROR tune.py:758 -- Trials did not complete: [SAC_L5-CLE-V1_f4f69_00000]
2022-12-29 15:00:54,333	INFO tune.py:762 -- Total run time: 272.99 seconds (272.61 seconds for the tuning loop).


<ray.tune.analysis.experiment_analysis.ExperimentAnalysis at 0x7f0903752fd0>

*** SIGTERM received at time=1672326111 on cpu 28 ***
PC: @     0x7f0aa3d5b46e  (unknown)  epoll_wait
    @     0x7f0aa3c7f090  (unknown)  (unknown)
[2022-12-29 15:01:51,032 E 1533378 1533378] logging.cc:361: *** SIGTERM received at time=1672326111 on cpu 28 ***
[2022-12-29 15:01:51,032 E 1533378 1533378] logging.cc:361: PC: @     0x7f0aa3d5b46e  (unknown)  epoll_wait
[2022-12-29 15:01:51,032 E 1533378 1533378] logging.cc:361:     @     0x7f0aa3c7f090  (unknown)  (unknown)


In [None]:
import ray
from ray import air, tune
train_envs = 4
ray_result_logdir = '/DATA/l5kit/ray_results/' + date
lr = 3e-3
lr_start = 3e-4
lr_end = 3e-5

config_param_space = {
    "env": "L5-CLE-V1",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 63,
    "num_envs_per_worker": train_envs,
    'q_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'policy_model_config' : {
            # "dim": 112,
            # "conv_filters" : [[64, [7,7], 3], [32, [11,11], 3], [32, [11,11], 3]],
            # "conv_activation": "relu",
            "post_fcnet_hiddens": [256],
            "post_fcnet_activation": "relu",
        },
    'tau': 0.005,
    'n_step': 1
    'target_network_update_freq': 5,
    'replay_buffer_config':{
        'type': 'MultiAgentPrioritizedReplayBuffer',
        'capacity': int(3e5),
    },
    'num_steps_sampled_before_learning_starts': 256,
    
    'target_entropy': 'auto',
#     "model": {
#         "custom_model": "GN_CNN_torch_model",
#         "custom_model_config": {'feature_dim':128},
#     },
    '_disable_preprocessor_api': True,
     "eager_tracing": True,
     "restart_failed_sub_environments": True,
    # 'train_batch_size': 4000,
    # 'sgd_minibatch_size': 256,
    # 'num_sgd_iter': 16,
    'seed': 42,
    'batch_mode': 'truncate_episodes',
    "rollout_fragment_length": 32,
    'gamma': 0.8,
    "lr_schedule": [
        [1e6, lr_start],
        [2e6, lr_end],
    ],
    "lr": lr,
}

result_grid = tune.Tuner(
    "SAC",
    run_config=air.RunConfig(
        stop={"episode_reward_mean": 0, 'timesteps_total': int(3e6)},
        local_dir=ray_result_logdir,
        checkpoint_config=air.CheckpointConfig(num_to_keep=2, checkpoint_frequency = 10, checkpoint_score_attribute = 'episode_reward_mean')
        ),
    param_space=config_param_space).fit()

In [39]:
import ray
from ray import air, tune
train_envs = 4
ray_result_logdir = '/DATA/l5kit/ray_results/29-12-2022_07-47-22'

tuner = tune.Tuner.restore(
    path=ray_result_logdir + '/SAC'
)
result = tuner.fit()

2022-12-29 14:46:39,383	INFO experiment_analysis.py:795 -- No `self.trials`. Drawing logdirs from checkpoint file. This may result in some information that is out of sync, as checkpointing is periodic.
2022-12-29 14:46:39,426	INFO trial_runner.py:688 -- A local experiment checkpoint was found and will be used to restore the previous experiment state.
2022-12-29 14:46:39,427	INFO trial_runner.py:825 -- Using following checkpoint to resume: /DATA/l5kit/ray_results/29-12-2022_07-47-22/SAC/experiment_state-2022-12-29_00-47-22.json
2022-12-29 14:46:39,436	INFO tune.py:653 -- TrialRunner resumed, ignoring new add_experiment but updating trial resources.


0,1
Current time:,2022-12-29 14:46:39
Running for:,00:00:00.02
Memory:,144.4/503.2 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
SAC_L5-CLE-V1_5af7a_00000,TERMINATED,172.17.0.2:2100464,249,30097.2,2007936,-48.418,-4.97356,-342.104,31


2022-12-29 14:46:39,596	INFO tune.py:762 -- Total run time: 0.17 seconds (0.00 seconds for the tuning loop).


In [None]:
best_result = result.get_best_result(metric="episode_reward_mean", mode = 'max')
best_checkpoint = best_result.checkpoint
best_checkpoint

In [19]:
num_results = len(result)
print("Number of results:", num_results)

Number of results: 1


In [None]:
result.get_best_result()

In [34]:
result.get_dataframe()

In [40]:
type(best_result)

ray.air.result.Result

*** SIGTERM received at time=1671780714 on cpu 2 ***
PC: @     0x7f4dbec0d46e  (unknown)  epoll_wait
    @     0x7f4dbeb31090  (unknown)  (unknown)
[2022-12-23 07:31:54,799 E 2700257 2700257] logging.cc:361: *** SIGTERM received at time=1671780714 on cpu 2 ***
[2022-12-23 07:31:54,799 E 2700257 2700257] logging.cc:361: PC: @     0x7f4dbec0d46e  (unknown)  epoll_wait
[2022-12-23 07:31:54,799 E 2700257 2700257] logging.cc:361:     @     0x7f4dbeb31090  (unknown)  (unknown)


In [39]:
result_df = best_result.metrics_dataframe() 
result_df[['episode_reward_mean']]

TypeError: 'NoneType' object is not callable

NOTE: Experiment has been interrupted, but the most recent state was saved. You can continue running this experiment by passing `resume=True` to `tune.run()`

2022-12-04 05:50:38,570	INFO experiment_analysis.py:795 -- No `self.trials`. Drawing logdirs from checkpoint file. This may result in some information that is out of sync, as checkpointing is periodic.

2022-12-04 05:50:39,684	INFO trial_runner.py:601 -- A local experiment checkpoint was found and will be used to restore the previous experiment state.
2022-12-04 05:50:39,687	INFO trial_runner.py:738 -- Using following checkpoint to resume: /content/drive/MyDrive/Colab Notebooks/l5kit/ray_results/PPO/experiment_state-2022-12-04_05-28-55.json

2022-12-04 05:50:39,710	WARNING trial_runner.py:743 -- Attempting to resume experiment from /content/drive/MyDrive/Colab Notebooks/l5kit/ray_results/PPO. This will ignore any new changes to the specification.

2022-12-04 05:50:40,703	INFO tune.py:668 -- TrialRunner resumed, ignoring new add_experiment but updating trial resources.

In [None]:
train_envs=2
ray_result_logdir = '/content/drive/MyDrive/Colab Notebooks/l5kit/ray_results'
# Create the Trainer.
algo = ppo.PPO(
        env="L5-CLE-V1",
        config={
            "framework": "torch",
            "num_gpus": 1,
            "num_workers": 2,
            "num_envs_per_worker": train_envs,
            'num_sgd_iter': 5,
            'sgd_minibatch_size': 256,
            'num_cpus_per_worker': 0,  # This avoids running out of resources in the notebook environment when this cell is re-executed
            "model": {
                "custom_model": "GN_CNN_torch_model",
                "custom_model_config": {'feature_dim':128},
            },
            '_disable_preprocessor_api': True,
        },
        logger_creator=custom_log_creator(os.path.expanduser(ray_result_logdir), 'L5_PPO')
    )

2022-12-04 04:42:57,593	INFO ppo.py:379 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
2022-12-04 04:42:57,599	INFO algorithm.py:457 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2022-12-04 04:43:51,371	INFO trainable.py:164 -- Trainable.setup took 53.780 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [None]:
checkpoint_path =  '/content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/'+ str(datetime.date.today())
for i in range(1, 1000):
   # Perform one iteration of training the policy with PPO
   result = algo.train()
   print(pretty_print(result))

   if i % 10 == 0:
       checkpoint = algo.save(checkpoint_dir= checkpoint_path)
       print("checkpoint saved at", checkpoint)

In [None]:
checkpoint_path = '/content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/2022-12-02/checkpoint_000130'
config={
            "framework": "torch",
            "num_gpus": 1,
            "num_workers": 2,
            "num_envs_per_worker": train_envs,
            'num_sgd_iter': 15,
            'sgd_minibatch_size': 128,
            'num_cpus_per_worker': 0,  # This avoids running out of resources in the notebook environment when this cell is re-executed
            "model": {
                "custom_model": "GN_CNN_torch_model",
                "custom_model_config": {'feature_dim':128},
            },
            '_disable_preprocessor_api': True,
        }
algo = ppo.PPO(config=config, env='L5-CLE-V1')
algo.restore(checkpoint_path)

2022-12-04 01:14:25,596	INFO trainable.py:164 -- Trainable.setup took 25.769 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2022-12-04 01:14:26,933	INFO trainable.py:766 -- Restored on 172.28.0.12 from checkpoint: /content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/2022-12-02/checkpoint_000130
2022-12-04 01:14:26,936	INFO trainable.py:775 -- Current state after restoring: {'_iteration': 130, '_timesteps_total': None, '_time_total': 14015.14718747139, '_episodes_total': 16772}


In [None]:
checkpoint_path =  '/content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/'+ str(datetime.date.today())
for i in range(1, 1000):
   # Perform one iteration of training the policy with PPO
   result = algo.train()
   print(pretty_print(result))

   if i % 10 == 0:
       checkpoint = algo.save(checkpoint_dir= checkpoint_path)
       print("checkpoint saved at", checkpoint)

In [None]:
checkpoint_path =  '/content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/'+ str(datetime.date.today())
for i in range(1, 1000):
   # Perform one iteration of training the policy with PPO
   result = algo.train()
   print(pretty_print(result))

   if i % 1 == 0:
       checkpoint = algo.save(checkpoint_dir= checkpoint_path)
       print("checkpoint saved at", checkpoint)

In [None]:
train_envs = 2
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 1
# config["framework"] = 'tf2'
config["num_workers"] = 2
config["num_envs_per_worker"] = train_envs
config['_disable_preprocessor_api'] = True,
config["model"]["dim"] = 112
config["model"]["conv_filters"] = [[64, 7, 3], [32, 11, 3], [32, 11, 3]]
config['num_sgd_iter'] = 1
config['sgd_minibatch_size'] = 256
# config['model']['fcnet_hiddens'] = [100, 100]
config['num_cpus_per_worker'] = 0  # This avoids running out of resources in the notebook environment when this cell is re-executed
# config['env_config'] = env_kwargs 
# config["log_level"] = 1
# config["evaluation_interval"] = 1 # change to 10000
# config["evaluation_duration"] = "auto"
# config["evaluation_parallel_to_training"] = True,
# config["evaluation_duration_unit"] = "timesteps"
# config["evaluation_num_workers"] = 3
# config["enable_async_evaluation"] = True,

config["model"]["conv_activation"] = 'relu'
config["model"]["post_fcnet_hiddens"] =  [256]
config["model"]["post_fcnet_activation"] = 'relu'
# config["train_batch_size"] = 200
algo = ppo.PPO(config=config, env="L5-CLE-V1")

2022-12-14 14:57:14,925	INFO algorithm_config.py:2503 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also then want to set eager_tracing=True in order to reach similar execution speed as with static-graph mode.
2022-12-14 14:57:14,964	INFO algorithm.py:501 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2022-12-14 14:58:13,797	INFO trainable.py:172 -- Trainable.setup took 58.834 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [None]:
checkpoint_path =  '/content/drive/MyDrive/Colab Notebooks/l5kit/rllib_logs/'+ str(datetime.date.today())
for i in range(1, 1000):
   # Perform one iteration of training the policy with PPO
   result = algo.train()
   print(pretty_print(result))

   if i % 10 == 0:
       checkpoint = algo.save(checkpoint_dir= checkpoint_path)
       print("checkpoint saved at", checkpoint)

## Stable baselines 3

In [None]:
# Train on episodes of length 32 time steps
train_eps_length = 32
train_envs = 4

# Evaluate on entire scene (~248 time steps)
eval_eps_length = None
eval_envs = 1

# make train env
train_sim_cfg = SimulationConfigGym()
train_sim_cfg.num_simulation_steps = train_eps_length + 1
env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, 'sim_cfg': train_sim_cfg}
env = make_vec_env("L5-CLE-v0", env_kwargs=env_kwargs, n_envs=train_envs,
                   vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "fork"})

# make eval env
validation_sim_cfg = SimulationConfigGym()
validation_sim_cfg.num_simulation_steps = None
eval_env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, \
                   'return_info': True, 'train': False, 'sim_cfg': validation_sim_cfg}
eval_env = make_vec_env("L5-CLE-v0", env_kwargs=eval_env_kwargs, n_envs=eval_envs,
                        vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "fork"})

### Define backbone feature extractor

The backbone feature extractor is shared between the policy and the value networks. The feature extractor *simple_gn* is composed of two convolutional networks followed by a fully connected layer, with ReLU activation. The feature extractor output is passed to both the policy and value networks composed of two fully connected layers with tanh activation (SB3 default).

We perform **group normalization** after every convolutional layer. Empirically, we found that group normalization performs far superior to batch normalization. This can be attributed to the fact that activation statistics change quickly in on-policy algorithms (PPO is on-policy) while batch-norm learnable parameters can be slow to update causing training issues.

In [None]:
# A simple 2 Layer CNN architecture with group normalization
model_arch = 'simple_gn'
features_dim = 128

# Custom Feature Extractor backbone
policy_kwargs = {
    "features_extractor_class": CustomFeatureExtractor,
    "features_extractor_kwargs": {"features_dim": features_dim, "model_arch": model_arch},
    "normalize_images": False
}

### Clipping Schedule

We linearly decrease the value of the clipping parameter $\epsilon$ as the PPO training progress as it shows improved training stability

In [None]:
# Clipping schedule of PPO epsilon parameter
start_val = 0.1
end_val = 0.01
training_progress_ratio = 1.0
clip_schedule = get_linear_fn(start_val, end_val, training_progress_ratio)

### Hyperparameters for PPO. 

For detailed description, refer https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/ppo/ppo.html#PPO

In [None]:
lr = 3e-4
num_rollout_steps = 256
gamma = 0.8
gae_lambda = 0.9
n_epochs = 10
seed = 42
batch_size = 64
tensorboard_log = '/content/drive/MyDrive/Colab Notebooks/l5kit/tb_logs/' + str(datetime.date.today()) + '/'

### Define the PPO Policy.

SB3 provides an easy interface to the define the PPO policy. Note: We do need to tweak appropriate hyperparameters and the custom policy backbone has been defined above.


In [None]:
# define model
model = PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=1, n_steps=num_rollout_steps,
            learning_rate=lr, gamma=gamma, tensorboard_log=tensorboard_log, n_epochs=n_epochs,
            clip_range=clip_schedule, batch_size=batch_size, seed=seed, gae_lambda=gae_lambda)

### Defining Callbacks

We can additionally define callbacks to save model checkpoints and evaluate models during training.

In [None]:
callback_list = []

# Save Model Periodically
save_freq = 10000
save_path = '/content/drive/MyDrive/Colab Notebooks/l5kit/logs/'+ str(datetime.date.today())
output = 'PPO'
checkpoint_callback = CheckpointCallback(save_freq=(save_freq // train_envs), save_path=save_path, \
                                         name_prefix=output)
callback_list.append(checkpoint_callback)

# Eval Model Periodically
eval_freq = 10000
n_eval_episodes = 1
val_eval_callback = L5KitEvalCallback(eval_env, eval_freq=(eval_freq // train_envs), \
                                      n_eval_episodes=n_eval_episodes, n_eval_envs=eval_envs)
callback_list.append(val_eval_callback)


### Train

In [None]:
outdir = '/content/drive/MyDrive/Colab Notebooks/l5kit/ppo_interrupted/'+ str(datetime.date.today()) + '/'
model_name = 'PPO'
try:
    n_steps = 1000000
    model.learn(n_steps, callback=callback_list)
except:
    model.save(outdir + model_name)
    # model.save_replay_buffer(outdir + model_name+ "_buffer")
    model.policy.save(outdir + model_name + "_policy")

**Voila!** We have a trained PPO policy! Train for larger number of steps for better accuracy. Typical RL algorithms require training atleast 1M steps for good convergence. You can visualize the quantitiative evaluation using tensorboard.

In [None]:
model = PPO.load('/content/drive/MyDrive/Colab Notebooks/l5kit/ppo_interrupted/PPO2022-11-16.zip', env = env)


In [None]:
model = PPO.load('/content/drive/MyDrive/Colab Notebooks/l5kit/logs/2022-11-25/PPO_4490000_steps.zip', env = env)
n_steps = 1000000
# model = PPO.load('./PPO_100000_steps.zip', env = env)
model.learn(n_steps, callback=callback_list, reset_num_timesteps=False)


### Visualize the episode from the environment

We can easily visualize the outputs obtained by rolling out episodes in the L5Kit using the Bokeh visualizer.

In [None]:
model = PPO.load('/content/drive/MyDrive/Colab Notebooks/l5kit/logs/2022-11-23/PPO_4210000_steps.zip', env = env)

In [None]:
# Train on episodes of length 32 time steps
train_eps_length = 32
train_envs = 4

# Evaluate on entire scene (~248 time steps)
eval_eps_length = None
eval_envs = 1

# make train env
train_sim_cfg = SimulationConfigGym()
train_sim_cfg.num_simulation_steps = 32 + 1
env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, 'sim_cfg': train_sim_cfg}
env = make_vec_env("L5-CLE-v0", env_kwargs=env_kwargs, n_envs=train_envs,
                   vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "fork"})

# make eval env
validation_sim_cfg = SimulationConfigGym()
validation_sim_cfg.num_simulation_steps = None
eval_env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, \
                   'return_info': True, 'train': False, 'sim_cfg': validation_sim_cfg}
eval_env = make_vec_env("L5-CLE-v0", env_kwargs=eval_env_kwargs, n_envs=eval_envs,
                        vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "fork"})

In [None]:
rollout_sim_cfg = SimulationConfigGym()
rollout_sim_cfg.num_simulation_steps = None
rollout_env = gym.make("L5-CLE-v0", env_config_path=env_config_path, sim_cfg=rollout_sim_cfg, \
                       use_kinematic=True, train=False, return_info=True)

def rollout_episode(model, env, idx = 0):
    """Rollout a particular scene index and return the simulation output.

    :param model: the RL policy
    :param env: the gym environment
    :param idx: the scene index to be rolled out
    :return: the episode output of the rolled out scene
    """

    # Set the reset_scene_id to 'idx'
    env.reset_scene_id = idx
    
    # Rollout step-by-step
    obs = env.reset()
    done = False
    while True:
        action, _ = model.predict(obs, deterministic=True)
        obs, _, done, info = env.step(action)
        if done:
            break

    # The episode outputs are present in the key "sim_outs"
    sim_out = info["sim_outs"][0]
    return sim_out

# Rollout one episode
# sim_out = rollout_episode(model, rollout_env)
# Rollout 5 episodes
sim_outs =[]
for i in range(5):
    sim_outs.append(rollout_episode(model, rollout_env))

In [None]:
# might change with different rasterizer
map_API = rollout_env.dataset.rasterizer.sem_rast.mapAPI

def visualize_outputs(sim_outs, map_API):
    for sim_out in sim_outs: # for each scene
        vis_in = episode_out_to_visualizer_scene_gym_cle(sim_out, map_API)
        show(visualize(sim_out.scene_id, vis_in))

output_notebook()
visualize_outputs(sim_outs, map_API)

## Calculate the performance metrics from the episode outputs

We can also calculate the various quantitative metrics on the rolled out episode output. 

In [None]:
def quantify_outputs(sim_outs, metric_set=None):
    metric_set = metric_set if metric_set is not None else L2DisplacementYawMetricSet()

    metric_set.evaluate(sim_outs)
    scene_results = metric_set.evaluator.scene_metric_results
    fields = ["scene_id", "FDE", "ADE"]
    table = PrettyTable(field_names=fields)
    tot_fde = 0.0
    tot_ade = 0.0
    for scene_id in scene_results:
        scene_metrics = scene_results[scene_id]
        ade_error = scene_metrics["displacement_error_l2"][1:].mean()
        fde_error = scene_metrics['displacement_error_l2'][-1]
        table.add_row([scene_id, round(fde_error.item(), 4), round(ade_error.item(), 4)])
        tot_fde += fde_error.item()
        tot_ade += ade_error.item()

    ave_fde = tot_fde / len(scene_results)
    ave_ade = tot_ade / len(scene_results)
    table.add_row(["Overall", round(ave_fde, 4), round(ave_ade, 4)])
    print(table)


quantify_outputs(sim_outs)

+----------+----------+---------+
| scene_id |   FDE    |   ADE   |
+----------+----------+---------+
|    18    | 123.9884 | 58.2833 |
|    76    | 94.0163  | 62.4942 |
|    80    | 95.1473  | 20.5979 |
|    11    | 25.0332  | 18.1672 |
|    16    | 80.5144  |  31.691 |
| Overall  | 83.7399  | 38.2467 |
+----------+----------+---------+


In [None]:
def quantify_outputs(sim_outs, metric_set=None):
    metric_set = metric_set if metric_set is not None else CLEMetricSet()

    metric_set.evaluate(sim_outs)
    scene_results = metric_set.evaluator.scene_metric_results
    fields = ["scene_id", "FDE", "ADE", "DRT", "CF", "CR", "CS", "PEGO"]
    table = PrettyTable(field_names=fields)
    tot_fde = 0.0
    tot_ade = 0.0
    tot_drt = 0.0
    tot_cf = 0.0
    tot_cr = 0.0
    tot_cs = 0.0
    tot_p_ego = 0.0
    tot_a_ego = 0.0
    # print(scene_results[0])
    for scene_id in scene_results:
        scene_metrics = scene_results[scene_id]
        ade_error = scene_metrics["displacement_error_l2"][1:].mean()
        fde_error = scene_metrics['displacement_error_l2'][-1]
        drt_error = scene_metrics['distance_to_reference_trajectory'][-1]
        cf_error = scene_metrics['collision_front'][-1]
        cr_error = scene_metrics['collision_rear'][-1]
        cs_error = scene_metrics['collision_side'][-1]
        p_ego = scene_metrics['simulated_minus_recorded_ego_speed'][-1]
        # a_ego = scene_metrics['aggressive_ego'][-1]
        table.add_row([scene_id, round(fde_error.item(), 4), round(ade_error.item(), 4), round(drt_error.item(), 4), round(cf_error.item(), 4), round(cr_error.item(), 4), 
        round(cs_error.item(), 4), round(p_ego.item(), 4)])
        tot_fde += fde_error.item()
        tot_ade += ade_error.item()
        tot_drt += drt_error.item()
        tot_cf += cf_error.item()
        tot_cr += cr_error.item()
        tot_cs += cs_error.item()
        tot_p_ego += p_ego.item()
        # tot_a_ego += a_ego.item()

    ave_fde = tot_fde / len(scene_results)
    ave_ade = tot_ade / len(scene_results)
    ave_drt = tot_drt / len(scene_results)
    ave_cf = tot_cf / len(scene_results)
    ave_cr = tot_cr / len(scene_results)
    ave_cs = tot_cs / len(scene_results)
    ave_p_ego = tot_p_ego / len(scene_results)
    # ave_a_ego = tot_a_ego / len(scene_results)
    table.add_row(["Overall", round(ave_fde, 4), round(ave_ade, 4), round(ave_drt, 4), round(ave_cf, 4), round(ave_cr, 4), round(ave_cs, 4), round(ave_p_ego, 4)])
    print(table)


quantify_outputs(sim_outs)

+----------+----------+---------+---------+-----+-----+-----+----------+
| scene_id |   FDE    |   ADE   |   DRT   |  CF |  CR |  CS |   PEGO   |
+----------+----------+---------+---------+-----+-----+-----+----------+
|    18    | 123.9884 | 58.2833 |  14.79  | 0.0 | 0.0 | 0.0 | -3.0587  |
|    76    | 94.0163  | 62.4942 |  4.3439 | 0.0 | 0.0 | 0.0 |  2.7008  |
|    80    | 95.1473  | 20.5979 |  4.7899 | 0.0 | 0.0 | 0.0 | -13.0379 |
|    11    | 25.0332  | 18.1672 |  3.202  | 0.0 | 0.0 | 0.0 |  0.678   |
|    16    | 80.5144  |  31.691 | 22.2307 | 0.0 | 0.0 | 0.0 |  0.138   |
| Overall  | 83.7399  | 38.2467 |  9.8713 | 0.0 | 0.0 | 0.0 |  -2.516  |
+----------+----------+---------+---------+-----+-----+-----+----------+
