In [33]:
import jax
import jax.numpy as jnp
import flax.linen as nn
#!pip install distrax
import distrax
import optax
from typing import Sequence, NamedTuple
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import numpy as np
import gymnasium as gym
from functools import partial

In [34]:
# Configuration
config = {
    # Environment
    "ENV_NAME": "CartPole-v1",
    "NUM_ENVS": 4,

    # Training
    "TOTAL_TIMESTAMPS": 1e6,
    "NUM_STEPS": 128,
    "NUM_MINIBATCHES": 4,
    "LEARNING_RATE": 2.5e-4,

    # PPO Parameters
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,

    # Network
    "ACTIVATION": "tanh",
    "SEED": 42,
}

# Derived Configuration
config["NUM_UPDATES"] = int(config["TOTAL_TIMESTAMPS"] // (config["NUM_STEPS"] * config["NUM_ENVS"]))
config["MINIBATCH"] = int((config["NUM_STEPS"] * config["NUM_ENVS"]) // config["NUM_MINIBATCHES"])

In [35]:
class Transition(NamedTuple):
  """Data structure to store environment transitions"""
  done: jnp.ndarray
  action: jnp.ndarray
  value: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: dict

class ActorCritic(nn.Module):
  """Actor Critic Network"""
  action_dim: int
  activation: str = "tanh"

  @nn.compact
  def __call__(self, x):
    """Defines the forward pass of the Actor Critic Network"""
    if self.activation == "relu":
      activation = nn.relu
    else:
      activation = nn.tanh

    # Shared layers
    shared = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
    shared = activation(shared)
    shared = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(shared)
    shared = activation(shared)

    # Actor head - decides what action to take (outputs a distribution probability)
    actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01))(shared)
    pi = distrax.Categorical(logits=actor_mean)

    # Critic head - estimates how good the current state is (outputs a value score)
    value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(shared)
    return pi, jnp.squeeze(value, axis=-1)

