In [1]:
import jax
import jax.numpy as jnp

import numpy as np
import wandb

from functools import partial
from pprint import pprint
import pickle

import chex
from chex import PRNGKey

from pgx import State

from type_aliases import Observation, Reward, Done, Action
import connect_four_env as env
import mcts_agent

In [2]:
def evaluate_pvp(rng: PRNGKey, policy1, policy2, batch_size: int):
    def single_move(prev: tuple[State, Observation], rng: PRNGKey) -> tuple[tuple[State, Observation], Reward]:
        state, observation = prev
        rng0, rng1, rng2 = jax.random.split(rng, 3)

        action0 = policy1(rng0, state)
        action1 = policy2(rng1, state)
        action = jnp.where(state.current_player == 0, action0, action1)
    
        new_state, new_observation, new_reward, new_done = jax.vmap(env.step_autoreset)(state, action, jax.random.split(rng2, batch_size))
        return (new_state, new_observation), (-new_reward * (state.current_player * 2 - 1), new_done)

    rng, subkey = jax.random.split(rng)
    state, observation = jax.vmap(env.reset)(jax.random.split(subkey, batch_size))
    
    first = state, observation
    _, out = jax.lax.scan(single_move, first, jax.random.split(rng, env.max_steps))
    rewards, done = out
    chex.assert_shape(rewards, [env.max_steps, batch_size])
    chex.assert_shape(done, [env.max_steps, batch_size])

    num_episodes = done.sum()
    wins = ((rewards[:, :] > 0) & done).sum() / num_episodes
    draws = ((rewards[:, :] == 0) & done).sum() / num_episodes
    losses = ((rewards[:, :] < 0) & done).sum() / num_episodes
    return wins, draws, losses


def random_policy(rng: PRNGKey, state: State) -> chex.Array:
    logits = jnp.zeros(env.num_actions)
    action_mask = state.legal_action_mask
    logits_masked = jnp.where(action_mask, logits, -1e9)
    return jax.random.categorical(rng, logits_masked)


def make_mcts_policy(num_simulations: int):
    def mcts_policy(rng: PRNGKey, state: State) -> chex.Array:
        out = mcts_agent.batched_compute_policy(rng, state, num_simulations)
        logits = out.action_weights
        action_mask = state.legal_action_mask
        logits_masked = jnp.where(action_mask, logits, -1e9)
        return logits_masked.argmax(axis=-1)
    return mcts_policy

In [3]:
evaluate_pvp(jax.random.key(0), random_policy, random_policy, 1024)

(Array(0.51399493, dtype=float32),
 Array(0.00127226, dtype=float32),
 Array(0.48473284, dtype=float32))

In [4]:
evaluate_pvp(jax.random.key(0), make_mcts_policy(16), random_policy, 32)

(Array(0.82758623, dtype=float32),
 Array(0., dtype=float32),
 Array(0.1724138, dtype=float32))

In [5]:
evaluate_pvp(jax.random.key(0), make_mcts_policy(64), make_mcts_policy(16), 32)

(Array(0.9066667, dtype=float32),
 Array(0.01333333, dtype=float32),
 Array(0.08, dtype=float32))