In [1]:
import jax
import jax.numpy as jnp

In [6]:
terminated = jnp.logical_or(
    1 < 2,
    3 < 4,
)
terminated = jnp.where(
    terminated, jnp.ones(1), jnp.zeros(1)
).astype(float)
truncated = jnp.where(
    5 >= 4, 1 - terminated, jnp.zeros_like(terminated)
)

print(terminated.shape)
print(terminated)
print(truncated.shape)
print(truncated)

(1,)
[1.]
(1,)
[0.]


In [10]:
jnp.ones(1) * 4

Array([4.], dtype=float32)

In [2]:
64*128*2

16384

In [5]:
import jax
import jax.numpy as jnp

from envs import make_env, Transition, MCTSTransition, has_discrete_action_space, is_atari_env
# from envs.brax_v1_wrappers import wrap_for_training
from envs.brax_wrappers import EvalWrapper, wrap_for_training
from networks.policy import Policy, ForwardPass
from networks.networks import FeedForwardNetwork, ActivationFn, make_policy_network, make_value_network, make_atari_feature_extractor
from networks.distributions import NormalTanhDistribution, ParametricDistribution, PolicyNormalDistribution, DiscreteDistribution
import replay_buffers
import running_statistics
from gymnax import gymnax
from gymnax.gymnax.wrappers.brax import GymnaxToBraxWrapper, State
import mctx

from functools import partial

is_atari = is_atari_env('CartPole-v1')
environment, env_params = gymnax.make('CartPole-v1')
discrete_action_space = has_discrete_action_space(environment, env_params)
if not discrete_action_space:
    raise NotImplementedError('Currently only discrete action spaces are supported.')
environment = GymnaxToBraxWrapper(environment)

env = wrap_for_training(
    environment,
    episode_length=500,
    action_repeat=1,
)
key = jax.random.PRNGKey(42)
key_envs, key = jax.random.split(key, 2)
reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_envs, 8 // 1)
key_envs = jnp.reshape(key_envs,
                        (1, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)

action_size = env.action_size()

if is_atari:
    observation_shape = env_state.obs.shape[-3:]
else:
    observation_shape = env_state.obs.shape[-1:]

dummy_obs = jnp.zeros(observation_shape,)
dummy_action = jnp.zeros((action_size,))
dummy_transition = MCTSTransition(  # pytype: disable=wrong-arg-types  # jax-ndarray
    observation=dummy_obs,
    action=dummy_action,
    reward=0.,
    discount=0.,
    next_observation=dummy_obs,
    target_policy_probs=jnp.zeros((action_size,)),
    target_value=0.,
    extras={
        'state_extras': {
            'truncation': 0.
        },
        'policy_extras': {
            'prior_log_prob': dummy_action,
            'raw_action': dummy_action
        }
    })



  KeyArray = Union[jax.Array, jax.random.KeyArray]  # pylint: disable=invalid-name
  PRNGKey = jax.random.KeyArray
  init_key: Optional[random.KeyArray] = None,
  init_key: Optional[random.KeyArray] = None,
  init_key: Optional[random.KeyArray] = None,
  KeyArray = Union[jax.Array, jax.random.KeyArray]
  from tensorflow.tsl.python.lib.core import pywrap_ml_dtypes


In [6]:
dummy_flatten, _unflatten_fn = jax.flatten_util.ravel_pytree(
        dummy_transition
    )

print(dummy_transition)
print(dummy_flatten.shape)
print(_unflatten_fn(dummy_flatten))

MCTSTransition(observation=Array([0., 0., 0., 0.], dtype=float32), action=Array([0., 0.], dtype=float32), reward=0.0, discount=0.0, next_observation=Array([0., 0., 0., 0.], dtype=float32), target_policy_probs=Array([0., 0.], dtype=float32), target_value=0.0, extras={'state_extras': {'truncation': 0.0}, 'policy_extras': {'prior_log_prob': Array([0., 0.], dtype=float32), 'raw_action': Array([0., 0.], dtype=float32)}})
(20,)
MCTSTransition(observation=Array([0., 0., 0., 0.], dtype=float32), action=Array([0., 0.], dtype=float32), reward=Array(0., dtype=float32), discount=Array(0., dtype=float32), next_observation=Array([0., 0., 0., 0.], dtype=float32), target_policy_probs=Array([0., 0.], dtype=float32), target_value=Array(0., dtype=float32), extras={'policy_extras': {'prior_log_prob': Array([0., 0.], dtype=float32), 'raw_action': Array([0., 0.], dtype=float32)}, 'state_extras': {'truncation': Array(0., dtype=float32)}})


In [None]:
key, logits_rng, search_rng = jax.random.split(key, 3)

# logits at root produced by the prior policy 
def forward()
prior_logits, value = forward(env_state.obs)

use_mixed_value = False

# NOTE: For AlphaZero embedding is env_state, for MuZero
# the root output would be the output of MuZero representation network.
root = mctx.RootFnOutput(
    prior_logits=prior_logits,
    value=value,
    # The embedding is used only to implement the MuZero model.
    embedding=env_state, 
)

# The recurrent_fn is provided by MuZero dynamics network.
# Or true environment for AlphaZero
# TODO MCTS: pass in dynamics function for MuZero
def recurrent_fn(params, rng_key, action, embedding):
    # environment (model)
    env_state = embedding
    nstate = env.step(env_state, action)

    # policy & value networks
    prior_logits, value = forward(env_state.obs)

    # Create the new MCTS node.
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=nstate.reward,
        # discount when terminal state reached
        discount=1 - nstate.done,
        # prior for the new state
        prior_logits=prior_logits,
        # value for the new state
        value=value,
    )

    # Return the new node and the new environment.
    return recurrent_fn_output, nstate

