In [1]:
import sys
sys.path.append('..')

In [4]:
import torch
import os

import rlcard
from rlcard.agents import RandomAgent, NFSPAgent, DQNAgent, CFRAgent


import argparse
import pprint
from rlcard.utils import (
    get_device,
    set_seed,
    tournament,
    reorganize,
    Logger,
    plot_curve,
)


env = rlcard.make('leduc-holdem', config={'seed': 0})

In [3]:
''' An example of training a reinforcement learning agent on the environments in RLCard
'''

def prepare_environment(args):
    # Check whether gpu is available
    device = get_device()

    # Seed numpy, torch, random
    set_seed(args.seed)

    # Make the environment with seed
    env = rlcard.make(args.env, config={'seed': args.seed})

    # Initialize the agent and use random agents as opponents
    if args.algorithm == 'dqn':
        from rlcard.agents import DQNAgent
        if args.load_checkpoint_path != "":
            agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = DQNAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                mlp_layers= args.mlp_layers,
                device=device,
                save_path=args.log_dir,
                save_every=args.save_every,
                estimator_network=args.estimator_network,
            )
    elif args.algorithm == 'nfsp':
        from rlcard.agents import NFSPAgent
        if args.load_checkpoint_path != "":
            agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = NFSPAgent(
                num_actions=env.num_actions,
                state_shape=env.state_shape[0],
                hidden_layers_sizes=args.hidden_layers_sizes,
                q_mlp_layers=[64,64],
                device=device,
                save_path=args.log_dir,
                save_every=args.save_every,
                estimator_network=args.estimator_network,
            )
    elif args.algorithm == 'cfr':
        from rlcard.agents import CFRAgent
        if args.load_checkpoint_path != "":
            agent = CFRAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
        else:
            agent = CFRAgent(
                env,
                #device=device,
                model_path=args.log_dir,
                #save_every=args.save_every
            )
    # Set agents to environment
    agents = [agent]
    for _ in range(1, env.num_players):
        agents.append(RandomAgent(num_actions=env.num_actions)) # Random agents as opponents
    env.set_agents(agents)

    return env, agent, agents



def train(args):
    env, agent, agents = prepare_environment(args)

# Start training
    with Logger(args.log_dir) as logger:
        for episode in range(args.num_episodes):

            if args.algorithm == 'nfsp':
                agents[0].sample_episode_policy()

            # Generate data from the environment
            trajectories, payoffs = env.run(is_training=True)

            # Reorganaize the data to be state, action, reward, next_state, done
            trajectories = reorganize(trajectories, payoffs)

            # Feed transitions into agent memory, and train the agent
            # Here, we assume that DQN always plays the first position
            # and the other players play randomly (if any)
            for ts in trajectories[0]:
                agent.feed(ts)

            # Evaluate the performance. Play with random agents.
            if episode % args.evaluate_every == 0:
                logger.log_performance(
                    episode,
                    tournament(
                        env,
                        args.num_eval_games,
                    )[0]
                )

        # Get the paths
        csv_path, fig_path = logger.csv_path, logger.fig_path

    # Plot the learning curve
    plot_curve(csv_path, fig_path, args.algorithm)

    # Save model
    save_path = os.path.join(args.log_dir, 'model.pth')
    torch.save(agent, save_path)
    print('Model saved in', save_path)