In [1]:
!pip install minatar
!pip install dm-haiku
!pip install distrax
!pip install pgx
!pip install omegaconf
!pip install learn2learn

Collecting minatar
  Downloading MinAtar-1.0.15-py3-none-any.whl.metadata (685 bytes)
Downloading MinAtar-1.0.15-py3-none-any.whl (16 kB)
Installing collected packages: minatar
Successfully installed minatar-1.0.15
Collecting dm-haiku
  Downloading dm_haiku-0.0.13-py3-none-any.whl.metadata (19 kB)
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl.metadata (8.9 kB)
Downloading dm_haiku-0.0.13-py3-none-any.whl (373 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m373.9/373.9 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.13 jmp-0.0.4
Collecting distrax
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:0

In [2]:

import sys
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from typing import NamedTuple, Literal
import distrax
import pgx
from pgx.experimental import auto_reset
import time

import pickle
from omegaconf import OmegaConf
from pydantic import BaseModel
import wandb
import learn2learn as l2l
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax



In [18]:
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"  # Focus on space_invaders
    seed: int = 0
    lr: float = 5e-3  # Updated Learning Rate
    num_envs: int = 64  # Updated number of environments
    num_eval_envs: int = 100
    num_steps: int = 128  # Unroll length
    total_timesteps: int = int(1e7)  # Updated number of timesteps
    update_epochs: int = 4  # Updated number of update epochs
    minibatch_size: int = 4096  # This depends on your choice of minibatches
    gamma: float = 0.99  # Updated gamma
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01  # Updated entropy coefficient
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5  # Updated max grad norm
    wandb_entity: str = "nonarruginitocalamarodiferro-usi"  # Your wandb entity
    wandb_project: str = "pgx-minatar-ppo"  # Your wandb project
    save_model: bool = False

    class Config:
        extra = "forbid"


# Initialize the PPOConfig manually with these hyperparameters
args = PPOConfig(
    env_name="minatar-space_invaders",  # Can be adjusted to other MinAtar environments if needed
    seed=0,
    lr=1e-3,
    num_envs=64,
    num_eval_envs=100,
    num_steps=128,
    total_timesteps=int(1e7),  # 10 million timesteps
    update_epochs=4,
    minibatch_size=4096,  # Adjusted as per the batch size and minibatches
    gamma=0.99,
    gae_lambda=0.90,
    clip_eps=0.25,
    ent_coef=0.005,  # Entropy coefficient set to 0.01
    vf_coef=0.5,
    max_grad_norm=0.5,  # Max grad norm set to 0.5
    wandb_entity="nonarruginitocalamarodiferro-usi",  # wandb entity
    wandb_project="pgx-minatar-ppo",  # wandb project
    save_model=False  # Set to True if you want to save the model after training
)

print(args)  # To verify your updated configuration

# Initialize environment
env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size




class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2),
                        strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(64)(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(self.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)


def forward_fn(x, is_eval=False):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value


forward = hk.without_apply_rng(hk.transform(forward_fn))


optimizer = optax.chain(optax.clip_by_global_norm(
    args.max_grad_norm), optax.adam(args.lr, eps=1e-5))


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state):
        # COLLECT TRAJECTORIES
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs
            )
            runner_state = (params, opt_state, env_state,
                            env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.num_steps
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + args.gamma * next_value * (1 - done) - value
                gae = (
                    delta
                    + args.gamma * args.gae_lambda * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets)
                    value_loss = (
                        0.5 * jnp.maximum(value_losses,
                                          value_losses_clipped).mean()
                    )

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - args.clip_eps,
                            1.0 + args.clip_eps,
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    total_loss = (
                        loss_actor
                        + args.vf_coef * value_loss
                        - args.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert (
                batch_size == args.num_steps * args.num_envs
            ), "batch size must be equal to number of steps * number of envs"
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, [num_minibatches, -1] + list(x.shape[1:])
                ),
                shuffled_batch,
            )
            (params, opt_state),  total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch,
                            advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch,
                        advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.update_epochs
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info
    return _update_step


@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, value = forward.apply(params, state.observation)
        # action = logits.argmax(axis=-1)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key
    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()


