In [66]:
import tensorflow.compat.v1 as tf
from open_spiel.python.algorithms import policy_gradient
from open_spiel.python.algorithms import dqn
import policy_gradient as pt_policy_gradient
from open_spiel.python.environments import catch

def _eval_agent(env, agent, num_episodes):
  """Evaluates `agent` for `num_episodes`."""
  rewards = 0.0
  for _ in range(num_episodes):
    time_step = env.reset()
    episode_reward = 0
    while not time_step.last():
      agent_output = agent.step(time_step, is_evaluation=True)
      time_step = env.step([agent_output.action])
      episode_reward += time_step.rewards[0]
    rewards += episode_reward
  return rewards / num_episodes


def test_tf(num_episodes, eval_interval, algorithm):
  env = catch.Environment()
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  train_episodes = num_episodes

  result = []
  with tf.Session() as sess:
    if algorithm in {"rpg", "qpg", "rm", "a2c"}:
      agent = policy_gradient.PolicyGradient(
          sess,
          player_id=0,
          info_state_size=info_state_size,
          num_actions=num_actions,
          loss_str=algorithm,
          hidden_layers_sizes=[128, 128],
          batch_size=128,
          entropy_cost=0.01,
          critic_learning_rate=0.1,
          pi_learning_rate=0.1,
          num_critic_before_pi=3)
    elif algorithm == "dqn":
      agent = dqn.DQN(
          sess,
          player_id=0,
          state_representation_size=info_state_size,
          num_actions=num_actions,
          learning_rate=1e-3,
          replay_buffer_capacity=10000,
          hidden_layers_sizes=[32, 32],
          epsilon_decay_duration=2000,  # 10% total data
          update_target_network_every=250)
    else:
      raise ValueError("Algorithm not implemented!")

    sess.run(tf.global_variables_initializer())

    # Train agent
    for ep in range(train_episodes):
      time_step = env.reset()
      while not time_step.last():
        agent_output = agent.step(time_step)
        action_list = [agent_output.action]
        time_step = env.step(action_list)
      # Episode is over, step agent with final info state.
      agent.step(time_step)

      if ep and ep % eval_interval == 0:
        avg_return = _eval_agent(env, agent, 100)
        result.append(avg_return)
        #print(f"{ep}:{avg_return}")
  return result

def test_pt(num_episodes, eval_interval, algorithm):
  """Trains a DQN agent in the catch environment."""
  env = catch.Environment()
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  train_episodes = num_episodes

  result = []
  if algorithm in {"rpg", "qpg", "rm", "a2c"}:
    agent = pt_policy_gradient.PolicyGradient(
        player_id=0,
        info_state_size=info_state_size,
        num_actions=num_actions,
        loss_str=algorithm,
        hidden_layers_sizes=[128, 128],
        batch_size=128,
        entropy_cost=1e-3,
        critic_learning_rate=0.1,
        pi_learning_rate=0.1,
        num_critic_before_pi=3)
  else:
    raise ValueError("Algorithm not implemented!")
  
  # Train agent
  for ep in range(train_episodes):
    time_step = env.reset()
    while not time_step.last():
      agent_output = agent.step(time_step)
      action_list = [agent_output.action]
      time_step = env.step(action_list)
    # Episode is over, step agent with final info state.
    agent.step(time_step)
  
    if ep and ep % eval_interval == 0:
      avg_return = _eval_agent(env, agent, 100)
      result.append(avg_return)
      #print(f"{ep}:{avg_return}")
  return result

In [64]:
NUM_EPISODES = int(1e5)
EVAL_INTERVAL = int(1e3)
ALGORITHM = 'a2c'

In [65]:
for _ in range(3):
    res_tf = test_tf(NUM_EPISODES, EVAL_INTERVAL, ALGORITHM)
    print(res_tf)

