In [45]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import equinox as eqx
import jax
import jax.numpy as jnp
from typing import Optional, Tuple
from jaxtyping import PRNGKeyArray, Array, Float, PyTree
import numpy as np
import gymnax
import rlax
import optax
from tqdm import tqdm

In [46]:
@eqx.filter_jit
def calculate_gae(
    rewards: Array,
    values: Array,
    dones: Array,
    gamma: float,
    lam: float,
) -> Array:
    def body_fun(
        carry: tuple[Array, Array], t: Array
    ) -> tuple[tuple[Array, Array], None]:
        advantages, gae_inner = carry
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
        gae_inner = delta + gamma * lam * (1 - dones[t]) * gae_inner
        advantages = advantages.at[t].set(gae_inner)
        return (advantages, gae_inner), None

    values = jnp.append(values, values[0])
    advt = jnp.zeros_like(rewards)
    gae = jnp.array(0.0)
    t = len(rewards)

    (advt, _), _ = jax.lax.scan(body_fun, (advt, gae), jnp.arange(t - 1, -1, -1))
    return advt

In [47]:
def rollout(rng_input, policy, env, env_params, steps_in_episode, epoch):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, state = env.reset(rng_reset, env_params)
    
    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, rng = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        logits = policy(obs)
        action = jax.random.categorical(logits=logits, key=rng_net)
        log_prob = jax.nn.log_softmax(logits)[action]
        next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        carry = [next_obs, next_state, rng]
        return carry, [obs, action, log_prob, reward, done, state]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, rng_episode],
      (),
      steps_in_episode
    )
    # Return masked sum of rewards accumulated by agent in episode
    obs, action, log_prob, reward, done, states = scan_out
    return obs, action, log_prob, reward, done, states

In [48]:
class RLDataset(Dataset):
    def __init__(self, states, actions, rewards, log_probs, values, advantages, dones) -> None:
        self.rewards = torch.tensor(rewards)
        self.actions = torch.tensor(actions)
        self.obs = torch.tensor(states)
        self.dones = torch.tensor(dones)
        self.log_probs = torch.tensor(log_probs)
        self.values = torch.tensor(values)
        self.advantages = torch.tensor(advantages)

    def __len__(self) -> int:
        return len(self.rewards)

    def __getitem__(
        self, idx
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor, torch.Tensor, torch.Tensor, 
               torch.Tensor]:
        return (
            self.obs[idx],
            self.actions[idx],
            self.rewards[idx],
            self.log_probs[idx],
            self.values[idx],
            self.advantages[idx],
            self.dones[idx],
        )

In [49]:
class Actor(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, in_size: int,
                 out_size: int,
                 width_size: int,
                 depth: int,
                 *,
                 key: PRNGKeyArray):
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            key=key
        )
    
    def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None):
        return self.mlp(x)

class Critic(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, in_size: int,
                 width_size: int,
                 depth: int,
                 *,
                 key: PRNGKeyArray):
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=1,
            width_size=width_size,
            depth=depth,
            key=key
        )
    
    def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None):
        return self.mlp(x)
        

In [50]:
@eqx.filter_jit
def get_value(obs: Float[Array, "n_dims"], critic: PyTree) -> Array:
    return critic(obs)

In [51]:
jit_rollout = eqx.filter_jit(rollout)

In [60]:
@eqx.filter_jit
def update_ppo(actor: PyTree,
               actor_optimizer: optax.GradientTransformation,
               actor_opt_state: optax.OptState,
               critic: PyTree,
               critic_optimizer: optax.GradientTransformation,
               critic_opt_state: optax.OptState,
               batch: Tuple[Array],
               policy_clip: float = 0.2,
              ):
    print("JIT")
    obs, actions, rewards, old_log_probs, values, advantages, dones = batch
    def ppo_objective(_actor, _obs, _old_log_probs, _advantages):
        new_logits = eqx.filter_vmap(actor)(obs)
        log_probs = jax.nn.log_softmax(new_logits)
        new_log_probs = jnp.take_along_axis(
            log_probs, jnp.expand_dims(actions, -1), axis=1
        )
        prob_ratio = jnp.exp(new_log_probs) / jnp.exp(old_log_probs)

        weighted_probs = advantages * prob_ratio
        weighted_clipped_probs = jnp.clip(prob_ratio, 
                                          1.0 - policy_clip, 1.0 + policy_clip) * advantages
        clipped_objective = jnp.fmin(weighted_probs, weighted_clipped_probs)
        return -clipped_objective.mean()
    
    def actor_step(_actor, _obs, _old_log_probs, _advantages, _actor_optimizer, 
                   _actor_opt_state):
        grad = eqx.filter_grad(ppo_objective)(_actor, _obs, _old_log_probs, _advantages)
        updates, _actor_opt_state = _actor_optimizer.update(grad, _actor_opt_state, _actor)
        _actor = eqx.apply_updates(_actor, updates)
        
        return _actor, _actor_opt_state
    
    def critic_loss(_critic, _obs, _advantages, _values):
        critic_values = eqx.filter_vmap(_critic)(obs)
        returns = _advantages + _values
        loss = returns - critic_values
        return loss.mean()
    
    def critic_step(_critic, _obs, _advantages, _values, 
                    _critic_optimizer, _critic_opt_state):
        grad = eqx.filter_grad(critic_loss)(_critic, _obs, _advantages, _values)
        updates, _critic_opt_state = _critic_optimizer.update(grad, _critic_opt_state, _critic)
        _critic = eqx.apply_updates(_critic, updates)
        
        return _critic, _critic_opt_state
    
    actor, actor_opt_state = actor_step(actor, obs, old_log_probs, advantages,
                                        actor_optimizer, actor_opt_state)
    
    critic, critic_opt_state = critic_step(critic, obs, advantages, values, critic_optimizer,
                                           critic_opt_state)
    
    return actor, actor_opt_state, critic, critic_opt_state