def train(rng):
    tt = 0
    st = time.time()
    # INIT NETWORK
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1, ) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # warm up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # initial evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)

env_name='minatar-space_invaders' seed=0 lr=0.001 num_envs=64 num_eval_envs=100 num_steps=128 total_timesteps=10000000 update_epochs=4 minibatch_size=4096 gamma=0.99 gae_lambda=0.9 clip_eps=0.25 ent_coef=0.005 vf_coef=0.5 max_grad_norm=0.5 wandb_entity='nonarruginitocalamarodiferro-usi' wandb_project='pgx-minatar-ppo' save_model=False


In [None]:
class DPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"  # Ambiente di default
    seed: int = 0  # Random seed
    lr: float = 5e-3  # Learning rate aggiornato
    num_envs: int = 64  # Numero di ambienti paralleli
    num_eval_envs: int = 100  # Numero di ambienti di valutazione
    num_steps: int = 128  # Unroll length
    total_timesteps: int = int(1e7)  # Totale timesteps: 10 milioni
    update_epochs: int = 4  # Numero di epoche per aggiornare
    minibatch_size: int = 8192  # Calcolato come num_envs * num_steps // num_minibatches
    num_minibatches: int = 8  # Numero di minibatch
    gamma: float = 0.99  # Fattore di sconto
    gae_lambda: float = 0.95  # Lambda per GAE
    clip_eps: float = 0.25  # Clipping ratio
    ent_coef: float = 0.01  # Entropy coefficient aggiornato
    vf_coef: float = 0.5  # Coefficiente per il valore (critic)
    max_grad_norm: float = 8.0  # Max grad norm per LPO/DPO
    dpo_alpha: float = 2.0  # Alpha per il drift di DPO
    dpo_beta: float = 0.6  # Beta per il drift di DPO
    wandb_entity: str = "nonarruginitocalamarodiferro-usi"  # wandb entity
    wandb_project: str = "pgx-minatar-dpo"  # wandb project aggiornato
    save_model: bool = False  # Se salvare il modello o meno

    class Config:
        extra = "forbid"  # Impedisce parametri extra nella configurazione

args = DPOConfig(
    env_name="minatar-space_invaders",  # Ambiente target
    seed=0,  # Seed random per riproducibilità
    lr=5e-3,  # Learning rate aggiornato
    num_envs=64,  # Numero di ambienti paralleli
    num_eval_envs=100,  # Numero di ambienti di valutazione
    num_steps=128,  # Unroll length
    total_timesteps=int(1e7),  # 10 milioni di timesteps
    update_epochs=4,  # Numero di epoche per aggiornare
    num_minibatches=8,  # Numero di minibatch
    gamma=0.99,  # Fattore di sconto
    gae_lambda=0.95,  # Lambda per GAE
    clip_eps=0.20,  # Clipping ratio
    ent_coef=0.01,  # Entropy coefficient aggiornato
    vf_coef=0.5,  # Coefficiente per il valore (critic)
    max_grad_norm=8.0,  # Max grad norm per LPO/DPO
    dpo_alpha=2.0,  # Parametro alpha per il drift di DPO
    dpo_beta=1.2,  # Parametro beta per il drift di DPO
    wandb_entity="nonarruginitocalamarodiferro-usi",  # Nome entità per wandb
    wandb_project="pgx-minatar-ppo",  # Nome progetto per wandb
    save_model=False  # Se salvare il modello o meno
)

print(args)  # To verify your updated configuration

# Initialize environment
env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size




