# **Mujoco Playground**

In [6]:
# 产看 Jax 是否可以使用 Nvidia 显卡
import jax
from jax import numpy as jp

# 检查JAX设备配置和GPU可用性
print("JAX devices:", jax.devices())
print("JAX default backend:", jax.default_backend())
gpu_devices = [d for d in jax.devices() if d.platform == 'gpu']
print("CUDA devices:", gpu_devices if gpu_devices else "No CUDA devices found")

# 检查当前计算在哪个设备上进行
test_array = jp.array([1.0, 2.0, 3.0])
print(f"Test array device: {test_array.device}")

# 检查是否可以使用GPU
try:
    import os
    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
    print(f"JAX_PLATFORM_NAME: {os.environ.get('JAX_PLATFORM_NAME', 'Not set')}")
except:
    pass

JAX devices: [CudaDevice(id=0)]
JAX default backend: gpu
CUDA devices: [CudaDevice(id=0)]
Test array device: cuda:0
CUDA_VISIBLE_DEVICES: Not set
JAX_PLATFORM_NAME: Not set


In [16]:
# @title Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
import os
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac
from etils import epath
from flax import struct
from flax.training import orbax_utils
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import numpy as np
from orbax import checkpoint as ocp

# **Intro**
使用倒立摆环境

In [17]:
from mujoco_playground import registry
env = registry.load('CartpoleBalance')
env

<mujoco_playground._src.dm_control_suite.cartpole.Balance at 0x7cd86bfbb8e0>

In [18]:
env_cfg = registry.get_default_config('CartpoleBalance')
env_cfg

action_repeat: 1
ctrl_dt: 0.01
episode_length: 1000
sim_dt: 0.01
vision: false
vision_config:
  enabled_geom_groups:
  - 0
  - 1
  - 2
  gpu_id: 0
  history: 3
  render_batch_size: 512
  render_height: 64
  render_width: 64
  use_rasterizer: false

# **Rollout**

In [19]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [20]:
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state]

f = 0.5
for i in range(env_cfg.episode_length):
  action = []
  for j in range(env.action_size):
    action.append(
        jp.sin(
            state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size
        )
    )
  action = jp.array(action)
  state = jit_step(state, action)
  rollout.append(state)

frames = env.render(rollout)
media.show_video(frames, fps=1.0 / env.dt)

100%|██████████| 1001/1001 [00:01<00:00, 628.01it/s]



0
This browser does not support the video tag.


In [22]:
state.obs.device

CpuDevice(id=0)

# **RL**

In [None]:
from mujoco_playground.config import dm_control_suite_params
ppo_params = dm_control_suite_params.brax_ppo_config('CartpoleBalance')
ppo_params

action_repeat: 1
batch_size: 1024
discounting: 0.995
entropy_cost: 0.01
episode_length: 1000
learning_rate: 0.001
normalize_observations: true
num_envs: 2048
num_evals: 10
num_minibatches: 32
num_timesteps: 60000000
num_updates_per_batch: 16
reward_scaling: 10.0
unroll_length: 30

# **PPO**

In [None]:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  clear_output(wait=True)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.ylim([0, 1100])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())

ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = functools.partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

train_fn = functools.partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
    progress_fn=progress
)

In [None]:
from mujoco_playground import wrapper

