In [1]:
import sys 
sys.path.append('..')

import torch 
import rlcard 
from rlcard.agents import NFSPAgent

In [2]:
state_shape = torch.Size([12])

env = rlcard.make(
    'leduc-holdem',
    config={
        'seed': 0,
    }
)
agent = NFSPAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                hidden_layers_sizes=[64,64],
                q_mlp_layers=[64,64],
                device='cpu',
                estimator_network='mlp'
            )

# Ideas 
[X] Maybe we should use an embedding layer before the transformer layer <br>
[X] Add pooling before final linear layer

In [3]:
import os
import argparse

import torch

import rlcard
from rlcard.agents import RandomAgent
from rlcard.utils import (
    get_device,
    set_seed,
    tournament,
    reorganize,
    Logger,
    plot_curve,
)

In [4]:
''' An example of training a reinforcement learning agent on the environments in RLCard
'''


def train(args):

    # Check whether gpu is available
    device = get_device()
        
    # Seed numpy, torch, random
    set_seed(args.seed)

    # Make the environment with seed
    env = rlcard.make(
        args.env,
        config={
            'seed': args.seed,
        }
    )

    # Initialize the agent and use random agents as opponents
    if args.algorithm == 'dqn':
        from rlcard.agents import DQNAgent
        if args.load_checkpoint_path != "":
            agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = DQNAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                mlp_layers=[64,64],
                device=device,
                save_path=args.log_dir,
                save_every=args.save_every,
                estimator_network=args.estimator_network,
            )
    elif args.algorithm == 'nfsp':
        from rlcard.agents import NFSPAgent
        if args.load_checkpoint_path != "":
            agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = NFSPAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                hidden_layers_sizes=[64,64],
                q_mlp_layers=[64,64],
                device=device,
                save_path=args.log_dir,
                save_every=args.save_every
            )
    elif args.algorithm == 'cfr': 
        from rlcard.agents import CFRAgent
        if args.load_checkpoint_path != "":
            agent = CFRAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = CFRAgent(
                env,
                device=device,
                save_path=args.log_dir,
                save_every=args.save_every
            )
    
    agents = [agent]
    for _ in range(1, env.num_players):
        agents.append(RandomAgent(num_actions=env.num_actions))
    env.set_agents(agents)

    # Start training
    with Logger(args.log_dir) as logger:
        for episode in range(args.num_episodes):

            if args.algorithm == 'nfsp':
                agents[0].sample_episode_policy()

            # Generate data from the environment
            trajectories, payoffs = env.run(is_training=True)

            # Reorganaize the data to be state, action, reward, next_state, done
            trajectories = reorganize(trajectories, payoffs)

            # Feed transitions into agent memory, and train the agent
            # Here, we assume that DQN always plays the first position
            # and the other players play randomly (if any)
            for ts in trajectories[0]:
                agent.feed(ts)

            # Evaluate the performance. Play with random agents.
            if episode % args.evaluate_every == 0:
                logger.log_performance(
                    episode,
                    tournament(
                        env,
                        args.num_eval_games,
                    )[0]
                )

        # Get the paths
        csv_path, fig_path = logger.csv_path, logger.fig_path

    # Plot the learning curve
    plot_curve(csv_path, fig_path, args.algorithm)

    # Save model
    save_path = os.path.join(args.log_dir, 'model.pth')
    torch.save(agent, save_path)
    print('Model saved in', save_path)


In [5]:
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
env = 'leduc-holdem'
algorithm = 'dqn'
cuda = ''
seed = 42
num_episodes = 5000
num_eval_games = 2000
evaluate_every = 100
log_dir = 'experiments/leduc_holdem_dqn_result/'
load_checkpoint_path = ""
save_every = -1

os.environ["CUDA_VISIBLE_DEVICES"] = cuda

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = dotdict({
    'env': env,
    'algorithm': algorithm,
    'seed': seed,
    'num_episodes': num_episodes,
    'num_eval_games': num_eval_games,
    'evaluate_every': evaluate_every,
    'log_dir': log_dir,
    'load_checkpoint_path': load_checkpoint_path,
    'save_every': save_every, 
    'estimator_network': 'transformer',
})

In [6]:
train(args)

--> Running on the CPU

----------------------------------------
  episode      |  0
  reward       |  -0.5095
----------------------------------------
[Transition(state=array([0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.]), action=3, reward=2.0, next_state=array([0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.]), done=True, legal_actions=[1, 2, 3]), Transition(state=array([0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.]), action=2, reward=-0.5, next_state=array([0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.]), done=True, legal_actions=[1, 2, 3]), Tran

KeyboardInterrupt: 

In [15]:
from collections import namedtuple


namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done', 'legal_actions'])(1, 2, 3, 4, 5, 6).state

1