In [None]:
# Requirements fo Kaggle
# install all missing packages
!pip install minatar
!pip install dm-haiku
!pip install distrax
!pip install pgx
!pip install omegaconf
!pip install learn2learn

Collecting minatar
  Downloading MinAtar-1.0.15-py3-none-any.whl.metadata (685 bytes)
Downloading MinAtar-1.0.15-py3-none-any.whl (16 kB)
Installing collected packages: minatar
Successfully installed minatar-1.0.15
Collecting dm-haiku
  Downloading dm_haiku-0.0.13-py3-none-any.whl.metadata (19 kB)
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl.metadata (8.9 kB)
Downloading dm_haiku-0.0.13-py3-none-any.whl (373 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m373.9/373.9 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25hDownloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.13 jmp-0.0.4
Collecting distrax
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m00:01[0m
[

In [None]:

import sys
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from typing import NamedTuple, Literal
import distrax
import pgx
from pgx.experimental import auto_reset
import time
import pickle
from omegaconf import OmegaConf
from pydantic import BaseModel
import wandb
import learn2learn as l2l

In [None]:
# BASE CONFIG
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"  
    seed: int = 0
    lr: float = 5e-3 
    num_envs: int = 64 
    num_eval_envs: int = 100
    num_steps: int = 128 
    total_timesteps: int = int(1e7)  
    update_epochs: int = 4 
    minibatch_size: int = 4096  
    gamma: float = 0.99 
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01  
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5  
    wandb_entity: str = "nonarruginitocalamarodiferro-usi"
    wandb_project: str = "pgx-minatar-ppo" 
    save_model: bool = False

    class Config:
        extra = "forbid"


# OUR iperparameters just to test
# We are in a notebook, so we can't use argparse
args = PPOConfig(
    env_name="minatar-space_invaders", 
    seed=0,
    lr=1e-3,
    num_envs=64,
    num_eval_envs=100,
    num_steps=128,
    total_timesteps=int(1e7),
    update_epochs=4,
    minibatch_size=4096,  
    gamma=0.99,
    gae_lambda=0.90,
    clip_eps=0.25,
    ent_coef=0.005, 
    vf_coef=0.5,
    max_grad_norm=0.5, 
    wandb_entity="nonarruginitocalamarodiferro-usi", 
    wandb_project="pgx-minatar-ppo",  
    save_model=False 
)

print(args) 

# THe rest is prette

env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size




class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2),
                        strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(64)(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(self.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)


def forward_fn(x, is_eval=False):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value


forward = hk.without_apply_rng(hk.transform(forward_fn))


optimizer = optax.chain(optax.clip_by_global_norm(
    args.max_grad_norm), optax.adam(args.lr, eps=1e-5))


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state):
        # COLLECT TRAJECTORIES
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs
            )
            runner_state = (params, opt_state, env_state,
                            env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.num_steps
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + args.gamma * next_value * (1 - done) - value
                gae = (
                    delta
                    + args.gamma * args.gae_lambda * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets)
                    value_loss = (
                        0.5 * jnp.maximum(value_losses,
                                          value_losses_clipped).mean()
                    )

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - args.clip_eps,
                            1.0 + args.clip_eps,
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    total_loss = (
                        loss_actor
                        + args.vf_coef * value_loss
                        - args.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert (
                batch_size == args.num_steps * args.num_envs
            ), "batch size must be equal to number of steps * number of envs"
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, [num_minibatches, -1] + list(x.shape[1:])
                ),
                shuffled_batch,
            )
            (params, opt_state),  total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch,
                            advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch,
                        advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.update_epochs
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info
    return _update_step


@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, value = forward.apply(params, state.observation)
        # action = logits.argmax(axis=-1)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key
    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()


def train(rng):
    tt = 0
    st = time.time()
    # INIT NETWORK
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1, ) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # warm up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # initial evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)

env_name='minatar-space_invaders' seed=0 lr=0.001 num_envs=64 num_eval_envs=100 num_steps=128 total_timesteps=10000000 update_epochs=4 minibatch_size=4096 gamma=0.99 gae_lambda=0.9 clip_eps=0.25 ent_coef=0.005 vf_coef=0.5 max_grad_norm=0.5 wandb_entity='nonarruginitocalamarodiferro-usi' wandb_project='pgx-minatar-ppo' save_model=False


