# Barkour Quadruped Training on Google Colab

This notebook trains a PPO policy for the Google Barkour VB quadruped robot using Brax and MuJoCo MJX.

**Based on the official [MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb)**

## Setup Instructions:
1. **Runtime → Change runtime type → T4 GPU** (or A100 for faster training)
2. Run all cells in order
3. Training takes ~6 minutes on A100 GPU, ~30-45 minutes on T4 GPU
4. Download trained policy from Files panel

## Features:
- ✅ Joystick control (x_vel, y_vel, ang_vel commands)
- ✅ Domain randomization (friction + actuator parameters)
- ✅ PPO training with Brax
- ✅ Policy visualization and video export

## 1. Install Dependencies

In [None]:
!pip install mujoco mujoco_mjx brax
!pip install flax optax orbax-checkpoint
!pip install mediapy
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)

print("\n✅ All dependencies installed!")

## 2. Verify GPU is Available

In [None]:
#@title Check GPU and Configure Environment

import subprocess
import os

# Check GPU
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add Nvidia EGL driver config for rendering
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use EGL rendering (GPU)
os.environ['MUJOCO_GL'] = 'egl'

# XLA optimization for better performance
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Test MuJoCo installation
import mujoco
mujoco.MjModel.from_xml_string('<mujoco/>')

import jax
print(f"✅ GPU detected: {jax.devices()}")
print(f"✅ MuJoCo configured with EGL rendering")

## 3. Download MuJoCo Menagerie Models

In [None]:
!git clone https://github.com/google-deepmind/mujoco_menagerie.git
print("\n✅ MuJoCo Menagerie downloaded!")

## 4. Import Libraries

In [None]:
import functools
from typing import Any, Dict, List, Sequence, Tuple
from datetime import datetime
from etils import epath

import jax
from jax import numpy as jp
import numpy as np
import matplotlib.pyplot as plt
import mediapy as media
from IPython.display import HTML

from ml_collections import config_dict
from flax.training import orbax_utils
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Motion, Transform
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

# More legible numpy printing
np.set_printoptions(precision=3, suppress=True, linewidth=100)

print("✅ Libraries imported successfully!")

## 5. Define Barkour Environment

In [None]:
#@title Barkour Environment Definition (from official MJX tutorial)

BARKOUR_ROOT_PATH = epath.Path('mujoco_menagerie/google_barkour_vb')


def get_config():
  """Returns reward config for barkour quadruped environment."""

  def get_default_rewards_config():
    default_config = config_dict.ConfigDict(
        dict(
            scales=config_dict.ConfigDict(
                dict(
                    tracking_lin_vel=1.5,
                    tracking_ang_vel=0.8,
                    lin_vel_z=-2.0,
                    ang_vel_xy=-0.05,
                    orientation=-5.0,
                    torques=-0.0002,
                    action_rate=-0.01,
                    feet_air_time=0.2,
                    stand_still=-0.5,
                    termination=-1.0,
                    foot_slip=-0.1,
                )
            ),
            tracking_sigma=0.25,
        )
    )
    return default_config

  default_config = config_dict.ConfigDict(
      dict(
          rewards=get_default_rewards_config(),
      )
  )

  return default_config


