### Env

In [10]:
from cartpole_jax_env import CartPole

### Train

In [12]:
import gymnasium as gym 
import jax
import jax.numpy as jnp
import numpy as np

class JaxToGymWrapper(gym.Env):
    """
    Wraps a pure JAX environment to make it compatible with the 
    gymnasium/Gym API expected by the training loop.
    """
    def __init__(self, jax_env, seed=42):
        self.env = jax_env
        self.params = jax_env.default_params
        self._key = jax.random.PRNGKey(seed)
        self._state = None
        
        # Convert JAX spaces to Gym spaces for compatibility
        # Observation space (Box)
        high = np.array([
            self.params.x_threshold * 2,
            np.finfo(np.float32).max,
            self.params.theta_threshold_radians * 2,
            np.finfo(np.float32).max
        ], dtype=np.float32)
        self.observation_space = gym.spaces.Box(-high, high, shape=(4,), dtype=np.float32)
        
        # Action space (Discrete)
        self.action_space = gym.spaces.Discrete(2)
        
        # Mock the 'spec' attribute accessed in the fit loop
        # The loop calls: range(env.spec.max_episode_steps)
        class EnvSpec:
            max_episode_steps = self.params.max_steps_in_episode
            id = "JaxCartPole-v1"
        self.spec = EnvSpec()

    def reset(self, seed=None, options=None):
        # Handle seeding
        if seed is not None:
            self._key = jax.random.PRNGKey(seed)
            
        self._key, subkey = jax.random.split(self._key)
        
        # JIT compilation happens automatically on the first call here
        obs, self._state = self.env.reset(subkey, self.params)
        
        # Convert jax array to numpy for the buffer/tracer
        return np.array(obs), {}

    def step(self, action):
        self._key, subkey = jax.random.split(self._key)
        
        # Step the JAX env
        obs, self._state, reward, done, info = self.env.step(
            subkey, self._state, action, self.params
        )
        
        # Determine truncated (time limit) vs terminated (failure)
        # The JAX env combines them into 'done', but we can inspect time
        truncated = self._state.time >= self.params.max_steps_in_episode
        terminated = bool(done) and not truncated
        
        return np.array(obs), float(reward), terminated, truncated, info

    def render(self):
        # Expose the JAX render function
        return np.array(self.env.render(self._state, self.params))

  fn()




  import urllib3


In [15]:
jax_cartpole = CartPole()

# 2. Wrap it to look like a Gym environment
# We create two separate instances to keep training and testing random states separate
train_env = JaxToGymWrapper(jax_cartpole, seed=42)


# reset
obs, info = train_env.reset()

# sample action
action = train_env.action_space.sample()

# step
obs, reward, terminated, truncated, info = train_env.step(action)



In [3]:
key = jax.random.PRNGKey(0)
env = CartPole()
params = env.default_params  # optional; you can also skip and use default

# reset
key, subkey = jax.random.split(key)
obs, state = env.reset(subkey, params)

# step
key, subkey = jax.random.split(key)

# sample action
action = env.action_space(params).sample(subkey)

# step
obs, new_state, reward, done, info = env.step(subkey, state, action, params)


In [9]:
key, subkey = jax.random.split(key)
obs_space = env.observation_space(params)
sample_input = jnp.expand_dims(
    obs_space.sample(subkey), axis=0
).astype(float)
sample_input

Array([[1.7698131,       inf, 0.2091639,       inf]], dtype=float32)

In [4]:
import time
import math