VBox(children=(Label(value='0.065 MB of 0.065 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
minatar-space_invaders/eval_R,▁▁▁▂▂▃▃▃▃▄▄▄▅▄▄▅▅▅▅▅▆▆▅▅▅▆▆▆▆▆▇▇▇▇▅▇█▇▇▇
sec,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇███
steps,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▇████

0,1
minatar-space_invaders/eval_R,55.08
sec,39.35105
steps,4218880.0


{'sec': 4.752812385559082, 'minatar-space_invaders/eval_R': 4.349999904632568, 'steps': 0}
{'sec': 4.844480991363525, 'minatar-space_invaders/eval_R': 4.46999979019165, 'steps': 8192}
{'sec': 4.916128396987915, 'minatar-space_invaders/eval_R': 5.329999923706055, 'steps': 16384}
{'sec': 4.9766552448272705, 'minatar-space_invaders/eval_R': 5.039999961853027, 'steps': 24576}
{'sec': 5.037154197692871, 'minatar-space_invaders/eval_R': 4.909999847412109, 'steps': 32768}
{'sec': 5.096987247467041, 'minatar-space_invaders/eval_R': 5.5, 'steps': 40960}
{'sec': 5.1582982540130615, 'minatar-space_invaders/eval_R': 5.869999885559082, 'steps': 49152}
{'sec': 5.220142841339111, 'minatar-space_invaders/eval_R': 6.119999885559082, 'steps': 57344}
{'sec': 5.281617164611816, 'minatar-space_invaders/eval_R': 7.170000076293945, 'steps': 65536}
{'sec': 5.341115474700928, 'minatar-space_invaders/eval_R': 6.440000057220459, 'steps': 73728}
{'sec': 5.402669429779053, 'minatar-space_invaders/eval_R': 7.190000

In [18]:
print(jax.devices())


x = jnp.ones((1000, 1000))
print(f"Array x is on device: {x.devices()}")

[cuda(id=0), cuda(id=1)]
Array x is on device: {cuda(id=0)}


In [15]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk
import distrax
import pgx
import wandb
import time
import pickle

# Configuration class for LPO
class LPOConfig:
    def __init__(self,
                 env_name: str = "minatar-space_invaders",
                 seed: int = 0,
                 lr: float = 5e-3,
                 num_envs: int = 64,
                 num_eval_envs: int = 100,
                 num_steps: int = 128,
                 total_timesteps: int = int(1e7),
                 update_epochs: int = 4,
                 minibatch_size: int = 4096,
                 gamma: float = 0.99,
                 gae_lambda: float = 0.95,
                 clip_eps: float = 0.2,
                 ent_coef: float = 0.01,
                 vf_coef: float = 0.5,
                 max_grad_norm: float = 0.5,
                 wandb_entity: str = "nonarruginitocalamarodiferro-usi",
                 wandb_project: str = "pgx-minatar-ppo",
                 save_model: bool = False):
        self.env_name = env_name
        self.seed = seed
        self.lr = lr
        self.num_envs = num_envs
        self.num_eval_envs = num_eval_envs
        self.num_steps = num_steps
        self.total_timesteps = total_timesteps
        self.update_epochs = update_epochs
        self.minibatch_size = minibatch_size
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_eps = clip_eps
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.wandb_entity = wandb_entity
        self.wandb_project = wandb_project
        self.save_model = save_model

args = LPOConfig(
    env_name="minatar-space_invaders",
    seed=0,
    lr=5e-3,
    num_envs=64,
    num_eval_envs=100,
    num_steps=128,
    total_timesteps=int(1e7),
    update_epochs=4,
    minibatch_size=4096,
    gamma=0.99,
    gae_lambda=0.95,
    clip_eps=0.2,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    wandb_entity="nonarruginitocalamarodiferro-usi",
    wandb_project="pgx-minatar-ppo",
    save_model=False
)

# Initialize environment
env = pgx.make(args.env_name)
num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

# Actor-Critic Model
class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        activation_fn = jax.nn.relu if self.activation == "relu" else jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation_fn(actor_mean)
        actor_mean = hk.Linear(env.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation_fn(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)

def forward_fn(x):
    net = ActorCritic(env.num_actions)
    return net(x)

forward = hk.without_apply_rng(hk.transform(forward_fn))

# Optimizer
optimizer = optax.chain(
    optax.clip_by_global_norm(args.max_grad_norm),
    optax.adam(args.lr, eps=1e-5)
)

# Evaluation function
def evaluate(params, rng_key, env):
    eval_env = pgx.make(args.env_name)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(eval_env.init)(subkeys)
    R = jnp.zeros((args.num_eval_envs,))  # Rewards
    steps = 0  # Step counter

    def cond_fn(carry):
        state, R, steps, rng_key = carry
        return (~state.terminated).any() & (steps < 1000)  # Stop if all terminated or max steps

    def body_fn(carry):
        state, R, steps, rng_key = carry
        obs = state.observation
        logits, _ = forward.apply(params, obs)
        pi = distrax.Categorical(logits=logits)

        rng_key, sub_key = jax.random.split(rng_key)
        actions = pi.sample(seed=sub_key)

        rng_key, step_key = jax.random.split(rng_key)
        keys = jax.random.split(step_key, state.observation.shape[0])
        state = jax.vmap(eval_env.step)(state, actions, keys)

        # Squeeze state.rewards to remove extra dimensions
        R += state.rewards.squeeze()  # Ensure shape matches (num_eval_envs,)
        steps += 1
        return state, R, steps, rng_key

    # Include rng_key in the carry tuple
    carry = (state, R, steps, rng_key)
    state, R, steps, rng_key = jax.lax.while_loop(cond_fn, body_fn, carry)
    return R.mean()


# Training loop
def train(rng, env):
    rng, init_rng = jax.random.split(rng)
    dummy_input = jnp.zeros((1,) + env.observation_shape)
    params = forward.init(init_rng, dummy_input)
    opt_state = optimizer.init(params)

    # Initialize environment states
    rng, env_rng = jax.random.split(rng)
    env_keys = jax.random.split(env_rng, args.num_envs)
    env_states = jax.vmap(env.init)(env_keys)

    total_steps = 0
    for update in range(num_updates):
        trajectories = []
        for step in range(args.num_steps):
            # Sample actions
            obs = env_states.observation
            logits, values = forward.apply(params, obs)
            pi = distrax.Categorical(logits=logits)
            rng, action_rng = jax.random.split(rng)
            actions = pi.sample(seed=action_rng)
            log_probs = pi.log_prob(actions)

            # Step environment with a new PRNG key
            rng, step_rng = jax.random.split(rng)
            env_keys = jax.random.split(step_rng, args.num_envs)
            env_states = jax.vmap(env.step)(env_states, actions, env_keys)

            rewards, dones = env_states.rewards, env_states.terminated

            trajectories.append({
                'obs': obs,
                'actions': actions,
                'log_probs': log_probs,
                'values': values,
                'rewards': rewards,
                'dones': dones
            })

            total_steps += args.num_envs

        # Convert trajectories to arrays
        traj_batch = {k: jnp.array([traj[k] for traj in trajectories]) for k in trajectories[0]}

           # Compute advantages and targets
        last_obs = env_states.observation
        _, last_values = forward.apply(params, last_obs)
        
        # Initialize advantages with the same shape as rewards: (num_steps, num_envs)
        advantages = jnp.zeros((args.num_steps, args.num_envs))  # Shape: (128, 64)
        
        # Initialize GAE with shape matching the number of environments: (num_envs,)
        gae = jnp.zeros((args.num_envs,))  # Shape: (64,)
        
        for t in reversed(range(args.num_steps)):
            # Compute delta, ensuring the shape is (num_envs,)
            delta = (
                traj_batch['rewards'][t].squeeze()  # Shape: (64,)
                + args.gamma * (1 - traj_batch['dones'][t].squeeze()) * last_values  # Shape: (64,)
                - traj_batch['values'][t].squeeze()  # Shape: (64,)
            )
        
            # Update GAE, ensuring consistent shape
            gae = delta + args.gamma * args.gae_lambda * (1 - traj_batch['dones'][t].squeeze()) * gae
        
            # Assign to advantages
            advantages = advantages.at[t].set(gae)  # Shape: (128, 64)
        
            # Update last_values for the next step
            last_values = traj_batch['values'][t].squeeze()  # Shape: (64,)
        
        # Compute returns as advantages + values
        returns = advantages + traj_batch['values']

        # Flatten the batch
        batch = {
            'obs': traj_batch['obs'].reshape(-1, *env.observation_shape),
            'actions': traj_batch['actions'].reshape(-1),
            'log_probs': traj_batch['log_probs'].reshape(-1),
            'advantages': advantages.reshape(-1),
            'returns': returns.reshape(-1),
        }

        # Normalize advantages
        batch['advantages'] = (batch['advantages'] - batch['advantages'].mean()) / (batch['advantages'].std() + 1e-8)

        # Update policy
        for epoch in range(args.update_epochs):
            idxs = jax.random.permutation(rng, len(batch['advantages']))
            for start in range(0, len(batch['advantages']), args.minibatch_size):
                end = start + args.minibatch_size
                mb_idxs = idxs[start:end]
                mb = {k: v[mb_idxs] for k, v in batch.items()}

                def loss_fn(params):
                    logits, values = forward.apply(params, mb['obs'])
                    pi = distrax.Categorical(logits=logits)
                    log_probs = pi.log_prob(mb['actions'])

                    ratio = jnp.exp(log_probs - mb['log_probs'])
                    surr1 = ratio * mb['advantages']
                    surr2 = jnp.clip(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * mb['advantages']
                    actor_loss = -jnp.minimum(surr1, surr2).mean()

                    value_loss = ((values - mb['returns']) ** 2).mean()
                    entropy_bonus = pi.entropy().mean()

                    total_loss = actor_loss + args.vf_coef * value_loss - args.ent_coef * entropy_bonus
                    return total_loss

                grads = jax.grad(loss_fn)(params)
                updates, opt_state = optimizer.update(grads, opt_state, params)
                params = optax.apply_updates(params, updates)

        # Logging and evaluation
        if update % 10 == 0 or update == num_updates - 1:
            rng, eval_rng = jax.random.split(rng)
            eval_reward = evaluate(params, eval_rng, env)
            wandb.log({
                'update': update,
                'total_steps': total_steps,
                'eval_reward': eval_reward,
            })
            print(f"Update {update}, Total Steps: {total_steps}, Eval Reward: {eval_reward}")

    return params


if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args))
    rng = jax.random.PRNGKey(args.seed)
    trained_params = train(rng, env)

    if args.save_model:
        with open(f"{args.env_name}-model.pkl", "wb") as f:
            pickle.dump(trained_params, f)


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Update 0, Total Steps: 8192, Eval Reward: 4.25


KeyboardInterrupt: 