<a href="https://colab.research.google.com/github/EureXaAI/EurexaBook/blob/main/rl/EurexaBook_0402.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
#@title 导入依赖包
#@markdown ### 本代码研究如何用 brax 和 mujoco 搭建最简单的强化学习案例
# JAX Imports:
import jax
import jax.numpy as jnp

In [18]:
#@title 模型: 木棍子
xml_model = """
<mujoco model="inverted pendulum">
    <compiler inertiafromgeom="true"/>

    <default>
        <joint armature="0" damping="1"/>
        <geom contype="0" conaffinity="0" friction="1 0.1 0.1"/>
    </default>

    <worldbody>
        <light diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
        <geom name="rail" type="capsule"  size="0.02 1.5" pos="0 0 0" quat="1 0 1 0" rgba="1 1 1 1"/>
        <body name="cart" pos="0 0 0">
            <joint name="slider" type="slide" axis="1 0 0" pos="0 0 0" limited="true" range="-1.5 1.5"/>
            <geom name="cart_geom" pos="0 0 0" quat="1 0 1 0" size="0.1 0.1" type="capsule"/>
            <body name="pole" pos="0 0 0">
                <joint name="hinge" type="hinge" axis="0 1 0" pos="0 0 0"/>
                <geom name="pole_geom" type="capsule" size="0.049 0.3"  fromto="0 0 0 0.001 0 0.6"/>
            </body>
        </body>
    </worldbody>

    <actuator>
        <motor ctrllimited="true" ctrlrange="-3 3" gear="100" joint="slider" name="slide"/>
    </actuator>

</mujoco>
"""

