In [3]:
from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
import pyspiel

import eva

class JointPolicy(policy.Policy):
  """Joint policy to be evaluated."""

  def __init__(self, env, agents):
    game = env.game
    player_ids = list(range(len(agents)))
    super(JointPolicy, self).__init__(game, player_ids)
    self._agents = agents

  def action_probabilities(self, state, player_id=None):
    cur_player = state.current_player()
    legal_actions = state.legal_actions(cur_player)
    probs = self._agents[cur_player].action_probabilities(state)
    return {action: probs[action] for action in legal_actions}


def pt_main(game_name, num_episodes):
  env_configs = {"players": 2}
  env = rl_environment.Environment(game_name, **env_configs)
  num_players = env.num_players
  num_actions = env.action_spec()["num_actions"]
  state_size = env.observation_spec()["info_state"][0]
  eva_agents = []
  for player in range(num_players):
    eva_agents.append(
        eva.EVAAgent(
            env,
            player,
            state_size,
            num_actions,
            batch_size=128,
            learning_rate=0.01,
            mixing_parameter=0.5,
            memory_capacity=int(1e6),
            discount_factor=1.0,
            update_target_network_every=1000,
            epsilon_start=0.06,
            epsilon_end=0.001,
            epsilon_decay_duration=int(1e6)))
  
  joint_policy = JointPolicy(env, eva_agents)
  
  result = []
  for episode in range(num_episodes):
    if (episode + 1) % 1000 == 0:
      conv = exploitability.nash_conv(env.game, joint_policy)
      result.append(conv)
      print("Episode:%s - NashConv: %s" %(episode+1, conv))
      
    time_step = env.reset()
    while not time_step.last():
      current_player = time_step.observations["current_player"]
      current_agent = eva_agents[current_player]
      step_out = current_agent.step(time_step)
      time_step = env.step([step_out.action])
        
    for agent in eva_agents:
      agent.step(time_step)
        
  return result


In [4]:
pt_result = []
for _ in range(3):
    result = pt_main('leduc_poker', 10000)
    print(result)
    pt_result.append(result)

Episode:1000 - NashConv: 4.816378315359527
Episode:2000 - NashConv: 5.183435771931083
Episode:3000 - NashConv: 4.981805847350352
Episode:4000 - NashConv: 4.959138290012623
Episode:5000 - NashConv: 4.987403549910527
Episode:6000 - NashConv: 4.894065221510903
Episode:7000 - NashConv: 5.087053143228525
Episode:8000 - NashConv: 5.001063352058948
Episode:9000 - NashConv: 4.574781550661821
Episode:10000 - NashConv: 4.45628410467555
[4.816378315359527, 5.183435771931083, 4.981805847350352, 4.959138290012623, 4.987403549910527, 4.894065221510903, 5.087053143228525, 5.001063352058948, 4.574781550661821, 4.45628410467555]
Episode:1000 - NashConv: 4.762049732195072
Episode:2000 - NashConv: 5.220065921048683
Episode:3000 - NashConv: 5.434335654150852
Episode:4000 - NashConv: 5.148536129794721
Episode:5000 - NashConv: 4.927764855547408
Episode:6000 - NashConv: 4.880683280302989
Episode:7000 - NashConv: 4.770471931455553
Episode:8000 - NashConv: 4.908578138855047
Episode:9000 - NashConv: 4.944229230