In [1]:
import jax
import jax.numpy as jnp

In [6]:
terminated = jnp.logical_or(
    1 < 2,
    3 < 4,
)
terminated = jnp.where(
    terminated, jnp.ones(1), jnp.zeros(1)
).astype(float)
truncated = jnp.where(
    5 >= 4, 1 - terminated, jnp.zeros_like(terminated)
)

print(terminated.shape)
print(terminated)
print(truncated.shape)
print(truncated)

(1,)
[1.]
(1,)
[0.]


In [10]:
jnp.ones(1) * 4

Array([4.], dtype=float32)

In [2]:
64*128*2

16384

In [5]:
import jax
import jax.numpy as jnp

from envs import make_env, Transition, MCTSTransition, has_discrete_action_space, is_atari_env
# from envs.brax_v1_wrappers import wrap_for_training
from envs.brax_wrappers import EvalWrapper, wrap_for_training
from networks.policy import Policy, ForwardPass
from networks.networks import FeedForwardNetwork, ActivationFn, make_policy_network, make_value_network, make_atari_feature_extractor
from networks.distributions import NormalTanhDistribution, ParametricDistribution, PolicyNormalDistribution, DiscreteDistribution
import replay_buffers
import running_statistics
from gymnax import gymnax
from gymnax.gymnax.wrappers.brax import GymnaxToBraxWrapper, State
import mctx

from functools import partial

is_atari = is_atari_env('CartPole-v1')
environment, env_params = gymnax.make('CartPole-v1')
discrete_action_space = has_discrete_action_space(environment, env_params)
if not discrete_action_space:
    raise NotImplementedError('Currently only discrete action spaces are supported.')
environment = GymnaxToBraxWrapper(environment)

env = wrap_for_training(
    environment,
    episode_length=500,
    action_repeat=1,
)
key = jax.random.PRNGKey(42)
key_envs, key = jax.random.split(key, 2)
reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_envs, 8 // 1)
key_envs = jnp.reshape(key_envs,
                        (1, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)

action_size = env.action_size()

if is_atari:
    observation_shape = env_state.obs.shape[-3:]
else:
    observation_shape = env_state.obs.shape[-1:]

dummy_obs = jnp.zeros(observation_shape,)
dummy_action = jnp.zeros((action_size,))
dummy_transition = MCTSTransition(  # pytype: disable=wrong-arg-types  # jax-ndarray
    observation=dummy_obs,
    action=dummy_action,
    reward=0.,
    discount=0.,
    next_observation=dummy_obs,
    target_policy_probs=jnp.zeros((action_size,)),
    target_value=0.,
    extras={
        'state_extras': {
            'truncation': 0.
        },
        'policy_extras': {
            'prior_log_prob': dummy_action,
            'raw_action': dummy_action
        }
    })



  KeyArray = Union[jax.Array, jax.random.KeyArray]  # pylint: disable=invalid-name
  PRNGKey = jax.random.KeyArray
  init_key: Optional[random.KeyArray] = None,
  init_key: Optional[random.KeyArray] = None,
  init_key: Optional[random.KeyArray] = None,
  KeyArray = Union[jax.Array, jax.random.KeyArray]
  from tensorflow.tsl.python.lib.core import pywrap_ml_dtypes


In [6]:
dummy_flatten, _unflatten_fn = jax.flatten_util.ravel_pytree(
        dummy_transition
    )

print(dummy_transition)
print(dummy_flatten.shape)
print(_unflatten_fn(dummy_flatten))

MCTSTransition(observation=Array([0., 0., 0., 0.], dtype=float32), action=Array([0., 0.], dtype=float32), reward=0.0, discount=0.0, next_observation=Array([0., 0., 0., 0.], dtype=float32), target_policy_probs=Array([0., 0.], dtype=float32), target_value=0.0, extras={'state_extras': {'truncation': 0.0}, 'policy_extras': {'prior_log_prob': Array([0., 0.], dtype=float32), 'raw_action': Array([0., 0.], dtype=float32)}})
(20,)
MCTSTransition(observation=Array([0., 0., 0., 0.], dtype=float32), action=Array([0., 0.], dtype=float32), reward=Array(0., dtype=float32), discount=Array(0., dtype=float32), next_observation=Array([0., 0., 0., 0.], dtype=float32), target_policy_probs=Array([0., 0.], dtype=float32), target_value=Array(0., dtype=float32), extras={'policy_extras': {'prior_log_prob': Array([0., 0.], dtype=float32), 'raw_action': Array([0., 0.], dtype=float32)}, 'state_extras': {'truncation': Array(0., dtype=float32)}})


In [None]:
key, logits_rng, search_rng = jax.random.split(key, 3)

# logits at root produced by the prior policy 
def forward()
prior_logits, value = forward(env_state.obs)

use_mixed_value = False

# NOTE: For AlphaZero embedding is env_state, for MuZero
# the root output would be the output of MuZero representation network.
root = mctx.RootFnOutput(
    prior_logits=prior_logits,
    value=value,
    # The embedding is used only to implement the MuZero model.
    embedding=env_state, 
)

# The recurrent_fn is provided by MuZero dynamics network.
# Or true environment for AlphaZero
# TODO MCTS: pass in dynamics function for MuZero
def recurrent_fn(params, rng_key, action, embedding):
    # environment (model)
    env_state = embedding
    nstate = env.step(env_state, action)

    # policy & value networks
    prior_logits, value = forward(env_state.obs)

    # Create the new MCTS node.
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=nstate.reward,
        # discount when terminal state reached
        discount=1 - nstate.done,
        # prior for the new state
        prior_logits=prior_logits,
        # value for the new state
        value=value,
    )

    # Return the new node and the new environment.
    return recurrent_fn_output, nstate

# Running the search.
policy_output = mctx.gumbel_muzero_policy(
    params=(),
    rng_key=search_rng,
    root=root,
    recurrent_fn=recurrent_fn,
    num_simulations=30,
    max_num_considered_actions=16,
    qtransform=partial(
        mctx.qtransform_completed_by_mix_value,
        use_mixed_value=use_mixed_value),
)

actions = policy_output.action
action_weights = policy_output.action_weights
best_actions = jnp.argmax(action_weights, axis=-1).astype(jnp.int32)
actions = jax.lax.select(deterministic_actions, best_actions, actions)

search_value = policy_output.search_tree.summary().value

policy_extras = {
    'prior_log_prob': tfd.Categorical(logits=prior_logits).log_prob(actions),
    'raw_action': actions
}

nstate = env.step(env_state, actions)
state_extras = {x: nstate.info[x] for x in extra_fields}
return nstate, MCTSTransition(  # pytype: disable=wrong-arg-types  # jax-ndarray
    observation=env_state.obs,
    action=actions,
    reward=nstate.reward,
    discount=1 - nstate.done,
    next_observation=nstate.obs,
    target_policy_probs=action_weights,
    target_value=search_value,
    extras={
        'policy_extras': policy_extras, 
        'state_extras': state_extras
    })