In [56]:
#@title 模型: unitree a1
xml_model = """
<mujoco model="a1">
  <compiler angle="radian" meshdir="assets" texturedir="assets" autolimits="true"/>

  <!-- brax doesn't support eliptic friction -->
  <option impratio="100"/>

  <default>
    <default class="a1">
      <geom friction="0.6" margin="0.001"/>
      <joint axis="0 1 0" damping="2" armature="0.01" frictionloss="0.2"/>
      <position kp="100" forcerange="-33.5 33.5"/>
      <default class="abduction">
        <joint axis="1 0 0" damping="1" range="-0.802851 0.802851"/>
        <position ctrlrange="-0.802851 0.802851"/>
      </default>
      <default class="hip">
        <joint range="-1.0472 4.18879"/>
        <position ctrlrange="-1.0472 4.18879"/>
      </default>
      <default class="knee">
        <joint range="-2.69653 -0.916298"/>
        <position ctrlrange="-2.69653 -0.916298"/>
      </default>
      <default class="visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2" material="dark"/>
      </default>
      <default class="collision">
        <!-- to avoid creating a huge A-matrix, turn off self-collision -->
        <geom group="3" type="capsule" contype="1" conaffinity="0" />
        <default class="hip_left">
          <geom size="0.04 0.04" quat="1 1 0 0" type="cylinder" pos="0 0.055 0"/>
        </default>
        <default class="hip_right">
          <geom size="0.04 0.04" quat="1 1 0 0" type="cylinder" pos="0 -0.055 0"/>
        </default>
        <default class="thigh1">
          <geom size="0.015" fromto="-0.02 0 0 -0.02 0 -0.16"/>
        </default>
        <default class="thigh2">
          <geom size="0.015" fromto="0 0 0 -0.02 0 -0.1"/>
        </default>
        <default class="thigh3">
          <geom size="0.015" fromto="-0.02 0 -0.16 0 0 -0.2"/>
        </default>
        <default class="calf1">
          <geom size="0.01" fromto="0 0 0 0.02 0 -0.13"/>
        </default>
        <default class="calf2">
          <geom size="0.01" fromto="0.02 0 -0.13 0 0 -0.2"/>
        </default>
        <default class="foot">
          <!-- BRAX: we don't yet support multiple solparams per geom, but will in the future -->
          <geom type="sphere" size="0.02" pos="0 0 -0.2" priority="1"
              condim="6" friction="0.8 0.02 0.01"/>
        </default>
      </default>
    </default>
  </default>

  <asset>
    <material name="dark" specular="0" shininess="0.25" rgba="0.2 0.2 0.2 1"/>
    <texture type="2d" name="trunk_A1" file="trunk_A1.png"/>
    <material name="carbonfibre" texture="trunk_A1" specular="0" shininess="0.25"/>

    <mesh class="a1" file="calf.obj"/>
    <mesh class="a1" file="hip.obj"/>
    <mesh class="a1" file="thigh.obj"/>
    <mesh class="a1" file="thigh_mirror.obj"/>
    <mesh class="a1" file="trunk.obj"/>
    <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.9375 0.7226 0.04296"
        rgb2="0 0 0" markrgb="0.8 0.8 0.8" width="1000" height="1000"/>
    <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5"
        reflectance="0.2"/>
  </asset>

  <worldbody>
    <light name="spotlight" mode="targetbodycom" target="trunk" pos="0 -1 2"/>
    <geom name="floor" size="0 0 .125" type="plane" material="groundplane" conaffinity="15" condim="3"/>
    <body name="trunk" pos="0 0 0.43" childclass="a1">
      <freejoint/>
      <inertial mass="4.713" pos="0 0.0041 -0.0005"
          fullinertia="0.0158533 0.0377999 0.0456542 -3.66e-05 -6.11e-05 -2.75e-05"/>
      <geom class="visual" mesh="trunk" material="carbonfibre"/>
      <geom class="collision" size="0.125 0.04 0.057" type="box"/>
      <geom class="collision" quat="1 0 1 0" pos="0 -0.04 0" size="0.058 0.125" type="cylinder"/>
      <geom class="collision" quat="1 0 1 0" pos="0 +0.04 0" size="0.058 0.125" type="cylinder"/>
      <geom class="collision" pos="0.25 0 0" size="0.005 0.06 0.05" type="box"/>
      <geom class="collision" pos="0.25 0.06 -0.01" size="0.009 0.035"/>
      <geom class="collision" pos="0.25 -0.06 -0.01" size="0.009 0.035"/>
      <geom class="collision" pos="0.25 0 -0.05" size="0.005 0.06" quat="1 1 0 0"/>
      <geom class="collision" pos="0.255 0 0.0355" size="0.021 0.052" quat="1 1 0 0"/>
      <body name="FR_hip" pos="0.183 -0.047 0">
        <inertial mass="0.696" pos="-0.003311 -0.000635 3.1e-05"
            quat="0.507528 0.506268 0.491507 0.494499"
            diaginertia="0.000807752 0.00055293 0.000468983"/>
        <joint class="abduction" name="FR_hip_joint"/>
        <geom class="visual" mesh="hip" quat="0 1 0 0"/>
        <geom class="hip_right"/>
        <body name="FR_thigh" pos="0 -0.08505 0">
          <inertial mass="1.013" pos="-0.003237 0.022327 -0.027326"
              quat="0.999125 -0.00256393 -0.0409531 -0.00806091"
              diaginertia="0.00555739 0.00513936 0.00133944"/>
          <joint class="hip" name="FR_thigh_joint"/>
          <geom class="visual" mesh="thigh_mirror"/>
          <geom class="thigh1"/>
          <geom class="thigh2"/>
          <geom class="thigh3"/>
          <body name="FR_calf" pos="0 0 -0.2">
            <inertial mass="0.226" pos="0.00472659 0 -0.131975"
                quat="0.706886 0.017653 0.017653 0.706886"
                diaginertia="0.00340344 0.00339393 3.54834e-05"/>
            <joint class="knee" name="FR_calf_joint"/>
            <geom class="visual" mesh="calf"/>
            <geom class="calf1"/>
            <geom class="calf2"/>
            <geom class="foot"/>
          </body>
        </body>
      </body>
      <body name="FL_hip" pos="0.183 0.047 0">
        <inertial mass="0.696" pos="-0.003311 0.000635 3.1e-05"
            quat="0.494499 0.491507 0.506268 0.507528"
            diaginertia="0.000807752 0.00055293 0.000468983"/>
        <joint class="abduction" name="FL_hip_joint"/>
        <geom class="visual" mesh="hip"/>
        <geom class="hip_left"/>
        <geom class="collision" size="0.04 0.04" pos="0 0.055 0" quat="1 1 0 0" type="cylinder"/>
        <body name="FL_thigh" pos="0 0.08505 0">
          <inertial mass="1.013" pos="-0.003237 -0.022327 -0.027326"
              quat="0.999125 0.00256393 -0.0409531 0.00806091"
              diaginertia="0.00555739 0.00513936 0.00133944"/>
          <joint class="hip" name="FL_thigh_joint"/>
          <geom class="visual" mesh="thigh"/>
          <geom class="thigh1"/>
          <geom class="thigh2"/>
          <geom class="thigh3"/>
          <body name="FL_calf" pos="0 0 -0.2">
            <inertial mass="0.226" pos="0.00472659 0 -0.131975"
                quat="0.706886 0.017653 0.017653 0.706886"
                diaginertia="0.00340344 0.00339393 3.54834e-05"/>
            <joint class="knee" name="FL_calf_joint"/>
            <geom class="visual" mesh="calf"/>
            <geom class="calf1"/>
            <geom class="calf2"/>
            <geom class="foot"/>
          </body>
        </body>
      </body>
      <body name="RR_hip" pos="-0.183 -0.047 0">
        <inertial mass="0.696" pos="0.003311 -0.000635 3.1e-05"
            quat="0.491507 0.494499 0.507528 0.506268"
            diaginertia="0.000807752 0.00055293 0.000468983"/>
        <joint class="abduction" name="RR_hip_joint"/>
        <geom class="visual" quat="0 0 0 -1" mesh="hip"/>
        <geom class="hip_right"/>
        <body name="RR_thigh" pos="0 -0.08505 0">
          <inertial mass="1.013" pos="-0.003237 0.022327 -0.027326"
              quat="0.999125 -0.00256393 -0.0409531 -0.00806091"
              diaginertia="0.00555739 0.00513936 0.00133944"/>
          <joint class="hip" name="RR_thigh_joint"/>
          <geom class="visual" mesh="thigh_mirror"/>
          <geom class="thigh1"/>
          <geom class="thigh2"/>
          <geom class="thigh3"/>
          <body name="RR_calf" pos="0 0 -0.2">
            <inertial mass="0.226" pos="0.00472659 0 -0.131975"
                quat="0.706886 0.017653 0.017653 0.706886"
                diaginertia="0.00340344 0.00339393 3.54834e-05"/>
            <joint class="knee" name="RR_calf_joint"/>
            <geom class="visual" mesh="calf"/>
            <geom class="calf1"/>
            <geom class="calf2"/>
            <geom class="foot"/>
          </body>
        </body>
      </body>
      <body name="RL_hip" pos="-0.183 0.047 0">
        <inertial mass="0.696" pos="0.003311 0.000635 3.1e-05"
            quat="0.506268 0.507528 0.494499 0.491507"
            diaginertia="0.000807752 0.00055293 0.000468983"/>
        <joint class="abduction" name="RL_hip_joint"/>
        <geom class="visual" quat="0 0 1 0" mesh="hip"/>
        <geom class="hip_left"/>
        <body name="RL_thigh" pos="0 0.08505 0">
          <inertial mass="1.013" pos="-0.003237 -0.022327 -0.027326"
              quat="0.999125 0.00256393 -0.0409531 0.00806091"
              diaginertia="0.00555739 0.00513936 0.00133944"/>
          <joint class="hip" name="RL_thigh_joint"/>
          <geom class="visual" mesh="thigh"/>
          <geom class="thigh1"/>
          <geom class="thigh2"/>
          <geom class="thigh3"/>
          <body name="RL_calf" pos="0 0 -0.2">
            <inertial mass="0.226" pos="0.00472659 0 -0.131975"
                quat="0.706886 0.017653 0.017653 0.706886"
                diaginertia="0.00340344 0.00339393 3.54834e-05"/>
            <joint class="knee" name="RL_calf_joint"/>
            <geom class="visual" mesh="calf"/>
            <geom class="calf1"/>
            <geom class="calf2"/>
            <geom class="foot"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>

  <actuator>
    <position class="abduction" name="FR_hip" joint="FR_hip_joint"/>
    <position class="hip" name="FR_thigh" joint="FR_thigh_joint"/>
    <position class="knee" name="FR_calf" joint="FR_calf_joint"/>
    <position class="abduction" name="FL_hip" joint="FL_hip_joint"/>
    <position class="hip" name="FL_thigh" joint="FL_thigh_joint"/>
    <position class="knee" name="FL_calf" joint="FL_calf_joint"/>
    <position class="abduction" name="RR_hip" joint="RR_hip_joint"/>
    <position class="hip" name="RR_thigh" joint="RR_thigh_joint"/>
    <position class="knee" name="RR_calf" joint="RR_calf_joint"/>
    <position class="abduction" name="RL_hip" joint="RL_hip_joint"/>
    <position class="hip" name="RL_thigh" joint="RL_thigh_joint"/>
    <position class="knee" name="RL_calf" joint="RL_calf_joint"/>
  </actuator>

  <keyframe>
    <key name="home" qpos="0 0 0.27 1 0 0 0 0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8"
        ctrl="0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8"/>
  </keyframe>
</mujoco>

"""

