In [1]:
import mediapy as mp
from functools import partial

import jax
from jax import numpy as jnp
from jax import random as jr
from jax import lax, vmap, nn

import optax
import equinox as eqx

import mctx
import mcts

In [2]:
class Policy(eqx.Module):
    value_head: eqx.nn.MLP
    policy_head: eqx.nn.MLP

    def __init__(self, input_shape, num_actions, support_size, key: jr.PRNGKey):
        full_support_size = support_size * 2 + 1
        in_features = input_shape[0] * input_shape[1]
        k1, k2 = jr.split(key)

        self.value_head = eqx.nn.MLP(
            in_size=in_features,
            out_size=full_support_size,
            width_size=64,
            depth=2,
            key=k1,
        )
        self.policy_head = eqx.nn.MLP(
            in_size=in_features,
            out_size=num_actions,
            width_size=64,
            depth=2,
            key=k2,
        )

    def __call__(self, obs: jnp.ndarray):
        flat_obs = obs.reshape(-1)
        value_logits = self.value_head(flat_obs)
        policy_logits = self.policy_head(flat_obs)
        return value_logits, policy_logits


def loss_fn(model, batch, env, support_size, rng_key):
    batch_size, traj_len = batch.action.shape
    init_obs = batch.obs[:, 0] 
    target_returns = mcts.to_discrete(batch.returns, support_size)

    def body(carry, t):
        total_loss, value_loss, policy_loss, obs, rng = carry

        value_logits, policy_logits = jax.vmap(model)(obs)
        target_value = target_returns[:, t]  
        target_policy = batch.action_probs[:, t]  
        action_t = batch.action[:, t] 

        v_loss = jnp.mean(optax.softmax_cross_entropy(value_logits, target_value))
        pi_loss = jnp.mean(optax.softmax_cross_entropy(policy_logits, target_policy))

        rng, subkey = jr.split(rng)
        subkeys = jr.split(subkey, batch_size)
        next_obs, reward, done = jax.vmap(env.step)(obs, action_t, subkeys)

        new_carry = (
            total_loss + v_loss + pi_loss,
            value_loss + v_loss,
            policy_loss + pi_loss,
            next_obs,
            rng,
        )
        return new_carry, None

    (total_loss, value_loss, policy_loss, _, _), _ = lax.scan(
        body, (0.0, 0.0, 0.0, init_obs, rng_key), jnp.arange(traj_len)
    )

    l2_loss = 0.5 * sum(
        jnp.sum(jnp.square(p))
        for p in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array))
    )
    total_loss = (total_loss / traj_len) + (1e-4 * l2_loss)

    aux = {
        "total_loss": total_loss,
        "value_loss": value_loss / traj_len,
        "policy_loss": policy_loss / traj_len,
        "l2_loss": 1e-4 * l2_loss,
    }
    return total_loss, aux


@eqx.filter_jit
def train_step(model, batch, env, optim, opt_state, support_size, rng_key):
    (loss_val, aux), grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(
        model, batch, env, support_size, rng_key
    )
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, aux

In [3]:
def root_fn(env, obs, model, support_size, rng_key):
    value_logits, policy_logits = jax.vmap(model)(obs)
    value_probs = nn.softmax(value_logits)
    value_scalar = mcts.from_discrete(value_probs, support_size)
    return mctx.RootFnOutput(
        prior_logits=policy_logits,
        value=value_scalar,
        embedding=obs,
    )


def recurrent_fn(params, rng_key, action, embedding):
    env, model, support_size = params
    B = embedding.shape[0]
    next_state, reward, done = vmap(env.step)(
        embedding, action.astype(jnp.int32), jr.split(rng_key, B)
    )
    value_logits, policy_logits = jax.vmap(model)(next_state)
    value_probs = nn.softmax(value_logits)
    value_scalar = mcts.from_discrete(value_probs, support_size)
    discount = jnp.where(done, 0.0, 1.0)
    return (
        mctx.RecurrentFnOutput(
            reward=reward,
            discount=discount,
            prior_logits=policy_logits,
            value=value_scalar,
        ),
        next_state,
    )


def act(state, model, env, support_size, rng_key, num_simulations=500, max_depth=20):
    batched_state = state[None]
    root = root_fn(env, batched_state, model, support_size, rng_key)
    params = (env, model, support_size)
    out = mctx.gumbel_muzero_policy(
        params=params,
        rng_key=rng_key,
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=num_simulations,
        max_depth=max_depth,
    )
    return out.action[0], out.action_weights[0], root.value[0]

In [4]:
def scan_step(env, model, support_size, carry):
    state, key = carry
    key, subkey = jr.split(key)
    action, action_probs, value = act(state, model, env, support_size, subkey)
    key, subkey = jr.split(key)
    next_state, reward, done = env.step(state, action, subkey)
    frame = env.render(state)
    return (next_state, key), (state, action, reward, done, value, action_probs, frame)


@eqx.filter_jit
def rollout_episode(state, env, model, support_size, rng_key, max_steps=200):
    def body_fn(carry, _):
        return scan_step(env, model, support_size, carry)

    (final_state, _), data = lax.scan(body_fn, (state, rng_key), None, length=max_steps)
    states, actions, rewards, dones, values, policies, frames = data

    return (
        mcts.Transition(
            obs=states,
            action=actions,
            reward=rewards,
            done=dones,
            value=values,
            action_probs=policies,
            returns=jnp.zeros_like(rewards),
            weight=jnp.ones_like(rewards),
        ),
        final_state,
        frames,
    )

In [None]:
key = jr.PRNGKey(0)
env = mcts.Pong()
state = env.reset(key)
buffer = mcts.Buffer()


support_size = 2
num_episodes = 16
num_warmup_episodes = 4
num_episode_steps = 500
num_train_epochs = 50

model_key, key = jr.split(key)
model = Policy(
    input_shape=env.observation_shape,
    num_actions=env.action_shape,
    support_size=support_size,
    key=model_key,
)
optim = optax.adam(3e-4)
opt_state = optim.init(eqx.filter(model, eqx.is_array))


for _ in range(num_warmup_episodes):
    key, subkey = jr.split(key)
    trajectory, state, _ = rollout_episode(
        state, env, model, support_size, subkey, num_episode_steps
    )
    trajectory = mcts.compute_returns(trajectory)
    buffer.add(trajectory, jnp.mean(trajectory.weight))

for ep in range(num_episodes):
    key, subkey = jr.split(key)
    trajectory, state, frames = rollout_episode(
        state, env, model, support_size, subkey, num_episode_steps
    )
    trajectory = mcts.compute_returns(trajectory)
    buffer.add(trajectory, jnp.mean(trajectory.weight))

    with mp.set_show_save_dir("."):
        mp.show_videos({"video": frames}, fps=30)

    for _ in range(num_train_epochs):
        key, subkey = jr.split(key)
        batch = buffer.sample(subkey, batch_size=32, steps=10)
        key, subkey = jr.split(key)
        model, opt_state, metrics = train_step(
            model, batch, env, optim, opt_state, support_size, subkey
        )

0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.
