In [None]:
from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
from open_spiel.python.pytorch import armac

In [None]:
class ARMACPolicies(policy.Policy):
  """Joint policy to be evaluated."""

  def __init__(self, env, armac_policies, mode):
    game = env.game
    player_ids = [0, 1]
    super(ARMACPolicies, self).__init__(game, player_ids)
    self._policies = armac_policies
    self._mode = mode
    self._obs = {"info_state": [None, None], "legal_actions": [None, None]}

  def action_probabilities(self, state, player_id=None):
    cur_player = state.current_player()
    legal_actions = state.legal_actions(cur_player)

    self._obs["current_player"] = cur_player
    self._obs["info_state"][cur_player] = (
        state.information_state_tensor(cur_player))
    self._obs["legal_actions"][cur_player] = legal_actions

    info_state = rl_environment.TimeStep(
        observations=self._obs, rewards=None, discounts=None, step_type=None)

    with self._policies[cur_player].temp_mode_as(self._mode):
      p = self._policies[cur_player].step(info_state, is_evaluation=True).probs
    prob_dict = {action: p[action] for action in legal_actions}
    return prob_dict

def pt_main(game,
            env_config,
            num_train_episodes,
            eval_every,
            hidden_layers_sizes,
            replay_buffer_capacity,
            reservoir_buffer_capacity,
            anticipatory_param):
  env = rl_environment.Environment(game, **env_configs)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]
  kwargs = {
      "replay_buffer_capacity": replay_buffer_capacity,
      "epsilon_decay_duration": num_train_episodes,
      "epsilon_start": 0.06,
      "epsilon_end": 0.001,
  }
  expl_list = []
  agents = [
      armac.ARMAC(idx, info_state_size, num_actions, hidden_layers_sizes,
                   reservoir_buffer_capacity, anticipatory_param,
                   **kwargs) for idx in range(num_players)
  ]
  expl_policies_avg = ARMACPolicies(env, agents, nfsp_pt.MODE.average_policy)  
  for ep in range(num_train_episodes):
    if (ep + 1) % eval_every == 0:
      losses = [agent.loss for agent in agents]
      print("Losses: %s" %losses)
      expl = exploitability.exploitability(env.game, expl_policies_avg)
      expl_list.append(expl)
      print("[%s] Exploitability AVG %s" %(ep + 1, expl))
      print("_____________________________________________")  
    time_step = env.reset()
    while not time_step.last():
      player_id = time_step.observations["current_player"]
      agent_output = agents[player_id].step(time_step)
      action_list = [agent_output.action]
      time_step = env.step(action_list)  
    # Episode is over, step all agents with final info state.
    for agent in agents:
      agent.step(time_step)
  return expl_list