In [57]:
#@title 初始化模型 { run: "auto" }
# Brax Imports:
from brax.mjx import pipeline
from brax.io import mjcf, html

# Load the MJCF model
sys = mjcf.loads(xml_model)

# Jitting the init and step functions for GPU acceleration
init_fn = jax.jit(pipeline.init)
step_fn = jax.jit(pipeline.step)

# Initializing the state:
state = init_fn(
    sys=sys, q=sys.init_q, qd=jnp.zeros(sys.qd_size()),
)

In [58]:
#@title 执行模拟
num_steps = 1000
state_history = []
ctrl = jnp.zeros(sys.act_size())
for i in range(num_steps):
    state = step_fn(sys, state, act=ctrl)
    state_history.append(state)

In [59]:
#@title 状态历史记录
# MJX Backend Requires Contact Information even if it does not exist:
state_history = list(map(lambda x: x.replace(contact=None), state_history))

In [60]:
#@title 可视化
from IPython.display import HTML

HTML(
    html.render(
        sys=sys,
        states=state_history,
        height=480,
        colab=True,
    ),
)

In [37]:
#@title 训练配置 1
from brax.envs.base import PipelineEnv, State

# Environment:
class CartPole(PipelineEnv):
    """ Environment for training Cart Pole balancing """

    def __init__(self, xml_model: str, backend: str = 'mjx', **kwargs):
        # Initialize System:
        sys = mjcf.loads(xml_model)
        self.step_dt = 0.02
        n_frames = kwargs.pop('n_frames', int(self.step_dt / sys.opt.timestep))
        super().__init__(sys, backend=backend, n_frames=n_frames)

    def reset(self, rng: jax.Array) -> State:
        key, theta_key, qd_key = jax.random.split(rng, 3)

        theta_init = jax.random.uniform(theta_key, (1,), minval=-0.1, maxval=0.1)[0]

        # q structure: [x th]
        q_init = jnp.array([0.0, theta_init])

        # qd structure: [dx dth]
        qd_init = jax.random.uniform(qd_key, (2,), minval=-0.1, maxval=0.1)

        # Initialize State:
        pipeline_state = self.pipeline_init(q_init, qd_init)

        # Initialize Rewards:
        reward, done = jnp.zeros(2)

        # Get observation for RL Algorithm (Input to our neural net):
        observation = self.get_observation(pipeline_state)

        # Metrics:
        metrics = {
            'rewards': reward,
            # 'observation': observation,
        }

        state = State(
            pipeline_state=pipeline_state,
            obs=observation,
            reward=reward,
            done=done,
            metrics=metrics,
        )

        return state

    def step(self, state: State, action: jax.Array) -> State:
        # Forward Physics Step:
        pipeline_state = self.pipeline_step(state.pipeline_state, action)

        # Get Observation from new state:
        observation = self.get_observation(pipeline_state)

        # Extract States:
        x, th = pipeline_state.q

        # Terminate if outside the range of the rail and pendulum is past the rail:
        outside_x = jnp.abs(x) > 1.0
        outside_th = jnp.abs(th) > jnp.pi / 2
        done = outside_x | outside_th
        done = jnp.float32(done)

        # Calculate Reward:
        reward = jnp.cos(th)

        metrics = {
            'rewards': reward,
            # 'observation': observation,
        }
        state.metrics.update(metrics)

        # Update State object:
        state = state.replace(
            pipeline_state=pipeline_state, obs=observation, reward=reward, done=done,
        )
        return state

    def get_observation(self, pipeline_state: State) -> jnp.ndarray:
        # Observation: [x, th, dx, dth]
        return jnp.concatenate([pipeline_state.q, pipeline_state.qd])


