In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('./rlcard'))

In [2]:
import numpy as np

import rlcard
from rlcard.agents import RandomAgent, CFRAgent
from rlcard.utils import set_global_seed, tournament
from rlcard.utils import Logger

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
# Make environment and enable human mode
env = rlcard.make('leduc-holdem', config={'seed': 0, 'allow_step_back':True})
eval_env = rlcard.make('leduc-holdem', config={'seed': 0})

In [4]:
# Set the iterations numbers and how frequently we evaluate/save plot
evaluate_every = 100
save_plot_every = 1000
evaluate_num = 10000
episode_num = 10000

# The paths for saving the logs and learning curves
log_dir = './experiments/leduc_holdem_cfr_result/'

# Set a global seed
set_global_seed(0)

# Initilize CFR Agent
agent = CFRAgent(env, model_path='./models/leduc_holdem_cfr')

random_agent = RandomAgent(action_num=env.action_num)

# Evaluate CFR against pre-trained NFSP
eval_env.set_agents([agent, random_agent])

# Init a Logger to plot the learning curve
logger = Logger(log_dir)





* Now we start to train CFR on Lecuc Hold'em. The training logs and the learning curves are shown as below.

In [5]:
for episode in range(episode_num):
    agent.train()
    print('\rIteration {}'.format(episode), end='')
    if episode % evaluate_every == 0:
        agent.save() # Save model
        logger.log_performance(episode, tournament(eval_env, evaluate_num)[0])

Iteration 0
----------------------------------------
  timestep     |  0
  reward       |  -0.0292
----------------------------------------
Iteration 100
----------------------------------------
  timestep     |  100
  reward       |  0.4784
----------------------------------------
Iteration 200
----------------------------------------
  timestep     |  200
  reward       |  0.63435
----------------------------------------
Iteration 300
----------------------------------------
  timestep     |  300
  reward       |  0.6745
----------------------------------------
Iteration 400
----------------------------------------
  timestep     |  400
  reward       |  0.8033
----------------------------------------
Iteration 500
----------------------------------------
  timestep     |  500
  reward       |  0.79115
----------------------------------------
Iteration 600
----------------------------------------
  timestep     |  600
  reward       |  0.82865
----------------------------------------

Iteration 5700
----------------------------------------
  timestep     |  5700
  reward       |  0.793
----------------------------------------
Iteration 5800
----------------------------------------
  timestep     |  5800
  reward       |  0.7922
----------------------------------------
Iteration 5900
----------------------------------------
  timestep     |  5900
  reward       |  0.789
----------------------------------------
Iteration 6000
----------------------------------------
  timestep     |  6000
  reward       |  0.78395
----------------------------------------
Iteration 6100
----------------------------------------
  timestep     |  6100
  reward       |  0.7665
----------------------------------------
Iteration 6200
----------------------------------------
  timestep     |  6200
  reward       |  0.73615
----------------------------------------
Iteration 6300
----------------------------------------
  timestep     |  6300
  reward       |  0.76005
-------------------------

In [6]:
# Close files in the logger
logger.close_files()

# Plot the learning curve
logger.plot('CFR')

./experiments/leduc_holdem_cfr_result/performance.csv
