*Training Poker Using RLLib*

In [1]:
import sys
from poker_env import PokerEnv
from agents.random_policy import RandomActions
from agents.heuristic_policy import HeuristicPolicy
from agents.folder_policy import FolderPolicy
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.sac import SAC
from ray.rllib.algorithms.dqn import DQN
from ray.rllib.algorithms.appo import APPO
from ray.rllib.algorithms.r2d2 import R2D2
from gym import spaces
import mpu
import numpy as np
import ray
from ray import air, tune
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env
from ray.tune import CLIReporter, register_env

  from collections import Iterable


In Rllib, a policy function needs to be passed to map agent IDs to the policy to use. We also create and register the learning environment.

In [2]:
def select_policy(agent_id, episode, **kwargs):
    if agent_id == 0:
        return "learned"
    elif agent_id == 1:
        return "dqn"
    elif agent_id == 2:
        return "folder"
    elif agent_id == 3:
        return "random"
    return "Heuristic_100"

def env_creator(config):
    env = PokerEnv(select_policy, config)
    return env

register_env("poker", lambda config: env_creator(config))

The config describes all aspects of the training. A full list of the parameters is found here: https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm_config.py 

In this example, we have used a default config that runs the PPO algorithm. Other heuristic agents have been defined that will play the game with the single learning agent. 

In [3]:
heuristic_observation_space = spaces.Dict({
            "hand": spaces.Box(0, 1, shape=(24, )),
            "community": spaces.Box(0, 1, shape=(24, ))
        })
action_space = spaces.Discrete(3)

#Defines the learning models architecture. 
model = MODEL_DEFAULTS.update({'fcnet_hiddens': [512, 512], 'fcnet_activation': 'relu', 'use_lstm': True})

config = (
    PPOConfig()
    #Each rollout worker uses a single cpu
    .rollouts(num_rollout_workers=8, num_envs_per_worker=1)\
    .training(train_batch_size=4000, gamma=0.99, model=model, lr=0.0004)\
    .environment(disable_env_checking=True)\
    .multi_agent(
        policies={
            #These policies thave pre-definded polices that dont learn.
            "random": PolicySpec(policy_class=RandomActions),
            "dqn": PolicySpec(config=DQN.get_default_config()),
            "folder": PolicySpec(policy_class=FolderPolicy),
            "Heuristic_100": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 1}),
            #Passing nothing causes this agent to deafult to using a PPO policy
            "learned": PolicySpec(
                config=R2D2.get_default_config()
            ),
        },
        policy_mapping_fn=select_policy,
        policies_to_train=['learned', 'dqn'],
    )\
    .resources(num_gpus=0)\
    .framework('torch')
)
trainer = config.build(env="poker")

2022-11-10 16:32:01,836	INFO worker.py:1528 -- Started a local Ray instance.
[2m[36m(RolloutWorker pid=85625)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85622)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85620)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85623)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85621)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85626)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85624)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=85627)[0m   from collections import Iterable
2022-11-10 16:32:21,378	INFO trainable.py:164 -- Trainable.setup took 22.040 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


Training loop, each run will rollout x timesteps (where x is train_batch_size). An weight update is then applied using the rollout data. 

In [4]:
for i in range(1000):
    trainer.train()

  """Adds a new policy to this Algorithm.


KeyboardInterrupt: 

In [5]:
#Saves a checkpoint of the trainer.
trainer.save("checkpoint/ppo_poker")



'checkpoint/ppo_poker/checkpoint_000719'