<a href="https://colab.research.google.com/github/alexeiplatzer/unitree-go2-mjx-rl/blob/main/notebooks/Universal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Univeral Notebook for Quadruped RL Training in MJX**
This notebook uses the `quadruped-mjx-rl` python package from the `unitree-go2-mjx-rl` repository to train locomotion policies for quadrupeds using reinforcement learning in the Mujoco XLA (MJX) simulation environment.

# Hardware Setup
This part sets up the `quadruped-mjx-rl` package on the machine.

In [1]:
#@title run this cell once each time on a new machine
#@markdown #### Setup configuration

#@markdown Choose your hardware option:
hardware = "Colab" #@param ["local","Colab","Kaggle"]

#@markdown Choose whether you want to build the madrona rendering setup for training
#@markdown with vision:
build_madrona_backend = False #@param {"type":"boolean"}

#@markdown Choose if you want to pull changes to the package repository during the runtime.
#@markdown (Requires a restart after executing this cell!)
editable_mode = True #@param {"type":"boolean"}

if build_madrona_backend:
    # Install madrona MJX
    import time
    print("Intalling Madrona MJX...")
    start_time = time.perf_counter()

    print("Setting up environment... (Step 1/3)")

    if hardware=="Kaggle":
        # Install the 12.4 cuda toolkit
        !wget -qO cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
        !sudo dpkg -i ./cuda-keyring.deb
        !sudo apt-get update -y -qq
        !sudo apt-get install -y -qq cuda-toolkit-12-4
        %CUDA_HOME=/usr/local/cuda-12.4
        %CUDAToolkit_ROOT=/usr/local/cuda-12.4
        %XLA_FLAGS="--xla_gpu_cuda_data_dir=/usr/local/cuda-12.4"
        %PATH="/usr/local/cuda-12.4/bin:$PATH"
        %LD_LIBRARY_PATH="/usr/local/cuda-12.4/lib64:${LD_LIBRARY_PATH}"

    !pip install jax["cuda12"]==0.5.2

    !sudo apt install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev mesa-common-dev

    !mkdir modules
    !git clone https://github.com/shacklettbp/madrona_mjx.git modules/madrona_mjx
    !git -C modules/madrona_mjx submodule update --init --recursive
    !mkdir modules/madrona_mjx/build

    print("Building the Madrona backend ... (Step 2/3)")
    !cmake -S modules/madrona_mjx -B modules/madrona_mjx/build -DLOAD_VULKAN=OFF
    !cmake --build modules/madrona_mjx/build -j

    print ("Installing Madrona MJX ... (Step 3/3)")
    !pip install modules/madrona_mjx

    minutes, seconds = divmod((time.perf_counter() - start_time), 60)
    print(f"Finished installing Madrona MJX in {minutes} m {seconds:.2f} s")

# Clones and installs our Quadruped RL package
!git clone https://github.com/alexeiplatzer/unitree-go2-mjx-rl.git
if editable_mode:
    !pip install -e unitree-go2-mjx-rl
else:
    !pip install unitree-go2-mjx-rl

