In [40]:
import tensorflow.compat.v1 as tf

from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
from open_spiel.python.algorithms import policy_gradient
import policy_gradient as pt_policy_gradient



class PolicyGradientPolicies(policy.Policy):
  """Joint policy to be evaluated."""

  def __init__(self, env, nfsp_policies):
    game = env.game
    player_ids = [0, 1]
    super(PolicyGradientPolicies, self).__init__(game, player_ids)
    self._policies = nfsp_policies
    self._obs = {"info_state": [None, None], "legal_actions": [None, None]}

  def action_probabilities(self, state, player_id=None):
    cur_player = state.current_player()
    legal_actions = state.legal_actions(cur_player)

    self._obs["current_player"] = cur_player
    self._obs["info_state"][cur_player] = (
        state.information_state_tensor(cur_player))
    self._obs["legal_actions"][cur_player] = legal_actions

    info_state = rl_environment.TimeStep(
        observations=self._obs, rewards=None, discounts=None, step_type=None)

    p = self._policies[cur_player].step(info_state, is_evaluation=True).probs
    prob_dict = {action: p[action] for action in legal_actions}
    return prob_dict


def test_tf(game, num_episodes, eval_every, loss_str):
  """
  loss_str:["a2c", "rpg", "qpg", "rm"]
  """
  num_players = 2

  env_configs = {"players": num_players}
  env = rl_environment.Environment(game, **env_configs)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  with tf.Session() as sess:
    # pylint: disable=g-complex-comprehension
    agents = [
        policy_gradient.PolicyGradient(
            sess,
            idx,
            info_state_size,
            num_actions,
            loss_str=loss_str,
            hidden_layers_sizes=(128,)) for idx in range(num_players)
    ]
    expl_policies_avg = PolicyGradientPolicies(env, agents)
    result = []

    sess.run(tf.global_variables_initializer())
    for ep in range(num_episodes):

      if (ep + 1) % eval_every == 0:
        losses = [agent.loss for agent in agents]
        expl = exploitability.exploitability(env.game, expl_policies_avg)
        result.append(expl)
        print(ep+1, expl)

      time_step = env.reset()
      while not time_step.last():
        player_id = time_step.observations["current_player"]
        agent_output = agents[player_id].step(time_step)
        action_list = [agent_output.action]
        time_step = env.step(action_list)

      # Episode is over, step all agents with final info state.
      for agent in agents:
        agent.step(time_step)
  return result

        
def test_pt(game, num_episodes, eval_every, loss_str):
  """
  loss_str:["a2c", "rpg", "qpg", "rm"]
  """
  num_players = 2

  env_configs = {"players": num_players}
  env = rl_environment.Environment(game, **env_configs)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  # pylint: disable=g-complex-comprehension
  agents = [
      pt_policy_gradient.PolicyGradient(
          idx,
          info_state_size,
          num_actions,
          loss_str=loss_str,
          hidden_layers_sizes=(128,)) for idx in range(num_players)
  ]
  expl_policies_avg = PolicyGradientPolicies(env, agents)
  result = []
  for ep in range(num_episodes):
  
    if (ep + 1) % eval_every == 0:
      losses = [agent.loss for agent in agents]
      expl = exploitability.exploitability(env.game, expl_policies_avg)
      result.append(expl)
      print(ep+1, expl)
  
    time_step = env.reset()
    while not time_step.last():
      player_id = time_step.observations["current_player"]
      agent_output = agents[player_id].step(time_step)
      action_list = [agent_output.action]
      time_step = env.step(action_list)
  
    # Episode is over, step all agents with final info state.
    for agent in agents:
      agent.step(time_step)
  return result

In [44]:
NUM_EPISODES = int(1e5)
EVAL_EVERY = int(1e4)
LOSS_STR = 'a2c'

In [67]:
tf_res = test_tf('kuhn_poker', NUM_EPISODES, EVAL_EVERY, LOSS_STR)
tf_res

