# Distributed Anakin Agent in `gymnax`
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/01_anakin.ipynb)

Adapted from Hessel et al. (2021) and DeepMind's [Example Colab](https://colab.research.google.com/drive/1974D-qP17fd5mLxy6QZv-ic4yxlPJp-G?usp=sharing#scrollTo=lhnJkrYLOvcs)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/gymnax.git@main
!pip install -q dm-haiku rlax

In [1]:
import chex
import os

# Set number of host devices before importing JAX!
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=4"

import jax
import haiku as hk
from jax import lax
from jax import random
from jax import numpy as jnp
import jax.numpy as jnp
import optax
import rlax
import timeit

jax.devices()

[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0)]

# Import `gymnax` and make `Catch-bsuite` environment transition/reset

In [2]:
import gymnax
from flax.serialization import to_state_dict, from_state_dict
from gymnax.environments.minatar.space_invaders import EnvState

env, env_params = gymnax.make("SpaceInvaders-MinAtar")

# Anakin DQN-Style (No Target Net) Distributed Agent Setup

In [3]:
@chex.dataclass(frozen=True)
class TimeStep:
    q_values: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array

def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""
    def network_fn(obs: chex.Array, rng: chex.PRNGKey) -> chex.Array:
        return hk.Sequential([  # flatten, 2x hidden + relu, output layer.
            hk.Flatten(),
            hk.Linear(256), jax.nn.relu,
            hk.Linear(256), jax.nn.relu,
            hk.Linear(num_outputs)])(obs)
    return hk.without_apply_rng(hk.transform(network_fn))

def get_learner_fn(
    env, forward_pass, opt_update, rollout_len, agent_discount,
    lambda_, iterations):
    """Define the minimal unit of computation in Anakin."""

    def loss_fn(params, outer_rng, env_state):
        """Compute the loss on a single trajectory."""

        def step_fn(env_state, rng):
            obs = env.get_obs(env_state)
            q_values = forward_pass(params, obs[None,], None)[0]  # forward pass.
            action = jnp.argmax(q_values)  # greedy policy.
            obs, env_state, reward, terminal, info = env.step(rng, env_state, action)  # step environment.
            return env_state, TimeStep(  # return env state and transition data.
              q_values=q_values, action=action, discount=1.-terminal, reward=reward)

        step_rngs = random.split(outer_rng, rollout_len)
        env_state, rollout = lax.scan(step_fn, env_state, step_rngs)  # trajectory.
        qa_tm1 = rlax.batched_index(rollout.q_values[:-1], rollout.action[:-1])
        td_error = rlax.td_lambda(  # compute multi-step temporal diff error.
            v_tm1=qa_tm1,  # predictions.
            r_t=rollout.reward[1:],  # rewards.
            discount_t=agent_discount * rollout.discount[1:],  # discount.
            v_t=jnp.max(rollout.q_values[1:], axis=-1),  # bootstrap values.
            lambda_=lambda_)  # mixing hyper-parameter lambda.
        return jnp.mean(td_error**2), env_state

    def update_fn(params, opt_state, rng, env_state):
        """Compute a gradient update from a single trajectory."""
        rng, loss_rng = random.split(rng)
        grads, new_env_state = jax.grad(  # compute gradient on a single trajectory.
            loss_fn, has_aux=True)(params, loss_rng, env_state)
        grads = lax.pmean(grads, axis_name='j')  # reduce mean across cores.
        grads = lax.pmean(grads, axis_name='i')  # reduce mean across batch.
        updates, new_opt_state = opt_update(grads, opt_state)  # transform grads.
        new_params = optax.apply_updates(params, updates)  # update parameters.
        return new_params, new_opt_state, rng, new_env_state

    def learner_fn(params, opt_state, rngs, env_states):
        """Vectorise and repeat the update."""
        batched_update_fn = jax.vmap(update_fn, axis_name='j')  # vectorize across batch.
        def iterate_fn(_, val):  # repeat many times to avoid going back to Python.
            params, opt_state, rngs, env_states = val
            return batched_update_fn(params, opt_state, rngs, env_states)
        return lax.fori_loop(0, iterations, iterate_fn, (
            params, opt_state, rngs, env_states))

    return learner_fn

