<a href="https://colab.research.google.com/github/alexeiplatzer/unitree-go2-mjx-rl/blob/main/notebooks/Universal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Univeral Notebook for Quadruped RL Training in MJX

## Hardware Setup

In [1]:
# @title Setup configuration

# @markdown Choose your hardware option:
hardware = "Colab" # @param ["local","Colab","Kaggle"]

# @markdown Choose whether you want to build the rendering setup for training
# @markdown with vision, and with what backend:
vision_backend = "None" # @param ["None","MJX","Madrona"]

In [2]:
# @title run this cell once each time on a new machine

import time

if vision_backend == "Madrona":
    # Install madrona MJX
    print("Intalling Madrona MJX...")
    start_time = time.perf_counter()
    print("Setting up environment... (Step 1/3)")
    !pip uninstall -y jax
    !pip install jax["cuda12_local"]==0.4.35

    !sudo apt install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev mesa-common-dev

    !mkdir modules
    !git clone https://github.com/shacklettbp/madrona_mjx.git modules/madrona_mjx

    !git -C modules/madrona_mjx submodule update --init --recursive

    !mkdir modules/madrona_mjx/build

    if hardware == "Kaggle":
        !sudo apt-get install -y nvidia-cuda-toolkit

    print("Building the Madrona backend ... (Step 2/3)")
    !cd modules/madrona_mjx/build && cmake -DLOAD_VULKAN=OFF .. && make -j 8

    print ("Installing Madrona MJX ... (Step 3/3)")
    !pip install -e modules/madrona_mjx

    minutes, seconds = divmod((time.perf_counter() - start_time), 60)
    print(f"Finished installing Madrona MJX in {minutes} m {seconds:.2f} s")

# Clones and installs our Quadruped RL package
!git clone https://github.com/alexeiplatzer/unitree-go2-mjx-rl.git
!pip install -e unitree-go2-mjx-rl/

