Adapted from `rlcard/examples/leduc_holdem_dqn_pytorch.py`

In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath('./rlcard'))

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
import torch
import rlcard
from DQNAgent import DQNAgent
from rlcard.agents import RandomAgent
from rlcard.utils import set_global_seed, tournament
from rlcard.utils import Logger

In [None]:
# Make environment
env = rlcard.make('kuhn-poker', config={'seed': 0})
eval_env = rlcard.make('kuhn-poker', config={'seed': 0})

In [None]:
# Set the iterations numbers and how frequently we evaluate the performance
evaluate_every = 100
evaluate_num = 1000
episode_num = 10000

In [None]:
# The paths for saving the logs and learning curves
log_dir = './experiments/kuhn_poker_dqn_result/'

In [None]:
# Set a global seed
set_global_seed(0)

In [None]:
agent = DQNAgent(env.action_num, env.state_shape[0])
random_agent = RandomAgent(action_num=eval_env.action_num)
env.set_agents([agent, agent])
eval_env.set_agents([agent, random_agent])

In [None]:
# Init a Logger to plot the learning curve
logger = Logger(log_dir)

In [None]:
for episode in range(episode_num):
    # Generate data from the environment
    trajectories, _ = env.run(is_training=True)

    # Feed transitions into agent memory, and train the agent
    for ts in trajectories[0]:
        agent.train(ts)

    # Evaluate the performance. Play with random agents.
    if episode % evaluate_every == 0:
        logger.log_performance(episode,
                               tournament(eval_env, evaluate_num)[0])

In [None]:
# Close files in the logger
logger.close_files()

In [None]:
# Plot the learning curve
logger.plot('DQN')

In [None]:
# Save model
save_dir = 'models/kuhn_poker_dqn'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
state_dict = agent.get_state_dict()

In [None]:
torch.save(state_dict, os.path.join(save_dir, 'model.pth'))

In [None]:
agent.weight_updates