In [1]:
%reload_ext tensorboard 
%tensorboard --logdir=content/data/tensorboard/

In [2]:
import jax 
from jax import numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

In [3]:
import haiku as hk

class Representation(hk.Module):
  def __init__(self, embedding_dim, name='representation'):
    super().__init__(name=name)

    self.repr_func = hk.Sequential([
        hk.Linear(embedding_dim), 
        # jax.nn.elu,
    ])

  def __call__(self, obs):
    s = self.repr_func(obs)
    return s 


class Prediction(hk.Module):
  def __init__(self, num_actions, full_support_size, name='prediction'):
    super().__init__(name=name)        
    
    self.v_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.pi_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(num_actions)
    ])
  
  def __call__(self, s):
    v = self.v_func(s)
    logits = self.pi_func(s)
    # logits = jax.nn.softmax(logits, axis=-1)
    return v, logits


class Dynamic(hk.Module):
  def __init__(self, embedding_dim, num_actions, full_support_size, name='dynamic'):
    super().__init__(name=name)
    
    self.ns_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(embedding_dim)
    ])
    self.r_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(64), jax.nn.elu,
        hk.Linear(16), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.cat_func = jax.jit(lambda s, a: 
                            jnp.concatenate([s, jax.nn.one_hot(a, num_actions)],
                                            axis=1)
                            )
  
  def __call__(self, s, a):
    sa = self.cat_func(s, a)
    r = self.r_func(sa)
    ns = self.ns_func(sa)
    return r, ns


def init_representation_func(representation_module, embedding_dim):
    def representation_func(obs):
      repr_model = representation_module(embedding_dim)
      return repr_model(obs)
    return representation_func
  
def init_prediction_func(prediction_module, num_actions, full_support_size):
  def prediction_func(s):
    pred_model = prediction_module(num_actions, full_support_size)
    return pred_model(s)
  return prediction_func

def init_dynamic_func(dynamic_module, embedding_dim, num_actions, full_support_size):
  def dynamic_func(s, a):
    dy_model = dynamic_module(embedding_dim, num_actions, full_support_size)
    return dy_model(s, a)
  return dynamic_func 

In [4]:
from gymnasium.wrappers import TimeLimit
from env_hiv import * 
class Spec :
    def __init__(self):
        self.max_episode_steps = 200

class HIVWrapped(HIVPatient):
    def __init__(self):
        super().__init__()
        self.spec = Spec()
env = TimeLimit(HIVWrapped(),200)

In [5]:
i = 1
support_size = 20
embedding_size = 10
full_support_size = int(support_size * 2 + 1)
num_actions = 4

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)
discount = 0.99
tracer = muax.PNStep(50, discount ,0.5)
buffer = muax.TrajectoryReplayBuffer(500)



gradient_transform = muax.model.optimizer(init_value=0.002, peak_value=0.002, end_value=0.0005, warmup_steps=20000, transition_steps=20000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)
from env_hiv import * 
model_path = muax.fit(model, None, 
                env = env,
                test_env = env,
                max_episodes=200,
                max_training_steps=60000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                buffer_warm_up=32,
                num_trajectory=32,
                tensorboard_dir='content/data/tensorboard/',
                save_name='model_params',
                random_seed=i,
                log_all_metrics=True)

buffer warm up stage...
start training...


  logger.warn(
[TrainMonitor|INFO] ep: 1,	T: 201,	G: 3.46e+06,	avg_r: 1.73e+04,	avg_G: 3.46e+06,	t: 200,	dt: 214.992ms,	v: 140,	Rn: 6.07e+05,	loss: 195,	training_step: 50,	test_G: 4.27e+06
  logger.warn(
[TrainMonitor|INFO] ep: 2,	T: 402,	G: 4.22e+06,	avg_r: 2.11e+04,	avg_G: 3.84e+06,	t: 200,	dt: 19.503ms,	v: 422,	Rn: 7.49e+05,	loss: 24.2,	training_step: 100
[TrainMonitor|INFO] ep: 3,	T: 603,	G: 3.47e+06,	avg_r: 1.74e+04,	avg_G: 3.72e+06,	t: 200,	dt: 19.538ms,	v: 422,	Rn: 6.1e+05,	loss: 1.22,	training_step: 150
[TrainMonitor|INFO] ep: 4,	T: 804,	G: 3.5e+06,	avg_r: 1.75e+04,	avg_G: 3.66e+06,	t: 200,	dt: 19.505ms,	v: 422,	Rn: 6.14e+05,	loss: 0.136,	training_step: 200
[TrainMonitor|INFO] ep: 5,	T: 1,005,	G: 3.44e+06,	avg_r: 1.72e+04,	avg_G: 3.62e+06,	t: 200,	dt: 19.380ms,	v: 422,	Rn: 6.03e+05,	loss: 0.14,	training_step: 250
[TrainMonitor|INFO] ep: 6,	T: 1,206,	G: 3.53e+06,	avg_r: 1.77e+04,	avg_G: 3.61e+06,	t: 200,	dt: 19.439ms,	v: 422,	Rn: 6.22e+05,	loss: 0.126,	training_step: 300
[TrainM

KeyboardInterrupt: 

In [None]:
model_path

'models/2024-02-29_14-26-35/epoch_0190_test_G_-98.04974613/model_params'

In [None]:

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

In [None]:
import gymnasium as gym 
from muax.test import test
env_id = 'LunarLander-v2'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

-101.5249334460554