In [None]:
def make_train(config):
  """Ensure the training function with the given configuration"""

  # Initialize environment
  def make_env():
    env = gym.make(config["ENV_NAME"])
    return env

  # For now let us create a simple vectorized environment wrapper
  class SimpleVecEnv:
    def __init__(self, env_fns):
      self.envs = [env_fn() for env_fn in env_fns]
      self.num_envs = len(env_fns)
      self.observation_space = self.envs[0].observation_space
      self.action_space = self.envs[0].action_space

    def reset(self, seed=None):
      obs = []
      infos = []
      if seed is None:
          seed_list = [None] * self.num_envs
      else:
          # Assuming seed is a JAX PRNGKey; jax.random.split returns keys of shape (2,)
          # Convert each PRNGKey to a Python int seed by taking the first uint32 word.
          # This yields a deterministic, easily convertible integer suitable for gym.reset(seed=...).
          keys = jax.random.split(seed, self.num_envs)
          seed_list = [int(k[0]) for k in keys]

      for i, env in enumerate(self.envs):
        if seed_list[i] is not None:
          o, info = env.reset(seed=seed_list[i])
        else:
          o, info = env.reset()
        obs.append(o)
        infos.append(info)
      return np.array(obs), infos

    def step(self, actions):
      obs, rewards, dones, truncated, infos = [], [], [], [], []
      # Ensure actions is an iterable of scalars (one per env)
      actions = np.asarray(actions)
      for env, action in zip(self.envs, actions):
        # action may be a 0-d array, a numpy scalar, or a JAX array; coerce to Python int
        a = int(np.array(action).item())
        o, r, d, t, info = env.step(a)
        if d or t:
          # reset the env immediately and use the new observation
          o, _ = env.reset()
        obs.append(o)
        rewards.append(r)
        dones.append(d or t)
        truncated.append(t)
        infos.append(info)

      return np.array(obs), np.array(rewards), np.array(dones), np.array(truncated), infos


  # Create vectorized environment
  env = SimpleVecEnv([make_env for _ in range(config["NUM_ENVS"])])

  # Get environment specs
  obs_shape = env.observation_space.shape
  action_dim = env.action_space.n

  def linear_schedule(count):
    """Linear learning rate schedule"""
    frac = 1.0 - (count / config["TOTAL_TIMESTAMPS"])
    return config["LEARNING_RATE"] * frac


  def train(rng):
    """Main training function"""
    print("Starting training inside train function...")

    # Initialize network
    network = ActorCritic(action_dim, config["ACTIVATION"])

    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros(obs_shape)
    network_params = network.init(_rng, init_x)

    # Initialize optimizer
    tx = optax.chain(
        optax.clip_by_global_norm(0.5),
        optax.adam(learning_rate=linear_schedule, eps=1e-5),
    )
    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )

    # Environment step function
    def env_step(runner_state, unused):
      train_state, env_state, last_obs, rng = runner_state

      # Select action
      rng, _rng = jax.random.split(rng)
      pi, value = network.apply(train_state.params, last_obs)
      action = pi.sample(seed=_rng)
      log_prob = pi.log_prob(action)

      # Environment step
      rng, _rng = jax.random.split(rng)

      # For now, we will use Numpy environment
      obs, rewards, dones, truncated, infos = env.step(np.array(action))

      transition = Transition(
          done = jnp.array(dones),
          action=action,
          value=value,
          reward=jnp.array(rewards),
          log_prob=log_prob,
          obs=last_obs,
          info=infos)

      runner_state = (train_state, env_state, jnp.array(obs), rng)
      return runner_state, transition

    def calculate_gae(traj_batch, last_val):
      """Calculate Generalized Advantage Estimation"""
      def get_advantages(gae_and_next_value, transition):
        gae, next_value = gae_and_next_value
        done, value, reward = transition.done, transition.value, transition.reward

        delta = reward + config["GAMMA"] * next_value * (1 - done) - value
        gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1-done) * gae

        return(gae, value), gae

      _, advantages = jax.lax.scan(
          get_advantages,
          (jnp.zeros_like(last_val), last_val),
          traj_batch,
          reverse=True
      )
      return advantages, advantages + traj_batch.value

    def update_echo(update_state, unused):
      """Update function for one epoch"""
      def update_minibatch(train_state, batch_info):
        """Update function for one minibatch"""
        traj_batch, advantages, targets = batch_info

        def loss_fn(params):
          # Forward pass
          pi, value = network.apply(params, traj_batch.obs)
          log_prob = pi.log_prob(traj_batch.action)

          # Calculate policy loss
          ratio = jnp.exp(log_prob - traj_batch.log_prob)
          gae = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

          loss_actor1 = ratio * gae
          loss_actor2 = jnp.clip(ratio, 1 - config["CLIP_EPS"], 1 + config["CLIP_EPS"]) * gae
          loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

          # Value loss
          value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
          value_losses = jnp.square(value - targets)
          value_losses_clipped = jnp.square(value_pred_clipped - targets)
          value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

          # Entropy loss
          entropy = pi.entropy().mean()
          total_loss = (
              loss_actor +
              config["VF_COEF"] * value_loss -
              config["ENT_COEF"] * entropy
          )
          return total_loss, (loss_actor, value_loss, entropy)

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        total_loss, grads = grad_fn(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state, total_loss

      train_state, traj_batch, advantages, targets, rng = update_state

      rng, _rng = jax.random.split(rng)
      batch_size = config["MINIBATCH"] * config["NUM_MINIBATCHES"]
      assert batch_size == config["NUM_STEPS"] * config["NUM_ENVS"], f"batch_size {batch_size}" != "{config['NUM_STEPS'] * config['NUM_ENVS']}"

      permutation = jax.random.permutation(_rng, batch_size)
      batch = (traj_batch, advantages, targets)
      batch = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (batch_size, ) + x.shape[2:]), batch)

      shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch,
      )
      minibatches = jax.tree_util.tree_map(lambda x: jnp.reshape(x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])), shuffled_batch,
      )

      train_batch, total_loss = jax.lax.scan(update_minibatch, train_state, minibatches)
      update_state = (train_batch, traj_batch, advantages, targets, rng)
      return update_state, total_loss

    # Initialize environment
    rng, _rng = jax.random.split(rng)
    reset_rng = _rng # Use the split key directly for the SimpleVecEnv reset
    obsv, infos = env.reset(reset_rng)
    obsv = jnp.array(obsv)
    env_state = obsv # Assuming env_state is the observation

    # Training loop
    runner_state = (train_state, env_state, obsv, rng)
    print(f"Starting training loop for {config['NUM_UPDATES']} updates...")
    for update in range(config["NUM_UPDATES"]):
      # collect rollout
      runner_state, traj_batch = jax.lax.scan(
          env_step, runner_state, None, config["NUM_STEPS"]
      )

      # Calculate advantages
      train_state, env_state, last_obs, rng = runner_state
      _, last_val = network.apply(train_state.params, last_obs)
      advantages, targets = calculate_gae(traj_batch, last_val)

      # Update network
      rng, _rng = jax.random.split(rng)
      update_state = (train_state, traj_batch, advantages, targets, rng)
      # Assuming NUM_EPOCHS is a config parameter
      update_state, loss_info = jax.lax.scan(update_echo, update_state, None, config["NUM_EPOCHS"])
      train_state = update_state[0]
      runner_state = (train_state, env_state, last_obs, rng)

      # Log progress
      if update % 10 == 0:
        print(f"Update {update}/{config['NUM_UPDATES']}, Loss: {loss_info[0].mean():.4f}")

    return train_state

  return train

The code within the `if __name__ == "__main__":` block is not executed when running a notebook cell directly. Moving the code that initiates the training to a separate cell will ensure it runs and you can see the training output.

In [37]:
rng = jax.random.PRNGKey(config["SEED"])
train_fn = make_train(config)

print("Starting PPO training...")
print(f"Environment: {config['ENV_NAME']}")
print(f"Number of environments: {config['NUM_ENVS']}")
print(f"Total updates: {config ['NUM_UPDATES']}")
print(f"Number of steps: {config['NUM_STEPS']}")
print(f"Number of minibatches: {config['NUM_MINIBATCHES']}")
print(f"Learning rate: {config['LEARNING_RATE']}")

final_train_state = train_fn(rng)
print("Training complete!")

Starting PPO training...
Environment: CartPole-v1
Number of environments: 4
Total updates: 1953
Number of steps: 128
Number of minibatches: 4
Learning rate: 0.00025
Starting training inside train function...


TypeError: Only scalar arrays can be converted to Python scalars; got arr.ndim=1