In [None]:
import matplotlib.pyplot as plt

from cfrx.algorithms.cfr import CFRState
from cfrx.policy import TabularPolicy
from cfrx.trainers.cfr import CFRTrainer

In [None]:
ENV_NAME = "Kuhn Poker"

In [None]:
if ENV_NAME == "Kuhn Poker":
    from cfrx.envs.kuhn_poker.env import KuhnPoker

    env_cls = KuhnPoker


elif ENV_NAME == "Leduc Poker":
    from cfrx.envs.leduc_poker.env import LeducPoker

    env_cls = LeducPoker

In [None]:
env = env_cls()

In [None]:
training_state = CFRState.init(n_states=env.n_info_states, n_actions=env.n_actions)
policy = TabularPolicy(n_actions=env.n_actions, info_state_idx_fn=env.info_state_idx)

In [None]:
trainer = CFRTrainer(env=env, policy=policy, device="cpu")
training_state, metrics = trainer.train(n_iterations=10000, metrics_period=10)

In [None]:
plt.plot(metrics["step"], metrics["exploitability"])
plt.yscale("log")
plt.xlabel("Iterations")
plt.title(f"CFR on {ENV_NAME}")
plt.ylabel("Exploitability")