In [None]:
import sys
from poker_env import PokerEnv
from agents.random_policy import RandomActions
from agents.heuristic_policy import HeuristicPolicy
from ray.rllib.algorithms.ppo import PPOConfig
from gym import spaces
import mpu
import numpy as np
import ray
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env

In [None]:
def select_policy(agent_id, episode, **kwargs):
    if agent_id == 0:
        return "learned"
    elif agent_id == 1:
        return "Heuristic_10"
    elif agent_id == 2:
        return "Heuristic_100"
    elif agent_id == 3:
        return "random"
    return "Heuristic_1000"

def env_creator(config):
    env = PokerEnv(select_policy, config)
    return env

register_env("poker", lambda config: env_creator(config))

In [3]:
heuristic_observation_space = spaces.Dict({
            "hand": spaces.Box(0, 1, shape=(24, )),
            "community": spaces.Box(0, 1, shape=(24, ))
        })
action_space = spaces.Discrete(3)

model = MODEL_DEFAULTS.update({'fcnet_hiddens': [512, 512], 'fcnet_activation': 'relu'})

config = (
    PPOConfig()
    .rollouts(num_rollout_workers=8, num_envs_per_worker=1)\
    .training(train_batch_size=4000, gamma=0.99, model=model, lr=0.0004)\
    .environment(disable_env_checking=True)\
    .multi_agent(
        policies={
            "random": PolicySpec(policy_class=RandomActions),
            "Heuristic_10": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 0}),
            "Heuristic_100": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 1}),
            "Heuristic_1000": (HeuristicPolicy, heuristic_observation_space, action_space, {'difficulty': 2}),
            "learned": PolicySpec(
                config={}
            ),
        },
        policy_mapping_fn=select_policy,
        policies_to_train=['learned'],
    )\
    .resources(num_gpus=0)\
    .framework('torch')
)
trainer = config.build(env="poker")


2022-11-08 14:21:25,710	INFO worker.py:1518 -- Started a local Ray instance.
[2m[36m(pid=30073)[0m   import imp
[2m[36m(pid=30074)[0m   import imp
[2m[36m(pid=30069)[0m   import imp
[2m[36m(pid=30072)[0m   import imp
[2m[36m(pid=30068)[0m   import imp
[2m[36m(pid=30071)[0m   import imp
[2m[36m(pid=30070)[0m   import imp
[2m[36m(pid=30075)[0m   import imp
[2m[36m(pid=30074)[0m   'nearest': pil_image.NEAREST,
[2m[36m(pid=30074)[0m   'bilinear': pil_image.BILINEAR,
[2m[36m(pid=30074)[0m   'bicubic': pil_image.BICUBIC,
[2m[36m(pid=30074)[0m   if hasattr(pil_image, 'HAMMING'):
[2m[36m(pid=30074)[0m   if hasattr(pil_image, 'BOX'):
[2m[36m(pid=30074)[0m   if hasattr(pil_image, 'LANCZOS'):
[2m[36m(pid=30069)[0m   'nearest': pil_image.NEAREST,
[2m[36m(pid=30069)[0m   'bilinear': pil_image.BILINEAR,
[2m[36m(pid=30069)[0m   'bicubic': pil_image.BICUBIC,
[2m[36m(pid=30069)[0m   if hasattr(pil_image, 'HAMMING'):
[2m[36m(pid=30069)[0m   if has

In [4]:
#!tensorboard --logdir=~/ray_results --host 0.0.0.0

In [5]:
for i in range(1000):
    trainer.train()



KeyboardInterrupt: 

In [6]:
trainer.save("checkpoint/ppo_poker")

'checkpoint/ppo_poker/checkpoint_000011'