# Franka Emika Panda Manipulation Training on Google Colab

This notebook trains a PPO policy for the Franka Emika Panda robotic arm using Brax and MuJoCo MJX.

**Based on the official [MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb)**

## Setup Instructions:
1. **Runtime → Change runtime type → T4 GPU** (or A100 for faster training)
2. Run all cells in order
3. Training takes ~5-10 minutes on A100 GPU, ~20-30 minutes on T4 GPU
4. Download trained policy from Files panel

## Features:
- ✅ Robotic arm reaching task
- ✅ 7-DOF arm + 2-finger gripper (8 actuators)
- ✅ Domain randomization (friction + actuator parameters)
- ✅ PPO training with Brax
- ✅ Target reaching with end-effector
- ✅ Policy visualization and video export

## Task Description:
The Panda arm learns to reach a randomly placed target sphere with its end-effector (gripper).

## 1. Install Dependencies

In [None]:
!pip install mujoco mujoco_mjx brax
!pip install flax optax orbax-checkpoint
!pip install mediapy
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)

print("\n✅ All dependencies installed!")

## 2. Verify GPU and Configure Environment

In [None]:
#@title Check GPU and Configure Environment

import subprocess
import os

# Check GPU
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add Nvidia EGL driver config for rendering
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{{
    "file_format_version" : "1.0.0",
    "ICD" : {{
        "library_path" : "libEGL_nvidia.so.0"
    }}
}}
""")

# Configure MuJoCo to use EGL rendering (GPU)
os.environ['MUJOCO_GL'] = 'egl'

# XLA optimization for better performance
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Test MuJoCo installation
import mujoco
mujoco.MjModel.from_xml_string('<mujoco/>')

import jax
print(f"✅ GPU detected: {jax.devices()}")
print(f"✅ MuJoCo configured with EGL rendering")

## 3. Download MuJoCo Menagerie Models

In [None]:
!git clone https://github.com/google-deepmind/mujoco_menagerie.git
print("\n✅ MuJoCo Menagerie downloaded!")

## 4. Import Libraries

In [None]:
import functools
from typing import Any, Dict, List, Sequence, Tuple
from datetime import datetime
from etils import epath

import jax
from jax import numpy as jp
import numpy as np
import matplotlib.pyplot as plt
import mediapy as media
from IPython.display import HTML

from ml_collections import config_dict
from flax.training import orbax_utils
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Motion, Transform
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

# More legible numpy printing
np.set_printoptions(precision=3, suppress=True, linewidth=100)

print("✅ Libraries imported successfully!")

## 5. Define Franka Panda Reaching Environment

In [None]:
#@title Franka Panda Reaching Environment Definition

PANDA_ROOT_PATH = epath.Path('mujoco_menagerie/franka_emika_panda')


class FrankaPandaReachEnv(PipelineEnv):
  """Environment for training Franka Panda to reach a target with its end-effector."""

  def __init__(
      self,
      distance_reward_weight=10.0,
      ctrl_cost_weight=0.01,
      reaching_reward=100.0,
      reaching_threshold=0.05,  # 5cm distance
      target_range=0.3,  # Target spawns within 30cm sphere
      reset_noise_scale=0.01,
      **kwargs,
  ):
    # Load MJX version of Panda (optimized for MJX)
    path = PANDA_ROOT_PATH / 'mjx_scene.xml'
    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    
    # Add target sphere to the scene
    xml = f"""
    <mujoco model="panda_reach">
      <include file="{path.as_posix()}"/>
      
      <worldbody>
        <body name="target" pos="0.5 0 0.5">
          <geom name="target" type="sphere" size="0.03" rgba="1 0 0 0.5" 
                contype="0" conaffinity="0"/>
          <site name="target_site" size="0.01"/>
        </body>
      </worldbody>
    </mujoco>
    """
    mj_model = mujoco.MjModel.from_xml_string(xml)
    
    # Optimize solver for manipulation
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._distance_reward_weight = distance_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._reaching_reward = reaching_reward
    self._reaching_threshold = reaching_threshold
    self._target_range = target_range
    self._reset_noise_scale = reset_noise_scale
    
    # Find important body/site indices
    self._target_body_id = mujoco.mj_name2id(
        mj_model, mujoco.mjtObj.mjOBJ_BODY, 'target')
    self._ee_site_id = mujoco.mj_name2id(
        mj_model, mujoco.mjtObj.mjOBJ_SITE, 'attachment_site')  # End-effector site

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

    # Reset robot with small noise
    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )
    
    # Randomize target position in workspace
    target_pos = jax.random.uniform(
        rng3, (3,), 
        minval=jp.array([0.3, -0.3, 0.2]),
        maxval=jp.array([0.7, 0.3, 0.6])
    )
    # Set target body position (first 3 elements after robot qpos)
    # Panda has 9 DOF, target has 7 (3 pos + 4 quat)
    qpos = qpos.at[9:12].set(target_pos)  # Set target xyz position

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'distance_reward': zero,
        'ctrl_cost': zero,
        'reaching_reward': zero,
        'distance_to_target': zero,
        'success': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    # Get end-effector position (attachment_site on hand)
    ee_pos = data.site_xpos[self._ee_site_id]
    
    # Get target position
    target_pos = data.xpos[self._target_body_id]
    
    # Calculate distance to target
    distance = jp.linalg.norm(ee_pos - target_pos)
    
    # Distance reward (negative distance, so closer = higher reward)
    distance_reward = -self._distance_reward_weight * distance
    
    # Reaching bonus (if within threshold)
    reaching_reward = jp.where(
        distance < self._reaching_threshold,
        self._reaching_reward,
        0.0
    )
    
    # Control cost (penalize large actions)
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
    
    # Success indicator
    success = jp.where(distance < self._reaching_threshold, 1.0, 0.0)

    obs = self._get_obs(data)
    reward = distance_reward + reaching_reward - ctrl_cost
    done = 0.0  # No terminal condition for reaching task
    
    state.metrics.update(
        distance_reward=distance_reward,
        ctrl_cost=ctrl_cost,
        reaching_reward=reaching_reward,
        distance_to_target=distance,
        success=success,
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(self, data: mjx.Data) -> jp.ndarray:
    """Observes robot joint positions, velocities, and relative target position."""
    # Robot state (9 DOF: 7 arm + 2 gripper)
    robot_qpos = data.qpos[:9]
    robot_qvel = data.qvel[:9]
    
    # End-effector position
    ee_pos = data.site_xpos[self._ee_site_id]
    
    # Target position
    target_pos = data.xpos[self._target_body_id]
    
    # Relative target position (goal - current)
    relative_target = target_pos - ee_pos
    
    return jp.concatenate([
        robot_qpos,
        robot_qvel,
        ee_pos,
        target_pos,
        relative_target,
    ])


envs.register_environment('franka_panda_reach', FrankaPandaReachEnv)
print("✅ FrankaPandaReachEnv registered!")

## 6. Domain Randomization Function

In [None]:
#@title Domain Randomization Function

def domain_randomize(sys, rng):
  """Randomizes the mjx.Model for better sim-to-real transfer."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction randomization
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator randomization (for arm actuators)
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes


