In [10]:
import jax
import jax.numpy as jnp
import chex
from chex import PRNGKey
import optax
from flax import nnx

from modeling.connect_four import ConnectFourNetwork
from modeling.common import NetworkVariables

In [11]:
model = ConnectFourNetwork(rngs=nnx.Rngs(0))
graphdef, params, state = nnx.split(model, nnx.Param, nnx.BatchStat)
variables = NetworkVariables(
    graphdef=graphdef,
    params=params,
    state=state,
)
model(jnp.zeros((1, 6, 7, 2)))

NetworkOutputs(pi=Array([[0., 0., 0., 0., 0., 0., 0.]], dtype=float32), v=Array([0.], dtype=float32))

In [12]:
from az_agent import batched_compute_policy

In [13]:
import connect_four_env as env

state, observation = jax.vmap(env.reset)(jax.random.split(jax.random.key(0), 1))

In [14]:
batched_compute_policy(variables, jax.random.key(0), state, observation, 4)

PolicyOutput(action=Array([5], dtype=int32), action_weights=Array([[0.07094882, 0.0038107 , 0.07094882, 0.05461996, 0.07094882,
        0.62503827, 0.10368459]], dtype=float32), search_tree=Tree(node_visits=Array([[5, 1, 1, 1, 1]], dtype=int32), raw_values=Array([[ 0.        ,  0.02921522,  0.10380112, -0.17983337,  0.4136383 ]],      dtype=float32), node_values=Array([[-0.07336425,  0.02921522,  0.10380112, -0.17983337,  0.4136383 ]],      dtype=float32), parents=Array([[-1,  0,  0,  0,  0]], dtype=int32), action_from_parent=Array([[-1,  6,  3,  5,  1]], dtype=int32), children_index=Array([[[-1,  4, -1,  2, -1,  3,  1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1]]], dtype=int32), children_prior_logits=Array([[[ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ],
        [-0.11296141, -0.936137  , -0.672769  ,  0.7346219 ,
          

In [15]:
from evaluation import make_az_policy, make_mcts_policy, random_policy, evaluate_pvp

In [16]:
evaluate_pvp(jax.random.key(0), make_mcts_policy(32), random_policy, 64)

(Array(0.928, dtype=float32),
 Array(0., dtype=float32),
 Array(0.072, dtype=float32))

In [17]:
evaluate_pvp(jax.random.key(0), make_az_policy(variables, 32), random_policy, 64)

(Array(0.8175676, dtype=float32),
 Array(0., dtype=float32),
 Array(0.18243243, dtype=float32))

In [18]:
evaluate_pvp(jax.random.key(0), make_az_policy(variables, 32), make_mcts_policy(32), 64)

(Array(0.19178082, dtype=float32),
 Array(0., dtype=float32),
 Array(0.8082192, dtype=float32))