In [None]:
from ray.rllib.algorithms.ppo import PPOConfig

config = (  # 1. Configure the algorithm,
    PPOConfig()
    .environment("Taxi-v3")
    .env_runners(num_env_runners=2)
    .framework("torch")
    .training(model={"fcnet_hiddens": [64, 64]})
    .evaluation(evaluation_num_env_runners=1)
)

algo = config.build()  # 2. build the algorithm,

for _ in range(5):
    print(algo.train())  # 3. train it,

algo.evaluate()  # 4. and evaluate it

In [None]:
from pprint import pprint

from ray.rllib.algorithms.ppo import PPOConfig

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("CartPole-v1")
    .env_runners(num_env_runners=1)
)

algo = config.build()

for i in range(10):
    result = algo.train()
    result.pop("config")
    pprint(result)

    if i % 5 == 0:
        checkpoint_dir = algo.save()
        print(f"Checkpoint saved in directory {checkpoint_dir}")

In [None]:
from ray import train, tune

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("CartPole-v1")
    .training(
        lr=tune.grid_search([0.01, 0.001, 0.0001]),
    )
)

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"env_runners/episode_return_mean": 150.0},
    ),
)

tuner.fit()

In [None]:
from ray import train, tune

# Tuner.fit() allows setting a custom log directory (other than ~/ray-results).
tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"num_env_steps_sampled_lifetime": 20000},
        checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
    ),
)

results = tuner.fit()

# Get the best result based on a particular metric.
best_result = results.get_best_result(
    metric="env_runners/episode_return_mean", mode="max"
)

# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint

print(f"Best learning rate: {best_result.config['lr']}")

In [None]:
import pathlib
import gymnasium as gym
import numpy as np
import torch
from ray.rllib.core.rl_module import RLModule

env = gym.make("CartPole-v1")

# Create only the neural network (RLModule) from our checkpoint.
rl_module = RLModule.from_checkpoint(
    pathlib.Path(best_checkpoint.path) / "learner_group" / "learner" / "rl_module"
)["default_policy"]

episode_return = 0
terminated = truncated = False

obs, info = env.reset()

while not terminated and not truncated:
    # Compute the next action from a batch (B=1) of observations.
    torch_obs_batch = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[
        "action_dist_inputs"
    ]
    # The default RLModule used here produces action logits (from which
    # we'll have to sample an action or use the max-likelihood one).
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, terminated, truncated, info = env.step(action)
    episode_return += reward

print(f"Reached episode return of {episode_return}.")

In [1]:
#@title Run to install MuJoCo and `dm_control`
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# print('Installing dm_control...')
# !pip install -q dm_control>=1.0.18

# Configure dm_control to use the EGL rendering backend (requires GPU)
%env MUJOCO_GL=egl

print('Checking that the dm_control installation succeeded...')
try:
  from dm_control import suite
  env = suite.load('cartpole', 'swingup')
  pixels = env.physics.render()
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')
else:
  del pixels, suite

!echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+")

Sat Oct 19 17:43:27 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:4B:00.0 Off |                  Off |
|  0%   31C    P8             20W /  450W |     277MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00

In [2]:
#@title Other imports and helper functions

# General
import copy
import os
import itertools
from IPython.display import clear_output
import numpy as np

# Graphics-related
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
import PIL.Image
# Internal loading of video libraries.

# Use svg backend for figure rendering
%config InlineBackend.figure_format = 'svg'

# Font sizes
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# Inline video helper function
if os.environ.get('COLAB_NOTEBOOK_TEST', False):
  # We skip video generation during tests, as it is quite expensive.
  display_video = lambda *args, **kwargs: None
else:
  def display_video(frames, framerate=30):
    height, width, _ = frames[0].shape
    dpi = 70

    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    # fig, ax = plt.subplots(1, 1, figsize=(scaled_width / dpi, scaled_height / dpi), dpi=dpi)
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])
    def update(frame):
      im.set_data(frame)
      return [im]
    interval = 1000/framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())

# Seed numpy's global RNG so that cell outputs are deterministic. We also try to
# use RandomState instances that are local to a single cell wherever possible.
np.random.seed(42)

In [3]:
import sys
import os

# Add the "env" folder to Python's module search path
env_path = "/home/zhangzhibo/CyberSpine"
if env_path not in sys.path:
    sys.path.append(env_path)

# print(sys.path)

In [4]:
import ray

ray.shutdown()
ray.init(runtime_env={"py_modules": ["/home/zhangzhibo/CyberSpine"]})

@ray.remote
def check_path():
    import sys
    return sys.path

print(ray.get(check_path.remote()))

2024-10-19 17:43:47,166	INFO worker.py:1783 -- Started a local Ray instance.
2024-10-19 17:43:47,209	INFO packaging.py:530 -- Creating a file package for local directory '/home/zhangzhibo/CyberSpine'.
2024-10-19 17:43:47,264	INFO packaging.py:358 -- Pushing file package 'gcs://_ray_pkg_60cd9e5eb2427767.zip' (4.43MiB) to Ray cluster...
2024-10-19 17:43:47,349	INFO packaging.py:371 -- Successfully pushed file package 'gcs://_ray_pkg_60cd9e5eb2427767.zip'.


['/home/zhangzhibo/CyberSpine/spine', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/thirdparty_files', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/workers', '/tmp/ray/session_2024-10-19_17-43-45_679365_3106382/runtime_resources/py_modules_files/_ray_pkg_60cd9e5eb2427767', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python311.zip', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/lib-dynload', '/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages']