10000 0.37900984807950605
20000 0.32177749715166226
30000 0.2988562104484438
40000 0.29518197830443826
50000 0.2915205565020096
60000 0.2893561715284406
70000 0.28679159436437335
80000 0.2845466054763692
90000 0.281548425862029
100000 0.2785872204966148


[0.37900984807950605,
 0.32177749715166226,
 0.2988562104484438,
 0.29518197830443826,
 0.2915205565020096,
 0.2893561715284406,
 0.28679159436437335,
 0.2845466054763692,
 0.281548425862029,
 0.2785872204966148]

In [63]:
pt_res = test_pt('kuhn_poker', NUM_EPISODES, EVAL_EVERY, LOSS_STR)
pt_res

10000 0.3620111648885066
20000 0.31271177510954773
30000 0.29973228050132517
40000 0.29117547732525173
50000 0.2869790826510561
60000 0.2802378814424329
70000 0.27519049690420605
80000 0.2703495365791758
90000 0.26785184518091587
100000 0.26297236263585877


[0.3620111648885066,
 0.31271177510954773,
 0.29973228050132517,
 0.29117547732525173,
 0.2869790826510561,
 0.2802378814424329,
 0.27519049690420605,
 0.2703495365791758,
 0.26785184518091587,
 0.26297236263585877]

In [60]:
NUM_EPISODES = int(1e5)
EVAL_EVERY = int(1e4)
LOSS_STR = 'a2c'

In [70]:
for _ in range(3):
    tf_res = test_tf('leduc_poker', NUM_EPISODES, EVAL_EVERY, LOSS_STR)
    print(tf_res)

10000 2.2392726266077894
20000 2.239777544707624
30000 2.2255411968226078
40000 2.1410972362134126
50000 2.115922933413102
60000 2.05420737023231
70000 1.98743257776856
80000 2.024815840747584
90000 1.911366473818718
100000 1.8708556764754363
[2.2392726266077894, 2.239777544707624, 2.2255411968226078, 2.1410972362134126, 2.115922933413102, 2.05420737023231, 1.98743257776856, 2.024815840747584, 1.911366473818718, 1.8708556764754363]
10000 2.595945624369934
20000 2.3978561096544
30000 2.2754552938515595
40000 2.292387165236848
50000 2.161305849638241
60000 2.1529491120985345
70000 2.066636061867233
80000 2.05650548975179
90000 2.0750003303561826
100000 2.029161603889028
[2.595945624369934, 2.3978561096544, 2.2754552938515595, 2.292387165236848, 2.161305849638241, 2.1529491120985345, 2.066636061867233, 2.05650548975179, 2.0750003303561826, 2.029161603889028]
10000 2.369499491805596
20000 2.1855354212742903
30000 2.2028650242181054
40000 2.1622182114165285
50000 2.165994703512826
60000 2.0

In [71]:
for _ in range(3):
    pt_res = test_pt('leduc_poker', NUM_EPISODES, EVAL_EVERY, LOSS_STR)
    print(pt_res)

10000 2.2749961873716575
20000 2.2666778677428434
30000 2.165287145100697
40000 2.079457954314968
50000 2.0918588635337105
60000 2.044400192524746
70000 2.027876559334195
80000 2.030867540135744
90000 1.9833670682156674
100000 1.9716290231472917
[2.2749961873716575, 2.2666778677428434, 2.165287145100697, 2.079457954314968, 2.0918588635337105, 2.044400192524746, 2.027876559334195, 2.030867540135744, 1.9833670682156674, 1.9716290231472917]
10000 2.2145492130950784
20000 2.1373351079906295
30000 2.1402202167504245
40000 2.1780810582849908
50000 2.0549493290473118
60000 2.0663108221192696
70000 2.007294744892595
80000 1.913486012680631
90000 1.9452607731626874
100000 1.8692626557938719
[2.2145492130950784, 2.1373351079906295, 2.1402202167504245, 2.1780810582849908, 2.0549493290473118, 2.0663108221192696, 2.007294744892595, 1.913486012680631, 1.9452607731626874, 1.8692626557938719]
10000 2.5407027568810427
20000 2.4192188835878694
30000 2.3870904835765687
40000 2.3705128244737397
50000 2.37