# Unitree G1 Humanoid Training on Google Colab

This notebook trains a PPO policy for the Unitree G1 humanoid robot using Brax and MuJoCo MJX.

**Based on the official [MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb)**

## Setup Instructions:
1. **Runtime → Change runtime type → T4 GPU** (or A100 for faster training)
2. Run all cells in order
3. Training takes ~10-15 minutes on A100 GPU, ~60-90 minutes on T4 GPU
4. Download trained policy from Files panel

## Features:
- ✅ Humanoid locomotion (forward walking)
- ✅ Domain randomization (friction + actuator parameters)
- ✅ PPO training with Brax
- ✅ 29 DOF (6 per leg, 3 waist, 7 per arm)
- ✅ Policy visualization and video export

## 1. Install Dependencies

In [None]:
!pip install mujoco mujoco_mjx brax
!pip install flax optax orbax-checkpoint
!pip install mediapy
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)

print("\n✅ All dependencies installed!")

## 2. Verify GPU and Configure Environment

In [None]:
#@title Check GPU and Configure Environment

import subprocess
import os

# Check GPU
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add Nvidia EGL driver config for rendering
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{{
    "file_format_version" : "1.0.0",
    "ICD" : {{
        "library_path" : "libEGL_nvidia.so.0"
    }}
}}
""")

# Configure MuJoCo to use EGL rendering (GPU)
os.environ['MUJOCO_GL'] = 'egl'

# XLA optimization for better performance
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Test MuJoCo installation
import mujoco
mujoco.MjModel.from_xml_string('<mujoco/>')

import jax
print(f"✅ GPU detected: {jax.devices()}")
print(f"✅ MuJoCo configured with EGL rendering")

## 3. Download MuJoCo Menagerie Models

In [None]:
!git clone https://github.com/google-deepmind/mujoco_menagerie.git
print("\n✅ MuJoCo Menagerie downloaded!")

## 4. Import Libraries

In [None]:
import functools
from typing import Any, Dict, List, Sequence, Tuple
from datetime import datetime
from etils import epath

import jax
from jax import numpy as jp
import numpy as np
import matplotlib.pyplot as plt
import mediapy as media
from IPython.display import HTML

from ml_collections import config_dict
from flax.training import orbax_utils
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Motion, Transform
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

# More legible numpy printing
np.set_printoptions(precision=3, suppress=True, linewidth=100)

print("✅ Libraries imported successfully!")

## 5. Define Unitree G1 Humanoid Environment

In [None]:
#@title Unitree G1 Humanoid Environment Definition (Based on MJX Tutorial Humanoid)

G1_ROOT_PATH = epath.Path('mujoco_menagerie/unitree_g1')


class UnitreeG1Env(PipelineEnv):
  """Environment for training the Unitree G1 humanoid in MJX."""

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.6, 2.0),  # G1 is taller than Humanoid
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    path = G1_ROOT_PATH / 'scene_mjx.xml'
    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    
    # Optimize solver for G1
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    # Center of mass velocity (for G1, pelvis is body index 1)
    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes G1 body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # External contact forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])


envs.register_environment('unitree_g1', UnitreeG1Env)
print("✅ UnitreeG1Env registered!")

## 6. Domain Randomization Function

In [None]:
#@title Domain Randomization Function

def domain_randomize(sys, rng):
  """Randomizes the mjx.Model for better sim-to-real transfer."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction randomization
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator randomization
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes


print("✅ Domain randomization function defined!")

## 7. Create Environment

In [None]:
# Create environment
env_name = 'unitree_g1'
env = envs.get_environment(env_name)

print(f"✅ Environment created!")
print(f"  - Observation size: {env.observation_size}")
print(f"  - Action size: {env.action_size}")
print(f"  - DOF: {env.sys.nq} position, {env.sys.nv} velocity")

## 8. Visualize Untrained Policy (Optional)

In [None]:
#@title Quick Test: Visualize Random Actions

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Initialize
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# Grab a short trajectory with random actions
for i in range(50):
  ctrl = 0.1 * jp.ones(env.sys.nu)  # Small random action
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)
print("\n⚠️ This is WITHOUT training - the robot will likely fall!")

## 9. Train Policy with Progress Tracking

In [None]:
#@title Train Policy with Progress Tracking

# Training configuration (humanoids need more steps than quadrupeds)
NUM_TIMESTEPS = 20_000_000  # 20M steps for humanoid (vs 100M for very good performance)
NUM_EVALS = 5
EPISODE_LENGTH = 1000
NUM_ENVS = 3072  # Slightly fewer envs due to larger state space
BATCH_SIZE = 512

# Progress tracking
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0

def progress(num_steps, metrics):
  """Callback to track and display training progress."""
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, NUM_TIMESTEPS * 1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'Unitree G1 Training | Reward: {y_data[-1]:.3f}')
  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.grid(True, alpha=0.3)
  plt.show()

  print(f"Steps: {num_steps:,} | Reward: {metrics['eval/episode_reward']:.2f}")

# Setup checkpoint saving
ckpt_path = epath.Path('/tmp/unitree_g1_humanoid/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  """Save checkpoints during training."""
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  print(f"💾 Checkpoint saved: {path}")

# Configure PPO training (matching tutorial's Humanoid settings)
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(256, 256, 256))  # Larger network for humanoid

