In [None]:
import sys
from poker_env import PokerEnv
from agents.random_policy import RandomActions
from agents.heuristic_policy import HeuristicPolicy
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.a3c import A3C
from ray.rllib.algorithms.sac import SAC
from ray.rllib.algorithms.dqn import DQN
from gym import spaces
import mpu
import numpy as np
import ray
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import PolicySpec
from ray.tune.registry import register_env

In [None]:

def select_policy(agent_id, episode, **kwargs):
    if agent_id == 0:
        return "a3c"
    elif agent_id == 1:
        return "sac"
    elif agent_id == 2:
        return "dqn"
    elif agent_id == 3:
        return "ppo"
    return "learned4"

policy_strings = ['a3c', 'sac', 'dqn', 'ppo']

def env_creator(config):
    env = PokerEnv(select_policy, config)
    return env

register_env("poker", lambda config: env_creator(config))

In [13]:
#Same config as the checkpoint you want to restore.

model = MODEL_DEFAULTS.update({'fcnet_hiddens': [512, 512], 'fcnet_activation': 'relu'})

config = (
    PPOConfig()
    #Each rollout worker uses a single cpu
    .rollouts(num_rollout_workers=8, num_envs_per_worker=1)\
    .training(train_batch_size=4000, gamma=0.99, model=model, lr=0.0004)\
    .environment(disable_env_checking=True)\
    .multi_agent(
        policies={
            #These policies thave pre-definded polices that dont learn.
            "a3c": PolicySpec(config=A3C.get_default_config()),
            "sac": PolicySpec(config=SAC.get_default_config()),
            "dqn": PolicySpec(config=DQN.get_default_config()),
            #Passing nothing causes this agent to deafult to using a PPO policy
            "ppo": PolicySpec(
                config={}
            ),
        },
        policy_mapping_fn=select_policy,
        policies_to_train=policy_strings,
    )\
    .resources(num_gpus=0)\
    .framework('torch')
)
trainer = config.build(env="poker")
trainer.load_checkpoint('checkpoint/many_algos/checkpoint_000484/checkpoint-484')

[2m[36m(RolloutWorker pid=11180)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11186)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11183)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11184)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11179)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11185)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11182)[0m   from collections import Iterable
[2m[36m(RolloutWorker pid=11181)[0m   from collections import Iterable
2022-11-11 09:23:55,175	INFO trainable.py:162 -- Trainable.setup took 23.687 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [16]:
policies = []
for policy_string in policy_strings:
    policies.append(trainer.get_policy(policy_string))

  and should_run_async(code)


In [64]:
policy_returns = np.zeros(4)

env = PokerEnv(select_policy, {})
for i in range(10):
    state = env.reset()[env.current_actor]
    done = False
    agent_index = env.current_actor
    while done == False:
        s = np.array([np.concatenate([state['obs'], state['state']])])
        action = policies[agent_index].compute_actions(s, training=False)
        state_g = env.step({env.current_actor: action[0][0]})
        done = state_g[2]['__all__']
        state = state_g[0][env.current_actor]
    for i, policy_string in enumerate(policy_strings):
        policy_returns[i] += state_g[1][i]

for i, policy_string in enumerate(policy_strings):
    print(policy_string + ': returns ' + str(policy_returns[i]))

a3c: returns 77.0
sac: returns -28.0
dqn: returns -18.0
ppo: returns -31.0


In [None]:
env.step({2: 1})[2]['__all__']