In [1]:
from functools import partial
import mediapy as mp

import jax
from jax import numpy as jnp
from jax import random as jr
from jax import lax

import mctx
import optax
import equinox as eqx

import mcts

In [None]:
class Model(eqx.Module):
    value: eqx.nn.MLP
    policy: eqx.nn.MLP

    def __init__(self, input_shape, num_actions, key):
        in_features = input_shape[0] * input_shape[1]
        k1, k2 = jr.split(key)
        self.value = eqx.nn.MLP(
            in_size=in_features, out_size=1, width_size=64, depth=2, key=k1
        )
        self.policy = eqx.nn.MLP(
            in_size=in_features, out_size=num_actions, width_size=64, depth=2, key=k2
        )

    def __call__(self, obs):
        flat = jnp.reshape(obs, (-1,))
        v = self.value(flat).squeeze()
        logits = self.policy(flat)
        return v, logits


def root_fn(env, obs, model, rng_key):
    v, logits = jax.vmap(model)(obs)
    return mctx.RootFnOutput(
        prior_logits=logits,
        value=v,
        embedding=obs,
    )


def recurrent_fn(params, rng_key, action, old_state):
    env, model = params
    rng_key, subkey = jr.split(rng_key)
    next_state, reward, done = env.step(old_state, action.astype(jnp.int32), subkey)
    discount = jnp.where(done, 0.0, 1.0)
    v, logits = model(next_state)
    return (
        mctx.RecurrentFnOutput(
            reward=reward,
            discount=discount,
            prior_logits=logits,
            value=v,
        ),
        next_state,
    )


def act(state, model, env, rng_key, num_simulations=500, max_depth=20):
    rng_key, subkey = jr.split(rng_key)
    root = root_fn(env, state[None], model, subkey)
    rng_key, subkey = jr.split(rng_key)
    out = mctx.gumbel_muzero_policy(
        params=(env, model),
        rng_key=subkey,
        root=root,
        recurrent_fn=jax.vmap(recurrent_fn, in_axes=(None, None, 0, 0)),
        num_simulations=num_simulations,
        max_depth=max_depth,
    )
    return out.action[0], out.action_weights, root.value


def loss_fn(model, batch, env, rng_key):
    B, L = batch.action.shape
    total_loss = 0.0
    obs = batch.obs[:, 0]

    def body(carry, i):
        loss, obs, rng = carry
        v_pred, logits_pred = jax.vmap(model)(obs)
        v_t = batch.returns[:, i]
        pi_t = batch.action_probs[:, i]
        lv = jnp.mean((v_pred - v_t) ** 2)
        lpi = jnp.mean(optax.softmax_cross_entropy(logits_pred, pi_t))
        rng, sk = jr.split(rng)
        sks = jr.split(sk, B)
        a = batch.action[:, i]
        no, rew, dn = jax.vmap(env.step)(obs, a, sks)
        return (loss + lv + lpi, no, rng), None

    (loss, _, _), _ = lax.scan(body, (0.0, obs, rng_key), jnp.arange(L))
    l2 = 0.5 * sum(
        jnp.sum(jnp.square(p))
        for p in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array))
    )
    return loss + 1e-4 * l2


@eqx.filter_jit
def train_step(model, batch, env, optim, opt_state, rng_key):
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model, batch, env, rng_key)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

In [3]:
def scan_step(env, model, carry, _):
    state, key = carry
    key, subkey = jr.split(key)
    a, p, v = act(state, model, env, subkey)
    key, subkey = jr.split(key)
    ns, r, dn = env.step(state, a, subkey)
    f = env.render(state)
    return (ns, key), (state, a, r, dn, v, p, f)


@eqx.filter_jit
def rollout_episode(state, env, model, rng_key, max_steps=200):
    def body_fn(carry, _):
        return scan_step(env, model, carry, None)

    (final_state, _), data = jax.lax.scan(
        body_fn, (state, rng_key), None, length=max_steps
    )

    states, acts, rews, dones, vals, probs, frames = data
    num_steps = states.shape[0]
    return mcts.Transition(
        obs=states,
        action=acts,
        reward=rews,
        done=dones,
        value=vals.squeeze(),
        action_probs=probs.squeeze(),
        returns=jnp.zeros((num_steps,)),
        weight=jnp.ones((num_steps,)),
    ), final_state, frames

In [4]:
key = jr.PRNGKey(0)
env = mcts.Pong()
state = env.reset(key)
buffer = mcts.Buffer()

model_key, key = jr.split(key)
model = Model(env.observation_shape, env.action_shape, model_key)

optim = optax.adam(3e-4)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

num_episodes = 16
num_warmup_episodes = 4
num_train_epochs = 50

for ep in range(num_warmup_episodes):
    key, subkey = jr.split(key)
    trans, state, _ = rollout_episode(state, env, model, subkey, max_steps=500)
    trans = mcts.compute_returns(trans, n=10, gamma=0.99, alpha=0.5)
    buffer.add(trans)


for ep in range(num_episodes):
    key, subkey = jr.split(key)
    trans, state, frames = rollout_episode(state, env, model, subkey, max_steps=500)
    trans = mcts.compute_returns(trans, n=10, gamma=0.99, alpha=0.5)
    buffer.add(trans)

    with mp.set_show_save_dir("./data"):
        mp.show_videos({"video": frames}, fps=30)

    for epoch in range(num_train_epochs):
        key, subkey = jr.split(key)
        batch = buffer.sample(subkey, batch_size=32, steps=10)

        key, subkey = jr.split(key)
        model, opt_state, loss = train_step(model, batch, env, optim, opt_state, subkey)

0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.


0
video  This browser does not support the video tag.
