In [None]:
!pip install muax

In [None]:
import jax 
jax.config.update('jax_platform_name', 'cpu')

import gymnasium as gym 

# 1. Use `muax.fit` to fit CartPole-v1

In [1]:
import muax
from muax import nn 

`muax` provides example `representation`, `prediction` and `dynamic` modules

In [2]:
support_size = 10 
embedding_size = 8
discount = 0.997
num_actions = 2
full_support_size = int(support_size * 2 + 1)

repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, num_actions, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, num_actions, full_support_size)

Alternatively, you can use your customized models

In [3]:
import haiku as hk

class Representation(hk.Module):
  def __init__(self, embedding_dim, name='representation'):
    super().__init__(name=name)

    self.repr_func = hk.Sequential([
        hk.Linear(embedding_dim), 
        # jax.nn.elu,
    ])

  def __call__(self, obs):
    s = self.repr_func(obs)
    s = nn.min_max_normalize(s)
    return s 


class Prediction(hk.Module):
  def __init__(self, num_actions, full_support_size, name='prediction'):
    super().__init__(name=name)        
    
    self.v_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.pi_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(num_actions)
    ])
  
  def __call__(self, s):
    v = self.v_func(s)
    logits = self.pi_func(s)
    # logits = jax.nn.softmax(logits, axis=-1)
    return v, logits


class Dynamic(hk.Module):
  def __init__(self, embedding_dim, num_actions, full_support_size, name='dynamic'):
    super().__init__(name=name)
    
    self.ns_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(embedding_dim)
    ])
    self.r_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.cat_func = jax.jit(lambda s, a: 
                            jnp.concatenate([s, jax.nn.one_hot(a, num_actions)],
                                            axis=1)
                            )
  
  def __call__(self, s, a):
    sa = self.cat_func(s, a)
    r = self.r_func(sa)
    ns = self.ns_func(sa)
    ns = nn.min_max_normalize(ns)
    return r, ns


def init_representation_func(representation_module, embedding_dim):
    def representation_func(obs):
      repr_model = representation_module(embedding_dim)
      return repr_model(obs)
    return representation_func
  
def init_prediction_func(prediction_module, num_actions, full_support_size):
  def prediction_func(s):
    pred_model = prediction_module(num_actions, full_support_size)
    return pred_model(s)
  return prediction_func

def init_dynamic_func(dynamic_module, embedding_dim, num_actions, full_support_size):
  def dynamic_func(s, a):
    dy_model = dynamic_module(embedding_dim, num_actions, full_support_size)
    return dy_model(s, a)
  return dynamic_func 

`muax` has `Episode tracer` and `replay buffuer` to track and store trajectories from interacting with environments

In [22]:
from muax.frameworks.coax.episode_tracer import PNStep
from muax.frameworks.coax.replay_buffer import TrajectoryReplayBuffer

tracer = PNStep(10, discount, 0.5)
buffer = TrajectoryReplayBuffer(500)

`muax` leverages `optax` to update weights

In [23]:
from muax.frameworks.coax import model 
gradient_transform = model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.0001, warmup_steps=15000, transition_steps=15000)
# gradient_transform = optax.adam(0.02)

In [10]:

i = 1
support_size = 20
embedding_size = 10
full_support_size = int(support_size * 2 + 1)
num_actions = 2

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

# tracer = muax.PNStep(50, 0.999, 0.5)
# buffer = muax.TrajectoryReplayBuffer(500)

# gradient_transform = muax.model.optimizer(init_value=0.002, peak_value=0.002, end_value=0.0005, warmup_steps=20000, transition_steps=20000)
from muax.frameworks.coax.model import MuZero
model = MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=0.999, support_size=support_size)


In [None]:
from muax.frameworks.coax.train import fit 

model_path = fit(model, 'CartPole-v1', 
                max_episodes=1000,
                max_training_steps=50000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                buffer_warm_up=128,
                num_trajectory=128,
                tensorboard_dir='/content/data/tensorboard/',
                save_name='model_params',
                random_seed=i,
                log_all_metrics=True)