print("✅ Domain randomization function defined!")

## 7. Create Environment

In [None]:
# Create environment
env_name = 'franka_panda_reach'
env = envs.get_environment(env_name)

print(f"✅ Environment created!")
print(f"  - Observation size: {env.observation_size}")
print(f"  - Action size: {env.action_size}")
print(f"  - DOF: {env.sys.nq} position, {env.sys.nv} velocity")

## 8. Visualize Untrained Policy (Optional)

In [None]:
#@title Quick Test: Visualize Random Actions

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Initialize
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# Grab a short trajectory with random actions
for i in range(100):
  ctrl = 0.1 * jax.random.normal(jax.random.PRNGKey(i), (env.sys.nu,))
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)
print("\n⚠️ This is WITHOUT training - random arm movements!")

## 9. Train Policy with Progress Tracking

In [None]:
#@title Train Policy with Progress Tracking

# Training configuration (manipulation tasks learn faster than locomotion)
NUM_TIMESTEPS = 10_000_000  # 10M steps for reaching task
NUM_EVALS = 5
EPISODE_LENGTH = 500  # Shorter episodes for reaching
NUM_ENVS = 4096  # Good for manipulation
BATCH_SIZE = 512

# Progress tracking
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 5000, -1000

def progress(num_steps, metrics):
  """Callback to track and display training progress."""
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, NUM_TIMESTEPS * 1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'Franka Panda Reaching | Reward: {y_data[-1]:.3f}')
  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.grid(True, alpha=0.3)
  plt.show()

  print(f"Steps: {num_steps:,} | Reward: {metrics['eval/episode_reward']:.2f}")