# Running the search.
policy_output = mctx.gumbel_muzero_policy(
    params=(),
    rng_key=search_rng,
    root=root,
    recurrent_fn=recurrent_fn,
    num_simulations=30,
    max_num_considered_actions=16,
    qtransform=partial(
        mctx.qtransform_completed_by_mix_value,
        use_mixed_value=use_mixed_value),
)

actions = policy_output.action
action_weights = policy_output.action_weights
best_actions = jnp.argmax(action_weights, axis=-1).astype(jnp.int32)
actions = jax.lax.select(deterministic_actions, best_actions, actions)

search_value = policy_output.search_tree.summary().value

policy_extras = {
    'prior_log_prob': tfd.Categorical(logits=prior_logits).log_prob(actions),
    'raw_action': actions
}

nstate = env.step(env_state, actions)
state_extras = {x: nstate.info[x] for x in extra_fields}
return nstate, MCTSTransition(  # pytype: disable=wrong-arg-types  # jax-ndarray
    observation=env_state.obs,
    action=actions,
    reward=nstate.reward,
    discount=1 - nstate.done,
    next_observation=nstate.obs,
    target_policy_probs=action_weights,
    target_value=search_value,
    extras={
        'policy_extras': policy_extras, 
        'state_extras': state_extras
    })


In [53]:
import jax
import jax.numpy as jnp
import chex