In [None]:
model_path

In [None]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

In [None]:
from muax.test import test
env_id = 'CartPole-v1'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

In [None]:
%load_ext tensorboard 
%tensorboard --logdir=tensorboard/cartpole

# 2. Customize

## 2.1 Customize the neural networks

`muax` uses `haiku` to implement the neural networks. A tutorial for using `haiku` can be found at (link). 

In [None]:
import haiku as hk

class Representation(hk.Module):
  def __init__(self, embedding_dim, name='representation'):
    super().__init__(name=name)

    self.repr_func = hk.Sequential([
        hk.Linear(embedding_dim), 
        # jax.nn.elu,
    ])

  def __call__(self, obs):
    s = self.repr_func(obs)
    s = nn.min_max_normalize(s)
    return s 


class Prediction(hk.Module):
  def __init__(self, num_actions, full_support_size, name='prediction'):
    super().__init__(name=name)        
    
    self.v_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.pi_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(num_actions)
    ])
  
  def __call__(self, s):
    v = self.v_func(s)
    logits = self.pi_func(s)
    # logits = jax.nn.softmax(logits, axis=-1)
    return v, logits


class Dynamic(hk.Module):
  def __init__(self, embedding_dim, num_actions, full_support_size, name='dynamic'):
    super().__init__(name=name)
    
    self.ns_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(embedding_dim)
    ])
    self.r_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.cat_func = jax.jit(lambda s, a: 
                            jnp.concatenate([s, jax.nn.one_hot(a, num_actions)],
                                            axis=1)
                            )
  
  def __call__(self, s, a):
    sa = self.cat_func(s, a)
    r = self.r_func(sa)
    ns = self.ns_func(sa)
    ns = nn.min_max_normalize(ns)
    return r, ns


def init_representation_func(representation_module, embedding_dim):
    def representation_func(obs):
      repr_model = representation_module(embedding_dim)
      return repr_model(obs)
    return representation_func
  
def init_prediction_func(prediction_module, num_actions, full_support_size):
  def prediction_func(s):
    pred_model = prediction_module(num_actions, full_support_size)
    return pred_model(s)
  return prediction_func

def init_dynamic_func(dynamic_module, embedding_dim, num_actions, full_support_size):
  def dynamic_func(s, a):
    dy_model = dynamic_module(embedding_dim, num_actions, full_support_size)
    return dy_model(s, a)
  return dynamic_func 

In [None]:
support_size = 10 
embedding_size = 8
full_support_size = int(support_size * 2 + 1)
repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, 2, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, 2, full_support_size)

## 2.2 Customize the training loop

inside the `muax.fit` function, the main structure is a typical RL interacting loop. Reset the env, agent takes an action based on the observation, updated current state until done

In [None]:
import numpy as np 
from muax import Trajectory

def temperature_fn(max_training_steps, training_steps):
  if training_steps < 0.5 * max_training_steps:
      return 1.0
  elif training_steps < 0.75 * max_training_steps:
      return 0.5
  else:
      return 0.25
  
def test(model, env, key, num_simulations, num_test_episodes=10, random_seed=None):
    total_rewards = np.zeros(num_test_episodes)
    for episode in range(num_test_episodes):
        obs, info = env.reset(seed=random_seed)
        done = False
        episode_reward = 0
        for t in range(env.spec.max_episode_steps):
            key, subkey = jax.random.split(key)
            a = model.act(subkey, obs, 
                          with_pi=False, 
                          with_value=False, 
                          obs_from_batch=False,
                          num_simulations=num_simulations,
                          temperature=0.) # Use deterministic actions during testing
            obs_next, r, done, truncated, info = env.step(a)
            episode_reward += r
            if done or truncated:
                break 
            obs = obs_next 
        
        total_rewards[episode] = episode_reward

    average_test_reward = np.mean(total_rewards)
    return average_test_reward  