# Setup checkpoint saving
ckpt_path = epath.Path('/tmp/franka_panda_reach/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  """Save checkpoints during training."""
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  print(f"💾 Checkpoint saved: {path}")

# Configure PPO training (standard network for manipulation)
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=(128, 128, 128))  # Standard size for manipulation

train_fn = functools.partial(
    ppo.train,
    num_timesteps=NUM_TIMESTEPS,
    num_evals=NUM_EVALS,
    reward_scaling=0.1,
    episode_length=EPISODE_LENGTH,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=10,
    num_minibatches=24,
    num_updates_per_batch=8,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-3,
    num_envs=NUM_ENVS,
    batch_size=BATCH_SIZE,
    network_factory=make_networks_factory,
    randomization_fn=domain_randomize,
    policy_params_fn=policy_params_fn,
    seed=0,
    progress_fn=progress
)

# Reset environment and start training
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)

print(f"\n🚀 Starting training...")
print(f"  - Total steps: {NUM_TIMESTEPS:,}")
print(f"  - Parallel envs: {NUM_ENVS}")
print(f"  - Estimated time: ~5-10 min on A100, ~20-30 min on T4 GPU")
print(f"  - Task: Reach randomly placed target with end-effector\n")

make_inference_fn, params, _ = train_fn(
    environment=env,
    eval_env=eval_env
)

print(f'\n🎉 Training completed!')
print(f'  - Time to JIT: {times[1] - times[0]}')
print(f'  - Time to train: {times[-1] - times[1]}')

## 10. Plot Training Progress

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(x_data, y_data, 'b-', linewidth=2)
plt.fill_between(x_data, 
                 [y - err for y, err in zip(y_data, ydataerr)],
                 [y + err for y, err in zip(y_data, ydataerr)],
                 alpha=0.3)
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Episode Reward', fontsize=12)
plt.title('Franka Panda Reaching Training Progress', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 11. Save Trained Policy

In [None]:
# Save and reload params
model_path = '/tmp/mjx_brax_franka_panda_policy'
model.save_params(model_path, params)

print(f"✅ Final policy saved: {model_path}")
print(f"\n📦 To download:")
print(f"1. Click Files icon (📁) in left sidebar")
print(f"2. Navigate to {ckpt_path} or {model_path}")
print(f"3. Right-click → Download")

# Reload params and create inference function
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

print(f"\n✅ Policy reloaded and ready for visualization!")

## 12. Visualize Trained Policy

In [None]:
#@title Visualize Trained Policy

# Create evaluation environment
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# Initialize
rng = jax.random.PRNGKey(42)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# Run trajectory
n_steps = 250
render_every = 2

print(f"Running {n_steps} steps with trained policy...")
for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  
  # Print success when target is reached
  if state.metrics['success'] > 0.5 and i % 50 == 0:
    print(f"  ✅ Target reached at step {i}! Distance: {state.metrics['distance_to_target']:.3f}m")

print(f"✅ Simulation completed: {len(rollout)} frames")

# Render with MuJoCo renderer
media.show_video(
    eval_env.render(rollout[::render_every], camera='side'),
    fps=1.0 / eval_env.dt / render_every)

## 13. Alternative: Render with Brax HTML Renderer

In [None]:
#@title Alternative: Render with Brax HTML Renderer

# This provides an interactive 3D visualization
HTML(html.render(
    eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}),
    rollout
))

## 14. Test Multiple Random Targets

In [None]:
#@title Test Multiple Random Targets

# Test success rate across multiple random targets
num_tests = 10
success_count = 0
distances = []

print(f"Testing policy on {num_tests} random targets...\n")