class BarkourEnv(PipelineEnv):
  """Environment for training the barkour quadruped joystick policy in MJX."""

  def __init__(
      self,
      obs_noise: float = 0.05,
      action_scale: float = 0.3,
      kick_vel: float = 0.05,
      scene_file: str = 'scene_mjx.xml',
      **kwargs,
  ):
    path = BARKOUR_ROOT_PATH / scene_file
    sys = mjcf.load(path.as_posix())
    self._dt = 0.02  # this environment is 50 fps
    sys = sys.tree_replace({'opt.timestep': 0.004})

    # override menagerie params for smoother policy
    sys = sys.replace(
        dof_damping=sys.dof_damping.at[6:].set(0.5239),
        actuator_gainprm=sys.actuator_gainprm.at[:, 0].set(35.0),
        actuator_biasprm=sys.actuator_biasprm.at[:, 1].set(-35.0),
    )

    n_frames = kwargs.pop('n_frames', int(self._dt / sys.opt.timestep))
    super().__init__(sys, backend='mjx', n_frames=n_frames)

    self.reward_config = get_config()
    # set custom from kwargs
    for k, v in kwargs.items():
      if k.endswith('_scale'):
        self.reward_config.rewards.scales[k[:-6]] = v

    self._torso_idx = mujoco.mj_name2id(
        sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'
    )
    self._action_scale = action_scale
    self._obs_noise = obs_noise
    self._kick_vel = kick_vel
    self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)
    self._default_pose = sys.mj_model.keyframe('home').qpos[7:]
    self.lowers = jp.array([-0.7, -1.0, 0.05] * 4)
    self.uppers = jp.array([0.52, 2.1, 2.1] * 4)
    feet_site = [
        'foot_front_left',
        'foot_hind_left',
        'foot_front_right',
        'foot_hind_right',
    ]
    feet_site_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)
        for f in feet_site
    ]
    assert not any(id_ == -1 for id_ in feet_site_id), 'Site not found.'
    self._feet_site_id = np.array(feet_site_id)
    lower_leg_body = [
        'lower_leg_front_left',
        'lower_leg_hind_left',
        'lower_leg_front_right',
        'lower_leg_hind_right',
    ]
    lower_leg_body_id = [
        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)
        for l in lower_leg_body
    ]
    assert not any(id_ == -1 for id_ in lower_leg_body_id), 'Body not found.'
    self._lower_leg_body_id = np.array(lower_leg_body_id)
    self._foot_radius = 0.0175
    self._nv = sys.nv

  def sample_command(self, rng: jax.Array) -> jax.Array:
    lin_vel_x = [-0.6, 1.5]  # min max [m/s]
    lin_vel_y = [-0.8, 0.8]  # min max [m/s]
    ang_vel_yaw = [-0.7, 0.7]  # min max [rad/s]

    _, key1, key2, key3 = jax.random.split(rng, 4)
    lin_vel_x = jax.random.uniform(
        key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
    )
    lin_vel_y = jax.random.uniform(
        key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
    )
    ang_vel_yaw = jax.random.uniform(
        key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
    )
    new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])
    return new_cmd

  def reset(self, rng: jax.Array) -> State:
    rng, key = jax.random.split(rng)

    pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))

    state_info = {
        'rng': rng,
        'last_act': jp.zeros(12),
        'last_vel': jp.zeros(12),
        'command': self.sample_command(key),
        'last_contact': jp.zeros(4, dtype=bool),
        'feet_air_time': jp.zeros(4),
        'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},
        'kick': jp.array([0.0, 0.0]),
        'step': 0,
    }

    obs_history = jp.zeros(15 * 31)  # store 15 steps of history
    obs = self._get_obs(pipeline_state, state_info, obs_history)
    reward, done = jp.zeros(2)
    metrics = {'total_dist': 0.0}
    for k in state_info['rewards']:
      metrics[k] = state_info['rewards'][k]
    state = State(pipeline_state, obs, reward, done, metrics, state_info)
    return state

  def step(self, state: State, action: jax.Array) -> State:
    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)

    # kick
    push_interval = 10
    kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
    kick *= jp.mod(state.info['step'], push_interval) == 0
    qvel = state.pipeline_state.qvel
    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
    state = state.tree_replace({'pipeline_state.qvel': qvel})

    # physics step
    motor_targets = self._default_pose + action * self._action_scale
    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
    x, xd = pipeline_state.x, pipeline_state.xd

    # observation data
    obs = self._get_obs(pipeline_state, state.info, state.obs)
    joint_angles = pipeline_state.q[7:]
    joint_vel = pipeline_state.qd[6:]

    # foot contact data based on z-position
    foot_pos = pipeline_state.site_xpos[self._feet_site_id]
    foot_contact_z = foot_pos[:, 2] - self._foot_radius
    contact = foot_contact_z < 1e-3  # a mm or less off the floor
    contact_filt_mm = contact | state.info['last_contact']
    contact_filt_cm = (foot_contact_z < 3e-2) | state.info['last_contact']
    first_contact = (state.info['feet_air_time'] > 0) * contact_filt_mm
    state.info['feet_air_time'] += self.dt

    # done if joint limits are reached or robot is falling
    up = jp.array([0.0, 0.0, 1.0])
    done = jp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
    done |= jp.any(joint_angles < self.lowers)
    done |= jp.any(joint_angles > self.uppers)
    done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18

    # reward
    rewards = {
        'tracking_lin_vel': (
            self._reward_tracking_lin_vel(state.info['command'], x, xd)
        ),
        'tracking_ang_vel': (
            self._reward_tracking_ang_vel(state.info['command'], x, xd)
        ),
        'lin_vel_z': self._reward_lin_vel_z(xd),
        'ang_vel_xy': self._reward_ang_vel_xy(xd),
        'orientation': self._reward_orientation(x),
        'torques': self._reward_torques(pipeline_state.qfrc_actuator),
        'action_rate': self._reward_action_rate(action, state.info['last_act']),
        'stand_still': self._reward_stand_still(
            state.info['command'], joint_angles,
        ),
        'feet_air_time': self._reward_feet_air_time(
            state.info['feet_air_time'],
            first_contact,
            state.info['command'],
        ),
        'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),
        'termination': self._reward_termination(done, state.info['step']),
    }
    rewards = {
        k: v * self.reward_config.rewards.scales[k] for k, v in rewards.items()
    }
    reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0)

    # state management
    state.info['kick'] = kick
    state.info['last_act'] = action
    state.info['last_vel'] = joint_vel
    state.info['feet_air_time'] *= ~contact_filt_mm
    state.info['last_contact'] = contact
    state.info['rewards'] = rewards
    state.info['step'] += 1
    state.info['rng'] = rng

    # sample new command if more than 500 timesteps achieved
    state.info['command'] = jp.where(
        state.info['step'] > 500,
        self.sample_command(cmd_rng),
        state.info['command'],
    )
    # reset the step counter when done
    state.info['step'] = jp.where(
        done | (state.info['step'] > 500), 0, state.info['step']
    )

    # log total displacement as a proxy metric
    state.metrics['total_dist'] = math.normalize(x.pos[self._torso_idx - 1])[1]
    state.metrics.update(state.info['rewards'])

    done = jp.float32(done)
    state = state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )
    return state

  def _get_obs(
      self,
      pipeline_state: base.State,
      state_info: dict[str, Any],
      obs_history: jax.Array,
  ) -> jax.Array:
    inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])
    local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)

    obs = jp.concatenate([
        jp.array([local_rpyrate[2]]) * 0.25,                 # yaw rate
        math.rotate(jp.array([0, 0, -1]), inv_torso_rot),    # projected gravity
        state_info['command'] * jp.array([2.0, 2.0, 0.25]),  # command
        pipeline_state.q[7:] - self._default_pose,           # motor angles
        state_info['last_act'],                              # last action
    ])

    # clip, noise
    obs = jp.clip(obs, -100.0, 100.0) + self._obs_noise * jax.random.uniform(
        state_info['rng'], obs.shape, minval=-1, maxval=1
    )
    # stack observations through time
    obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)

    return obs

  # ------------ reward functions----------------
  def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:
    return jp.square(xd.vel[0, 2])

  def _reward_ang_vel_xy(self, xd: Motion) -> jax.Array:
    return jp.sum(jp.square(xd.ang[0, :2]))

  def _reward_orientation(self, x: Transform) -> jax.Array:
    up = jp.array([0.0, 0.0, 1.0])
    rot_up = math.rotate(up, x.rot[0])
    return jp.sum(jp.square(rot_up[:2]))

  def _reward_torques(self, torques: jax.Array) -> jax.Array:
    return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))

  def _reward_action_rate(
      self, act: jax.Array, last_act: jax.Array
  ) -> jax.Array:
    return jp.sum(jp.square(act - last_act))

  def _reward_tracking_lin_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))
    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
    lin_vel_reward = jp.exp(
        -lin_vel_error / self.reward_config.rewards.tracking_sigma
    )
    return lin_vel_reward

  def _reward_tracking_ang_vel(
      self, commands: jax.Array, x: Transform, xd: Motion
  ) -> jax.Array:
    base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))
    ang_vel_error = jp.square(commands[2] - base_ang_vel[2])
    return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)

  def _reward_feet_air_time(
      self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
  ) -> jax.Array:
    rew_air_time = jp.sum((air_time - 0.1) * first_contact)
    rew_air_time *= (
        math.normalize(commands[:2])[1] > 0.05
    )  # no reward for zero command
    return rew_air_time

  def _reward_stand_still(
      self,
      commands: jax.Array,
      joint_angles: jax.Array,
  ) -> jax.Array:
    return jp.sum(jp.abs(joint_angles - self._default_pose)) * (
        math.normalize(commands[:2])[1] < 0.1
    )

  def _reward_foot_slip(
      self, pipeline_state: base.State, contact_filt: jax.Array
  ) -> jax.Array:
    pos = pipeline_state.site_xpos[self._feet_site_id]  # feet position
    feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]
    offset = base.Transform.create(pos=feet_offset)
    foot_indices = self._lower_leg_body_id - 1  # we got rid of the world body
    foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel

    return jp.sum(jp.square(foot_vel[:, :2]) * contact_filt.reshape((-1, 1)))

  def _reward_termination(self, done: jax.Array, step: jax.Array) -> jax.Array:
    return done & (step < 500)

  def render(
      self, trajectory: List[base.State], camera: str | None = None,
      width: int = 240, height: int = 320,
  ) -> Sequence[np.ndarray]:
    camera = camera or 'track'
    return super().render(trajectory, camera=camera, width=width, height=height)


