In [1]:
# Configure JAX to use GPU
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # Prevent JAX from grabbing all GPU memory

import jax
# Verify GPU is available
print(f"JAX devices: {jax.devices()}")
print(f"GPU available: {any('gpu' in str(d).lower() or 'cuda' in str(d).lower() for d in jax.devices())}")

import mujoco
import mujoco.mjx as mjx
import jax.numpy as jnp

# Load Ant model with absolute path
project_root = r"C:\GitHub\training-lucy"
xml_path = os.path.join(project_root, "animals", "ant.xml")

model = mujoco.MjModel.from_xml_path(xml_path)
mjx_model = mjx.put_model(model)

def mjx_step(model, data):
    return mjx.step(model, data)

# Batch of 1024 envs on GPU
batch_size = 1024
mjx_datas = mjx.make_data(model)

# JIT-compiled step for GPU acceleration
mjx_step_jit = jax.jit(mjx_step)

# Example rollout
mjx_datas = mjx_step_jit(mjx_model, mjx_datas)
obs = mjx_datas.qpos

JAX devices: [CpuDevice(id=0)]
GPU available: False
Failed to import warp: No module named 'warp'
Failed to import mujoco_warp: No module named 'warp'


In [None]:
import gymnasium
import torch

# Ensure CUDA is available for Stable Baselines3 (PyTorch-based)
assert torch.cuda.is_available(), "CUDA not available - check your PyTorch installation"

xml_file = os.path.join(project_root, "animals", "ant.xml")
env = gymnasium.make("Ant-v5", xml_file=xml_file)

PyTorch CUDA available: False


In [None]:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", env, verbose=1, device="cuda")  # Force CPU for max throughput [web:15]
model.learn(total_timesteps=1_000_000)

model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
)  # Hyperparameters typical for MuJoCo continuous control. [web:6][web:15]


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 91.4     |
|    ep_rew_mean     | -104     |
| time/              |          |
|    fps             | 2457     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 151         |
|    ep_rew_mean          | -157        |
| time/                   |             |
|    fps                  | 1782        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010977402 |
|    clip_fraction        | 0.111       |
|    clip_range           | 0.2         |
|    entropy_loss   

In [None]:
env.close()
model.save("animals/ppo_ant")

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Vectorized envs for parallel rollouts (RTX 5070 can handle more)
vec_env = make_vec_env(
    "Ant-v5",
    n_envs=16,  # 16 parallel envs for high throughput
    env_kwargs={"xml_file": xml_file}
)

# RTX 5070 optimized PPO
model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=1,
    device="cuda",
    n_steps=2048,
    batch_size=1024,   # Large batch for RTX 5070 (plenty of VRAM)
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
    learning_rate=3e-4,
)

In [None]:
# Cleanup training environment
vec_env.close()
env.close()

In [None]:
env = gym.make("Ant-v4", render_mode="human")
model = PPO.load("ppo_ant")

obs, info = env.reset()
done = False

while True:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    if done:
        obs, info = env.reset()

  logger.deprecation(


ValueError: Error: Unexpected observation shape (27,) for Box environment, please use (105,) or (n_env, 105) for the observation shape.