# Offline reinforcement learning with Ray AIR
In this example, we'll train a reinforcement learning agent using offline training.

Offline training means that the data from the environment (and the actions performed by the agent) have been stored on disk. In contrast, online training samples experiences live by interacting with the environment.

Let's start with installing our dependencies:

In [1]:
# !pip install -qU "ray[rllib]" gymnasium

Now we can run some imports:

In [1]:
import argparse
import gymnasium as gym
import os

import numpy as np
import ray
from ray.air import Checkpoint
from ray.air.config import RunConfig
from ray.train.rl.rl_predictor import RLPredictor
from ray.train.rl.rl_trainer import RLTrainer
from ray.air.config import ScalingConfig
from ray.air.result import Result
from ray.rllib.algorithms.bc import BC
from ray.tune.tuner import Tuner

  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
  _descriptor.FieldDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORSHAPEPROTO = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescriptor(
  _descriptor.EnumValueDescri

  if (distutils.version.LooseVersion(tf.__version__) <
  distutils.version.LooseVersion(required_tensorflow_version)):
Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


We will be training on offline data - this means we have full agent trajectories stored somewhere on disk and want to train on these past experiences.

Usually this data could come from external systems, or a database of historical data. But for this example, we'll generate some offline data ourselves and store it using RLlibs `output_config`.

In [2]:
def generate_offline_data(path: str):
    print(f"Generating offline data for training at {path}")
    trainer = RLTrainer(
        algorithm="PPO",
        run_config=RunConfig(stop={"timesteps_total": 5000}),
        config={
            "env": "CartPole-v1",
            "output": "dataset",
            "output_config": {
                "format": "json",
                "path": path,
                "max_num_samples_per_file": 1,
            },
            "batch_mode": "complete_episodes",
            "framework": "torch"
        },
    )
    trainer.fit()

Here we define the training function. It will create an `RLTrainer` using the `PPO` algorithm and kick off training on the `CartPole-v1` environment. It will use the offline data provided in `path` for this.

In [3]:
def train_rl_bc_offline(path: str, num_workers: int, use_gpu: bool = False) -> Result:
    print("Starting offline training")
    dataset = ray.data.read_json(
        path, parallelism=num_workers, ray_remote_args={"num_cpus": 1}
    )

    trainer = RLTrainer(
        run_config=RunConfig(stop={"training_iteration": 5}),
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
        datasets={"train": dataset},
        algorithm=BC,
        config={
            "env": "CartPole-v1",
            "framework": "tf",
            "evaluation_num_workers": 1,
            "evaluation_interval": 1,
            "evaluation_config": {"input": "sampler"},
            "framework": "torch"
        },
    )

    # Todo (krfricke/xwjiang): Enable checkpoint config in RunConfig
    # result = trainer.fit()
    tuner = Tuner(
        trainer,
        _tuner_kwargs={"checkpoint_at_end": True},
    )
    result = tuner.fit()[0]
    return result

Once we trained our RL policy, we want to evaluate it on a fresh environment. For this, we will also define a utility function:

In [4]:
def evaluate_using_checkpoint(checkpoint: Checkpoint, num_episodes) -> list:
    predictor = RLPredictor.from_checkpoint(checkpoint)

    env = gym.make("CartPole-v1")

    rewards = []
    for i in range(num_episodes):
        obs, _ = env.reset()
        reward = 0.0
        terminated = truncated = False
        while not terminated and not truncated:
            action = predictor.predict(np.array([obs]))
            obs, r, terminated, truncated, _ = env.step(action[0])
            reward += r
        rewards.append(reward)

    return rewards

Let's put it all together. First, we initialize Ray and create the offline data:

In [5]:
ray.init(num_cpus=8)

path = "/tmp/out"
generate_offline_data(path)

2023-03-29 13:35:26,608	INFO worker.py:1612 -- Started a local Ray instance.


Generating offline data for training at /tmp/out


0,1
Current time:,2023-03-29 13:35:44
Running for:,00:00:17.76
Memory:,17.7/32.0 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
AIRPPO_3d2ab_00000,TERMINATED,127.0.0.1:10755,2,8.53423,8111,39.0381,163,9,39.0381


[2m[36m(pid=10755)[0m   DESCRIPTOR = _descriptor.FileDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10755)[0m   _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
[2m[36m(pid=10755)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10755)[0m   _TENSORSHAPEPROTO = _descriptor.Descriptor(
[2m[36m(pid=10755)[0m   DESCRIPTOR = _descriptor.FileDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10755)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=

[2m[36m(pid=10755)[0m   if (distutils.version.LooseVersion(tf.__version__) <
[2m[36m(pid=10755)[0m   distutils.version.LooseVersion(required_tensorflow_version)):
[2m[36m(pid=10755)[0m Instructions for updating:
[2m[36m(pid=10755)[0m experimental_relax_shapes is deprecated, use reduce_retracing instead
[2m[36m(AIRPPO pid=10755)[0m 2023-03-29 13:35:31,721	INFO algorithm.py:527 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[36m(pid=10777)[0m   DESCRIPTOR = _descriptor.FileDescriptor([32m [repeated 10x across cluster][0m
[2m[36m(pid=10777)[0m   _descriptor.FieldDescriptor([32m [repeated 80x across cluster][0m
[2m[36m(pid=10777)[0m   _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor([32m [repeated 2x across cluster][0m
[2m[36m(pid=10777)[0m   _TENSORSHAPEPROTO = _descriptor.Descriptor([32m [repeated 2x across cluster][0m
[2m[36m(pid=10777)[0m   _descriptor.EnumValueDescriptor([32m [r

[2m[36m(pid=10777)[0m Write: 0 active, 0 queued, 0.0 MiB objects 3:   0%|                                                                                                                                    | 0/1 [00:00<?, ?it/s][0m[A[A[A




[2m[36m(pid=10776)[0m Resource usage vs limits 0:   0%|                                                                                                                                                      | 0/1 [00:00<?, ?it/s][0m[A[A[A[A[A





[2m[36m(pid=10776)[0m Repartition 1:   0%|                                                                                                                                                                   | 0/1 [00:00<?, ?it/s][0m[A[A[A[A[A[A






[2m[36m(pid=10776)[0m Repartition 2:   0%|                                                                                                                                                                   | 0/1 [00:00<?, ?it/s][0m[A

[2m[36m(pid=10777)[0m Resource usage vs limits: 0.0/8.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory 0:   0%|                                                                                     | 0/1 [00:01<?, ?it/s][0m[A[A[A



[2m[36m(pid=10777)[0m output: 1 queued 4:   0%|                                                                                                                                                              | 0/1 [00:01<?, ?it/s][0m[A[A[A[A


[2m[36m(pid=10777)[0m Write: 0 active, 0 queued, 0.0 MiB objects 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.06s/it][0m[A[A[A



                                                                                                                                                                                                                         0,  1.06s/it][0m[A[A[A[A
          

Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_total,training_iteration,trial_id
AIRPPO_3d2ab_00000,8111,"{'ObsPreprocessorConnector_ms': 0.0021071661086309525, 'StateBufferConnector_ms': 0.0016682488577706473, 'ViewRequirementAgentConnector_ms': 0.03844828832717169}","{'num_env_steps_sampled': 8111, 'num_env_steps_trained': 8111, 'num_agent_steps_sampled': 8111, 'num_agent_steps_trained': 8111}",{},2023-03-29_13-35-44,True,39.0381,{},163,39.0381,9,105,278,avnishs-mbp-3.lan,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 1.189158587049072, 'cur_kl_coeff': 0.29999999999999993, 'cur_lr': 5.0000000000000016e-05, 'total_loss': 8.621792416771253, 'policy_loss': -0.02960230849760895, 'vf_loss': 8.645999835431576, 'vf_explained_var': 0.052853564110895, 'kl': 0.017982947373659144, 'entropy': 0.610724447791775, 'entropy_coeff': 0.0}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 128.0, 'num_grad_updates_lifetime': 1410.5, 'diff_num_grad_updates_vs_sampler_policy': 479.5}}, 'num_env_steps_sampled': 8111, 'num_env_steps_trained': 8111, 'num_agent_steps_sampled': 8111, 'num_agent_steps_trained': 8111}",2,127.0.0.1,8111,8111,8111,4099,8111,4099,0,2,0,0,4099,"{'cpu_util_percent': 23.299999999999997, 'ram_util_percent': 55.150000000000006}",10755,{},{},{},"{'mean_raw_obs_processing_ms': 0.15118922621316908, 'mean_inference_ms': 0.30473902694959276, 'mean_action_processing_ms': 0.044607727799071704, 'mean_env_wait_ms': 0.020418471832598782, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 163.0, 'episode_reward_min': 9.0, 'episode_reward_mean': 39.03809523809524, 'episode_len_mean': 39.03809523809524, 'episode_media': {}, 'episodes_this_iter': 105, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [49.0, 18.0, 50.0, 28.0, 11.0, 46.0, 19.0, 75.0, 24.0, 9.0, 18.0, 32.0, 80.0, 41.0, 27.0, 38.0, 68.0, 24.0, 19.0, 42.0, 17.0, 40.0, 34.0, 40.0, 24.0, 18.0, 91.0, 28.0, 27.0, 35.0, 26.0, 37.0, 76.0, 22.0, 16.0, 27.0, 14.0, 12.0, 15.0, 23.0, 33.0, 26.0, 38.0, 12.0, 19.0, 12.0, 97.0, 10.0, 48.0, 59.0, 58.0, 38.0, 69.0, 31.0, 18.0, 11.0, 26.0, 19.0, 77.0, 22.0, 16.0, 14.0, 80.0, 15.0, 22.0, 43.0, 12.0, 24.0, 39.0, 90.0, 163.0, 52.0, 95.0, 11.0, 26.0, 49.0, 29.0, 13.0, 21.0, 43.0, 81.0, 22.0, 64.0, 72.0, 25.0, 25.0, 12.0, 17.0, 24.0, 151.0, 72.0, 53.0, 20.0, 42.0, 64.0, 29.0, 48.0, 31.0, 61.0, 21.0, 115.0, 38.0, 10.0, 14.0, 68.0], 'episode_lengths': [49, 18, 50, 28, 11, 46, 19, 75, 24, 9, 18, 32, 80, 41, 27, 38, 68, 24, 19, 42, 17, 40, 34, 40, 24, 18, 91, 28, 27, 35, 26, 37, 76, 22, 16, 27, 14, 12, 15, 23, 33, 26, 38, 12, 19, 12, 97, 10, 48, 59, 58, 38, 69, 31, 18, 11, 26, 19, 77, 22, 16, 14, 80, 15, 22, 43, 12, 24, 39, 90, 163, 52, 95, 11, 26, 49, 29, 13, 21, 43, 81, 22, 64, 72, 25, 25, 12, 17, 24, 151, 72, 53, 20, 42, 64, 29, 48, 31, 61, 21, 115, 38, 10, 14, 68]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.15118922621316908, 'mean_inference_ms': 0.30473902694959276, 'mean_action_processing_ms': 0.044607727799071704, 'mean_env_wait_ms': 0.020418471832598782, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.0021071661086309525, 'StateBufferConnector_ms': 0.0016682488577706473, 'ViewRequirementAgentConnector_ms': 0.03844828832717169}}",8.53423,3.76533,8.53423,"{'training_iteration_time_ms': 4264.166, 'sample_time_ms': 1619.068, 'load_time_ms': 0.201, 'load_throughput': 20189910.827, 'learn_time_ms': 2643.283, 'learn_throughput': 1534.266, 'synch_weights_time_ms': 1.133}",1680122144,8111,2,3d2ab_00000


[2m[36m(pid=10776)[0m Resource usage vs limits 0:   0%|                                                                                                                                                      | 0/1 [00:00<?, ?it/s][0m
[2m[36m(pid=10776)[0m Repartition 1:   0%|                                                                                                                                                                   | 0/1 [00:00<?, ?it/s][0m[A

[2m[36m(pid=10776)[0m Repartition 2:   0%|                                                                                                                                                                   | 0/1 [00:00<?, ?it/s][0m[A[A

[2m[36m(pid=10776)[0m   *- Repartition 2:   0%|                                                                                                                                                              | 0/1 [00:00<?, ?it/s][0m[A[A


[2m[36m(pid=10776)[0m Write 3:   0%|  

[2m[36m(pid=10777)[0m Repartition 2:   0%|                                                                                                                                                                   | 0/1 [00:00<?, ?it/s][0m[A[A

[A[A                                                                                                                                                                                                                   


[2m[36m(pid=10776)[0m Resource usage vs limits: 0.0/8.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory 0:   0%|                                                                                     | 0/1 [00:00<?, ?it/s][0m[A[A[A


[A[A[A                                                                                                                                                                                                                



[2m[36m(pid=10776)[0m Repartition: 0 active, 0 queued, 0.0 MiB objec

[2m[36m(pid=10777)[0m output: 0 queued 4:   0%|                                                                                                                                                              | 0/1 [00:00<?, ?it/s][0m[A[A[A[A
[2m[36m(pid=10777)[0m Repartition: 0 active, 0 queued, 0.0 MiB objects, 0 output 1:   0%|                                                                                                                    | 0/1 [00:00<?, ?it/s][0m[A


[2m[36m(pid=10777)[0m Write: 0 active, 0 queued, 0.0 MiB objects 3:   0%|                                                                                                                                    | 0/1 [00:00<?, ?it/s][0m[A[A[A
[2m[36m(pid=10777)[0m Resource usage vs limits: 0.0/8.0 CPU, 0.0/0.0 GPU, 0.3 MiB/512.0 MiB object_store_memory 0:   0%|                                                                                     | 0/1 [00:00<?, ?it/s][0m[A
[2m[36m(pid=10777)[0m Repart

Then, we run training:

In [6]:
result = train_rl_bc_offline(path=path, num_workers=2, use_gpu=False)

Starting offline training


0,1
Current time:,2023-03-29 13:36:00
Running for:,00:00:14.26
Memory:,17.9/32.0 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
AIRBC_484d3_00000,TERMINATED,127.0.0.1:10841,5,0.859835,20292,,,,


[2m[36m(pid=10841)[0m   DESCRIPTOR = _descriptor.FileDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10841)[0m   _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
[2m[36m(pid=10841)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.FieldDescriptor(
[2m[36m(pid=10841)[0m   _TENSORSHAPEPROTO = _descriptor.Descriptor(
[2m[36m(pid=10841)[0m   DESCRIPTOR = _descriptor.FileDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=10841)[0m   _descriptor.EnumValueDescriptor(
[2m[36m(pid=

[2m[36m(pid=10841)[0m   if (distutils.version.LooseVersion(tf.__version__) <
[2m[36m(pid=10841)[0m   distutils.version.LooseVersion(required_tensorflow_version)):
[2m[36m(pid=10841)[0m Instructions for updating:
[2m[36m(pid=10841)[0m experimental_relax_shapes is deprecated, use reduce_retracing instead
[2m[36m(AIRBC pid=10841)[0m 2023-03-29 13:35:50,196	INFO algorithm.py:527 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[36m(AIRBC pid=10841)[0m 2023-03-29 13:35:50,201	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadJSON] -> AllToAllOperator[Repartition]
[2m[36m(pid=10841)[0m Resource usage vs limits 0:   0%|                                                                                                                                                      | 0/1 [00:00<?, ?it/s][0m
[2m[36m(pid=10841)[0m ReadJSON 1:   0%|                        

[2m[36m(RolloutWorker pid=10887)[0m DatasetReader 1 has 2, samples.


[2m[36m(pid=10888)[0m   if (distutils.version.LooseVersion(tf.__version__) <[32m [repeated 2x across cluster][0m
[2m[36m(pid=10888)[0m   distutils.version.LooseVersion(required_tensorflow_version)):[32m [repeated 2x across cluster][0m
[2m[36m(pid=10887)[0m Instructions for updating:[32m [repeated 2x across cluster][0m
[2m[36m(pid=10887)[0m experimental_relax_shapes is deprecated, use reduce_retracing instead[32m [repeated 2x across cluster][0m
[2m[36m(RolloutWorker pid=10887)[0m 2023-03-29 13:35:59,108	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]
[2m[36m(RolloutWorker pid=10887)[0m 2023-03-29 13:35:59,136	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]


Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,evaluation,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_total,training_iteration,trial_id
AIRBC_484d3_00000,20292,{},"{'num_env_steps_sampled': 20292, 'num_env_steps_trained': 20292, 'num_agent_steps_sampled': 20292, 'num_agent_steps_trained': 20292}",{},2023-03-29_13-36-00,True,,{},,,,0,0,"{'episode_reward_max': 50.0, 'episode_reward_min': 14.0, 'episode_reward_mean': 26.7, 'episode_len_mean': 26.7, 'episode_media': {}, 'episodes_this_iter': 10, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [14.0, 22.0, 25.0, 20.0, 16.0, 34.0, 50.0, 22.0, 29.0, 35.0], 'episode_lengths': [14, 22, 25, 20, 16, 34, 50, 22, 29, 35]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.12496042053729176, 'mean_inference_ms': 0.24082571639064926, 'mean_action_processing_ms': 0.043894146487920616, 'mean_env_wait_ms': 0.020369652395921135, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.00164031982421875, 'StateBufferConnector_ms': 0.0016498565673828125, 'ViewRequirementAgentConnector_ms': 0.03698587417602539}, 'num_agent_steps_sampled_this_iter': 267, 'num_env_steps_sampled_this_iter': 267, 'timesteps_this_iter': 267, 'num_healthy_workers': 1, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}",avnishs-mbp-3.lan,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 0.2725714482367039, 'policy_loss': 0.6909505128860474, 'total_loss': 0.6909505128860474}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 2000.0, 'num_grad_updates_lifetime': 9.5, 'diff_num_grad_updates_vs_sampler_policy': 8.5}}, 'num_env_steps_sampled': 20292, 'num_env_steps_trained': 20292, 'num_agent_steps_sampled': 20292, 'num_agent_steps_trained': 20292}",5,127.0.0.1,20292,20292,20292,4070,20292,4070,0,2,0,0,4070,{},10841,{},{},{},{},"{'episode_reward_max': nan, 'episode_reward_min': nan, 'episode_reward_mean': nan, 'episode_len_mean': nan, 'episode_media': {}, 'episodes_this_iter': 0, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [], 'episode_lengths': []}, 'sampler_perf': {}, 'num_faulty_episodes': 0, 'connector_metrics': {}}",0.859835,0.17546,0.859835,"{'training_iteration_time_ms': 52.015, 'sample_time_ms': 32.737, 'load_time_ms': 0.316, 'load_throughput': 12858561.228, 'learn_time_ms': 17.679, 'learn_throughput': 229558.949, 'synch_weights_time_ms': 1.219}",1680122160,20292,5,484d3_00000


[2m[36m(pid=10927)[0m   DESCRIPTOR = _descriptor.FileDescriptor([32m [repeated 5x across cluster][0m
[2m[36m(pid=10927)[0m   _descriptor.FieldDescriptor([32m [repeated 40x across cluster][0m
[2m[36m(pid=10927)[0m 2023-03-29 13:35:59,505	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]
[2m[36m(pid=10927)[0m 2023-03-29 13:35:59,505	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]
[2m[36m(pid=10927)[0m   _descriptor.EnumValueDescriptor([32m [repeated 47x across cluster][0m
[2m[36m(pid=10927)[0m 2023-03-29 13:35:59,505	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]
[2m[36m(pid=10927)[0m 2023-03-29 13:35:59,505	INFO streaming_executor.py:83 -- Executing DAG InputDataBuffer[Input] -> AllToAllOperator[RandomShuffle]
[2m[36m(pid=10927)[0m 2023-03-29 13:35:59,505	INFO streaming_executor.py:83 -- Ex

And then, using the obtained checkpoint, we evaluate the policy on a fresh environment:

In [7]:
num_eval_episodes = 3

rewards = evaluate_using_checkpoint(result.checkpoint, num_episodes=num_eval_episodes)
print(f"Average reward over {num_eval_episodes} episodes: " f"{np.mean(rewards)}")

2023-03-29 13:36:00,540	INFO policy.py:1284 -- Policy (worker=local) running on CPU.
2023-03-29 13:36:00,540	INFO torch_policy_v2.py:110 -- Found 0 visible cuda devices.


Average reward over 3 episodes: 18.333333333333332
