*Training Poker Using RLLib*

In [1]:
import sys
from poker_env import PokerEnv
from agents.random_policy import RandomActions
from agents.heuristic_policy import HeuristicPolicy
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.a3c import A3C
from ray.rllib.algorithms.sac import SAC
from ray.rllib.algorithms.dqn import DQN
from gym import spaces
import mpu
import numpy as np
import ray
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env

  from collections import Iterable


In Rllib, a policy function needs to be passed to map agent IDs to the policy to use. We also create and register the learning environment.

In [2]:
def select_policy(agent_id, episode, **kwargs):
    if agent_id == 0:
        return "a3c"
    elif agent_id == 1:
        return "sac"
    elif agent_id == 2:
        return "dqn"
    elif agent_id == 3:
        return "ppo"
    return "learned4"

def env_creator(config):
    env = PokerEnv(select_policy, config)
    return env

register_env("poker", lambda config: env_creator(config))

The config describes all aspects of the training. A full list of the parameters is found here: https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm_config.py 

In this example, we have used a default config that runs the PPO algorithm. Other heuristic agents have been defined that will play the game with the single learning agent. 

In [3]:
heuristic_observation_space = spaces.Dict({
            "hand": spaces.Box(0, 1, shape=(24, )),
            "community": spaces.Box(0, 1, shape=(24, ))
        })
action_space = spaces.Discrete(3)

#Defines the learning models architecture. 
model = MODEL_DEFAULTS.update({'fcnet_hiddens': [512, 512], 'fcnet_activation': 'relu'})

config = (
    PPOConfig()
    #Each rollout worker uses a single cpu
    .rollouts(num_rollout_workers=8, num_envs_per_worker=1)\
    .training(train_batch_size=4000, gamma=0.99, model=model, lr=0.0004)\
    .environment(disable_env_checking=True)\
    .multi_agent(
        policies={
            #These policies thave pre-definded polices that dont learn.
            "a3c": PolicySpec(config=A3C.get_default_config()),
            "sac": PolicySpec(config=SAC.get_default_config()),
            "dqn": PolicySpec(config=DQN.get_default_config()),
            #Passing nothing causes this agent to deafult to using a PPO policy
            "ppo": PolicySpec(
                config={}
            ),
        },
        policy_mapping_fn=select_policy,
        policies_to_train=['a3c', 'sac', 'dqn', 'ppo'],
    )\
    .resources(num_gpus=0)\
    .framework('torch')
)
trainer = config.build(env="poker")

2022-11-10 14:37:09,920	INFO worker.py:1528 -- Started a local Ray instance.
[2m[36m(RolloutWorker pid=76323)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76325)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76330)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76327)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76328)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76324)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76329)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76326)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=76323)[0m 2022-11-10 14:37:22,378	ERROR worker.py:763 -- Exception raised in creation task: The actor died because of an error raised in its creation task, [36mray::RolloutWorker.__init__()[39m (pid=76323, ip=127.0.0.1, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f7c857459d0

ImportError: Could not import tensorflow!

In [None]:
#Start up tensorboard
#!tensorboard --logdir=~/ray_results --host 0.0.0.0

Training loop, each run will rollout x timesteps (where x is train_batch_size). An weight update is then applied using the rollout data. 

In [None]:
for i in range(1000):
    trainer.train()

In [None]:
#Saves a checkpoint of the trainer.
trainer.save("checkpoint/ppo_poker")