# Rollout/Step the Anakin Agent in Parallel

In [4]:
class TimeIt():
    def __init__(self, tag, frames=None):
        self.tag = tag
        self.frames = frames

    def __enter__(self):
        self.start = timeit.default_timer()
        return self

    def __exit__(self, *args):
        self.elapsed_secs = timeit.default_timer() - self.start
        msg = self.tag + (': Elapsed time=%.2fs' % self.elapsed_secs)
        if self.frames:
            msg += ', FPS=%.2e' % (self.frames / self.elapsed_secs)
        print(msg)


def run_experiment(env, batch_size, rollout_len, step_size, iterations, seed):
    """Runs experiment."""
    cores_count = len(jax.devices())  # get available TPU cores.
    network = get_network_fn(env.num_actions)  # define network.
    optim = optax.adam(step_size)  # define optimiser.

    rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3)  # prng keys.
    obs, state = env.reset(rng_e)
    dummy_obs = obs[None,]  # dummy for net init.
    params = network.init(rng_p, dummy_obs, None)  # initialise params.
    opt_state = optim.init(params)  # initialise optimiser stats.

    learn = get_learner_fn(  # get batched iterated update.
      env, network.apply, optim.update, rollout_len=rollout_len,
      agent_discount=1, lambda_=0.99, iterations=iterations)
    learn = jax.pmap(learn, axis_name='i')  # replicate over multiple cores.

    broadcast = lambda x: jnp.broadcast_to(x, (cores_count, batch_size) + x.shape)
    params = jax.tree.map(broadcast, params)  # broadcast to cores and batch.
    opt_state = jax.tree.map(broadcast, opt_state)  # broadcast to cores and batch

    rng, *env_rngs = jax.random.split(rng, cores_count * batch_size + 1)
    env_obs, env_states = jax.vmap(env.reset)(jnp.stack(env_rngs))  # init envs.
    rng, *step_rngs = jax.random.split(rng, cores_count * batch_size + 1)

    reshape = lambda x: x.reshape((cores_count, batch_size) + x.shape[1:])
    step_rngs = reshape(jnp.stack(step_rngs))  # add dimension to pmap over.
    env_obs = reshape(env_obs)  # add dimension to pmap over.
    env_states_re = to_state_dict(env_states)
    env_states = {k: reshape(env_states_re[k]) for k in env_states_re.keys()}
    env_states = EnvState(**env_states)
    with TimeIt(tag='COMPILATION'):
        learn(params, opt_state, step_rngs, env_states)  # compiles

    num_frames = cores_count * iterations * rollout_len * batch_size
    with TimeIt(tag='EXECUTION', frames=num_frames):
        params, opt_state, step_rngs, env_states = learn(  # runs compiled fn
            params, opt_state, step_rngs, env_states)
    return params

In [5]:
print('Running on', len(jax.devices()), 'cores.', flush=True)
batch_params = run_experiment(env, 128, 16, 3e-4, 10000, 42)

Running on 4 cores.




COMPILATION: Elapsed time=115.92s
EXECUTION: Elapsed time=106.05s, FPS=7.72e+05


# Performance Evaluation

In [8]:
# Get model ready for evaluation - squeeze broadcasted params
model = get_network_fn(env.num_actions)
squeeze = lambda x: x[0][0]
params = jax.tree.map(squeeze, batch_params)

# Simple single episode rollout for policy
rng = jax.random.PRNGKey(0)

In [9]:
obs, state = env.reset(rng)
cum_ret = 0

for step in range(env_params.max_steps_in_episode):
    rng, key_step = jax.random.split(rng)
    q_values = model.apply(params, obs[None,], None)
    action = jnp.argmax(q_values)
    n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
    cum_ret += reward
    
    if done:
        break
    else:
        state = n_state
        obs = n_obs

cum_ret

DeviceArray(7., dtype=float32)