In [1]:
import os

from absl import logging
logging.set_verbosity(logging.INFO)

import jax

from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.mfg import utils
from open_spiel.python.mfg.algorithms import distribution
from open_spiel.python.mfg.algorithms import trajectory_munchausen_deep_mirror_descent
from open_spiel.python.mfg.algorithms import nash_conv
from open_spiel.python.mfg.algorithms import policy_value
from open_spiel.python.mfg.games import factory
from open_spiel.python.utils import metrics



# Constants

In [2]:
# The name of the game to play.
GAME_NAME = "mfg_crowd_modelling_2d"
# Name of the game setting.
ENV_SETTING = "crowd_modelling_2d_four_rooms"
# Number of transitions to sample at each learning step.
BATCH_SIZE = 128
# Number of steps between learning updates.
LEARN_EVERY = 64
# Number of training episodes for each iteration.
NUM_EPISODES_PER_ITERATION = 1000
# Number of iterations.
NUM_ITERATIONS = 100
# Number of game steps over which epsilon is decayed.
EPSILON_DECAY_DURATION = 100000
# Power for the epsilon decay.
EPSILON_POWER = 1.0
# Starting exploration parameter.
EPSILON_START = 0.1
# Final exploration parameter.
EPSILON_END = 0.1
# Discount factor for future rewards.
DISCOUNT_FACTOR = 1.0
# Reset the replay buffer when the softmax policy is updated.
RESET_REPLAY_BUFFER_ON_UPDATE = False
# Training seed.
SEED = 42
# Episode frequency at which the agents are evaluated.
EVAL_EVERY = 200
# Number of hidden units in the avg-net and Q-net.
HIDDEN_LAYERS_SIZES = [128, 128]
#Number of steps beween DQN target network updates.
UPDATE_TARGET_NETWORK_EVERY = 200
# Size of the trajectory replay buffer.
TRAJECTORY_REPLAY_BUFFER_CAPACITY = 1000
# Number of trajectories in buffer before learning begins.
MIN_TRAJECTORY_REPLAY_BUFFER_SIZE_TO_LEARN = 25
# Number of transitions in trajectory.
TRAJECTORY_SAMPLE_LENGTH = 40
# Number of transitions to overlap between trajectories.
TRAJECTORY_SAMPLE_OVERLAP_LENGTH = 20
# Optimizer
OPTIMIZER = "adam"
# Learning rate for inner RL agent.
LEARNING_RATE = 0.01
# Loss function.
LOSS = "mse"
# Parameter for Huber loss.
HUBER_LOSS_PARAMETER = 1.0
# Value to clip the gradient to.
GRADIENT_CLIPPING = None
# Temperature parameter in Munchausen target.
TAU = 10
# Alpha parameter in Munchausen target.
ALPHA = 0.99
# Use Munchausen penalty terms.
WITH_MUNCHAUSEN = True
# Save/load neural network weights.
USE_CHECKPOINTS = False
# Directory to save/load the agent.
CHECKPOINT_DIR = "/tmp/dqn_test"
# Logging directory to use for TF summary files.
LOGDIR = None
# Enables logging of the distribution.
LOG_DISTRIBUTION = False

# Implementation

In [3]:
game = factory.create_game_with_setting(GAME_NAME, ENV_SETTING)
game

INFO:absl:Creating mfg_crowd_modelling_2d game with parameters: {'forbidden_states': '[0|0;1|0;2|0;3|0;4|0;5|0;6|0;7|0;8|0;9|0;10|0;11|0;12|0;0|1;6|1;12|1;0|2;6|2;12|2;0|3;12|3;0|4;6|4;12|4;0|5;6|5;12|5;0|6;1|6;2|6;4|6;5|6;6|6;7|6;8|6;10|6;11|6;12|6;0|7;6|7;12|7;0|8;6|8;12|8;0|9;12|9;0|10;6|10;12|10;0|11;6|11;12|11;0|12;1|12;2|12;3|12;4|12;5|12;6|12;7|12;8|12;9|12;10|12;11|12;12|12]', 'horizon': 40, 'initial_distribution': '[1|1]', 'initial_distribution_value': '[1.0]', 'size': 13, 'only_distribution_reward': True}


mfg_crowd_modelling_2d(forbidden_states=[0|0;1|0;2|0;3|0;4|0;5|0;6|0;7|0;8|0;9|0;10|0;11|0;12|0;0|1;6|1;12|1;0|2;6|2;12|2;0|3;12|3;0|4;6|4;12|4;0|5;6|5;12|5;0|6;1|6;2|6;4|6;5|6;6|6;7|6;8|6;10|6;11|6;12|6;0|7;6|7;12|7;0|8;6|8;12|8;0|9;12|9;0|10;6|10;12|10;0|11;6|11;12|11;0|12;1|12;2|12;3|12;4|12;5|12;6|12;7|12;8|12;9|12;10|12;11|12;12|12],horizon=40,initial_distribution=[1|1],initial_distribution_value=[1.0],only_distribution_reward=True,size=13)

In [4]:
num_players = game.num_players()
num_players

1

In [5]:
uniform_policy = policy.UniformRandomPolicy(game)
uniform_dist = distribution.DistributionPolicy(game, uniform_policy)
uniform_policy, uniform_dist

(<open_spiel.python.policy.UniformRandomPolicy at 0x7f6a65661050>,
 <open_spiel.python.mfg.algorithms.distribution.DistributionPolicy at 0x7f6a65663b10>)

In [6]:
envs = [
    rl_environment.Environment(
        game,
        mfg_distribution=uniform_dist,
        mfg_population=p,
        observation_type=rl_environment.ObservationType.OBSERVATION
    )
    for p in range(num_players)
]
envs