fatal: destination path 'unitree-go2-mjx-rl' already exists and is not an empty directory.
Obtaining file:///content/unitree-go2-mjx-rl
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: quadruped_mjx_rl
  Building editable for quadruped_mjx_rl (pyproject.toml) ... [?25l[?25hdone
  Created wheel for quadruped_mjx_rl: filename=quadruped_mjx_rl-0.0.1-0.editable-py3-none-any.whl size=1943 sha256=a21af704bb5192610e5f5b14fdaaa7bb1c85b8f7315b35608af8fcac402a9672
  Stored in directory: /tmp/pip-ephem-wheel-cache-2b3m565v/wheels/69/a8/f7/6388f9792c334613332ec581d654046b6dcb0f2561a271a157
Successfully built quadruped_mjx_rl
Installing collected packages: quadruped_mjx_rl
  Attempting uninstall: quadruped_mjx_rl
    Found existing installation: quadrupe

### Now restart the session and continue.
### You can skip setup next time while you are on the same machine.

# Session setup
Run once in the beggining of every session, i.e. after restarts and crashes.

In [7]:
# @title Refresh the package if any necessary changes have been pushed. Important in development
repo_path = "./unitree-go2-mjx-rl"
!git -C {repo_path} pull

remote: Enumerating objects: 6, done.[K
remote: Counting objects:  16% (1/6)[Kremote: Counting objects:  33% (2/6)[Kremote: Counting objects:  50% (3/6)[Kremote: Counting objects:  66% (4/6)[Kremote: Counting objects:  83% (5/6)[Kremote: Counting objects: 100% (6/6)[Kremote: Counting objects: 100% (6/6), done.[K
remote: Total 6 (delta 4), reused 6 (delta 4), pack-reused 0 (from 0)[K
Unpacking objects:  16% (1/6)Unpacking objects:  33% (2/6)Unpacking objects:  50% (3/6)Unpacking objects:  66% (4/6)Unpacking objects:  83% (5/6)Unpacking objects: 100% (6/6)Unpacking objects: 100% (6/6), 495 bytes | 247.00 KiB/s, done.
From https://github.com/alexeiplatzer/unitree-go2-mjx-rl
   6c75107..3341918  main       -> origin/main
Updating 6c75107..3341918
Fast-forward
 src/quadruped_mjx_rl/terrain_gen/obstacles.py | 2 [32m+[m[31m-[m
 1 file changed, 1 insertion(+), 1 deletion(-)


In [1]:
# @title Configuration for both local and for Colab instances.

# Configure logging
import logging
logging.basicConfig(level=logging.INFO, force=True)
logging.info("Logging switched on.")

import os

#@markdown Choose your hardware option:
hardware = "Colab" #@param ["local","Colab","Kaggle"]

#@markdown Choose whether you want to build the madrona rendering setup for training
#@markdown with vision:
build_madrona_backend = False #@param {"type":"boolean"}

if build_madrona_backend:
    from pathlib import Path
    # On your second reading, load the compiled rendering backend to save time!
    cache_path = Path("modules/madrona_mjx/build/cache")
    if cache_path.exists():
        os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = "modules/madrona_mjx/build/cache"
    # Ensure that Madrona gets the chance to pre-allocate memory before Jax
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Check if MuJoCo installation was successful
import subprocess
if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError(
        'Cannot communicate with GPU. '
        'Make sure you are using a GPU Colab runtime. '
        'Go to the Runtime menu and select Choose runtime type.'
    )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
        "file_format_version" : "1.0.0",
        "ICD" : {
            "library_path" : "libEGL_nvidia.so.0"
        }
    }
    """)

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
    print('Checking that the installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the shell output above '
        'for more information.\n'
        'If using a hosted Colab runtime, make sure you enable GPU acceleration '
        'by going to the Runtime menu and selecting "Choose runtime type".'
    )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# More legible printing from numpy.
import numpy as np
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Prepare directories
from etils.epath import Path
repo_path = Path("unitree-go2-mjx-rl")
experiments_dir = Path("experiments")
trained_policy_dir = experiments_dir / "trained_policies"
!mkdir -p {trained_policy_dir}
configs_dir = experiments_dir / "configs"
!mkdir -p {configs_dir}
rollout_configs_dir = configs_dir / "rollout_configs"
!mkdir -p {rollout_configs_dir}
animations_dir = experiments_dir / "rendered_rollouts"
!mkdir -p {animations_dir}

INFO:root:Logging switched on.
INFO:OpenGL.acceleratesupport:No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


# Training

## Configurations

### Robot Configuration

In [2]:
from quadruped_mjx_rl.robots import predefined_robot_configs
#@markdown #### Choose the robot
robot = "unitree_go2" #@param ["unitree_go2", "google_barkour_vb"]
robot_config = predefined_robot_configs[robot]()

### Model Configuration

In [3]:
#@title #### Choose the model architecture and set its hyperparameters
from quadruped_mjx_rl import models
model_architecture = "ActorCritic" #@param ["ActorCritic","TeacherStudent","TeacherStudentVision"]
#@markdown ---
#@markdown **Model hyperparameters for the Actor-Critic Architecture**
if model_architecture == "ActorCritic":
    policy_layers = [128, 128, 128, 128, 128] # @param
    value_layers = [256, 256, 256, 256, 256] # @param

    model_config_class = models.ActorCriticConfig
    model_config = model_config_class(
        modules=model_config_class.ModulesConfig(
            policy=policy_layers,
            value=value_layers,
        ),
    )
#@markdown ---
#@markdown **Model hyperparameters for the Teacher-Student Architecture**
if model_architecture == "TeacherStudent":
    policy_layers = [256, 256, 256] #@param
    value_layers = [256, 256, 256] #@param
    teacher_encoder_layers = [256, 256] #@param
    student_encoder_layers = [256, 256] #@param
    latent_representation_size = 64 # @param {"type":"integer"}

    model_config_class = models.TeacherStudentConfig
    model_config = model_config_class(
        modules=model_config_class.ModulesConfig(
            policy=policy_layers,
            value=value_layers,
            encoder=teacher_encoder_layers,
            adapter=student_encoder_layers,
        ),
        latent_size=latent_representation_size,
    )
#@markdown ---
#@markdown **Model hyperparameters for the Teacher-Student-Vision Architecture**
if model_architecture == "TeacherStudentVision":
    policy_layers = [128, 128] #@param
    value_layers = [256, 256] #@param
    teacher_encoder_convolutional_layers = [32, 64, 64] #@param
    teacher_encoder_dense_layers = [256, 256] #@param
    student_encoder_convolutional_layers = [32, 64, 64] #@param
    student_encoder_dense_layers = [256, 256] #@param
    latent_representation_size = 128 #@param {"type":"integer"}

    model_config_class = models.TeacherStudentVisionConfig
    model_config = model_config_class(
        modules=model_config_class.ModulesConfig(
            policy=policy_layers,
            value=value_layers,
            encoder_convolutional=teacher_encoder_convolutional_layers,
            encoder_dense=teacher_encoder_dense_layers,
            adapter_convolutional=student_encoder_convolutional_layers,
            adapter_dense=student_encoder_dense_layers,
        ),
        latent_size=latent_representation_size,
    )

### Environment Configuration

In [4]:
#@title #### Configure the Environment
from quadruped_mjx_rl import environments

model_architecture = type(model_config).config_class_key()
if model_architecture == "TeacherStudentVision":
    env_config_class = environments.QuadrupedVisionEnvConfig
elif model_architecture == "TeacherStudent":
    env_config_class = environments.TeacherStudentEnvironmentConfig
elif model_architecture == "ActorCritic":
    env_config_class = environments.JoystickBaseEnvConfig
else:
    raise NotImplementedError

# TODO: add support for more environment params
simulation_timestep = 0.002 #@param {type:"number"}
control_timestep = 0.04 #@param {type:"number"}

environment_config = env_config_class(
    sim=env_config_class.SimConfig(
        sim_dt=simulation_timestep,
        ctrl_dt=control_timestep,
    ),
)

### Training Configuration

In [5]:
#@title #### Configure the training process
from quadruped_mjx_rl.training.configs import (
    TrainingConfig,
    TrainingWithVisionConfig,
    HyperparamsPPO,
    OptimizerConfig,
    TeacherStudentOptimizerConfig,
)
from quadruped_mjx_rl.robotic_vision import VisionConfig
model_architecture = type(model_config).config_class_key()
#@markdown ---
#@markdown **PPO Hyperparameters**
discounting = 0.97 #@param {"type":"number"}
entropy_cost = 0.01 #@param {"type":"number"}
clipping_epsilon = 0.3 #@param {"type":"number"}
gae_lambda = 0.95 #@param {"type":"number"}
normalize_advantage = True #@param {"type":"boolean"}
reward_scaling = 1 #@param {"type":"integer"}
learning_rate = 0.0004 #@param {"type":"number"}
ppo_hyperparams = HyperparamsPPO(
    discounting=discounting,
    entropy_cost=entropy_cost,
    clipping_epsilon=clipping_epsilon,
    gae_lambda=gae_lambda,
    normalize_advantage=normalize_advantage,
)
#@markdown **Teacher-Student specific hyperparameters**
#@markdown ---
if model_architecture == "TeacherStudent" or model_architecture == "TeacherStudentVision":
    student_learning_rate = 0.001 #@param {"type":"number"}
    max_grad_norm = 1.0 #@param
    optimizer_config = TeacherStudentOptimizerConfig(
        learning_rate=learning_rate,
        student_learning_rate=student_learning_rate,
        max_grad_norm=max_grad_norm,
    )
else:
    optimizer_config = OptimizerConfig(learning_rate=learning_rate)
#@markdown ---
#@markdown **Training hyperparameters without vision**
if model_architecture == "ActorCritic" or model_architecture == "TeacherStudent":
    num_envs = 8192 #@param {"type":"integer"}
    num_eval_envs = 8192 #@param {"type":"integer"}
    seed = 0 #@param {"type":"integer"}
    num_timesteps = 100_000_000 #@param {"type":"integer"}
    num_evals = 10 #@param {"type":"integer"}
    deterministic_eval = False #@param {"type":"boolean"}
    num_resets_per_eval = 0 #@param {"type":"integer"}
    episode_length = 1000 #@param {"type":"integer"}
    unroll_length = 20 #@param {"type":"integer"}
    normalize_observations = True  #@param {"type":"boolean"}
    action_repeat = 1 #@param {"type":"integer"}
    batch_size = 256 #@param {"type":"integer"}
    num_updates_per_batch = 4 #@param {"type":"integer"}
    num_minibatches = 32 #@param {"type":"integer"}
    training_config = TrainingConfig(
        num_envs=num_envs,
        num_eval_envs=num_eval_envs,
        seed=seed,
        num_timesteps=num_timesteps,
        num_evals=num_evals,
        deterministic_eval=deterministic_eval,
        num_resets_per_eval=num_resets_per_eval,
        episode_length=episode_length,
        unroll_length=unroll_length,
        normalize_observations=normalize_observations,
        action_repeat=action_repeat,
        batch_size=batch_size,
        num_updates_per_batch=num_updates_per_batch,
        num_minibatches=num_minibatches,
        rl_hyperparams=ppo_hyperparams,
        optimizer=optimizer_config
    )
    vision_config = None
#@markdown ---
#@markdown **Training hyperparameters with vision**
elif model_architecture == "TeacherStudentVision":
    num_envs = 256 #@param {"type":"integer"}
    num_eval_envs = 256 #@param {"type":"integer"}
    seed = 0 #@param {"type":"integer"}
    num_timesteps = 100_000_000 #@param {"type":"integer"}
    num_evals = 10 #@param {"type":"integer"}
    deterministic_eval = False #@param {"type":"boolean"}
    num_resets_per_eval = 0 #@param {"type":"integer"}
    episode_length = 1000 #@param {"type":"integer"}
    unroll_length = 20 #@param {"type":"integer"}
    normalize_observations = True  #@param {"type":"boolean"}
    action_repeat = 1
    batch_size = 256 #@param {"type":"integer"}
    num_updates_per_batch = 4 #@param {"type":"integer"}
    num_minibatches = 32 #@param {"type":"integer"}
    training_config = TrainingWithVisionConfig(
        num_envs=num_envs,
        num_eval_envs=num_eval_envs,
        seed=seed,
        num_timesteps=num_timesteps,
        num_evals=num_evals,
        deterministic_eval=deterministic_eval,
        num_resets_per_eval=num_resets_per_eval,
        episode_length=episode_length,
        unroll_length=unroll_length,
        normalize_observations=normalize_observations,
        action_repeat=action_repeat,
        batch_size=batch_size,
        num_updates_per_batch=num_updates_per_batch,
        num_minibatches=num_minibatches,
        rl_hyperparams=ppo_hyperparams,
        optimizer=optimizer_config,
    )
#@markdown **Vision renderer parameters**
    enabled_cameras=[0, 1, 2] # @param
    enabled_geom_groups=[0, 1, 2] # @param
    render_width=64 # @param {"type": "integer"}
    render_height=64 # @param {"type": "integer"}
    vision_config = VisionConfig(
        render_batch_size=training_config.num_envs,
        enabled_cameras=enabled_cameras,
        enabled_geom_groups=enabled_geom_groups,
        render_width=render_width,
        render_height=render_height,
    )
else:
    raise NotImplementedError

In [6]:
#@title #### Configure the terrain generation
use_challenging_terrain = True #@param {"type":"boolean"}
if use_challenging_terrain:
    from quadruped_mjx_rl.terrain_gen.obstacles import FlatTile, StripesTile
    from quadruped_mjx_rl.terrain_gen.tile import TerrainConfig
    flat_tile = FlatTile()
    stripes_tile = StripesTile()
    terrain_config = TerrainConfig(tiles=[[
        FlatTile(),
        StripesTile(stripe_amplitude=0.04),
        StripesTile(stripe_amplitude=0.08),
        StripesTile(stripe_amplitude=0.12),
        StripesTile(stripe_amplitude=0.16),
    ]])

### Save configs to a yaml file

In [10]:
#@title #### Save configs
from quadruped_mjx_rl.config_utils import save_configs
#@markdown Fill out a name for the experiment and all configuration parameters.
#@markdown If you want to add another experiment, change the parameters and run
#@markdown this cell again.
experiment_name = "my_experiment" # @param {type:"string"}
config_file_path = configs_dir / f"{experiment_name}.yaml"
configs_to_save = [robot_config, model_config, environment_config, training_config]
if vision_config is not None:
    configs_to_save.append(vision_config)
# if use_challenging_terrain:
#     configs_to_save.append(terrain_config)
save_configs(config_file_path, *configs_to_save)
print(f"Experiment configs saved to {config_file_path}")

Experiment configs saved to experiments/configs/my_experiment.yaml


## Training runs

In [8]:
#@title List all configuration files
!ls {configs_dir}

my_experiment.yaml  rollout_configs


In [None]:
# @title Sequentially run training for all configurations
import functools
from quadruped_mjx_rl.config_utils import prepare_configs
from quadruped_mjx_rl.training.train_interface import train
from quadruped_mjx_rl import environments
from quadruped_mjx_rl.environments import get_env_factory
from quadruped_mjx_rl.environments.physics_pipeline import load_to_spec, spec_to_model
from quadruped_mjx_rl.environments.rendering import render_model, large_overview_camera

# @markdown Choose with which configs to train
training_runs = None # @param {"type":"raw","placeholder":"[\"experiment_name1\", \"experiment_name2\", ... ]"}
# @markdown or
run_them_all = True # @param {"type":"boolean"}
if run_them_all:
    training_runs = [
        config_file.stem
        for config_file in configs_dir.iterdir() if config_file.name.endswith(".yaml")
    ]


for experiment_name in training_runs:
    config_path = configs_dir / f"{experiment_name}.yaml"
    configs = prepare_configs(config_path)
    environment_config = configs["environment"]
    robot_config = configs["robot"]
    model_config = configs["model"]
    training_config = configs["training"]
    vision_config = configs.get("vision")
    # terrain_config = configs.get("terrain")

    if terrain_config is not None:
        scene_file = "scene_mjx_empty_arena.xml"
    elif isinstance(environment_config, environments.QuadrupedVisionEnvConfig):
        scene_file = "scene_mjx_vision.xml"
    else:
        scene_file = "scene_mjx.xml"
    init_scene_path = repo_path / "resources" / robot_config.robot_name / scene_file

    env_spec = load_to_spec(init_scene_path)
    if terrain_config is not None:
        terrain_config.make_arena(env_spec)
    env_model = spec_to_model(env_spec)
    env_model = environments.QuadrupedJoystickBaseEnv.customize_model(
        env_model, environment_config
    )

    render_model(env_model=env_model, camera=large_overview_camera())

    if vision_config is not None:
        # instantiating mjx before madrona
        from mujoco import mjx
        mjx_model = mjx.put_model(env_model)
        mjx_data = mjx.make_data(mjx_model)
        mjx_data = mjx_data.replace(qpos=jnp.array(init_qpos))
        mjx_data = mjx.forward(mjx_model, mjx_data)

        from quadruped_mjx_rl.robotic_vision import get_renderer
        renderer_maker=functools.partial(get_renderer, vision_config=vision_config)
        get_env_factory = functools.partial(
            get_env_factory, vision_config=vision_config, renderer_maker=renderer_maker
        )
    env_factory = get_env_factory(
        robot_config=robot_config,
        environment_config=environment_config,
        env_model=env_model,
    )

    logging.info("Initializing the environment...")
    env = env_factory()
    if vision_config is None:
        eval_env = env_factory()
    else:
        eval_env = None


    trained_policy_save_path = trained_policy_dir / f"{experiment_name}"

    print(f"Starting training for: {experiment_name}")
    params = train(
        training_env=env,
        evaluation_env=eval_env,
        model_config=model_config,
        training_config=training_config,
    )
    from quadruped_mjx_rl.models.io import save_params
    save_params(trained_policy_save_path, params)
    print(f"Trained policy saved to {trained_policy_save_path}")

INFO:root:Initializing the environment...
INFO:root:Using JAX default device: TFRT_CPU_0.
INFO:root:No CUDA GPU devices found in jax.devices("cuda").
INFO:root:Using JAX default device: TFRT_CPU_0.
INFO:root:No CUDA GPU devices found in jax.devices("cuda").
INFO:root:Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1


Starting training for: my_experiment


INFO:root:Using JAX default device: TFRT_CPU_0.
INFO:root:No CUDA GPU devices found in jax.devices("cuda").
INFO:root:observation size: (465,)
INFO:root:observation size: (465,)


# Results
This section can be run independently from the training section, including after restarts and crashes. As long as all the created files remain in the session's disk memory.

In [None]:
#@title List all configured experiments
!ls {configs_dir}

my_experiment.yaml  rollout_configs


In [None]:
# @title Configure a rollout for rendering
from quadruped_mjx_rl.config_utils import save_configs
from quadruped_mjx_rl.policy_rendering import RenderConfig

experiment_name = "my_experiment"  #@param {type:"string"}
rollout_name = "my_rollout" # @param {type:"string"}

n_steps = 500 # @param {"type":"integer"}
render_every = 2 # @param {"type":"integer"}
random_seed = 0 # @param {"type":"integer"}

# @markdown ---
# @markdown Joystick command for the robot to follow (in SI)
x_vel = 0.8 # @param {"type":"number"}
y_vel = 0.0 # @param {"type":"number"}
ang_vel = 0.0 # @param {"type":"number"}

render_config = RenderConfig(
    n_steps=n_steps,
    episode_length=n_steps * 2,
    render_every=render_every,
    seed=random_seed,
    command={
        "x_vel": x_vel,
        "y_vel": y_vel,
        "ang_vel": ang_vel,
    },
)

config_file_path = rollout_configs_dir / f"{experiment_name}_rendering_{rollout_name}.yaml"
save_configs(config_file_path, render_config)
print(f"Rollout configs saved to {config_file_path}")

Rollout configs saved to experiments/configs/rollout_configs/my_experiment_rendering_my_rollout.yaml


In [None]:
# @title Render all configured policy rollouts
from quadruped_mjx_rl.policy_rendering import render_policy_rollout
from quadruped_mjx_rl.environments import get_env_factory
from quadruped_mjx_rl.config_utils import prepare_configs

# @markdown All rollouts present will be rendered
delete_rollouts_after_rendering = True # @param {"type":"boolean"}
save_rollout_gifs = True # @param {"type":"boolean"}

for experiment_config_file in configs_dir.iterdir():
    if not experiment_config_file.name.endswith(".yaml"):
        continue
    experiment_name = experiment_config_file.stem
    configs = prepare_configs(experiment_config_file)
    environment_config = configs["environment"]
    robot_config = configs["robot"]
    model_config = configs["model"]
    vision_config = configs.get("vision")

    init_scene_path = repo_path / "resources" / robot_config.robot_name / "scene_mjx.xml"

    env_factory, vision = get_env_factory(
        env_config=environment_config,
        robot_config=robot_config,
        init_scene_path=init_scene_path,
    )

    trained_policy = trained_policy_dir / f"{experiment_name}"

    rollout_configs_list = []
    for rollout_config_file in rollout_configs_dir.iterdir():
        if (
            not rollout_config_file.name.endswith(".yaml")
            or "_rendering_" not in rollout_config_file.name
            or experiment_name != rollout_config_file.name.split("_rendering_")[0]
        ):
            continue
        rollout_configs_list.append(rollout_config_file)

        render_config = prepare_configs(rollout_config_file)["render"]

        if save_rollout_gifs:
            animation_save_path = animations_dir / f"{rollout_config_file.stem}.gif"
        else:
            animation_save_path = None

        render_policy_rollout(
            env_factory=env_factory,
            model_config=model_config,
            trained_model_path=trained_policy,
            render_config=render_config,
            animation_save_path=animation_save_path,
            vision=vision,
        )

    if delete_rollouts_after_rendering:
        for rollout_config_file in rollout_configs_list:
            rollout_config_file.unlink()



In [None]:
# @title Saving results
from google.colab import files, drive
from etils.epath import Path

# @markdown (This can be run in a separate session)

# @markdown Choose what you want to save
policies = True # @param {"type":"boolean"}
rollout_gifs = True # @param {"type":"boolean"}
config_files = True # @param {"type":"boolean"}
# @markdown Only the configs for the training are saved

# @markdown Choose whether you want to download your results
download_results = False # @ param {"type":"boolean"}

# @markdown Choose whether you want to save results to your Google drive
save_to_drive = True # @param {"type":"boolean"}
drive_save_folder = "quadruped_mjx_rl_Results" # @param {type:"string"}
if save_to_drive:
    drive.mount('/content/drive')

for do_save, directory in zip(
    [policies, rollout_gifs, config_files], [trained_policy_dir, animations_dir, configs_dir],
):
    if not do_save:
        continue
    for file_path in directory.iterdir():
        if file_path.is_dir():
            continue
        if download_results:
            files.download(file_path)
        if save_to_drive:
            drive_dir = Path(f"/content/drive/MyDrive/{drive_save_folder}/{directory.name}")
            drive_dir.mkdir(parents=True, exist_ok=True)
            file_path.copy(
                dst=drive_dir / file_path.name, overwrite=True
            )

if save_to_drive:
    drive.flush_and_unmount()

Mounted at /content/drive
