In [None]:
# Check and install dependencies
import subprocess
import sys

# Try to import required packages
try:
    import jax
    import jax.numpy as jnp
    print("✓ JAX is available")
except ImportError:
    print("❌ JAX not found. Please install with: conda install jax -c conda-forge")
    
try:
    import brax
    from brax.io import html
    print("✓ Brax is available")
except ImportError:
    print("❌ Brax not found. Please install with: pip install brax")

# If both are available, continue
try:
    import jax
    import brax
    from IPython.display import HTML
    print("✓ All required packages are available")
except ImportError as e:
    print(f"❌ Missing required packages: {e}")
    print("Please install the missing packages and restart the kernel")

In [None]:
# Load the left hand environment from local QDax repo
import os
import sys
import importlib.util

# Add the local QDax repo to the path
qdax_repo_path = os.path.dirname(os.path.abspath(""))
sys.path.insert(0, qdax_repo_path)

# Load the left hand environment directly
env_file_path = os.path.join(qdax_repo_path, "qdax", "environments", "left_hand.py")
spec = importlib.util.spec_from_file_location("left_hand", env_file_path)
left_hand_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(left_hand_module)

# Create the left hand environment
env = left_hand_module.LeftHand()
print(f"✓ Left hand environment loaded: {env}")
print(f"Action space: {env.action_size}")
print(f"Observation space: {env.observation_size}")

In [None]:
# Run a rollout and render with HTML
import jax
import jax.numpy as jnp
from brax.io import html
from IPython.display import HTML

# Initialize environment
key = jax.random.PRNGKey(0)
state = env.reset(key)

# Run a rollout with random actions
rollout = []
episode_length = 1000  # Total steps to simulate

for _ in range(episode_length):
    # Random action
    action = jax.random.uniform(key, (env.action_size,), minval=-1.0, maxval=1.0)
    key, subkey = jax.random.split(key)
    
    # Step environment
    state = env.step(state, action)
    rollout.append(state)

print(f"✓ Completed rollout with {len(rollout)} steps")

# Render the first 500 steps as HTML
HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))