make_inference_fn, params, metrics = train_fn(
    environment=env,
    wrap_env_fn=wrapper.wrap_for_brax_training,
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

# **GPU训练配置**

要使用GPU进行训练，您需要完成以下步骤：

## 1. 硬件和驱动要求
- NVIDIA GPU（支持CUDA）
- 正确安装NVIDIA驱动程序
- 安装CUDA工具包

## 2. 安装GPU版本的JAX
当前环境使用的是CPU版本的JAX。要启用GPU支持，需要安装GPU版本：

```bash
# 卸载CPU版本的JAX
pip uninstall jax jaxlib

# 安装GPU版本的JAX (根据您的CUDA版本选择)
# 对于CUDA 12.x:
pip install jax[cuda12]

# 对于CUDA 11.x:
pip install jax[cuda11]
```

## 3. 环境变量配置
可以通过环境变量来控制JAX的行为：
- `JAX_PLATFORM_NAME=gpu` - 强制使用GPU
- `CUDA_VISIBLE_DEVICES=0` - 指定使用的GPU设备

## 4. 代码中的GPU配置
JAX会自动检测并使用可用的GPU。下面的代码展示了如何验证和配置GPU使用。

In [24]:
# GPU训练配置函数
def configure_gpu_training():
    """配置GPU训练环境"""
    
    # 检查GPU可用性
    gpu_devices = jax.devices('gpu') if any(d.platform == 'gpu' for d in jax.devices()) else []
    
    if not gpu_devices:
        print("⚠️  警告: 没有检测到GPU设备")
        print("当前将使用CPU进行训练，速度可能较慢")
        return False
    
    print(f"✅ 检测到 {len(gpu_devices)} 个GPU设备:")
    for i, device in enumerate(gpu_devices):
        print(f"  GPU {i}: {device}")
    
    # 设置默认设备为GPU
    print(f"🚀 将使用GPU进行训练")
    return True

# 强制GPU训练的函数 (如果有GPU的话)
def force_gpu_if_available():
    """如果有GPU可用，强制使用GPU"""
    import os
    
    # 检查是否有GPU
    if any(d.platform == 'gpu' for d in jax.devices()):
        print("设置JAX使用GPU...")
        os.environ['JAX_PLATFORM_NAME'] = 'gpu'
        # 重启JAX后端
        jax.clear_backends()
        print(f"当前设备: {jax.devices()}")
    else:
        print("没有可用的GPU，继续使用CPU")

# 运行配置
gpu_available = configure_gpu_training()

⚠️  警告: 没有检测到GPU设备
当前将使用CPU进行训练，速度可能较慢


In [25]:
# GPU优化的训练参数配置
def get_gpu_optimized_training_params(base_params, gpu_available=False):
    """
    根据是否有GPU可用来优化训练参数
    """
    optimized_params = dict(base_params)
    
    if gpu_available:
        print("🚀 配置GPU优化参数:")
        
        # GPU上可以使用更大的批量大小
        if 'batch_size' in optimized_params:
            original_batch_size = optimized_params['batch_size']
            optimized_params['batch_size'] = max(original_batch_size * 2, 512)
            print(f"  批量大小: {original_batch_size} -> {optimized_params['batch_size']}")
        
        # GPU上可以使用更多的环境并行
        if 'num_envs' in optimized_params:
            original_num_envs = optimized_params['num_envs']
            optimized_params['num_envs'] = max(original_num_envs * 2, 2048)
            print(f"  并行环境数: {original_num_envs} -> {optimized_params['num_envs']}")
            
        # 可以增加网络大小
        if 'network_factory' in optimized_params:
            if optimized_params['network_factory'] is None:
                optimized_params['network_factory'] = {}
            
            # 增加隐藏层大小
            network_params = optimized_params.get('network_factory', {})
            if 'policy_hidden_layer_sizes' not in network_params:
                network_params['policy_hidden_layer_sizes'] = (256, 256)
            if 'value_hidden_layer_sizes' not in network_params:
                network_params['value_hidden_layer_sizes'] = (256, 256)
            
            optimized_params['network_factory'] = network_params
            print(f"  网络大小: 增加到 {network_params.get('policy_hidden_layer_sizes', '默认')}")
        
        print("  ✅ GPU优化配置完成")
        
    else:
        print("💻 使用CPU优化参数:")
        
        # CPU上使用较小的批量大小以减少内存使用
        if 'batch_size' in optimized_params:
            optimized_params['batch_size'] = min(optimized_params['batch_size'], 256)
            print(f"  批量大小限制为: {optimized_params['batch_size']}")
            
        # CPU上减少并行环境数
        if 'num_envs' in optimized_params:
            optimized_params['num_envs'] = min(optimized_params['num_envs'], 512)
            print(f"  并行环境数限制为: {optimized_params['num_envs']}")
    
    return optimized_params

# 应用优化参数到当前的PPO配置
optimized_ppo_params = get_gpu_optimized_training_params(ppo_params, gpu_available)
print("\n📊 优化后的关键参数:")
for key in ['batch_size', 'num_envs', 'num_timesteps']:
    if key in optimized_ppo_params:
        print(f"  {key}: {optimized_ppo_params[key]}")

💻 使用CPU优化参数:
  批量大小限制为: 256
  并行环境数限制为: 512

📊 优化后的关键参数:
  batch_size: 256
  num_envs: 512
  num_timesteps: 60000000


In [26]:
# 使用优化参数重新配置训练函数
def setup_optimized_training(env, params_dict, gpu_available=False):
    """
    设置优化的训练函数
    """
    
    # 创建进度回调函数
    x_data_opt, y_data_opt, y_dataerr_opt = [], [], []
    times_opt = [datetime.now()]
    
    def progress_optimized(num_steps, metrics):
        clear_output(wait=True)
        
        times_opt.append(datetime.now())
        x_data_opt.append(num_steps)
        y_data_opt.append(metrics["eval/episode_reward"])
        y_dataerr_opt.append(metrics["eval/episode_reward_std"])
        
        plt.figure(figsize=(10, 6))
        plt.xlim([0, params_dict["num_timesteps"] * 1.25])
        plt.ylim([0, 1100])
        plt.xlabel("# environment steps")
        plt.ylabel("reward per episode")
        
        device_info = "GPU" if gpu_available else "CPU"
        plt.title(f"训练进度 ({device_info}) - 当前奖励: {y_data_opt[-1]:.3f}")
        
        plt.errorbar(x_data_opt, y_data_opt, yerr=y_dataerr_opt, color="green" if gpu_available else "blue")
        plt.grid(True, alpha=0.3)
        plt.show()
    
    # 配置网络
    training_params = dict(params_dict)
    network_factory = ppo_networks.make_ppo_networks
    
    if "network_factory" in params_dict and params_dict["network_factory"] is not None:
        del training_params["network_factory"]
        network_factory = functools.partial(
            ppo_networks.make_ppo_networks,
            **params_dict["network_factory"]
        )
    
    # 创建训练函数
    optimized_train_fn = functools.partial(
        ppo.train, 
        **training_params,
        network_factory=network_factory,
        progress_fn=progress_optimized
    )
    
    return optimized_train_fn, times_opt

# 准备优化的训练函数
print("🔧 准备优化的训练配置...")
optimized_train_fn, times_opt = setup_optimized_training(env, optimized_ppo_params, gpu_available)
print("✅ 优化训练配置完成！")

# 显示训练建议
print("\n💡 训练建议:")
if gpu_available:
    print("  🚀 您正在使用GPU训练:")
    print("     - 可以使用更大的批量大小和更多并行环境")
    print("     - 训练速度会显著提升")
    print("     - 可以尝试更大的神经网络")
else:
    print("  💻 您正在使用CPU训练:")
    print("     - 建议使用较小的批量大小以节省内存")
    print("     - 可以考虑减少训练步数进行快速测试")
    print("     - 如果有GPU，安装GPU版JAX可大幅提升速度")

print(f"\n当前配置将使用 {optimized_ppo_params['num_envs']} 个并行环境进行训练")

🔧 准备优化的训练配置...
✅ 优化训练配置完成！

💡 训练建议:
  💻 您正在使用CPU训练:
     - 建议使用较小的批量大小以节省内存
     - 可以考虑减少训练步数进行快速测试
     - 如果有GPU，安装GPU版JAX可大幅提升速度

当前配置将使用 512 个并行环境进行训练


# **GPU安装和使用指南**

## 🔧 安装GPU支持

如果您有NVIDIA GPU并希望加速训练，请执行以下步骤：

### 1. 检查GPU状态
```bash
# 检查GPU是否可用
nvidia-smi

# 检查CUDA版本
nvcc --version
```

### 2. 安装GPU版本的JAX
```bash
# 在终端中执行以下命令:

# 激活虚拟环境 (如果使用虚拟环境)
source .venv/bin/activate

# 卸载CPU版本
pip uninstall jax jaxlib -y

# 安装GPU版本 (根据您的CUDA版本选择)
# 对于CUDA 12.x:
pip install jax[cuda12]==0.4.20

# 对于CUDA 11.x:
pip install jax[cuda11]==0.4.20

# 重新安装其他依赖
pip install -r requirements.txt  # 如果有requirements.txt文件
```

### 3. 验证安装
重启Jupyter kernel后运行GPU检查代码，应该能看到GPU设备。

## ⚡ GPU训练优势

- **速度提升**: GPU训练比CPU快5-50倍
- **更大批量**: 可以使用更大的批量大小
- **并行环境**: 支持更多并行环境同时训练
- **复杂网络**: 可以训练更深更宽的神经网络

## 🚀 开始GPU训练

一旦安装了GPU支持，只需重启kernel并重新运行notebook即可自动使用GPU进行训练！

In [27]:
# 性能测试函数
def benchmark_computation(device_type="current"):
    """
    对当前设备进行简单的计算性能测试
    """
    import time
    
    print(f"🔬 开始 {device_type} 设备性能测试...")
    
    # 矩阵乘法测试
    size = 2048
    key = jax.random.PRNGKey(42)
    
    # 生成随机矩阵
    print(f"生成 {size}x{size} 随机矩阵...")
    A = jax.random.normal(key, (size, size))
    B = jax.random.normal(key, (size, size))
    
    # 编译计算函数
    @jax.jit
    def matrix_multiply(a, b):
        return jp.dot(a, b)
    
    # 预热
    print("预热中...")
    _ = matrix_multiply(A, B).block_until_ready()
    
    # 性能测试
    print("开始性能测试...")
    start_time = time.time()
    
    n_iterations = 10
    for i in range(n_iterations):
        result = matrix_multiply(A, B).block_until_ready()
    
    end_time = time.time()
    
    total_time = end_time - start_time
    avg_time = total_time / n_iterations
    
    print(f"✅ 测试完成!")
    print(f"   总时间: {total_time:.3f}秒")
    print(f"   平均时间: {avg_time:.3f}秒/次")
    print(f"   设备: {result.device}")
    
    # 计算FLOPS (浮点运算次数)
    flops = 2 * size**3  # 矩阵乘法的FLOPS
    gflops = (flops * n_iterations) / total_time / 1e9
    print(f"   性能: {gflops:.2f} GFLOPS")
    
    return avg_time, gflops

# 运行当前设备的性能测试
current_avg_time, current_gflops = benchmark_computation("当前")

print("\n" + "="*50)
print("💡 性能提升建议:")
if current_gflops < 10:  # 假设低于10 GFLOPS为CPU性能
    print("   📈 当前性能较低，建议:")
    print("   🚀 安装GPU版本JAX可获得5-50倍性能提升")
    print("   💾 确保有足够内存支持大批量训练")
    print("   🔧 考虑使用更小的模型进行快速原型开发")
else:
    print("   ⚡ 当前性能良好，可以进行高效训练!")
    print("   🎯 可以尝试更大的批量大小和网络规模")

🔬 开始 当前 设备性能测试...
生成 2048x2048 随机矩阵...
预热中...
开始性能测试...
✅ 测试完成!
   总时间: 0.105秒
   平均时间: 0.010秒/次
   设备: TFRT_CPU_0
   性能: 1637.45 GFLOPS

💡 性能提升建议:
   ⚡ 当前性能良好，可以进行高效训练!
   🎯 可以尝试更大的批量大小和网络规模


In [None]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [None]:
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)