In [1]:
#@title Setup and Installation
%%capture
!pip install mujoco
!pip install mujoco_mjx
!pip install ml_collections
!git clone https://github.com/Itssshikhar/mujoco_playground.git
!cd mujoco_playground && pip install -e .

import sys
sys.path.append('mujoco_playground')

In [2]:
!pip install playground

Collecting playground
  Using cached playground-0.0.3-py3-none-any.whl.metadata (8.0 kB)
Collecting brax>=0.12.1 (from playground)
  Downloading brax-0.12.1-py3-none-any.whl.metadata (7.7 kB)
Collecting dm_env (from brax>=0.12.1->playground)
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Collecting flask_cors (from brax>=0.12.1->playground)
  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)
Collecting jaxopt (from brax>=0.12.1->playground)
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Collecting pytinyrenderer (from brax>=0.12.1->playground)
  Downloading pytinyrenderer-0.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting tensorboardX (from brax>=0.12.1->playground)
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading playground-0.0.3-py3-none-any.whl (5.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m60.6 MB/s[0m eta [3

In [3]:
!#@title Import Dependencies and Set Environment
import os
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import logging

# Configure logging for Colab
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

In [4]:
#@title Download and Prepare Training Script
%%writefile train_zbot.py

#######################
# Setup & Dependencies
#######################

import argparse
import logging
import pickle
from datetime import datetime
from pathlib import Path

import cv2
import jax
import matplotlib.pyplot as plt
import numpy as np
from ml_collections import config_dict
from playground.zbot import joystick as zbot_joystick
from playground.zbot import randomize as zbot_randomize
from playground.zbot import zbot_constants
from playground.runner import ZBotRunner

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('zbot_training.log')
    ]
)
logger = logging.getLogger(__name__)

########################
# Training Configuration
########################

def create_training_args(task="flat_terrain", load_existing=False):
    """Create training arguments with enhanced settings"""
    args = argparse.Namespace(
        env="ZbotJoystickFlatTerrain",
        task=task,
        debug=False,
        save_model=True,
        load_model=load_existing,
        seed=42,
        num_episodes=3,
        episode_length=3000,
        x_vel=1.0,
        y_vel=0.0,
        yaw_vel=0.0
    )
    return args

def plot_training_progress(runner, title):
    """Plot training progress with error bands"""
    plt.figure(figsize=(10, 6))
    plt.plot(runner.x_data, runner.y_data, label='Mean Reward')
    plt.fill_between(
        runner.x_data,
        np.array(runner.y_data) - np.array(runner.y_dataerr),
        np.array(runner.y_data) + np.array(runner.y_dataerr),
        alpha=0.2,
        label='Std Dev'
    )
    plt.xlabel('Training Steps')
    plt.ylabel('Episode Reward')
    plt.title(f'Training Progress: {title}')
    plt.grid(True)
    plt.legend()
    plt.savefig(f'{title.lower().replace(" ", "_")}_progress.png')
    plt.close()

def save_training_metrics(runner, filename):
    """Save training metrics for later analysis"""
    metrics = {
        'steps': runner.x_data,
        'rewards': runner.y_data,
        'reward_std': runner.y_dataerr,
        'training_time': (runner.times[-1] - runner.times[0]).total_seconds()
    }
    with open(filename, 'wb') as f:
        pickle.dump(metrics, f)

#############################
# Flat Terrain Training Phase
#############################

def train_flat_terrain():
    """Train the initial policy on flat terrain"""
    logger.info("=" * 50)
    logger.info("Starting flat terrain training phase")
    logger.info("=" * 50)

    # Initialize runner with flat terrain config
    args = create_training_args(task="flat_terrain", load_existing=False)
    logger.info("Training configuration:")
    for key, value in vars(args).items():
        logger.info(f"  {key}: {value}")

    runner = ZBotRunner(args, logger)

    # Train policy
    logger.info("Beginning training loop...")
    runner.train()

    # Log training statistics
    logger.info("Training completed. Final statistics:")
    logger.info(f"  Total steps: {len(runner.x_data)}")
    logger.info(f"  Final reward: {runner.y_data[-1]:.2f} ± {runner.y_dataerr[-1]:.2f}")
    logger.info(f"  Training time: {(runner.times[-1] - runner.times[0]).total_seconds():.2f}s")

    # Plot and save results
    logger.info("Saving training visualizations and metrics...")
    plot_training_progress(runner, "Flat Terrain Training")
    save_training_metrics(runner, "flat_terrain_metrics.pkl")

    # Evaluate policy
    logger.info("Starting flat terrain policy evaluation...")
    runner.evaluate()

    return runner

##############################
# Rough Terrain Training Phase
##############################

def train_rough_terrain(flat_terrain_runner):
    """Adapt the policy to rough terrain"""
    logger.info("=" * 50)
    logger.info("Starting rough terrain adaptation phase")
    logger.info("=" * 50)

    # Initialize runner with rough terrain config
    args = create_training_args(task="rough_terrain", load_existing=True)
    logger.info("Training configuration:")
    for key, value in vars(args).items():
        logger.info(f"  {key}: {value}")

    runner = ZBotRunner(args, logger)

    # Load flat terrain policy
    logger.info("Loading pre-trained flat terrain policy...")
    runner.params = flat_terrain_runner.params

    # Continue training on rough terrain
    logger.info("Beginning rough terrain adaptation...")
    runner.train()

    # Log training statistics
    logger.info("Adaptation completed. Final statistics:")
    logger.info(f"  Total steps: {len(runner.x_data)}")
    logger.info(f"  Final reward: {runner.y_data[-1]:.2f} ± {runner.y_dataerr[-1]:.2f}")
    logger.info(f"  Training time: {(runner.times[-1] - runner.times[0]).total_seconds():.2f}s")

    # Plot and save results
    logger.info("Saving training visualizations and metrics...")
    plot_training_progress(runner, "Rough Terrain Training")
    save_training_metrics(runner, "rough_terrain_metrics.pkl")

    # Evaluate policy
    logger.info("Starting rough terrain policy evaluation...")
    runner.evaluate()

    return runner

#######################
# Analysis & Evaluation
#######################

def analyze_performance(flat_metrics, rough_metrics):
    """Compare and analyze training performance"""
    logger.info("=" * 50)
    logger.info("Performance Analysis")
    logger.info("=" * 50)

    # Print summary statistics
    logger.info("Training Summary:")
    logger.info("Flat Terrain:")
    logger.info(f"  Training time: {flat_metrics['training_time']:.2f}s")
    logger.info(f"  Final reward: {flat_metrics['rewards'][-1]:.2f} ± {flat_metrics['reward_std'][-1]:.2f}")
    logger.info(f"  Peak reward: {max(flat_metrics['rewards']):.2f}")

    logger.info("Rough Terrain:")
    logger.info(f"  Training time: {rough_metrics['training_time']:.2f}s")
    logger.info(f"  Final reward: {rough_metrics['rewards'][-1]:.2f} ± {rough_metrics['reward_std'][-1]:.2f}")
    logger.info(f"  Peak reward: {max(rough_metrics['rewards']):.2f}")

    # Create comparison plot
    logger.info("Generating performance comparison plot...")
    plt.figure(figsize=(12, 6))

    # Plot flat terrain progress
    plt.plot(flat_metrics['steps'], flat_metrics['rewards'],
             label='Flat Terrain', color='blue')
    plt.fill_between(
        flat_metrics['steps'],
        np.array(flat_metrics['rewards']) - np.array(flat_metrics['reward_std']),
        np.array(flat_metrics['rewards']) + np.array(flat_metrics['reward_std']),
        alpha=0.2,
        color='blue'
    )

    # Plot rough terrain progress
    plt.plot(rough_metrics['steps'], rough_metrics['rewards'],
             label='Rough Terrain', color='red')
    plt.fill_between(
        rough_metrics['steps'],
        np.array(rough_metrics['rewards']) - np.array(rough_metrics['reward_std']),
        np.array(rough_metrics['rewards']) + np.array(rough_metrics['reward_std']),
        alpha=0.2,
        color='red'
    )

    plt.xlabel('Training Steps')
    plt.ylabel('Episode Reward')
    plt.title('Training Progress Comparison')
    plt.grid(True)
    plt.legend()
    plt.savefig('training_comparison.png')
    plt.close()

##############
# Main Script
##############

def main():
    """Main training pipeline"""
    logger.info("=" * 50)
    logger.info("Starting ZBot Training Pipeline")
    logger.info("=" * 50)

    # Create output directory
    output_dir = Path("outputs")
    output_dir.mkdir(exist_ok=True)
    logger.info(f"Created output directory: {output_dir}")

    try:
        # Train on flat terrain
        logger.info("Starting flat terrain training phase...")
        flat_runner = train_flat_terrain()

        # Train on rough terrain
        logger.info("Starting rough terrain adaptation phase...")
        rough_runner = train_rough_terrain(flat_runner)

        # Load and analyze results
        logger.info("Loading training metrics for analysis...")
        with open("flat_terrain_metrics.pkl", 'rb') as f:
            flat_metrics = pickle.load(f)
        with open("rough_terrain_metrics.pkl", 'rb') as f:
            rough_metrics = pickle.load(f)

        analyze_performance(flat_metrics, rough_metrics)

        logger.info("Training pipeline completed successfully!")
        logger.info("Check the outputs directory for results and visualizations.")

    except Exception as e:
        logger.error(f"An error occurred during training: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()

Writing train_zbot.py


In [5]:
#@title Verify Installation
!ls -R mujoco_playground
!python -c "from mujoco_playground._src import mjx_env; print('Import successful!')"

mujoco_playground:
kscale_mujoco_playground.egg-info  LICENSE   playground      README.md	setup.py
learning			   Makefile  pyproject.toml  setup.cfg	train_zbot.py

mujoco_playground/kscale_mujoco_playground.egg-info:
dependency_links.txt  PKG-INFO	requires.txt  SOURCES.txt  top_level.txt

mujoco_playground/learning:
notebooks

mujoco_playground/learning/notebooks:
locomotion.ipynb

mujoco_playground/playground:
__init__.py  requirements-dev.txt  requirements.txt  resources	runner.py  zbot

mujoco_playground/playground/resources:
zbot

mujoco_playground/playground/resources/zbot:
assets	LICENSE  README.md  scene.xml  zbot.png  zbot.xml

mujoco_playground/playground/resources/zbot/assets:
3215_1Flange_2.stl     3215_BothFlange_5.stl  FOOT_2.stl	  U-HIP-R.stl
3215_1Flange.stl       3215_BothFlange_6.stl  FOOT.stl		  Z-BOT2_MASTER-BODY-SKELETON.stl
3215_BothFlange_2.stl  3215_BothFlange.stl    L-ARM-MIRROR_1.stl  Z-BOT2-MASTER-SHOULDER2_2.stl
3215_BothFlange_3.stl  FINGER_1_2.stl	      R-A

In [6]:
#@title Training Configuration
#@markdown Adjust training parameters here
NUM_EPISODES = 3  #@param {type:"integer"}
EPISODE_LENGTH = 3000  #@param {type:"integer"}
TASK = "flat_terrain"  #@param ["flat_terrain", "rough_terrain"]
LOAD_EXISTING = False  #@param {type:"boolean"}

from train_zbot import create_training_args, train_flat_terrain, train_rough_terrain

args = create_training_args(
    task=TASK,
    load_existing=LOAD_EXISTING
)
args.num_episodes = NUM_EPISODES
args.episode_length = EPISODE_LENGTH

  and should_run_async(code)


In [7]:
#@title Check Repository Structure
!pwd
!ls -R /content/mujoco_playground/playground/zbot/

/content
/content/mujoco_playground/playground/zbot/:
base.py  __init__.py  joystick.py  __pycache__	randomize.py  xmls  zbot_constants.py

/content/mujoco_playground/playground/zbot/__pycache__:
base.cpython-311.pyc	  joystick.cpython-311.pyc   zbot_constants.cpython-311.pyc
__init__.cpython-311.pyc  randomize.cpython-311.pyc

/content/mujoco_playground/playground/zbot/xmls:
assets				     scene_mjx_feetonly_rough_terrain.xml
scene_mjx_feetonly_flat_terrain.xml  zbot_feet_only.xml

/content/mujoco_playground/playground/zbot/xmls/assets:
hfield.png  rocky_texture.png


In [None]:
#@title Run Training
#@markdown Click to start training

# First, ensure we're in the correct directory
import os
os.chdir('/content/mujoco_playground')

# Add the repository root to Python path
import sys
sys.path.insert(0, '/content/mujoco_playground')

# Import after path setup
from playground.zbot import zbot_constants
from playground.runner import ZBotRunner
from train_zbot import create_training_args, train_flat_terrain, train_rough_terrain

# Verify XML file exists
xml_path = zbot_constants.task_to_xml(TASK)
print(f"Looking for XML file at: {xml_path}")
print(f"File exists: {os.path.exists(xml_path)}")

# Run training with proper error handling
try:
    if TASK == "flat_terrain":
        print("Starting flat terrain training...")
        runner = train_flat_terrain()
    else:
        print("Starting rough terrain training...")
        flat_runner = train_flat_terrain()
        runner = train_rough_terrain(flat_runner)

    # Display training progress
    display.display(plt.gcf())

except FileNotFoundError as e:
    print(f"Error: Could not find required files: {e}")
    print("Current working directory:", os.getcwd())
    print("\nContents of zbot directory:")
    !ls -R playground/zbot/
except Exception as e:
    print(f"An error occurred: {e}")
    import traceback
    traceback.print_exc()

Looking for XML file at: playground/zbot/xmls/scene_mjx_feetonly_flat_terrain.xml
File exists: True
Starting flat terrain training...


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
