In [None]:
!pip install minatar
!pip install dm-haiku
!pip install distrax
!pip install pgx
!pip install omegaconf
!pip install learn2learn

In [None]:

import sys
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from typing import NamedTuple, Literal
import distrax
import pgx
from pgx.experimental import auto_reset
import time

import pickle
from omegaconf import OmegaConf
from pydantic import BaseModel
import wandb
import learn2learn as l2l
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax



In [None]:
# Configuration class remains unchanged
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders" 
    seed: int = 0
    lr: float = 1e-3
    num_envs: int = 64 
    num_eval_envs: int = 100
    num_steps: int = 128
    total_timesteps: int = int(1e7)
    update_epochs: int = 4
    minibatch_size: int = 4096
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    wandb_entity: str = "alfiovavassori-universit-della-svizzera-italiana-org" 
    wandb_project: str = "pgx-minatar-ppo"  
    save_model: bool = False

    class Config:
        extra = "forbid"

args = PPOConfig()

# Initialize the environment
env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

# Actor-Critic model definition
class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2),
                        strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(64)(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(self.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)


def forward_fn(x, is_eval=False):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value


forward = hk.without_apply_rng(hk.transform(forward_fn))


optimizer = optax.chain(optax.clip_by_global_norm(
    args.max_grad_norm), optax.adam(args.lr, eps=1e-5))


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state):
        # COLLECT TRAJECTORIES
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs
            )
            runner_state = (params, opt_state, env_state,
                            env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.num_steps
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + args.gamma * next_value * (1 - done) - value
                gae = (
                    delta
                    + args.gamma * args.gae_lambda * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets)
                    value_loss = (
                        0.5 * jnp.maximum(value_losses,
                                          value_losses_clipped).mean()
                    )

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - args.clip_eps,
                            1.0 + args.clip_eps,
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    total_loss = (
                        loss_actor
                        + args.vf_coef * value_loss
                        - args.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert (
                batch_size == args.num_steps * args.num_envs
            ), "batch size must be equal to number of steps * number of envs"
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, [num_minibatches, -1] + list(x.shape[1:])
                ),
                shuffled_batch,
            )
            (params, opt_state),  total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch,
                            advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch,
                        advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.update_epochs
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info
    return _update_step


@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, value = forward.apply(params, state.observation)
        # action = logits.argmax(axis=-1)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key
    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()


def train(rng):
    tt = 0
    st = time.time()
    # INIT NETWORK
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1, ) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # warm up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # initial evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)

In [None]:
# Configuration class for DPO
class DPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"
    seed: int = 2
    lr: float = 5e-4
    num_envs: int = 2048
    num_eval_envs: int = 128
    num_steps: int = 5
    total_timesteps: int = 10000000
    update_epochs: int = 4
    minibatch_size: int = 1024
    num_minibatches: int = 32
    gamma: float = 0.99
    gae_lambda: float = 0.95
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 8.0
    dpo_alpha: float = 8.0
    dpo_beta: float = 0.6
    wandb_entity: str = "your-wandb-entity"
    wandb_project: str = "pgx-minatar-dpo"
    save_model: bool = False

    class Config:
        extra = "forbid"

args = DPOConfig()

# Initialize the environment
env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

# Actor-Critic model definition
class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation

    def __call__(self, x):
        x = x.astype(jnp.float32)
        activation = jax.nn.tanh if self.activation == "tanh" else jax.nn.relu
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(self.num_actions)(activation(hk.Linear(64)(activation(hk.Linear(64)(x)))))
        critic = hk.Linear(1)(activation(hk.Linear(64)(activation(hk.Linear(64)(x)))))
        return actor_mean, jnp.squeeze(critic, axis=-1)

def forward_fn(x):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value

forward = hk.without_apply_rng(hk.transform(forward_fn))

# Optimizer
optimizer = optax.chain(
    optax.clip_by_global_norm(args.max_grad_norm),
    optax.adam(args.lr, eps=1e-5)
)

# Transition class
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    lifetime_info: jnp.ndarray  # Temporal information for TA-LPO

# Generalized Advantage Estimation
def _calculate_gae(traj_batch, last_val):
    def _get_advantages(gae_and_next_value, transition):
        gae, next_value = gae_and_next_value
        delta = transition.reward + args.gamma * next_value * (1 - transition.done) - transition.value
        gae = delta + args.gamma * args.gae_lambda * (1 - transition.done) * gae
        return (gae, transition.value), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val),
        traj_batch,
        reverse=True
    )
    return advantages, advantages + traj_batch.value