def n_step_bootstrapped_targets(
        rewards: jnp.ndarray,
        discounts: jnp.ndarray,
        termination_discount: jnp.ndarray,
        observations: jnp.ndarray,
        values: jnp.ndarray,
        n: int = 5,
        gamma: float = 1.,
    ) -> jnp.ndarray:
    """Computes n-step bootstrapped return targets over a sequence.

    Args:
        rewards: rewards at times [1, ..., T].
        discounts: discounts at times [1, ..., T].
        termination_discount: discount from termination at times [1, ..., T].
        observations: observation at time [1, ...., T].
        values: values at time [1, ...., T].
        n: number of steps over which to accumulate reward before bootstrapping.

    Returns:
        estimated bootstrapped returns prefixes at times [0, ...., T-1]
        observation to bootstrap from at times [0, ...., T-1]
        discount factor for bootstrap value at times [0, ...., T-1]
    """
    chex.assert_type([rewards, discounts, values], float)
    chex.assert_equal_shape([rewards, discounts, values])
    batch_shape = rewards.shape
    seq_len = batch_shape[0]

    # Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
    pad_size = min(n - 1, seq_len)
    bootstrap_observations = jnp.concatenate([observations[n - 1:], jnp.array([observations[-1]] * pad_size)])
    bootstrap_values = jnp.concatenate([values[n - 1:], jnp.array([values[-1]] * pad_size)])

    # Pad sequences. Shape is now (T + n - 1, ...).
    rewards = jnp.concatenate([rewards, jnp.zeros((n - 1,) + batch_shape[1:])])
    discounts = jnp.concatenate([discounts, jnp.ones((n - 1,) + batch_shape[1:])]) * gamma

    value_prefix_targets = jax.lax.dynamic_slice_in_dim(rewards, n-1, seq_len)
    bootstrap_discounts = jnp.concatenate([termination_discount, jnp.ones((n - 1,) + batch_shape[1:])]) * gamma
    bootstrap_discounts = jax.lax.dynamic_slice_in_dim(bootstrap_discounts, n-1, seq_len)

    def f(carry, unused_t):
        i, value_prefix_targets, bootstrap_discounts = carry
        i -= 1
        r_ = jax.lax.dynamic_slice_in_dim(rewards, i, seq_len)
        discount_ = jax.lax.dynamic_slice_in_dim(discounts, i, seq_len)
        value_prefix_targets = r_ + discount_ * value_prefix_targets
        bootstrap_discounts *= discount_
        return (i, value_prefix_targets, bootstrap_discounts), unused_t

    (_, value_prefix_targets, bootstrap_discounts), _ = jax.lax.scan(
        f, (n-1, value_prefix_targets, bootstrap_discounts),
        (),
        length=n-1)

    return value_prefix_targets, bootstrap_observations, bootstrap_values, bootstrap_discounts

In [37]:
n_step_bootstrapped_returns(
        rewards=jnp.array([1,2,3,4,5,6,7,8,9,10], dtype=float),
        discount_t=jnp.array([1,1,1,1,1,1,0,1,1,1], dtype=float),
        v_t=jnp.array([10,20,30,40,50,60,70,80,90,100], dtype=float),
        n=5,
        discount=0.5,
    )

Array([ 5.125 ,  7.375 ,  7.4375,  8.875 ,  9.75  ,  9.5   ,  7.    ,
       18.125 , 17.125 , 13.125 ], dtype=float32)

In [54]:
value_prefix_targets, bootstrap_observations, bootstrap_values, bootstrap_discounts = n_step_bootstrapped_targets(
        rewards=jnp.array([1,2,3,4,5,6,7,8,9,10], dtype=float),
        discounts=jnp.array([1,1,1,1,1,1,0,1,1,1], dtype=float),
        termination_discount=jnp.array([1,1,1,1,1,1,0,1,1,1], dtype=float),
        observations=jnp.array([10,20,30,40,50,60,70,80,90,100], dtype=float),
        values=jnp.array([10,20,30,40,50,60,70,80,90,100], dtype=float),
        n=5,
        gamma=0.5,
    )

print(value_prefix_targets)
print(bootstrap_observations)
print(bootstrap_values)
print(bootstrap_discounts)
print(bootstrap_discounts*bootstrap_observations + value_prefix_targets)

[ 3.5625  5.5     7.4375  8.875   9.75    9.5     7.     15.     14.
 10.    ]
[ 50.  60.  70.  80.  90. 100. 100. 100. 100. 100.]
[ 50.  60.  70.  80.  90. 100. 100. 100. 100. 100.]
[0.03125 0.03125 0.      0.      0.      0.      0.      0.03125 0.03125
 0.03125]
[ 5.125   7.375   7.4375  8.875   9.75    9.5     7.     18.125  17.125
 13.125 ]


In [50]:
1 + 1 + 3*0.25 + 4*0.125 + 5*0.125*0.5 # + 50*0.125*0.25

3.5625

In [30]:
v_t=jnp.array([[10,120,30,40,50,60,70,80,90,100],
               [20,220,30,40,50,60,70,80,90,100],
               [30,320,30,40,50,60,70,80,90,100],
               [130,4320,30,40,50,60,70,80,90,100],
               [230,6320,30,40,50,60,70,80,90,100],
               [330,7320,30,40,50,60,70,80,90,100],
               [430,8320,30,40,50,60,70,80,90,100],
               [530,9320,30,40,50,60,70,80,90,100],
               [630,11320,30,40,50,60,70,80,90,100],
               [730,12320,30,40,50,60,70,80,90,100],], dtype=float)