class TrainMonitor:
    """
    Extremely lightweight training monitor for your pure-JAX CartPole.

    Usage:
        mon = TrainMonitor()

        # at the start of an episode:
        mon.begin_episode()

        # each environment step:
        mon.step(
            reward,
            v=v_value,
            Rn=return_n,
            loss=loss_value,
            training_step=global_training_step,
        )

        # when episode terminates:
        mon.end_episode()
    """

    def __init__(self, smoothing: int = 10):
        self.smoothing = float(smoothing)
        self.reset_global()

    # ---------- global counters ----------

    def reset_global(self):
        self.T = 0              # global steps
        self.ep = 0             # episodes
        self.t = 0              # steps in current episode
        self.G = 0.0            # return in current episode
        self.avg_G = 0.0        # smoothed return
        self._n_avg_G = 0.0
        self._ep_start_time = time.time()
        self._last_metrics = {}
        self._in_episode = False

    # ---------- derived properties ----------

    @property
    def avg_r(self) -> float:
        if self.t == 0:
            return math.nan
        return self.G / self.t

    @property
    def dt_ms(self) -> float:
        if self.t == 0:
            return math.nan
        return 1000.0 * (time.time() - self._ep_start_time) / self.t

    # ---------- episode lifecycle ----------

    def begin_episode(self):
        """Call once after env.reset() at the start of an episode."""
        if self._in_episode:
            # if user forgot to end previous episode, close it
            self.end_episode()

        self.ep += 1
        self.t = 0
        self.G = 0.0
        self._ep_start_time = time.time()
        self._last_metrics = {}
        self._in_episode = True

    def step(self, reward: float, **metrics):
        """
        Call every environment step.

        `reward` is the scalar environment reward (Python float).
        Extra metrics (v, Rn, loss, training_step, etc.) can be passed as kwargs.
        Only the LAST values in the episode are printed.
        """
        if not self._in_episode:
            # allow using step() without explicit begin_episode()
            self.begin_episode()

        self.T += 1
        self.t += 1
        self.G += float(reward)

        # store last seen metrics for this episode
        self._last_metrics.update(metrics)

    def end_episode(self):
        """Call when the episode terminates (done=True)."""
        if not self._in_episode:
            return
        self._in_episode = False

        # update running avg_G
        if self._n_avg_G < self.smoothing:
            self._n_avg_G += 1.0
        self.avg_G += (self.G - self.avg_G) / self._n_avg_G

        self._print_line()

    # ---------- printing ----------

    def _print_line(self):
        avg_r = self.avg_r
        dt = self.dt_ms

        # Optional metrics; default to NaN if not provided
        v = float(self._last_metrics.get("v", math.nan))
        Rn = float(self._last_metrics.get("Rn", math.nan))
        loss = float(self._last_metrics.get("loss", math.nan))
        training_step = float(self._last_metrics.get("training_step", math.nan))

        # Format similar to your example
        msg = (
            "[TrainMonitor|INFO] "
            f"ep: {self.ep},\t"
            f"T: {self.T:,},\t"
            f"G: {self.G:.0f},\t"
            f"avg_r: {avg_r:.3g},\t"
            f"avg_G: {self.avg_G:.1f},\t"
            f"t: {self.t},\t"
            f"dt: {dt:.3f}ms,\t"
            f"v: {v:.1f},\t"
            f"Rn: {Rn:.1f},\t"
            f"loss: {loss:.2f},\t"
            f"training_step: {training_step:.2e}"
        )
        print(msg)


In [5]:
import jax
import jax.numpy as jnp

env = CartPole()
params = env.default_params
key = jax.random.PRNGKey(0)

monitor = TrainMonitor(smoothing=100)

global_training_step = 0.0  # e.g., your optimizer step counter

for episode in range(1000):
    # reset env
    key, subkey = jax.random.split(key)
    obs, state = env.reset(subkey, params)

    monitor.begin_episode()

    done = False
    while not bool(done):
        # --- choose action (dummy random policy here) ---
        key, akey, skey = jax.random.split(key, 3)
        action = env.action_space(params).sample(akey)

        # step env
        obs, state, reward, done, info = env.step(skey, state, action, params)

        # suppose you have these training stats from your agent:
        v = 44.4             # value estimate
        Rn = 45.7            # n-step return
        loss = 2.52          # loss
        global_training_step += 1.0

        monitor.step(
            float(reward),
            v=v,
            Rn=Rn,
            loss=loss,
            training_step=global_training_step,
        )

    # episode finished
    monitor.end_episode()


[TrainMonitor|INFO] ep: 1,	T: 28,	G: 28,	avg_r: 1,	avg_G: 28.0,	t: 28,	dt: 14.797ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 2.80e+01
[TrainMonitor|INFO] ep: 2,	T: 44,	G: 16,	avg_r: 1,	avg_G: 22.0,	t: 16,	dt: 6.338ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 4.40e+01
[TrainMonitor|INFO] ep: 3,	T: 62,	G: 18,	avg_r: 1,	avg_G: 20.7,	t: 18,	dt: 6.341ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 6.20e+01
[TrainMonitor|INFO] ep: 4,	T: 78,	G: 16,	avg_r: 1,	avg_G: 19.5,	t: 16,	dt: 6.386ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 7.80e+01
[TrainMonitor|INFO] ep: 5,	T: 97,	G: 19,	avg_r: 1,	avg_G: 19.4,	t: 19,	dt: 6.272ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 9.70e+01
[TrainMonitor|INFO] ep: 6,	T: 114,	G: 17,	avg_r: 1,	avg_G: 19.0,	t: 17,	dt: 6.317ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 1.14e+02
[TrainMonitor|INFO] ep: 7,	T: 145,	G: 31,	avg_r: 1,	avg_G: 20.7,	t: 31,	dt: 6.372ms,	v: 44.4,	Rn: 45.7,	loss: 2.52,	training_step: 1.45e+02
[TrainMonitor|INFO] ep: 

KeyboardInterrupt: 