[-0.56, -0.52, -0.52, -0.36, -0.44, -0.52, -0.46, -0.46, -0.52, -0.44, -0.48, -0.62, -0.4, -0.62, -0.46, -0.44, -0.56, -0.34, -0.48, -0.4, -0.4, -0.36, -0.44, -0.48, -0.48, -0.52, -0.52, -0.48, -0.34, -0.44, -0.44, -0.46, -0.54, -0.5, -0.3, -0.52, -0.36, -0.44, -0.38, -0.26, -0.46, -0.38, -0.42, -0.38, -0.24, -0.26, -0.32, -0.2, -0.42, -0.4, -0.4, -0.34, -0.28, -0.4, -0.2, -0.34, -0.22, -0.3, -0.14, -0.14, -0.3, -0.24, -0.34, -0.02, -0.14, -0.06, -0.16, -0.24, -0.1, -0.1, -0.04, 0.1, -0.1, 0.18, 0.12, -0.06, 0.16, 0.14, 0.16, 0.02, 0.0, 0.16, 0.28, 0.24, 0.36, 0.3, 0.18, 0.4, 0.28, 0.12, 0.28, 0.22, 0.26, 0.44, 0.44, 0.4, 0.58, 0.6, 0.42]
[-0.5, -0.68, -0.68, -0.58, -0.48, -0.4, -0.42, -0.5, -0.18, -0.3, -0.48, -0.58, -0.54, -0.52, -0.46, -0.56, -0.6, -0.52, -0.4, -0.5, -0.58, -0.6, -0.5, -0.58, -0.42, -0.58, -0.54, -0.54, -0.28, -0.56, -0.48, -0.26, -0.52, -0.54, -0.46, -0.52, -0.32, -0.48, -0.36, -0.32, -0.38, -0.28, -0.52, -0.46, -0.46, -0.38, -0.38, -0.4, -0.38, -0.2, -0.34, -0.44,

In [67]:
for _ in range(3):
    res_pt = test_pt(NUM_EPISODES, EVAL_INTERVAL, ALGORITHM)
    print(res_pt)

[-0.64, -0.48, -0.72, -0.44, -0.56, -0.62, -0.52, -0.44, -0.54, -0.48, -0.52, -0.56, -0.64, -0.44, -0.38, -0.56, -0.5, -0.42, -0.58, -0.38, -0.46, -0.54, -0.46, -0.52, -0.72, -0.4, -0.6, -0.38, -0.36, -0.5, -0.44, -0.44, -0.42, -0.38, -0.52, -0.36, -0.54, -0.36, -0.48, -0.38, -0.44, -0.4, -0.32, -0.4, -0.5, -0.44, -0.44, -0.32, -0.44, -0.42, -0.3, -0.22, -0.34, -0.5, -0.28, -0.46, -0.4, -0.04, -0.1, -0.26, -0.32, -0.28, -0.18, -0.36, -0.26, -0.1, -0.42, -0.22, -0.16, -0.12, -0.14, -0.28, -0.1, -0.08, -0.06, 0.0, 0.1, 0.02, 0.0, 0.04, 0.12, 0.06, 0.18, 0.12, 0.1, 0.34, 0.22, 0.2, 0.28, 0.22, 0.16, 0.04, 0.14, 0.28, 0.42, 0.42, 0.44, 0.4, 0.54]
[-0.62, -0.48, -0.36, -0.56, -0.6, -0.6, -0.56, -0.5, -0.56, -0.56, -0.4, -0.26, -0.46, -0.5, -0.44, -0.5, -0.52, -0.38, -0.54, -0.5, -0.44, -0.56, -0.46, -0.3, -0.54, -0.48, -0.4, -0.42, -0.52, -0.46, -0.48, -0.5, -0.58, -0.38, -0.54, -0.38, -0.42, -0.46, -0.38, -0.4, -0.32, -0.38, -0.34, -0.44, -0.34, -0.38, -0.36, -0.24, -0.18, -0.18, -0.2, -0.

In [53]:
NUM_EPISODES = int(1e5)
EVAL_INTERVAL = int(1e3)
ALGORITHM = 'dqn'

In [None]:
for _ in range(1):
    res_tf = test_tf(NUM_EPISODES, EVAL_INTERVAL, ALGORITHM)
    print(res_tf)