class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2),
                        strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(64)(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(self.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)


def forward_fn(x, is_eval=False):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value


forward = hk.without_apply_rng(hk.transform(forward_fn))


optimizer = optax.chain(optax.clip_by_global_norm(
    args.max_grad_norm), optax.adam(args.lr, eps=1e-5))


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state):
        # COLLECT TRAJECTORIES
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs
            )
            runner_state = (params, opt_state, env_state,
                            env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.num_steps
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + args.gamma * next_value * (1 - done) - value
                gae = (
                    delta
                    + args.gamma * args.gae_lambda * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)
                
                    # CALCOLA LA PERDITA DEL VALORE (CRITICO)
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                
                    # CALCOLA LA PERDITA DELL'ATTORE (DPO)
                    log_diff = log_prob - traj_batch.log_prob
                    ratio = jnp.exp(log_diff)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                
                    # Differenzia tra vantaggio positivo e negativo
                    is_pos = (gae >= 0.0).astype(jnp.float32)
                
                    # Drift per vantaggi positivi
                    r1 = ratio - 1.0
                    drift1 = nn.relu(r1 * gae - args.dpo_alpha * jnp.tanh(r1 * gae / args.dpo_alpha))
                
                    # Drift per vantaggi negativi
                    drift2 = nn.relu(
                        log_diff * gae - args.dpo_beta * jnp.tanh(log_diff * gae / args.dpo_beta)
                    )
                
                    # Combina i drift in base al segno del vantaggio
                    drift = drift1 * is_pos + drift2 * (1 - is_pos)
                
                    # Calcola la perdita dell'attore con drift
                    loss_actor = -(ratio * gae - drift).mean()
                    entropy = pi.entropy().mean()
                
                    # PERDITA TOTALE
                    total_loss = (
                        loss_actor
                        + args.vf_coef * value_loss
                        - args.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert (
                batch_size == args.num_steps * args.num_envs
            ), "batch size must be equal to number of steps * number of envs"
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, [num_minibatches, -1] + list(x.shape[1:])
                ),
                shuffled_batch,
            )
            (params, opt_state),  total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch,
                            advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch,
                        advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.update_epochs
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info
    return _update_step


@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, value = forward.apply(params, state.observation)
        # action = logits.argmax(axis=-1)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key
    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()


def train(rng):
    tt = 0
    st = time.time()
    # INIT NETWORK
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1, ) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # warm up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # initial evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)

env_name='minatar-space_invaders' seed=0 lr=0.005 num_envs=64 num_eval_envs=100 num_steps=128 total_timesteps=10000000 update_epochs=4 minibatch_size=8192 num_minibatches=8 gamma=0.99 gae_lambda=0.95 clip_eps=0.2 ent_coef=0.01 vf_coef=0.5 max_grad_norm=8.0 dpo_alpha=2.0 dpo_beta=1.2 wandb_entity='nonarruginitocalamarodiferro-usi' wandb_project='pgx-minatar-ppo' save_model=False


VBox(children=(Label(value='0.033 MB of 0.033 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
minatar-space_invaders/eval_R,▁▁▁▁▁▂▃▄▄▅▅▆▆▆▆▆▆▇▆█▇▇▇█▇▇▇▇▇▇▆▇▆▇▇▆▇▆██
sec,▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
steps,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇████

0,1
minatar-space_invaders/eval_R,39.32
sec,19.58736
steps,1376256.0


{'sec': 5.309812068939209, 'minatar-space_invaders/eval_R': 4.349999904632568, 'steps': 0}
{'sec': 5.399070978164673, 'minatar-space_invaders/eval_R': 4.610000133514404, 'steps': 8192}
{'sec': 5.4634788036346436, 'minatar-space_invaders/eval_R': 5.089999675750732, 'steps': 16384}
{'sec': 5.518078327178955, 'minatar-space_invaders/eval_R': 5.559999942779541, 'steps': 24576}
{'sec': 5.573474884033203, 'minatar-space_invaders/eval_R': 5.699999809265137, 'steps': 32768}
{'sec': 5.628509044647217, 'minatar-space_invaders/eval_R': 6.009999752044678, 'steps': 40960}
{'sec': 5.682448148727417, 'minatar-space_invaders/eval_R': 7.349999904632568, 'steps': 49152}
{'sec': 5.739185810089111, 'minatar-space_invaders/eval_R': 6.869999885559082, 'steps': 57344}
{'sec': 5.792344570159912, 'minatar-space_invaders/eval_R': 6.909999847412109, 'steps': 65536}
{'sec': 5.8447425365448, 'minatar-space_invaders/eval_R': 7.989999771118164, 'steps': 73728}
{'sec': 5.898927450180054, 'minatar-space_invaders/eval_