In [38]:
#@title 训练配置 2
env = CartPole(xml_model=xml_model, backend='mjx')
eval_env = CartPole(xml_model=xml_model, backend='mjx')

def progress_fn(current_step, metrics):
    if current_step > 0:
        print(f'Step: {current_step} \t Reward: Episode Reward: {metrics["eval/episode_reward"]:.3f}')

In [39]:
#@title 训练配置 3
import functools

from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo

make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128, 128),
)

train_fn = functools.partial(
    ppo.train,
    num_timesteps=200_000,
    num_evals=10,
    episode_length=200,
    num_envs=32,
    num_eval_envs=4,
    batch_size=32,
    num_minibatches=4,
    unroll_length=20,
    num_updates_per_batch=4,
    normalize_observations=True,
    discounting=0.97,
    learning_rate=3.0e-4,
    entropy_cost=1e-2,
    network_factory=make_networks_factory,
    seed=0,
)


In [40]:
#@title 执行训练
make_policy_fn, params, _ = train_fn(
    environment=env,
    progress_fn=progress_fn,
    eval_env=eval_env,
)

Step: 23040 	 Reward: Episode Reward: 74.832
Step: 46080 	 Reward: Episode Reward: 199.880
Step: 69120 	 Reward: Episode Reward: 199.953
Step: 92160 	 Reward: Episode Reward: 199.948
Step: 115200 	 Reward: Episode Reward: 199.974
Step: 138240 	 Reward: Episode Reward: 199.985
Step: 161280 	 Reward: Episode Reward: 199.987
Step: 184320 	 Reward: Episode Reward: 199.980
Step: 207360 	 Reward: Episode Reward: 199.976


In [41]:
#@title 策略加载
# 把 params 也就是训练好的网络参数 装载 到了 policy 网络中
policy_fn = make_policy_fn(params)
# 将 policy_fn 编译成更快的版本, 后续调用时会极大加速推理速度
policy_fn = jax.jit(policy_fn)

In [42]:
#@title 策略推理
env = CartPole(xml_model=xml_model, backend='mjx')

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

key = jax.random.key(42)
state = reset_fn(key)

state_history = []
num_steps = 200
for i in range(num_steps):
    key, subkey = jax.random.split(key)
    action, _ = policy_fn(state.obs, subkey)
    state = step_fn(state, action)
    state_history.append(state.pipeline_state)

In [43]:
#@title 构建历史记录
# MJX Backend Requires Contact Information even if it does not exist:
state_history = list(map(lambda x: x.replace(contact=None), state_history))

In [44]:
#@title 渲染历史记录
HTML(
    html.render(
        sys=env.sys.tree_replace({'opt.timestep': env.dt}),
        states=state_history,
        height=480,
        colab=True,
    ),
)