In [1]:
import jax
import jax.numpy as jnp

In [145]:
num_samples = 10000

sample_dist_logits = jnp.array([[1,2,3,4], [3,3,-50,1], [1,1,1,1]], dtype=float)
empirical_logits = jnp.zeros_like(sample_dist_logits)

sample_dist_logits = jnp.expand_dims(sample_dist_logits, axis=1)
batch_size = empirical_logits.shape[0]
samples = jax.random.categorical(jax.random.PRNGKey(14), sample_dist_logits, shape=(batch_size, num_samples))

# empirical_logits = empirical_logits.at[jnp.arange(batch_size), samples].add(1.)
# empirical_logits = empirical_logits.at[samples].add(1.)
# empirical_logits = empirical_logits.at[batch_size, samples].add(1.)

def update(x, *indices):
  return x.at[indices].add(1.)

batch_update = jax.vmap(update)
empirical_logits = batch_update(empirical_logits, samples) / num_samples

In [146]:
print(empirical_logits)
print(jax.nn.softmax(sample_dist_logits))

[[0.0343 0.0847 0.2402 0.6408]
 [0.465  0.4692 0.     0.0658]
 [0.2522 0.2484 0.2537 0.2457]]
[[[3.2058604e-02 8.7144323e-02 2.3688284e-01 6.4391428e-01]]

 [[4.6831053e-01 4.6831053e-01 4.4970365e-24 6.3378938e-02]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]


In [149]:
def _apply_temperature(logits, temperature):
  """Returns `logits / temperature`, supporting also temperature=0."""
  # The max subtraction prevents +inf after dividing by a small temperature.
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  tiny = jnp.finfo(logits.dtype).tiny
  return logits / jnp.maximum(tiny, temperature)

temp = 0.1
print(jax.nn.softmax(sample_dist_logits / temp))
print(jax.nn.softmax(_apply_temperature(sample_dist_logits, temp)))

[[[9.3571980e-14 2.0610600e-09 4.5397868e-05 9.9995458e-01]]

 [[5.0000000e-01 5.0000000e-01 0.0000000e+00 1.0305768e-09]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]
[[[9.3571980e-14 2.0610600e-09 4.5397868e-05 9.9995458e-01]]

 [[5.0000000e-01 5.0000000e-01 0.0000000e+00 1.0305768e-09]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]


: 

In [5]:
import jax.numpy as jnp
import jax
a = jnp.array([[[1],[2],[3]],[[4],[5],[6]]])
a[jnp.arange(a.shape[0]), jnp.array([1,2]), :]

Array([[2],
       [6]], dtype=int32)

In [4]:
import jax
import jax.numpy as jnp

def compute_gae(truncation: jnp.ndarray,
                termination: jnp.ndarray,
                rewards: jnp.ndarray,
                values: jnp.ndarray,
                bootstrap_value: jnp.ndarray,
                lambda_: float = 1.0,
                discount: float = 0.99):
    """Calculates the Generalized Advantage Estimation (GAE).

    Args:
        truncation: A float32 tensor of shape [T, B] with truncation signal.
        termination: A float32 tensor of shape [T, B] with termination signal.
        rewards: A float32 tensor of shape [T, B] containing rewards generated by
        following the behaviour policy.
        values: A float32 tensor of shape [T, B] with the value function estimates
        wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function estimate at
        time T.
        lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to
        lambda_=1.
        discount: TD discount.

    Returns:
        A float32 tensor of shape [T, B]. Can be used as target to
        train a baseline (V(x_t) - vs_t)^2.
        A float32 tensor of shape [T, B] of advantages.
    """

    truncation_mask = 1 - truncation
    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = jnp.concatenate(
        [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
    deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values
    deltas *= truncation_mask

    acc = jnp.zeros_like(bootstrap_value)
    vs_minus_v_xs = []

    def compute_vs_minus_v_xs(carry, target_t):
        lambda_, acc = carry
        truncation_mask, delta, termination = target_t
        acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc
        return (lambda_, acc), (acc)

    (_, _), (vs_minus_v_xs) = jax.lax.scan(
        compute_vs_minus_v_xs, (lambda_, acc),
        (truncation_mask, deltas, termination),
        length=int(truncation_mask.shape[0]),
        reverse=True)
    # Add V(x_s) to get v_s.
    vs = jnp.add(vs_minus_v_xs, values)

    vs_t_plus_1 = jnp.concatenate(
        [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
    advantages = (rewards + discount *
                    (1 - termination) * vs_t_plus_1 - values) * truncation_mask
    return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages)

truncation = jnp.array([[0],[0],[0],[0],[0],[1],[0],[0],[0],[0],], dtype=float)
termination = jnp.array([[0],[1],[0],[0],[0],[0],[0],[1],[0],[0],], dtype=float)
rewards = jnp.array([[1],[1],[1],[1],[1],[1],[1],[1],[1],[1],], dtype=float)
values = jnp.array([[10],[15],[20],[25],[30],[35],[40],[45],[50],[55],], dtype=float)
bootstrap_value = jnp.array([100], dtype=float)
lambda_ = 1.
discount = 1.

vs, adv = compute_gae(truncation,termination,rewards,values,bootstrap_value,lambda_,discount)
vs

Array([[  2.],
       [  1.],
       [ 38.],
       [ 37.],
       [ 36.],
       [ 35.],
       [  2.],
       [  1.],
       [102.],
       [101.]], dtype=float32)

In [67]:
import jax
import jax.numpy as jnp
from envs.brax_wrappers import EvalWrapper, wrap_for_training, EpisodeWrapper, AutoResetWrapper
from gymnax import gymnax
from gymnax.gymnax.wrappers.brax import GymnaxToBraxWrapper


local_key = jax.random.PRNGKey(42)
local_key = jax.random.fold_in(local_key, 12)
local_key, rb_key, key_envs, eval_key = jax.random.split(local_key, 4)
environment, env_params = gymnax.make('CartPole-v1')
environment = GymnaxToBraxWrapper(environment)

env = wrap_for_training(
    environment,
    episode_length=10,
    action_repeat=1,
)
env = EpisodeWrapper(environment, 10, 1)
# env = VmapWrapper(env)
env = AutoResetWrapper(env)

reset_fn = jax.jit(env.reset)
# key_envs = jax.random.split(key_envs, 4)
# key_envs = jnp.reshape(key_envs,
#                         (1, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)


reset


In [124]:
if 'steps' in env_state.info:
    print('first_obs')
env_state = env.step(env_state, jnp.array(0))

first_obs


In [37]:
import jax.numpy as jnp

c1 = 50
c2 = 19652

visit_counts = jnp.array([0,50])
node_visit = jnp.sum(visit_counts)
pb_c = c1 + jnp.log((node_visit + c2 + 1.) / c2)
prior_probs = jnp.array([0.5, 0.5])
policy_score = jnp.sqrt(node_visit) * pb_c * prior_probs / (visit_counts + 1)

value_score = jnp.array([0, 90])

print(policy_score)
print(value_score)
print(policy_score + value_score)

[176.78586     3.4663894]
[ 0 90]
[176.78586  93.46639]


In [40]:
import jax.numpy as jnp

a = jnp.array([0,1,2,3,0,0,5,5]).astype(float)

print(a)
print(jnp.where(a, a, -jnp.inf))
print(jnp.where(a != 0, a, -jnp.inf))

[0. 1. 2. 3. 0. 0. 5. 5.]
[-inf   1.   2.   3. -inf -inf   5.   5.]
[-inf   1.   2.   3. -inf -inf   5.   5.]


In [None]:
import jax
import jax.numpy as jnp
import flax
import optax
from typing import Any
from functools import partial

@flax.struct.dataclass
class TrainingState:
    params: Any
    optimizer_state: Any


class MLP(flax.linen.Module):
    def __init__(self, hidden_layer_sizes):
        super.__init__()
        self.hidden_layer_sizes = hidden_layer_sizes

    @flax.linen.compact
    def __call__(self, x):
        for i in self.hidden_layer_sizes:
            x = flax.linen.Dense(i)(x)
            x = flax.linen.relu(x)
    
        return flax.linen.Dense(1)(x)
    
mlp = MLP((16,16))
dummy_observation = jnp.zeros((1,) + data.x.take(1).shape)
params = mlp.init(jax.random.PRNGKey(12), dummy_observation)

optimizer = optax.adam(1e-4)
optimizer_state = optimizer.init(params)

training_state = TrainingState(params=params, optimizer_state=optimizer_state)

def loss(params, data, key, network):
    y = network.apply(params, data.x)
    loss = jnp.mean((y-data.label)**2)
    return loss, loss

loss_fn = partial(loss, network=mlp)

def update_fn(loss_fn, optimizer):
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    def f(*args, optimizer_state):
        values, grads = grad_fn(*args)
        params_update, optimizer_state = optimizer.apply(grads, optimizer_state)
        params = optax.apply_updates(params, params_update)
        return values, params, optimizer_state
    
    return f

gradient_update_fn = update_fn(loss_fn, optimizer)

def training_step(carry, data):
    optimizer_state, params, key = carry
    key, key_loss = jax.random.split(key)
    (loss, metrics), params, optimizer_state = gradient_update_fn(params, data, key_loss, optimizer_state=optimizer_state)
    return (optimizer_state, params, key), metrics

def training_epoch(training_state, data, key, num_minibatches):
    key, key_perm, key_grad = jax.random.split(key, 3)

    def convert_data(x):
        x = jax.random.permutation(key_perm, x) # TODO unnecessary: data already randomly sampled from buffer
        x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
        return x

    shuffled_data = jax.tree_util.tree_map(convert_data, data)
    (optimizer_state, params, _), metrics = jax.lax.scan(
        training_step, 
        (training_state.optimizer_state, training_state.params, key_grad),
        shuffled_data,
        length=num_minibatches)
    new_training_state = TrainingState(params=params, optimizer_state=optimizer_state)
    metrics = jax.tree_util.tree_map(jnp.mean, metrics)
    return new_training_state, metrics, key

training_epoch = jax.jit(partial(training_epoch, num_minibatches=32))

key = jax.random.PRNGKey(123)
for i in range(10):
    training_state, metrics, key = training_epoch(training_state, data, key)