# Drift-based Loss Function
def _loss_fn(params, traj_batch, advantages, targets, lifetime_info):
    logits, value = forward.apply(params, traj_batch.obs)
    pi = distrax.Categorical(logits=logits)
    log_prob = pi.log_prob(traj_batch.action)

    # Critic Loss
    value_pred_clipped = traj_batch.value + (
        value - traj_batch.value
    ).clip(-args.clip_eps, args.clip_eps)
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

    # Drift Function Logic
    alpha = args.dpo_alpha
    beta = args.dpo_beta
    log_diff = log_prob - traj_batch.log_prob
    ratio = jnp.exp(log_diff)
    gae = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # Temporal Factor for TA-LPO
    temporal_factor = lifetime_info[..., None]
    is_pos = (gae >= 0.0).astype("float32")
    drift1 = nn.relu(ratio - 1.0 - alpha * nn.tanh((ratio - 1.0) / alpha)) * temporal_factor
    drift2 = nn.relu(log_diff - beta * nn.tanh(log_diff / beta)) * (1 - temporal_factor)
    drift = drift1 * is_pos + drift2 * (1 - is_pos)
    loss_actor = -(ratio * gae - drift).mean()

    # Entropy Regularization
    entropy = pi.entropy().mean()

    # Total Loss
    total_loss = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy
    return total_loss, (value_loss, loss_actor, entropy)

def make_update_fn():
    def _update_step(runner_state):
        step_fn = jax.vmap(env.step)

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng, step_count = runner_state

            # Calculate Lifetime Information (Temporal Awareness)
            total_steps = args.num_steps * args.num_envs
            lifetime_info = jnp.array([step_count / total_steps], dtype=jnp.float32)

            # Select Action
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # Step Environment
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)

            # Create Transition
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs,
                lifetime_info
            )

            # Update Runner State
            runner_state = (params, opt_state, env_state, env_state.observation, rng, step_count + 1)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, args.num_steps)
        params, opt_state, env_state, last_obs, rng, step_count = runner_state
        _, last_val = forward.apply(params, last_obs)
        advantages, targets = _calculate_gae(traj_batch, last_val)

        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)

                    # Critic Loss
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

                    # Actor Loss with DPO
                    alpha = args.dpo_alpha
                    beta = args.dpo_beta
                    log_diff = log_prob - traj_batch.log_prob
                    ratio = jnp.exp(log_diff)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    is_pos = (gae >= 0.0).astype("float32")
                    drift1 = nn.relu(ratio - 1.0 - alpha * nn.tanh((ratio - 1.0) / alpha))
                    drift2 = nn.relu(log_diff - beta * nn.tanh(log_diff / beta))
                    drift = drift1 * is_pos + drift2 * (1 - is_pos)
                    loss_actor = -(ratio * gae - drift).mean()

                    # Entropy Regularization
                    entropy = pi.entropy().mean()

                    # Total Loss
                    total_loss = (
                        loss_actor
                        + args.vf_coef * value_loss
                        - args.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert batch_size == args.num_steps * args.num_envs, "Batch size mismatch."
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
            shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled_batch,
            )
            (params, opt_state), _ = jax.lax.scan(_update_minbatch, (params, opt_state), minibatches)
            update_state = (params, opt_state, traj_batch, advantages, targets, rng)
            return update_state, None

        update_state = (params, opt_state, traj_batch, advantages, targets, rng)
        update_state, _ = jax.lax.scan(_update_epoch, update_state, None, args.update_epochs)
        params, opt_state, _, _, _, rng = update_state
        runner_state = (params, opt_state, env_state, last_obs, rng, step_count)
        return runner_state, None

    return _update_step

@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, _ = forward.apply(params, state.observation)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key

    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()

def train(rng):
    tt = 0
    st = time.time()

    # Initialize Network
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1,) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # Initialize Update Function
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # Initialize Environment
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # Warm-Up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # Initial Evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    # Training Loop
    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # Evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)

In [None]:
# Configuration class for PPO without TA
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"
    seed: int = 0
    lr: float = 1e-3
    num_envs: int = 64
    num_eval_envs: int = 100
    num_steps: int = 128
    total_timesteps: int = int(1e7)
    update_epochs: int = 4
    minibatch_size: int = 4096
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    wandb_entity: str = "your-wandb-entity"
    wandb_project: str = "pgx-minatar-ppo"
    save_model: bool = False

    class Config:
        extra = "forbid"

args = PPOConfig()

# Initialize the environment
env = pgx.make(str(args.env_name))

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

