In [None]:
!pip install --upgrade open_spiel

from open_spiel.python import policy
from open_spiel.python.algorithms import deep_cfr_tf2
from open_spiel.python.algorithms import expected_game_score
from open_spiel.python.algorithms import exploitability
import pyspiel



In [None]:
num_iterations = 400
num_iter_each = 4 #set this to 400 so that it is as same as the example on github, but I am trying to see whether we can plot each iteration
num_traversal = 40

game = pyspiel.load_game('kuhn_poker')
deep_cfr_solver = deep_cfr_tf2.DeepCFRSolver(
    game,
    policy_network_layers=(16,),
    advantage_network_layers=(16,),
    num_iterations=num_iter_each,
    num_traversals=num_traversal,
    learning_rate=1e-3,
    batch_size_advantage=128,
    batch_size_strategy=1024,
    memory_capacity=1e7,
    policy_network_train_steps=400,
    advantage_network_train_steps=20,
    reinitialize_advantage_networks=False,
    infer_device="cpu",
    train_device="cpu")

for i in range(num_iterations // num_iter_each):
  print(f"iteration: {i}")
  _, advantage_losses, policy_loss = deep_cfr_solver.solve()
 # for player, losses in advantage_losses.items():
 #   print("Advantage for player {}: {}".format(player,losses[:2] + ["..."] + losses[-2:]))
 #   print("Advantage Buffer Size for player {}: {}".format(player, len(deep_cfr_solver.advantage_buffers[player])))

#  print("Strategy Buffer Size: {}".format(len(deep_cfr_solver.strategy_buffer)))
  print("Final policy loss: {}".format(policy_loss))

  average_policy = policy.tabular_policy_from_callable(game, deep_cfr_solver.action_probabilities)

  conv = exploitability.nash_conv(game, average_policy)
  exp = exploitability.exploitability(game, average_policy)
  print(f"NashConv: {conv}")
  print(f"Exploitability: {exp}")

#  average_policy_values = expected_game_score.policy_value(
#      game.new_initial_state(), [average_policy] * 2)
#  print("Computed player 0 value: {}".format(average_policy_values[0]))
#  print("Computed player 1 value: {}".format(average_policy_values[1]))
  print()

iteration: 0
Final policy loss: 0.03564165532588959
NashConv: 0.8016757269971324
Exploitability: 0.4008378634985662

iteration: 1
Final policy loss: 0.044525474309921265
NashConv: 0.558005497071003
Exploitability: 0.2790027485355015

iteration: 2
Final policy loss: 0.04287794604897499
NashConv: 0.47397062391173056
Exploitability: 0.23698531195586528

iteration: 3
Final policy loss: 0.047114867717027664
NashConv: 0.3981467962463053
Exploitability: 0.19907339812315264

iteration: 4
Final policy loss: 0.05440378189086914
NashConv: 0.3338458445140342
Exploitability: 0.1669229222570171

iteration: 5
Final policy loss: 0.06457510590553284
NashConv: 0.25687932256799256
Exploitability: 0.12843966128399628

iteration: 6
Final policy loss: 0.08764326572418213
NashConv: 0.2727949743483316
Exploitability: 0.1363974871741658

iteration: 7
Final policy loss: 0.08211521804332733
NashConv: 0.22686152826079997
Exploitability: 0.11343076413039999

iteration: 8
Final policy loss: 0.07931879162788391
Nash