INFO:absl:Using game instance: mfg_crowd_modelling_2d


[<open_spiel.python.rl_environment.Environment at 0x7f6a65744d50>]

In [7]:
env = envs[0]
env

<open_spiel.python.rl_environment.Environment at 0x7f6a65744d50>

In [8]:
info_state_size = env.observation_spec()["info_state"][0]
info_state_size

67

In [9]:
num_actions = env.action_spec()["num_actions"]
num_actions

5

In [10]:
kwargs = {
    "alpha": ALPHA,
    "batch_size": BATCH_SIZE,
    "discount_factor": DISCOUNT_FACTOR,
    "epsilon_decay_duration": EPSILON_DECAY_DURATION,
    "epsilon_end": EPSILON_END,
    "epsilon_power": EPSILON_POWER,
    "epsilon_start": EPSILON_START,
    "gradient_clipping": GRADIENT_CLIPPING,
    "hidden_layers_sizes": HIDDEN_LAYERS_SIZES,
    "huber_loss_parameter": HUBER_LOSS_PARAMETER,
    "learn_every": LEARN_EVERY,
    "learning_rate": LEARNING_RATE,
    "loss": LOSS,
    "min_trajectory_replay_buffer_size_to_learn": MIN_TRAJECTORY_REPLAY_BUFFER_SIZE_TO_LEARN,
    "optimizer": OPTIMIZER,
    "trajectory_replay_buffer_capacity": TRAJECTORY_REPLAY_BUFFER_CAPACITY,
    "trajectory_sample_length": TRAJECTORY_SAMPLE_LENGTH,
    "trajectory_sample_overlap_length": TRAJECTORY_SAMPLE_OVERLAP_LENGTH,
    "reset_replay_buffer_on_update": RESET_REPLAY_BUFFER_ON_UPDATE,
    "seed": SEED,
    "tau": TAU,
    "update_target_network_every": UPDATE_TARGET_NETWORK_EVERY,
    "with_munchausen": WITH_MUNCHAUSEN,
}

agents = [
    trajectory_munchausen_deep_mirror_descent.TrajectoryMunchausenDQN(
        p,
        info_state_size,
        num_actions,
        **kwargs,
    )
    for p in range(num_players)
]
agents



[<open_spiel.python.mfg.algorithms.trajectory_munchausen_deep_mirror_descent.TrajectoryMunchausenDQN at 0x7f6abfb00f90>]

In [11]:
just_logging = LOGDIR is None or jax.host_id() > 0
writer = metrics.create_default_writer(
    logdir=LOGDIR, just_logging=just_logging,
)
writer.write_hparams(kwargs)
writer

INFO:absl:[Hyperparameters] {'alpha': 0.99, 'batch_size': 128, 'discount_factor': 1.0, 'epsilon_decay_duration': 100000, 'epsilon_end': 0.1, 'epsilon_power': 1.0, 'epsilon_start': 0.1, 'gradient_clipping': None, 'hidden_layers_sizes': [128, 128], 'huber_loss_parameter': 1.0, 'learn_every': 64, 'learning_rate': 0.01, 'loss': 'mse', 'min_trajectory_replay_buffer_size_to_learn': 25, 'optimizer': 'adam', 'trajectory_replay_buffer_capacity': 1000, 'trajectory_sample_length': 40, 'trajectory_sample_overlap_length': 20, 'reset_replay_buffer_on_update': False, 'seed': 42, 'tau': 10, 'update_target_network_every': 200, 'with_munchausen': True}


<clu.metric_writers.async_writer.AsyncMultiWriter at 0x7f6a65fd0450>

In [12]:
def logging_fn(it, episode, vals):
    writer.write_scalars(it * NUM_EPISODES_PER_ITERATION + episode, vals)

In [13]:
md = trajectory_munchausen_deep_mirror_descent.TrajectoryDeepOnlineMirrorDescent(
    game,
    envs,
    agents,
    eval_every=EVAL_EVERY,
    num_episodes_per_iteration=NUM_EPISODES_PER_ITERATION,
    logging_fn=logging_fn,
)
md

<open_spiel.python.mfg.algorithms.trajectory_munchausen_deep_mirror_descent.TrajectoryDeepOnlineMirrorDescent at 0x7f6a64fa2550>

In [14]:
def log_metrics(it):
   initial_states = game.new_initial_states()
   pi_value = policy_value.PolicyValue(game, md.distribution, md.policy)
   m = {
      f"best_response/{state}": pi_value.eval_state(state)
      for state in initial_states
   }
   nash_conv_md = nash_conv.NashConv(game, md.policy).nash_conv()
   m["nash_conv_md"] = nash_conv_md
   if LOG_DISTRIBUTION and LOGDIR:
      filename = os.path.join(LOGDIR, f"distribution_{it}.pkl")
      utils.save_parametric_distribution(md.distribution, filename)
   logging_fn(it, 0, m)

In [15]:
log_metrics(0)
for it in range(1, NUM_ITERATIONS + 1):
   md.iteration()
   log_metrics(it)

INFO:absl:[0] best_response/initial=144.195, nash_conv_md=167.801
INFO:absl:[200] agent0/loss=457.754638671875
INFO:absl:[400] agent0/loss=162.80447387695312
INFO:absl:[600] agent0/loss=95.06572723388672
INFO:absl:[800] agent0/loss=28.330066680908203
INFO:absl:[1000] agent0/loss=38.8974494934082
INFO:absl:[1000] best_response/initial=155.288, nash_conv_md=99.1171


KeyboardInterrupt: 

In [None]:
writer.flush()