In [68]:
n_episodes = 1000
n_epochs = 10
policy_clip = 0.2
batch_size = 64
rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

# Create the Pendulum-v1 environment
env, env_params = gymnax.make("CartPole-v1")
jit_rollout = eqx.filter_jit(rollout)

In [69]:
key, actor_key, critic_key = jax.random.split(jax.random.PRNGKey(33), 3)
actor = Actor(in_size=env.observation_space(env_params).shape[0], 
             out_size=env.action_space(env_params).n,
             width_size=32,
             depth=3,
             key=actor_key)
critic = Critic(in_size=env.observation_space(env_params).shape[0], 
             width_size=32,
             depth=3,
             key=critic_key)
actor_optimizer = optax.adam(learning_rate=0.001)
actor_opt_state = actor_optimizer.init(eqx.filter(actor, eqx.is_inexact_array))

critic_optimizer = optax.adam(learning_rate=0.001)
critic_opt_state = critic_optimizer.init(eqx.filter(critic, eqx.is_inexact_array))

In [70]:
rng = jax.random.PRNGKey(33)
all_rewards = []
for eps in tqdm(range(n_episodes)):
    rng, subkey = jax.random.split(rng)
    obs, actions, log_probs, rewards, dones, states = jit_rollout(subkey, actor, env, env_params, 200, 0)
    
    #true_indices = jnp.where(dones)[0]
    #distances = jnp.diff(true_indices, prepend=-1)

    # Calculate average distance
    #avg_distance = jnp.mean(distances)
    #all_rewards.append(avg_distance)
    
    all_rewards.append(jnp.sum(dones))
    
    values = jax.vmap(get_value, in_axes=(0, None))(obs, critic)
    advantages = calculate_gae(
        rewards = rewards,
        values = values,
        dones = dones,
        gamma = 0.99,
        lam = 0.95
    )
    dataset = RLDataset(
        rewards=np.array(rewards),
        states=np.array(obs),
        actions=np.array(actions),
        log_probs=np.array(log_probs),
        dones=np.array(dones),
        values=np.array(values),
        advantages=np.array(advantages)
    )
    dataloader = DataLoader(batch_size=batch_size,
                        shuffle=True, drop_last=True,
                        dataset=dataset)
    for epoch in range(n_epochs):
        for batch in dataloader:
            obs, actions, rewards, old_log_probs, values, advantages, dones = batch
            b = (
                jnp.array(obs.numpy()),
                jnp.array(actions.numpy()), 
                jnp.array(rewards.numpy()),
                jnp.array(old_log_probs.numpy()),
                jnp.array(values.numpy()),
                jnp.array(advantages.numpy()), 
                jnp.array(dones.numpy())
            )
            actor, actor_opt_state, critic, critic_opt_state = update_ppo(actor, actor_optimizer, actor_opt_state, 
               critic, critic_optimizer, critic_opt_state, b,
                                                                         policy_clip=0.2)

print(jnp.array(all_rewards))

  0%|                                                                  | 1/1000 [00:00<08:34,  1.94it/s]

JIT


100%|███████████████████████████████████████████████████████████████| 1000/1000 [00:46<00:00, 21.47it/s]


[10 10  7 10  9 11  7 10  7 12  9  7 10  9 10  6 11  9  7  9  4  9 10  7
  9  9  8  9  6  7 10 10  8  8  8  8 10 11  9 10  9 11  6  6 10 12 11  8
  7  7 11 10  7  8  8  7  8 10  6  8 10 10  9  9  9  8  7 10  9 10  9  7
  6  6  7  6 10  8  9 10 10  9  6  5 12  8  7  8  6  9  9 12 10 12  9  7
 11  9  8 10  8 10  9  8  8  7  6  7  6  9  5  6  6 11 12  7 10 12  6  7
  9 10  6  9  7 11  8 11  7  5 10  7  9 10  8 10 11  9 10 10  6  6  5 11
  8  8  9 11  7 10  8 11  7  7  5  9  8  7 10  8  9 11  8 11  6 10  5  9
  8  8  8 10  6 10  9  8  9 10  6  8 10 10 10  8 11  7  8  8  8 13  7 10
 10  7  9  6  8  7 10  7 10  7  7  9  8 10 10  9 10  8 10  7  8 11 10 10
  9  6  9  8 10  5 11 10  9 11  8  7  8  8  6  9  9 10 10  9  9  7  8  6
 10  8 10  9  7  7  8  8  9 10  8  9  9  7  7 10  8  8 10  9  9 11 11  8
 10  8  9  8  9  6  8  8 11  9  7  8  8 11  6  8 12 10  7  7  7  6  9  7
  9  8 10 10  8  9  7  7 12  8  9  9 11 12 13 10  9  9  7  7  7  7  8  8
  8  9  9  9  9  8 10  7 11 10  9  5  9 11  8  7  9