<a href="https://colab.research.google.com/github/EureXaAI/EurexaBook/blob/main/letter/EurexaBook_A_Z.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title 安装 brax
# 安装 Brax: Google 开发的独立物理引擎, 原生用 JAX 写的
!pip install brax

Collecting brax
  Downloading brax-0.12.1-py3-none-any.whl.metadata (7.7 kB)
Collecting dm_env (from brax)
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Collecting flask_cors (from brax)
  Downloading flask_cors-5.0.1-py3-none-any.whl.metadata (961 bytes)
Collecting jaxopt (from brax)
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Collecting ml_collections (from brax)
  Downloading ml_collections-1.0.0-py3-none-any.whl.metadata (22 kB)
Collecting mujoco (from brax)
  Downloading mujoco-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mujoco-mjx (from brax)
  Downloading mujoco_mjx-3.3.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pytinyrenderer (from brax)
  Downloading pytinyrenderer-0.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting tensorb

In [2]:
#@title 导入依赖包
#@markdown ### 本代码研究如何用 brax 和 mujoco 搭建最简单的强化学习案例
# JAX Imports:
import jax
import jax.numpy as jnp

In [86]:
#@title 物理模型
xml_model = """
<mujoco>
  <asset>
    <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.9375 0.7226 0.04296"
        rgb2="0 0 0" markrgb="0.8 0.8 0.8" width="1000" height="1000"/>
    <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5"
        reflectance="0.2"/>
  </asset>

  <worldbody>
    <light name="top" pos="0 0 1"/>
    <geom name="floor" pos="0 0 -1" size="0 0 .125" type="plane" material="groundplane" conaffinity="15" condim="3"/>
    <body name="letter_A" pos="0 0 0">
      <!-- 固定铰链 -->
      <joint name="swing" type="hinge" axis="1 -1 0" pos="0 0 0"/>
      <!-- 左侧斜杆 -->
      <geom name="left_leg" type="box" pos="-0.15 0 0" size=".025 .025 .36" quat="1.159 0 0.2588 0" rgba="0.9375 0.7226 0.04296 1"/>
      <!-- 右侧斜杆 -->
      <geom name="right_leg" type="box" pos="0.15 0 0" size=".025 .025 .36" quat="1.1659 0 -0.2588 0" rgba="0.9375 0.7226 0.04296 1"/>
      <!-- 横杆 -->
      <geom name="crossbar" type="box" pos="0 0 0" size=".15 .025 .025" rgba="0.9375 0.7226 0.04296 1"/>
    </body>
  </worldbody>
</mujoco>
"""

In [87]:
#@title 初始化模型 { run: "auto" }
# Brax Imports:
from brax.mjx import pipeline
from brax.io import mjcf, html

# Load the MJCF model
sys = mjcf.loads(xml_model)

# Jitting the init and step functions for GPU acceleration
init_fn = jax.jit(pipeline.init)
step_fn = jax.jit(pipeline.step)

# Initializing the state:
state = init_fn(
    sys=sys, q=sys.init_q, qd=jnp.zeros(sys.qd_size()),
)

In [84]:
#@title 执行模拟
num_steps = 100
state_history = []
ctrl = jnp.zeros(sys.act_size())
for i in range(num_steps):
    state = step_fn(sys, state, act=ctrl)
    state_history.append(state)

In [85]:
#@title 可视化
from IPython.display import HTML

HTML(
    html.render(
        sys=sys,
        states=state_history,
        height=500,
        colab=True,
    ),
)