Cloning into 'unitree-go2-mjx-rl'...
remote: Enumerating objects: 856, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 856 (delta 70), reused 105 (delta 41), pack-reused 711 (from 1)[K
Receiving objects: 100% (856/856), 22.29 MiB | 30.03 MiB/s, done.
Resolving deltas: 100% (434/434), done.
Obtaining file:///content/unitree-go2-mjx-rl
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting mujoco (from quadruped_mjx_rl==0.0.1)
  Downloading mujoco-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mujoco_mjx (from quadruped_mjx_rl==0.0.1

### Now restart the session and continue.
### You can skip setup next time while you are on the same machine.

## Training

In [1]:
# @title Configuration for both local and for Colab instances.

repo_path = "./unitree-go2-mjx-rl"

# Refresh the repo for recent changes
# Important in development
!git -C {repo_path} pull

# On your second reading, load the compiled rendering backend to save time!
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError(
        'Cannot communicate with GPU. '
        'Make sure you are using a GPU Colab runtime. '
        'Go to the Runtime menu and select Choose runtime type.'
    )

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
        "file_format_version" : "1.0.0",
        "ICD" : {
            "library_path" : "libEGL_nvidia.so.0"
        }
    }
    """)

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
    print('Checking that the installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the shell output above '
        'for more information.\n'
        'If using a hosted Colab runtime, make sure you enable GPU acceleration '
        'by going to the Runtime menu and selecting "Choose runtime type".'
    )

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

Already up to date.
Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


In [11]:
#@title Choose the policy approximation method
model_architecture = "Actor-Critic" #@param ["Actor-Critic", "Teacher-Student"]
training_algorithm = "ppo" #@param ["ppo"]
vision_backend = "None" #@param ["None","MJX","Madrona"]

In [5]:
#@title Choose the Robot
from quadruped_mjx_rl.configs import predefined_robot_configs

robot = "unitree_go2" #@param ["unitree_go2", "google_barkour_vb"]
robot_config = predefined_robot_configs[robot]()
init_scene_path = f"{repo_path}/resources/{robot}/scene_mjx.xml"
# TODO: add scenes for vision

In [6]:
# @title Configure Environment
from quadruped_mjx_rl import environments

if model_architecture == "Actor-Critic" and training_algorithm == "ppo":
    env_class = environments.QuadrupedJoystickEnhancedEnv
    env_config_class = environments.EnhancedEnvironmentConfig
else:
    raise NotImplementedError

simulation_timestep = 0.002 #@param {type:"number"}
control_timestep = 0.002 #@param {type:"number"}

env_config = env_config_class(
    sim=env_config_class.SimConfig(
        sim_dt=simulation_timestep,
        ctrl_dt=control_timestep,
    ),
)

env = env_class(
    environment_config=env_config,
    robot_config=robot_config,
    init_scene_path=init_scene_path,
)


In [8]:
#@title Configure Model
from quadruped_mjx_rl.configs.config_classes import models

if model_architecture == "Actor-Critic":
    #@markdown ---
    #@markdown Model hyperparameters for the "Actor-Critic" architecture:
    policy_layers = [256, 256] #@param
    value_layers = [256, 256] #@param

    model_config_class = models.ActorCriticConfig
    model_config = model_config_class(
        modules=model_config_class.ActorCriticModulesConfig(
            policy=policy_layers,
            value=value_layers,
        ),
    )
elif model_architecture == "Teacher-Student":
    #@markdown ---
    #@markdown Model hyperparameters for the "Teacher-Student" architecture:
    policy_layers = [256, 256] #@param
    value_layers = [256, 256] #@param
    teacher_encoder_layers = [256, 256] #@param
    student_encoder_layers = [256, 256] #@param
    latent_representation_size = 16 # @param {"type":"integer"}

    model_config_class = models.TeacherStudentConfig
    model_config = model_config_class(
        modules=model_config_class.TeacherStudentModulesConfig(
            policy=policy_layers,
            value=value_layers,
            encoder=teacher_encoder_layers,
            adapter=student_encoder_layers,
        ),
        latent_size=latent_representation_size,
    )


In [12]:
#@title Configure training procedure
from quadruped_mjx_rl.configs import TrainingConfig


if vision_backend == "None":
    #@markdown ---
    #@markdown #### Training without vision:
    training_config = TrainingConfig(
        num_timesteps=1_000_000 #@param {"type":"integer"}
        ,num_evals=5 #@param {"type":"integer"}
        ,reward_scaling=1 #@param {"type":"integer"}
        ,episode_length=1000 #@param {"type":"integer"}
        ,normalize_observations=True #@param {"type":"boolean"}
        ,action_repeat=1 #@param {"type":"integer"}
        ,unroll_length=10 #@param {"type":"integer"}
        ,num_minibatches=8 #@param {"type":"integer"}
        ,num_updates_per_batch=8 #@param {"type":"integer"}
        ,discounting=0.97 #@param {"type":"number"}
        ,learning_rate=0.0005  #@param {"type":"number"}
        ,entropy_cost=0.005  #@param {"type":"number"}
        ,num_envs=512 #@param {"type":"integer"}
        ,batch_size=256 #@param {"type":"integer"}
    )
elif vision_backend == "Madrona":
    training_config_class = TrainingWithVisionConfig
else:
    raise NotImplementedError

In [None]:
#@title #### Training with vision
from quadruped_mjx_rl.config_classes import TrainingWithVisionConfig

training_config = TrainingWithVisionConfig(
    num_timesteps=1_000_000 # @param {"type":"integer","placeholder":"1_000_000"}
    num_evals=5 #@param {"type":"integer","default":"1"}
    reward_scaling=1 #@param {"type":"integer"}
    episode_length=1000 #@param {"type":"integer"}
    normalize_observations=True #@param {"type":"boolean"}
    action_repeat=1 #@param {"type":"integer"}
    unroll_length=10 #@param {"type":"integer"}
    num_minibatches=8 #@param {"type":"integer"}
    num_updates_per_batch=8 #@param {"type":"integer"}
    discounting=0.97 #@param {"type":"number"}
    learning_rate=0.0005  #@param {"type":"number"}
    entropy_cost=0.005  #@param {"type":"number"}
    num_envs=512 #@param {"type":"integer"}
    batch_size=256 #@param {"type":"integer"}

)


In [None]:
#@title Set Save Path
trained_policy_save_name = "my_policy" #@param {"type":"string"}
trained_policy_dir = "trained_policies"
!mkdir -p {trained_policy_dir}
save_model_path = f"{trained_policy_dir}/{trained_policy_save_name}"
print(f"Trained policy save path: {save_model_path}")

In [None]:
from quadruped_mjx_rl.training import train


## Results