In [1]:
from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
import pyspiel

import eva

class JointPolicy(policy.Policy):
  """Joint policy to be evaluated."""

  def __init__(self, agents):
    self._agents = agents

  def action_probabilities(self, state, player_id=None):
    cur_player = state.current_player()
    return self._agents[cur_player].action_probabilities(state)


def pt_main(game_name, num_episodes):
  env_configs = {"players": 2}
  env = rl_environment.Environment(game_name, **env_configs)
  num_players = env.num_players
  num_actions = env.action_spec()["num_actions"]
  state_size = env.observation_spec()["info_state"][0]
  eva_agents = []
  for player in range(num_players):
    eva_agents.append(
        eva.EVAAgent(
            env,
            player,
            state_size,
            num_actions,
            batch_size=128,
            learning_rate=0.01,
            mixing_parameter=0.5,
            memory_capacity=int(1e6),
            discount_factor=1.0,
            update_target_network_every=1000,
            epsilon_start=0.06,
            epsilon_end=0.001,
            epsilon_decay_duration=int(1e6)))
  
  joint_policy = JointPolicy(eva_agents)
  
  result = []
  for episode in range(num_episodes):
    if (episode + 1) % 1000 == 0:
      conv = exploitability.nash_conv(env.game, joint_policy)
      result.append(conv)
      print("Episode:%s - NashConv: %s" %(episode+1, conv))
      
    time_step = env.reset()
    while not time_step.last():
      current_player = time_step.observations["current_player"]
      current_agent = eva_agents[current_player]
      step_out = current_agent.step(time_step)
      time_step = env.step([step_out.action])
        
    for agent in eva_agents:
      agent.step(time_step)
        
  return result


In [3]:
pt_result = []
for _ in range(3):
    result = pt_main('kuhn_poker', 20000)
    print(result)
    pt_result.append(result)

Episode:1000 - NashConv: 0.9117420066336805
Episode:2000 - NashConv: 0.7940201644225783
Episode:3000 - NashConv: 0.7288770716970222
Episode:4000 - NashConv: 0.6910837244776069
Episode:5000 - NashConv: 0.6371336078693635
Episode:6000 - NashConv: 0.5907259362084191
Episode:7000 - NashConv: 0.5574534220973357
Episode:8000 - NashConv: 0.535207538216038
Episode:9000 - NashConv: 0.5255008743026045
Episode:10000 - NashConv: 0.5212872126513359
Episode:11000 - NashConv: 0.5251768393628992
Episode:12000 - NashConv: 0.5195450900686065
Episode:13000 - NashConv: 0.5142098394847641
Episode:14000 - NashConv: 0.5130339942029863
Episode:15000 - NashConv: 0.5081925707640036
Episode:16000 - NashConv: 0.5081132336912204
Episode:17000 - NashConv: 0.502744607793246
Episode:18000 - NashConv: 0.5031543143834853
Episode:19000 - NashConv: 0.5098682246087921
Episode:20000 - NashConv: 0.5149615066565998
[0.9117420066336805, 0.7940201644225783, 0.7288770716970222, 0.6910837244776069, 0.6371336078693635, 0.59072593

KeyboardInterrupt: 

In [None]:
pt_result