In [None]:
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt
from cartpole import CartPoleEnv, CartPoleParams
import jax.lax as lax

In [None]:
env = CartPoleEnv()
params = CartPoleParams(num_agents=1, dt=0.05)

def simulate(rng:int = 0):
    key = random.PRNGKey(rng)
    state = env.reset(key, params)
    state = state.replace(physics=jnp.array([0.2,0,-1,0,0]).reshape(params.num_agents, 5))
    frames = jnp.zeros((100, params.num_agents, 2))
    action = jnp.zeros((params.num_agents, ))
    for i in range(100):
        frames = frames.at[i, :].set(jnp.array([state.physics[:, 0], state.physics[:, 2]]).reshape(params.num_agents, 2))
        state = env.step(state, action, params)
    return frames

frames = simulate(0)
print(frames)

In [None]:
env = CartPoleEnv()
params = CartPoleParams(num_agents=3, dt=0.05)

def simulate(rng:int = 0):
    key = random.PRNGKey(rng)
    state = env.reset(key, params)
    state = state.replace(physics=jnp.array([0.2,0,-1,0,0]*params.num_agents).T.reshape(params.num_agents, 5))
    frames = jnp.zeros((100, params.num_agents, 2))
    action = jnp.zeros((params.num_agents, ))
    for i in range(100):
        frames = frames.at[i, :].set(jnp.array([state.physics[:, 0], state.physics[:, 2]]).reshape(params.num_agents, 2))
        state = env.step(state, action, params)
    return frames

frames = simulate(0)
print(frames)

In [None]:
state = env.reset(random.PRNGKey(0), params)
state = state.replace(physics=jnp.array([0.2,0,-1,0,0]*params.num_agents).T.reshape(params.num_agents, 5))
frames = jnp.zeros((100, params.num_agents, 2))
frames = frames.at[0, :].set(jnp.array([state.physics[:, 0], state.physics[:, 2]]).T.reshape(params.num_agents, 2))
frames

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters (example values similar to PPO paper)
mu_old = 0.0           # old policy mean
log_std_old = 0.0      # old policy log std (σ_old = 1)
sigma_old = np.exp(log_std_old)

a_t = 0.0              # action taken
eps = 0.2              # PPO clip epsilon

# Grid for μ_t and log σ_t
mu_range = np.linspace(-1.2, 1.2, 300)
log_std_range = np.linspace(-1.4, 0.5, 300)
MU, LOG_STD = np.meshgrid(mu_range, log_std_range)
SIGMA = np.exp(LOG_STD)

# Compute log r_t
log_r = (log_std_old - LOG_STD) + \
        ((a_t - mu_old)**2) / (2 * sigma_old**2) - \
        ((a_t - MU)**2) / (2 * SIGMA**2)

# Clip bounds in log space
lower_bound = np.log(1 - eps)
upper_bound = np.log(1 + eps)

# Region mask
mask = (log_r >= lower_bound) & (log_r <= upper_bound)

# Compute KL divergence D_KL( pi_old || pi_new ) for Gaussian policies
KL = np.log(SIGMA / sigma_old) + \
     (sigma_old**2 + (MU - mu_old)**2) / (2 * SIGMA**2) - 0.5

# Plot with KL contours
plt.figure(figsize=(6,6))
plt.contour(MU, LOG_STD, log_r, levels=[lower_bound, upper_bound],
            colors='k', linestyles='--', linewidths=1.5, alpha=0.8)
plt.imshow(1-mask, extent=[mu_range.min(), mu_range.max(), log_std_range.min(), log_std_range.max()],
           origin='lower', alpha=0.15, cmap='gray')

# KL divergence contours
kl_levels = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]
contours = plt.contour(MU, LOG_STD, KL, levels=kl_levels, cmap='coolwarm')
plt.clabel(contours, inline=True, fontsize=8, fmt="%.2f")

# Mark theta_old
plt.scatter(mu_old, log_std_old, color='red', label=r'$\theta_{\mathrm{old}}$')
plt.xlabel(r'$\mu_t$')
plt.ylabel(r'$\log \Sigma_t$')
plt.legend()
plt.title("PPO Clipping Region with KL Divergence Contours")
plt.show()


In [1]:
import jax
import jax.numpy as jnp

_rng = jax.random.PRNGKey(0)



In [4]:
jax.random.split(_rng, 10).shape

(10, 2)