In [None]:
from jax import numpy as jnp 

random_seed = 0
discount = 0.997
buffer_warm_up = 64
max_training_steps = 10000
max_episodes = 1000
num_simulations = 50
num_test_episodes = 10
num_trajectory = 32
sample_per_trajectory = 1
k_steps = 10

gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)
model = muax.MuZero(repr_fn, pred_fn, dy_fn, optimizer=gradient_transform, discount=discount)

env_id = 'CartPole-v1'
env = gym.make(env_id, render_mode='rgb_array')
test_env = gym.make(env_id, render_mode='rgb_array')

sample_input = jnp.expand_dims(jnp.zeros(env.observation_space.shape), axis=0)
key = jax.random.PRNGKey(random_seed)
key, test_key, subkey = jax.random.split(key, num=3)
model.init(subkey, sample_input) 

training_step = 0
best_test_G = -float('inf')
max_training_steps_reached = False

# buffer warm up
print('buffer warm up stage...')
while len(buffer) < buffer_warm_up:
  obs, info = env.reset()    
  trajectory = Trajectory()
  temperature = temperature_fn(max_training_steps=max_training_steps, training_steps=training_step)
  for t in range(env.spec.max_episode_steps):
    key, subkey = jax.random.split(key)
    a, pi, v = model.act(subkey, obs, 
                          with_pi=True, 
                          with_value=True, 
                          obs_from_batch=False,
                          num_simulations=num_simulations,
                          temperature=temperature)
    obs_next, r, done, truncated, info = env.step(a)
    # if truncated:
    #   r = 1 / (1 - tracer.gamma)
    tracer.add(obs, a, r, done or truncated, v=v, pi=pi)
    while tracer:
      trans = tracer.pop()
      trajectory.add(trans)
    if done or truncated:
      break 
    obs = obs_next 
  trajectory.finalize()
  if len(trajectory) >= k_steps:
    buffer.add(trajectory, trajectory.batched_transitions.w.mean())
  
print('start training...')
  
for ep in range(max_episodes):
  obs, info = env.reset(seed=random_seed)    
  trajectory = Trajectory()
  temperature = temperature_fn(max_training_steps=max_training_steps, training_steps=training_step)
  for t in range(env.spec.max_episode_steps):
    key, subkey = jax.random.split(key)
    a, pi, v = model.act(subkey, obs, 
                          with_pi=True, 
                          with_value=True, 
                          obs_from_batch=False,
                          num_simulations=num_simulations,
                          temperature=temperature)
    obs_next, r, done, truncated, info = env.step(a)
    # if truncated:
    #   r = 1 / (1 - tracer.gamma)
    tracer.add(obs, a, r, done or truncated, v=v, pi=pi)
    while tracer:
      trans = tracer.pop()
      trajectory.add(trans)
    if done or truncated:
      break 
    obs = obs_next 
  trajectory.finalize()
  if len(trajectory) >= k_steps:
    buffer.add(trajectory, trajectory.batched_transitions.w.mean())
  
  if max_training_steps_reached:
    break
  train_loss = 0
  for _ in range(50):
    transition_batch = buffer.sample(num_trajectory=num_trajectory,
                                      sample_per_trajectory=sample_per_trajectory,
                                      k_steps=k_steps)
    loss_metric = model.update(transition_batch)
    train_loss += loss_metric['loss']
    training_step += 1
    if training_step >= max_training_steps:
      max_training_steps_reached = True
      break
  train_loss /= 50
  print(f'epoch: {ep:04d}, loss: {train_loss:.8f}, training_step: {training_step}')

  test_G = test(model, test_env, test_key, num_simulations=num_simulations, num_test_episodes=num_test_episodes)      
  print(f'epoch: {ep:04d}, test_G: {test_G:.8f}')
  if test_G >= best_test_G:
    best_test_G = test_G

print(f'Best total reward in test: {best_test_G}')