*Training Poker Using RLLib*

In [1]:
import sys
from poker_env import PokerEnv
from agents.random_policy import RandomActions
from agents.heuristic_policy import HeuristicPolicy
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.a3c import A3CConfig
from ray.rllib.algorithms.ars import ARSConfig


from gym import spaces
import mpu
import numpy as np
import ray
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env

  import distutils.spawn


In Rllib, a policy function needs to be passed to map agent IDs to the policy to use. We also create and register the learning environment.

In [2]:
def select_policy(agent_id, episode, **kwargs):
    if agent_id == 0:
        return "learned"
    elif agent_id == 1:
        return "Heuristic_10"
    elif agent_id == 2:
        return "Heuristic_100"
    elif agent_id == 3:
        return "random"
    return "Heuristic_1000"

def env_creator(config):
    env = PokerEnv(select_policy, config)
    return env

register_env("poker", lambda config: env_creator(config))

The config describes all aspects of the training. A full list of the parameters is found here: https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm_config.py 

In this example, we have used a default config that runs the PPO algorithm. Other heuristic agents have been defined that will play the game with the single learning agent. 

In [3]:
heuristic_observation_space = spaces.Dict({
            "hand": spaces.Box(0, 1, shape=(24, )),
            "community": spaces.Box(0, 1, shape=(24, ))
        })
action_space = spaces.Discrete(3)

#Defines the learning models architecture. 
model = MODEL_DEFAULTS.update({'fcnet_hiddens': [512, 512], 'fcnet_activation': 'relu'})

config = (
    A3CConfig()
    #Each rollout worker uses a single cpu
    .rollouts(num_rollout_workers=8, num_envs_per_worker=1)\
    .training(train_batch_size=4000, gamma=0.99, model=model, lr=0.0004)\
    .environment(disable_env_checking=True)\
    .multi_agent(
        policies={
            #These policies thave pre-definded polices that dont learn.
            "random": PolicySpec(policy_class=RandomActions),
            "Heuristic_10": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 0}),
            "Heuristic_100": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 1}),
            "Heuristic_1000": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 2}),
            #Passing nothing causes this agent to deafult to using a PPO policy
            "learned": PolicySpec(
                config={}
            ),
        },
        policy_mapping_fn=select_policy,
        policies_to_train=['learned'],
    )\
    .resources(num_gpus=0)\
    .framework('torch')
)
trainer = config.build(env="poker")

2022-11-10 16:37:30,939	INFO worker.py:1528 -- Started a local Ray instance.
[2m[36m(pid=11780)[0m   import distutils.spawn
[2m[36m(pid=11781)[0m   import distutils.spawn
[2m[36m(pid=11782)[0m   import distutils.spawn
[2m[36m(pid=11785)[0m   import distutils.spawn
[2m[36m(pid=11787)[0m   import distutils.spawn
[2m[36m(pid=11784)[0m   import distutils.spawn
[2m[36m(pid=11789)[0m   import distutils.spawn
[2m[36m(pid=11788)[0m   import distutils.spawn
2022-11-10 16:37:42,417	INFO trainable.py:164 -- Trainable.setup took 13.782 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [4]:
#Start up tensorboard
!tensorboard --logdir=~/ray_results --host 0.0.0.0

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.11.0 at http://0.0.0.0:6007/ (Press CTRL+C to quit)
^C


Training loop, each run will rollout x timesteps (where x is train_batch_size). An weight update is then applied using the rollout data. 

In [None]:
for i in range(1000):
    trainer.train()
    print(i)
    #Saves a checkpoint of the trainer.
    if i % 200 == 0:
        trainer.save("checkpoint/a3c_poker")



0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


In [None]:
#Saves a checkpoint of the trainer.
trainer.save("checkpoint/a3c_poker")