# Monte Carlo Counterfactual Regret Minimization (MCCFR)

In this example we showcase how to use `cfrx` to run the MCCFR (outcome-sampling variation) on simple games Kuhn Poker, Leduc Poker.

We'll see how to:
 - Initialize an environment from `cfrx`
 - Initialize a random policy and sample a rollout
 - Write a small training loop to run the MCCFR algorithm
 - Measure the evolution of our strategy exploitability throughout the training

In [None]:
!pip install matplotlib

In [None]:
import jax

jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from cfrx.algorithms.mccfr.outcome_sampling import MCCFRState, do_iteration, unroll
from cfrx.metrics import exploitability
from cfrx.policy import TabularPolicy
from cfrx.utils import regret_matching

In [None]:
from IPython.display import clear_output

def plot_partial(plot_fn, *plot_args):
    clear_output(wait=True)
    fig = plot_fn(*plot_args)
    plt.show(fig)

In [None]:
device = jax.devices("cpu")[0]

In [None]:
# Hyperparameters
ENV_NAME = "Kuhn Poker"
NUM_ITERATIONS = 100000
EXPLORATION_FACTOR = 0.6
SEED = 0
METRICS_PERIOD = 10000

random_key = jax.random.PRNGKey(SEED)

In [None]:
if ENV_NAME == "Kuhn Poker":
    from cfrx.envs.kuhn_poker.env import KuhnPoker

    env_cls = KuhnPoker


elif ENV_NAME == "Leduc Poker":
    from cfrx.envs.leduc_poker.env import LeducPoker

    env_cls = LeducPoker

## The environment

[Kuhn Poker](https://en.wikipedia.org/wiki/Kuhn_poker) is a simplified version of the Poker game. In cfrx, we use the environment from [pgx](https://github.com/sotetsuk/pgx), and add a wrapper to explicitly handle random nodes and information states.

In [None]:
env = env_cls()

In [None]:
# Number of info_sets, number of possible actions
n_states = env.n_info_states
n_actions = env.n_actions

n_states, n_actions

In [None]:
s0 = env.init(random_key)
s0  # Cards haven't been dealed yet

In [None]:
# Give a J to player 1 and a K to player 2
s1 = env.step(s0, action=jnp.array(0))
s2 = env.step(s1, action=jnp.array(2))
jax.tree_map(lambda *z: jnp.stack(z), s1, s2)

## Random policy

In [None]:
# Initialize a training state
training_state = MCCFRState.init(n_states, n_actions)
jax.tree_map(np.shape, training_state)

In [None]:
# Initialize a Policy object and print the probability distribution for our current strategy and state
policy = TabularPolicy(
    n_actions=n_actions,
    exploration_factor=EXPLORATION_FACTOR,
    info_state_idx_fn=env.info_state_idx,
)

policy.prob_distribution(
    params=training_state.probs,
    info_state=s2.info_state,
    action_mask=s2.legal_action_mask,
    use_behavior_policy=jnp.bool_(False),
)

In [None]:
# Let's do an unroll with our uniformly-random Policy
random_key, subkey = jax.random.split(random_key)
episode, states = unroll(
    init_state=s2,
    training_state=training_state,
    random_key=subkey,
    update_player=0,
    env=env,
    policy=policy,
    n_max_steps=env.max_episode_length,
)

Print out the action sequence "b" means "bet" and "p" pass

In [None]:
jax.tree_map(lambda x: x[~states.terminated], states)

In [None]:
"".join(
    [env_cls.action_to_string(x) for x in episode.action[episode.mask.astype(bool)]]
)

## MCCFR implementation
We use the `cfrx` components to implement the MCCFR algorithm.

The algorithm consists in alternating iterations for the two players, and logging the exploitability from time to time.

Note: We make sure to Jit both the iteration and exploitability function, to make the most of Jax capabilities.

In [None]:
# This function samples a trajectory, compute counterfactual regrets and update the policy accordingly
do_iteration_fn = jax.jit(
    lambda training_state, random_key, update_player: do_iteration(
        training_state=training_state,
        random_key=random_key,
        env=env,
        policy=policy,
        update_player=update_player,
    )
)

In [None]:
# This function measures the exploitability of a strategy
exploitability_fn = jax.jit(
    lambda policy_params: exploitability(
        policy_params=policy_params,
        env=env,
        n_players=env.n_players,
        n_max_nodes=env.max_nodes,
        policy=policy,
    ),
    device=device,
)

In [None]:
# One iteration consists in updating the policy for both players
n_loops = 2 * NUM_ITERATIONS

exploitabilities = []
iterations = []

for k in range(n_loops):
    random_key, subkey = jax.random.split(random_key)

    # Update players alternatively
    update_player = k % 2
    training_state = do_iteration_fn(
        training_state=training_state,
        random_key=subkey,
        update_player=update_player,
    )

    # Logging
    if k == 0 or (k + 1) % (METRICS_PERIOD * 2) == 0:
        current_policy = training_state.avg_probs
        current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)

        exp = exploitability_fn(policy_params=current_policy)

        exploitabilities.append(exp)
        iterations.append(k // 2)
        plt.xlabel("Iterations")
        plt.title(f"MCCFR outcome sampling on {ENV_NAME}")
        plt.ylabel("Exploitability")
        plt.yscale("log")
        plt.xlim(0, NUM_ITERATIONS)

        plot_partial(plt.plot, iterations, exploitabilities)

All this logic is also implemented inside a trainer, which is further optimized to reduce the runtime

In [None]:
from cfrx.trainers.mccfr import MCCFRTrainer

In [None]:
trainer = MCCFRTrainer(env=env, policy=policy)

In [None]:
training_state = trainer.train(
    random_key=random_key, n_iterations=NUM_ITERATIONS, metrics_period=METRICS_PERIOD
)