envs.register_environment('barkour', BarkourEnv)
print("✅ BarkourEnv registered!")

## 6. Training Configuration

In [None]:
#@title Domain Randomization Function

def domain_randomize(sys, rng):
  """Randomizes the mjx.Model for better sim-to-real transfer."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction randomization
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator randomization
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes


print("✅ Domain randomization function defined!")

## 7. Create Environment with Domain Randomization

In [None]:
# Create environment
env_name = 'barkour'
env = envs.get_environment(env_name)

print(f"✅ Environment created!")
print(f"  - Observation size: {env.observation_size}")
print(f"  - Action size: {env.action_size}")

## 8. Training with Progress Tracking

In [None]:
#@title Train Policy with Progress Tracking

# Training configuration
NUM_TIMESTEPS = 100_000_000  # 100M steps (official tutorial setting)
NUM_EVALS = 10
EPISODE_LENGTH = 1000
NUM_ENVS = 8192  # Large batch for GPU efficiency
BATCH_SIZE = 256

# Progress tracking
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 40, 0

def progress(num_steps, metrics):
  """Callback to track and display training progress."""
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, NUM_TIMESTEPS * 1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'Barkour Training | Reward: {y_data[-1]:.3f}')
  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.grid(True, alpha=0.3)
  plt.show()

  print(f"Steps: {num_steps:,} | Reward: {metrics['eval/episode_reward']:.2f}")

# Setup checkpoint saving
ckpt_path = epath.Path('/tmp/barkour_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  """Save checkpoints during training."""
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  print(f"💾 Checkpoint saved: {path}")

# Configure PPO training
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128, 128))

train_fn = functools.partial(
    ppo.train,
    num_timesteps=NUM_TIMESTEPS,
    num_evals=NUM_EVALS,
    reward_scaling=1,
    episode_length=EPISODE_LENGTH,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=20,
    num_minibatches=32,
    num_updates_per_batch=4,
    discounting=0.97,
    learning_rate=3.0e-4,
    entropy_cost=1e-2,
    num_envs=NUM_ENVS,
    batch_size=BATCH_SIZE,
    network_factory=make_networks_factory,
    randomization_fn=domain_randomize,
    policy_params_fn=policy_params_fn,
    seed=0,
    progress_fn=progress
)

# Reset environment and start training
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)

print(f"\n🚀 Starting training...")
print(f"  - Total steps: {NUM_TIMESTEPS:,}")
print(f"  - Parallel envs: {NUM_ENVS}")
print(f"  - Estimated time: ~6 min on A100, ~30-45 min on T4 GPU\n")

make_inference_fn, params, _ = train_fn(
    environment=env,
    eval_env=eval_env
)

print(f'\n🎉 Training completed!')
print(f'  - Time to JIT: {times[1] - times[0]}')
print(f'  - Time to train: {times[-1] - times[1]}')

## 9. Plot Training Progress

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(training_steps, training_rewards, 'b-', linewidth=2)
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Episode Reward', fontsize=12)
plt.title('Barkour Training Progress', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()

plot_path = f"{LOG_DIR}/training_progress.png"
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
print(f"📊 Training plot saved: {plot_path}")
plt.show()

## 10. Save Trained Policy

In [None]:
# Save and reload final policy
model_path = '/tmp/mjx_brax_quadruped_policy'
model.save_params(model_path, params)

print(f"✅ Final policy saved: {model_path}")
print(f"\n📦 To download:")
print(f"1. Click Files icon (📁) in left sidebar")
print(f"2. Navigate to {ckpt_path} or {model_path}")
print(f"3. Right-click → Download")

# Reload params and create inference function
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

print(f"\n✅ Policy reloaded and ready for visualization!")

## 11. Visualize Trained Policy

In [None]:
#@title Visualize Trained Policy

# @markdown **Joystick Commands:**
x_vel = 1.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
ang_vel = -0.5  #@param {type: "number"}

the_command = jp.array([x_vel, y_vel, ang_vel])

# Create evaluation environment
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# Initialize and set command
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]

# Run trajectory
n_steps = 500
render_every = 2

print(f"Running {n_steps} steps with command: x={x_vel}, y={y_vel}, ang={ang_vel}")
for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

print(f"✅ Simulation completed: {len(rollout)} frames")

# Render with MuJoCo renderer
media.show_video(
    eval_env.render(rollout[::render_every], camera='track'),
    fps=1.0 / eval_env.dt / render_every)

## 12. Save Video File

In [None]:
#@title Alternative: Render with Brax HTML Renderer

# This provides an interactive 3D visualization
HTML(html.render(
    eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}),
    rollout
))

## 13. Download All Files

In [None]:
#@title Package and Download All Results

# Create zip file with all outputs
!zip -r /tmp/barkour_training_results.zip {ckpt_path} {model_path}

print("\n📦 All results packaged: /tmp/barkour_training_results.zip")
print("\n✅ Download from Files panel (📁)")
print("\nContains:")
print(f"  - Final policy: {model_path}")
print(f"  - Checkpoints: {ckpt_path}")
print(f"\n💡 Use these files with your local run_barkour_local.py script!")

## Summary

### What was trained:
- **Robot**: Google Barkour VB quadruped from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb)
- **Task**: Joystick control (forward/sideways/turning locomotion)
- **Algorithm**: PPO (Proximal Policy Optimization)
- **Training steps**: 100M (100 million environment steps)
- **Parallel environments**: 8192 (GPU-accelerated with MJX)
- **Domain randomization**: Friction (0.6-1.4x) + Actuator parameters (±5 gain)

### Files generated:
1. **`/tmp/mjx_brax_quadruped_policy`** - Final trained policy (main file)
2. **`/tmp/barkour_joystick/ckpts/`** - Intermediate checkpoints
3. **Training progress plot** - Reward vs. steps visualization
4. **Video demos** - Rendered robot locomotion

### Joystick Commands:
- **`x_vel`**: Forward/backward speed (-0.6 to 1.5 m/s)
- **`y_vel`**: Sideways speed (-0.8 to 0.8 m/s)
- **`ang_vel`**: Turning speed (-0.7 to 0.7 rad/s)

### Next steps:
1. **Download** the trained policy files from the Files panel (📁)
2. **Transfer** to your local machine: `c:\users\hatem\Desktop\MuJoCo\training_checkpoints\`
3. **Run locally** using your `run_barkour_local.py` script
4. **Experiment** with different commands and observe locomotion behaviors

### Based on:
This notebook is directly adapted from the [official MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb) which demonstrates training quadruped policies with Brax and MJX.