jnp.concatenate([v_t[5 - 1:], jnp.array([v_t[-1]] * 4)])

Array([[  230.,  6320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  330.,  7320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  430.,  8320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  530.,  9320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  630., 11320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  730., 12320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  730., 12320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  730., 12320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  730., 12320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.],
       [  730., 12320.,    30.,    40.,    50.,    60.,    70.,    80.,
           90.,   100.]], dtype=float32)

In [33]:
rewards=jnp.array([[10,120,30,40,50,60,70,80,90,100],
               [20,220,30,40,50,60,70,80,90,100],
               [30,320,30,40,50,60,70,80,90,100],
               [130,4320,30,40,50,60,70,80,90,100],
               [230,6320,30,40,50,60,70,80,90,100],
               [330,7320,30,40,50,60,70,80,90,100],
               [430,8320,30,40,50,60,70,80,90,100],
               [530,9320,30,40,50,60,70,80,90,100],
               [630,11320,30,40,50,60,70,80,90,100],
               [730,12320,30,40,50,60,70,80,90,100],], dtype=float)
batch_shape = rewards.shape
jnp.concatenate([rewards, jnp.zeros((5 - 1,) + batch_shape[1:])])

Array([[1.000e+01, 1.200e+02, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [2.000e+01, 2.200e+02, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [3.000e+01, 3.200e+02, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [1.300e+02, 4.320e+03, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [2.300e+02, 6.320e+03, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [3.300e+02, 7.320e+03, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [4.300e+02, 8.320e+03, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+01, 1.000e+02],
       [5.300e+02, 9.320e+03, 3.000e+01, 4.000e+01, 5.000e+01, 6.000e+01,
        7.000e+01, 8.000e+01, 9.000e+

In [36]:
jnp.zeros((5 - 1,) + batch_shape[1:])

Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [21]:
1 + 1 + 3*0.25 + 4*0.125 + 5*0.125*0.5 + 50*0.125*0.25

5.125

In [55]:
from typing import NamedTuple
import jax 
import jax.numpy as jnp

class MCTSTransition(NamedTuple):
    """Container for a transition."""
    observation: jnp.array
    action: jnp.array
    reward: jnp.array

data = MCTSTransition(observation=jnp.zeros((4,10,3)), action=jnp.zeros((4,10)), reward=jnp.zeros((4,10)))

In [57]:
data._replace(action=jnp.ones((4,10)))

MCTSTransition(observation=Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]], dtype=float32), action=Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1

In [74]:
target = jnp.array([[1,2,3,4], [5,6,7,8]])
chosen_action_logits = jnp.array([[11,22,33,44], [55,66,77,88]])
kappa = 0.5
num_atoms = 4

# target = jnp.array([[1], [5]])
# chosen_action_logits = jnp.array([[3], [9]])
# kappa = 100000
# num_atoms = 1

# bellman_errors = (target[:, None, :] -
#                       chosen_action_logits[:, :, None])  # Input `u' of Eq. 9.
bellman_errors = (jnp.expand_dims(target, -2) -
                      jnp.expand_dims(chosen_action_logits, -1))
# Eq. 9 of paper.
huber_loss = (
    (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
    0.5 * bellman_errors ** 2 +
    (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
    kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))

tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) /
            num_atoms)  # Quantile midpoints.  See Lemma 2 of paper.
# Eq. 10 of paper.
tau_bellman_diff = jnp.abs(
    tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32))
quantile_huber_loss = tau_bellman_diff * huber_loss
# Sum over tau dimension, average over target value dimension.
loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
final_loss = jnp.mean(loss)

In [75]:
final_loss

Array(37.875, dtype=float32)

In [70]:
jnp.mean(0.5 * (target - chosen_action_logits)**2)

Array(5., dtype=float32)

In [78]:
jnp.mean(jnp.array([[1,2,3],[4,4,5]]), -1)
jnp.mean(jnp.array([[1],[4]]), -1)

Array([1., 4.], dtype=float32)

In [85]:
import jax 
import jax.numpy as jnp

sample_key = jax.random.PRNGKey(103)
x_sample_key = jax.random.PRNGKey(13)
sample_batch_size = 10000
sample_position = 4
insert_position = 15

idx = jax.random.randint(
            sample_key,
            (sample_batch_size,),
            minval=sample_position,
            maxval=insert_position,
        )
c_idx = jax.random.choice(
    sample_key,
    jnp.arange(sample_position, insert_position),
    (sample_batch_size,),
    replace=True,
    p=None,
)

print(idx.shape, c_idx.shape)
print(jnp.mean(idx), jnp.mean(c_idx))
print(jnp.median(idx), jnp.median(c_idx))

(10000,) (10000,)
8.988299 8.988299
9.0 9.0


In [94]:

data = jnp.array([[1,2,3],[4,5,6],[7,8,9]])
print(data[:, -1])

data = data.at[:, -1].set(jnp.array([-1,-2,-3]))
print(data[:, -1])
print(data)

[3 6 9]
[-1 -2 -3]
[[ 1  2 -1]
 [ 4  5 -2]
 [ 7  8 -3]]


In [102]:
import jax.numpy as jnp

test = jnp.ones((6,8))

idx = jnp.array([0,0,0,0])
new_val = jnp.ones((4,)) * 2

test = test.at[idx, -1].set(new_val)

# batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap')

test

Array([[1., 1., 1., 1., 1., 1., 1., 2.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

In [107]:
jnp.expand_dims(jnp.ones(4), (0, -1)).shape

(1, 4, 1)

In [114]:
a = jnp.zeros(4)

c = a.at[jnp.array([0,1,0,2,0,1])].add(1)

print(a)
print(c)

[0. 0. 0. 0.]
[3. 2. 1. 0.]


In [145]:
num_samples = 10000

sample_dist_logits = jnp.array([[1,2,3,4], [3,3,-50,1], [1,1,1,1]], dtype=float)
empirical_logits = jnp.zeros_like(sample_dist_logits)

sample_dist_logits = jnp.expand_dims(sample_dist_logits, axis=1)
batch_size = empirical_logits.shape[0]
samples = jax.random.categorical(jax.random.PRNGKey(14), sample_dist_logits, shape=(batch_size, num_samples))

# empirical_logits = empirical_logits.at[jnp.arange(batch_size), samples].add(1.)
# empirical_logits = empirical_logits.at[samples].add(1.)
# empirical_logits = empirical_logits.at[batch_size, samples].add(1.)

def update(x, *indices):
  return x.at[indices].add(1.)

batch_update = jax.vmap(update)
empirical_logits = batch_update(empirical_logits, samples) / num_samples

In [146]:
print(empirical_logits)
print(jax.nn.softmax(sample_dist_logits))

[[0.0343 0.0847 0.2402 0.6408]
 [0.465  0.4692 0.     0.0658]
 [0.2522 0.2484 0.2537 0.2457]]
[[[3.2058604e-02 8.7144323e-02 2.3688284e-01 6.4391428e-01]]

 [[4.6831053e-01 4.6831053e-01 4.4970365e-24 6.3378938e-02]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]


In [149]:
def _apply_temperature(logits, temperature):
  """Returns `logits / temperature`, supporting also temperature=0."""
  # The max subtraction prevents +inf after dividing by a small temperature.
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  tiny = jnp.finfo(logits.dtype).tiny
  return logits / jnp.maximum(tiny, temperature)

temp = 0.1
print(jax.nn.softmax(sample_dist_logits / temp))
print(jax.nn.softmax(_apply_temperature(sample_dist_logits, temp)))

[[[9.3571980e-14 2.0610600e-09 4.5397868e-05 9.9995458e-01]]

 [[5.0000000e-01 5.0000000e-01 0.0000000e+00 1.0305768e-09]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]
[[[9.3571980e-14 2.0610600e-09 4.5397868e-05 9.9995458e-01]]

 [[5.0000000e-01 5.0000000e-01 0.0000000e+00 1.0305768e-09]]

 [[2.5000000e-01 2.5000000e-01 2.5000000e-01 2.5000000e-01]]]


: 