for test_id in range(num_tests):
  rng = jax.random.PRNGKey(test_id + 100)
  state = jit_reset(rng)
  
  # Run episode
  for i in range(EPISODE_LENGTH):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
  
  final_distance = float(state.metrics['distance_to_target'])
  success = float(state.metrics['success']) > 0.5
  
  distances.append(final_distance)
  if success:
    success_count += 1
  
  status = "✅ SUCCESS" if success else "❌ FAILED"
  print(f"Test {test_id + 1}: {status} | Final distance: {final_distance:.4f}m")

print(f"\n📊 Results:")
print(f"  - Success rate: {success_count}/{num_tests} ({100*success_count/num_tests:.1f}%)")
print(f"  - Average distance: {np.mean(distances):.4f}m")
print(f"  - Min distance: {np.min(distances):.4f}m")
print(f"  - Max distance: {np.max(distances):.4f}m")

## 15. Package and Download All Results

In [None]:
#@title Package and Download All Results

# Create zip file with all outputs
!zip -r /tmp/franka_panda_training_results.zip {ckpt_path} {model_path}

print("\n📦 All results packaged: /tmp/franka_panda_training_results.zip")
print("\n✅ Download from Files panel (📁)")
print("\nContains:")
print(f"  - Final policy: {model_path}")
print(f"  - Checkpoints: {ckpt_path}")
print(f"\n💡 Use these files with a local inference script!")

## Summary

### What was trained:
- **Robot**: Franka Emika Panda from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie/tree/main/franka_emika_panda)
- **Task**: Reaching task - move end-effector to randomly placed target sphere
- **Algorithm**: PPO (Proximal Policy Optimization)
- **Training steps**: 10M (10 million environment steps)
- **Parallel environments**: 4096 (GPU-accelerated with MJX)
- **Domain randomization**: Friction (0.6-1.4x) + Actuator parameters (±5 gain)

### Robot Specifications:
- **Type**: 7-DOF robotic manipulator + 2-finger parallel gripper
- **Total DOF**: 8 actuated joints (7 arm + 1 gripper)
- **Workspace**: ~0.85m reach radius
- **Observation space**: 27 dimensions
  - Joint positions (9)
  - Joint velocities (9)
  - End-effector position (3)
  - Target position (3)
  - Relative target vector (3)
- **Action space**: 8 dimensions (position control)

### Reward Function:
- **Distance reward**: -10 × distance (closer = higher reward)
- **Reaching bonus**: +100 when within 5cm of target
- **Control cost**: -0.01 × sum of squared actions

### Files generated:
1. **`/tmp/mjx_brax_franka_panda_policy`** - Final trained policy
2. **`/tmp/franka_panda_reach/ckpts/`** - Intermediate checkpoints
3. **Training progress plot** - Reward vs. steps visualization
4. **Video demos** - Rendered reaching behavior

### Training Notes:
- **Manipulation tasks** learn faster than locomotion (10M vs 100M steps)
- Reaching is a **standard benchmark** for robotic manipulation
- Target position is **randomized** each episode for generalization
- The policy learns to reach from any starting configuration
- Domain randomization helps with sim-to-real transfer

### Expected Behavior:
With 10M steps training:
- ✅ Consistent reaching to random targets
- ✅ Success rate: 80-95%
- ✅ Smooth, efficient trajectories
- ✅ Generalizes to unseen target positions

For even better performance, train for 20M+ steps (~10-15 min on A100).

### Differences from Locomotion:
| Aspect | Locomotion (Barkour/G1) | Manipulation (Panda) |
|--------|-------------------------|----------------------|
| Task | Walking/running | Reaching target |
| Episode length | 1000 steps | 500 steps |
| Training steps | 20M-100M | 10M-20M |
| Success metric | Distance traveled | Target reached |
| Observation | Body pose, velocities | Joint state, target position |
| Termination | Falling | None (continuous) |

### Next Steps:
- **Object manipulation**: Add objects to grasp and move
- **Multi-target**: Reach sequence of targets
- **Obstacle avoidance**: Add obstacles in workspace
- **Contact-rich tasks**: Push, slide, insert objects

### Based on:
- [Official MuJoCo MJX Tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb)
- [Franka Emika Panda](https://www.franka.de/) official model
- Standard robotic reaching task from RL literature