# Actor-Critic model definition
class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation

    def __call__(self, x):
        x = x.astype(jnp.float32)
        activation = jax.nn.tanh if self.activation == "tanh" else jax.nn.relu
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(self.num_actions)(activation(hk.Linear(64)(activation(hk.Linear(64)(x)))))
        critic = hk.Linear(1)(activation(hk.Linear(64)(activation(hk.Linear(64)(x)))))
        return actor_mean, jnp.squeeze(critic, axis=-1)

def forward_fn(x):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value

forward = hk.without_apply_rng(hk.transform(forward_fn))

# Optimizer
optimizer = optax.chain(
    optax.clip_by_global_norm(args.max_grad_norm),
    optax.adam(args.lr, eps=1e-5)
)

# Transition class
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray

# Generalized Advantage Estimation
def _calculate_gae(traj_batch, last_val):
    def _get_advantages(gae_and_next_value, transition):
        gae, next_value = gae_and_next_value
        delta = transition.reward + args.gamma * next_value * (1 - transition.done) - transition.value
        gae = delta + args.gamma * args.gae_lambda * (1 - transition.done) * gae
        return (gae, transition.value), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val),
        traj_batch,
        reverse=True
    )
    return advantages, advantages + traj_batch.value

def _loss_fn(params, traj_batch, advantages, targets):
    logits, value = forward.apply(params, traj_batch.obs)
    pi = distrax.Categorical(logits=logits)
    log_prob = pi.log_prob(traj_batch.action)

    # Critic Loss
    value_pred_clipped = traj_batch.value + (
        value - traj_batch.value
    ).clip(-args.clip_eps, args.clip_eps)
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

    # Actor Loss
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = jnp.clip(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps) * gae
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

    # Entropy Regularization
    entropy = pi.entropy().mean()

    # Total Loss
    total_loss = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy
    return total_loss, (value_loss, loss_actor, entropy)

def make_update_fn():
    def _update_step(runner_state):
        step_fn = jax.vmap(env.step)

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)
            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs,
            )
            runner_state = (params, opt_state, env_state, env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, args.num_steps)
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)
        advantages, targets = _calculate_gae(traj_batch, last_val)

        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info
                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert batch_size == args.num_steps * args.num_envs, "Batch size mismatch."
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
            shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled_batch,
            )
            (params, opt_state), _ = jax.lax.scan(_update_minbatch, (params, opt_state), minibatches)
            update_state = (params, opt_state, traj_batch, advantages, targets, rng)
            return update_state, None

        update_state = (params, opt_state, traj_batch, advantages, targets, rng)
        update_state, _ = jax.lax.scan(_update_epoch, update_state, None, args.update_epochs)
        params, opt_state, _, _, _ = update_state
        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, None

    return _update_step

# Evaluation Function
@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, _ = forward.apply(params, state.observation)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key

    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()

# Training Function
def train(rng, mode="ppo"):
    """
    Train function to switch between PPO, DPO, and LPO.
    `mode`: Can be "ppo", "dpo", or "lpo".
    """
    tt = 0
    st = time.time()

    # Initialize Network
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros((1,) + env.observation_shape)
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # Initialize Update Function
    if mode == "ppo":
        _update_step = make_update_fn(with_drift=False, with_lifetime=False)
    elif mode == "dpo":
        _update_step = make_update_fn(with_drift=True, with_lifetime=False)
    elif mode == "lpo":
        _update_step = make_update_fn(with_drift=True, with_lifetime=True)
    else:
        raise ValueError(f"Unsupported mode: {mode}")

    jitted_update_step = jax.jit(_update_step)

    # Initialize Environment
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)

    rng, _rng = jax.random.split(rng)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng, 0)  # step_count = 0

    # Warm-Up
    _, _ = jitted_update_step(runner_state)

    steps = 0

    # Initial Evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for i in range(num_updates):
        runner_state, loss_info = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # Evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(runner_state[0], _rng)
        log = {
            "sec": tt,
            f"{args.env_name}/eval_R": float(eval_R),
            "steps": steps,
            f"{args.env_name}/loss_actor": float(loss_info[1]),
            f"{args.env_name}/loss_critic": float(loss_info[0]),
        }
        print(log)
        wandb.log(log)
        st = time.time()

    return runner_state

if __name__ == "__main__":
    # Initialize Weights and Biases
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=args.dict())

    # RNG Initialization
    rng = jax.random.PRNGKey(args.seed)

    # Train PPO
    print("Training PPO...")
    train(rng, mode="ppo")

    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)