train_fn = functools.partial(
    ppo.train,
    num_timesteps=NUM_TIMESTEPS,
    num_evals=NUM_EVALS,
    reward_scaling=0.1,
    episode_length=EPISODE_LENGTH,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=10,
    num_minibatches=24,
    num_updates_per_batch=8,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-3,
    num_envs=NUM_ENVS,
    batch_size=BATCH_SIZE,
    network_factory=make_networks_factory,
    randomization_fn=domain_randomize,
    policy_params_fn=policy_params_fn,
    seed=0,
    progress_fn=progress
)

# Reset environment and start training
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)

print(f"\n🚀 Starting training...")
print(f"  - Total steps: {NUM_TIMESTEPS:,}")
print(f"  - Parallel envs: {NUM_ENVS}")
print(f"  - Estimated time: ~10-15 min on A100, ~60-90 min on T4 GPU")
print(f"  - Note: Humanoids are harder to train than quadrupeds!\n")

make_inference_fn, params, _ = train_fn(
    environment=env,
    eval_env=eval_env
)

print(f'\n🎉 Training completed!')
print(f'  - Time to JIT: {times[1] - times[0]}')
print(f'  - Time to train: {times[-1] - times[1]}')

## 10. Plot Training Progress

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(x_data, y_data, 'b-', linewidth=2)
plt.fill_between(x_data, 
                 [y - err for y, err in zip(y_data, ydataerr)],
                 [y + err for y, err in zip(y_data, ydataerr)],
                 alpha=0.3)
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Episode Reward', fontsize=12)
plt.title('Unitree G1 Humanoid Training Progress', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 11. Save Trained Policy

In [None]:
# Save and reload params
model_path = '/tmp/mjx_brax_unitree_g1_policy'
model.save_params(model_path, params)

print(f"✅ Final policy saved: {model_path}")
print(f"\n📦 To download:")
print(f"1. Click Files icon (📁) in left sidebar")
print(f"2. Navigate to {ckpt_path} or {model_path}")
print(f"3. Right-click → Download")

# Reload params and create inference function
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

print(f"\n✅ Policy reloaded and ready for visualization!")

## 12. Visualize Trained Policy

In [None]:
#@title Visualize Trained Policy

# Create evaluation environment
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# Initialize
rng = jax.random.PRNGKey(1)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# Run trajectory
n_steps = 500
render_every = 2

print(f"Running {n_steps} steps with trained policy...")
for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  
  if state.done:
    print(f"Episode terminated at step {i}")
    break

print(f"✅ Simulation completed: {len(rollout)} frames")

# Render with MuJoCo renderer
media.show_video(
    eval_env.render(rollout[::render_every], camera='side'),
    fps=1.0 / eval_env.dt / render_every)

## 13. Alternative: Render with Brax HTML Renderer

In [None]:
#@title Alternative: Render with Brax HTML Renderer

# This provides an interactive 3D visualization
HTML(html.render(
    eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}),
    rollout
))

## 14. Package and Download All Results

In [None]:
#@title Package and Download All Results

# Create zip file with all outputs
!zip -r /tmp/unitree_g1_training_results.zip {ckpt_path} {model_path}

print("\n📦 All results packaged: /tmp/unitree_g1_training_results.zip")
print("\n✅ Download from Files panel (📁)")
print("\nContains:")
print(f"  - Final policy: {model_path}")
print(f"  - Checkpoints: {ckpt_path}")
print(f"\n💡 Use these files with a local inference script!")

## Summary

### What was trained:
- **Robot**: Unitree G1 humanoid from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie/tree/main/unitree_g1)
- **Task**: Forward walking locomotion
- **Algorithm**: PPO (Proximal Policy Optimization)
- **Training steps**: 20M (20 million environment steps)
- **Parallel environments**: 3072 (GPU-accelerated with MJX)
- **Domain randomization**: Friction (0.6-1.4x) + Actuator parameters (±5 gain)

### Robot Specifications:
- **DOF**: 29 actuated joints
  - Legs: 6 DOF each (hip pitch/roll/yaw, knee, ankle pitch/roll)
  - Waist: 3 DOF (yaw, roll, pitch)
  - Arms: 7 DOF each (shoulder pitch/roll/yaw, elbow, wrist roll/pitch/yaw)
- **Height**: ~1.3m standing
- **Observation space**: ~400 dimensions (includes position, velocity, inertia, forces)

### Files generated:
1. **`/tmp/mjx_brax_unitree_g1_policy`** - Final trained policy
2. **`/tmp/unitree_g1_humanoid/ckpts/`** - Intermediate checkpoints
3. **Training progress plot** - Reward vs. steps visualization
4. **Video demos** - Rendered humanoid locomotion

### Training Notes:
- Humanoid locomotion is **much harder** than quadruped locomotion
- 20M steps gives basic walking; 100M+ steps for robust locomotion
- The G1 has a **complex state space** with 29 actuators
- Domain randomization helps with sim-to-real transfer
- Based on the [MuJoCo Playground](https://playground.mujoco.org/) project

### Expected Behavior:
With 20M steps training:
- ✅ Basic forward walking
- ✅ Maintaining balance
- ⚠️ May fall occasionally

For production-quality policies, train for 100M+ steps (~30-60 min on A100).

### Based on:
- [Official MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb)
- [MuJoCo Playground Paper](https://arxiv.org/abs/2502.08844)
- Unitree G1 model from MuJoCo Menagerie