In [5]:
from env import mice_env
import shimmy
from dm_control import suite

In [6]:
gym_mice_env = shimmy.DmControlCompatibilityV0(mice_env.rodent_maze_forage())

In [7]:
from gym.envs.registration import register

# 注册自定义环境
register(
    id='MiceEnv-v0',
    entry_point=lambda: shimmy.DmControlCompatibilityV0(mice_env.rodent_maze_forage())
)


In [8]:
from ray import tune
from shimmy.dm_control_compatibility import DmControlCompatibilityV0
from env import mice_env

def create_custom_env(cfg):
    return DmControlCompatibilityV0(mice_env.rodent_maze_forage())

# Register the environment with Ray
tune.register_env("MiceEnv-v0", lambda cfg: create_custom_env(cfg))


In [9]:
from ray.rllib.algorithms.ppo import PPOConfig

# 使用注册的环境ID
config = PPOConfig().environment("MiceEnv-v0")

  from jax import xla_computation as _xla_computation
  from jax import xla_computation as _xla_computation
  if (distutils.version.LooseVersion(tf.__version__) <
  distutils.version.LooseVersion(required_tensorflow_version)):


In [10]:
algo = config.build()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-10-19 17:44:33,256	ERROR actor_manager.py:523 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, [36mray::RolloutWorker.__init__()[39m (pid=3120302, ip=192.168.10.49, actor_id=edfd2bbd0d61f8284d4ccd1101000000, repr=<ray.rllib.evaluation.r

[36m(RolloutWorker pid=3120302)[0m No module named 'env'
[36m(RolloutWorker pid=3120302)[0m Traceback (most recent call last):
[36m(RolloutWorker pid=3120302)[0m   File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 423, in deserialize_objects
[36m(RolloutWorker pid=3120302)[0m     obj = self._deserialize_object(data, metadata, object_ref)
[36m(RolloutWorker pid=3120302)[0m           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[36m(RolloutWorker pid=3120302)[0m   File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 280, in _deserialize_object
[36m(RolloutWorker pid=3120302)[0m     return self._deserialize_msgpack_data(data, metadata_fields)
[36m(RolloutWorker pid=3120302)[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[36m(RolloutWorker pid=3120302)[0m   File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packa

RaySystemError: System error: No module named 'env'
traceback: Traceback (most recent call last):
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 423, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 280, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 235, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 225, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 457, in subimport
    __import__(name)
ModuleNotFoundError: No module named 'env'


In [12]:
from ray.tune.registry import register_env

def env_creator(env_config):
    return CustomMiceEnv  # Ensure your custom environment is instantiated correctly

register_env("CustomMiceEnv", env_creator)

# Debug Record

24/10/15 Debug:
1. parser.parse_args() 的问题在于 Jupyter Notebook 本身不通过命令行运行，因此无法传递参数给 argparse。
2. 如果你想在 .ipynb 文件中传递参数，可以手动设置参数，而不是从命令行解析。

24/10/16 Debug:
1. RLlib 不接受用function定义的环境作为.environment的输入，
2. 必须将maze_forazing()改为gym环境或者自定义为class变量输入给env
3. 重点考虑手动定义class的方法

24/10/18 Debug:
1. 仔细看RLlib 定义custom env的方法 https://docs.ray.io/en/latest/rllib/rllib-env.html

In [38]:
import sys
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
from ray.rllib.utils.test_utils import add_rllib_example_script_args

sys.argv = ['script_name','--num-gpus','1','--num-env-runners','4']

parser = add_rllib_example_script_args(
    default_iters=1000000,
    default_reward=800.0,
    default_timesteps=1000000
)
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.

args = parser.parse_args()


config = (
    DreamerV3Config()
    # Use image observations.
    .environment(
        env='MiceEnv-v0',
        env_config={"from_pixels": True},
    )
    .learners(
        num_learners=0 if args.num_gpus == 1 else args.num_gpus,
        num_gpus_per_learner=1 if args.num_gpus else 0,
    )
    .env_runners(
        num_env_runners=(args.num_env_runners or 0),
        # If we use >1 GPU and increase the batch size accordingly, we should also
        # increase the number of envs per worker.
        num_envs_per_env_runner=4 * (args.num_gpus or 1),
        remote_worker_envs=True,
    )
    .reporting(
        metrics_num_episodes_for_smoothing=(args.num_gpus or 1),
        report_images_and_videos=False,
        report_dream_data=False,
        report_individual_batch_item_stats=False,
    )
    # See Appendix A.
    .training(
        model_size="S",
        training_ratio=512,
        batch_size_B=16 * (args.num_gpus or 1),
    )

)

config.remote_worker_envs=False
rllib_algo = config.build(use_copy=False)


2024-10-19 17:18:46,801	ERROR actor_manager.py:523 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, [36mray::DreamerV3EnvRunner.__init__()[39m (pid=3045255, ip=192.168.10.49, actor_id=94e658cd9e7074997840b4a601000000, repr=<ray.rllib.algorithms.dreamerv3.utils.env_runner.DreamerV3EnvRunner object at 0x74c6bedd0710>)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: No module named 'env'
traceback: Traceback (most recent call last):
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 457, in subimport
    __import__(name)
ModuleNotFoundError: No module named 'env'
2024-10-19

RaySystemError: System error: No module named 'env'
traceback: Traceback (most recent call last):
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 423, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 280, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 235, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/_private/serialization.py", line 225, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangzhibo/anaconda3/envs/mujoco/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 457, in subimport
    __import__(name)
ModuleNotFoundError: No module named 'env'
