In [None]:
import json
import os
import logging
import plotly.express as px
import pandas as pd
import multiprocessing
from itertools import repeat

import ray
from ray.rllib.agents.trainer import with_common_config, with_base_config
from ray.rllib.models.catalog import MODEL_DEFAULTS
from ray.rllib.utils import try_import_tf

from mprl.scripts.poker_parallel_algos.utils.policy_config_keys import POKER_ARCH1_MODEL_CONFIG_KEY
from mprl.rl.envs.opnspl.measure_exploitability_eval_callback import measure_exploitability_nonlstm
from mprl.utility_services.cloud_storage import maybe_download_object, connect_storage_client, BUCKET_NAME
from mprl.rl.sac.sac_policy import SACDiscreteTFPolicy
from mprl.rl.common.stratego_preprocessor import STRATEGO_PREPROCESSOR, StrategoDictFlatteningPreprocessor
from mprl.rl.envs.opnspl.poker_multiagent_env import POKER_ENV, KUHN_POKER, LEDUC_POKER, PARTIALLY_OBSERVABLE, PokerMultiAgentEnv
from mprl.rl.common.sac_stratego_model import SAC_STRATEGO_MODEL
from mprl.scripts.poker_parallel_algos.utils.metanash import get_fp_metanash_for_payoff_table
from mprl.utility_services.payoff_table import PayoffTable
from mprl.utils import datetime_str

tf = try_import_tf()

OBSERVATION_MODE = PARTIALLY_OBSERVABLE

POLICY_CLASS = SACDiscreteTFPolicy
POLICY_CLASS_NAME = SACDiscreteTFPolicy.__name__
MODEL_CONFIG_KEY = POKER_ARCH1_MODEL_CONFIG_KEY

MANAGER_SEVER_HOST = "localhost"

logger = logging.getLogger(__name__)


def get_stats_for_single_payoff_table(payoff_table_key, experiment_name, poker_game_version, model_config_key):
    POKER_ENV_CONFIG = {
        'version': poker_game_version,
    }

    storage_client = connect_storage_client()

    # If you use ray for more than just this single example fn, you'll need to move ray.init to the top of your main()
    ray.init(address=os.getenv('RAY_HEAD_NODE'), ignore_reinit_error=True, local_mode=True)

    model_config_file_path, _ = maybe_download_object(storage_client=storage_client,
                                                      bucket_name=BUCKET_NAME,
                                                      object_name=model_config_key,
                                                      force_download=False)

    with open(model_config_file_path, 'r') as config_file:
        model_config = json.load(fp=config_file)

    example_env = PokerMultiAgentEnv(env_config=POKER_ENV_CONFIG)

    obs_space = example_env.observation_space
    act_space = example_env.action_space

    preprocessor = StrategoDictFlatteningPreprocessor(obs_space=obs_space)
    graph = tf.Graph()
    sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}), graph=graph)

    def fetch_logits(policy):
        return {
            "behaviour_logits": policy.model.last_output(),
        }

    _policy_cls = POLICY_CLASS.with_updates(
        extra_action_fetches_fn=fetch_logits
    )

    with graph.as_default():
        with sess.as_default():
            policy = _policy_cls(
                obs_space=preprocessor.observation_space,
                action_space=act_space,
                config=with_common_config({
                    'model': with_base_config(base_config=MODEL_DEFAULTS, extra_config=model_config),
                    'env': POKER_ENV,
                    'env_config': POKER_ENV_CONFIG,
                    'custom_preprocessor': STRATEGO_PREPROCESSOR}))

    def set_policy_weights(weights_key):
        weights_file_path, _ = maybe_download_object(storage_client=storage_client,
                                                 bucket_name=BUCKET_NAME,
                                                 object_name=weights_key,
                                                 force_download=False)
        policy.load_model_weights(weights_file_path)

    payoff_table_local_path, _ = maybe_download_object(storage_client=storage_client,
                                                           bucket_name=BUCKET_NAME,
                                                           object_name=payoff_table_key,
                                                           force_download=False)

    payoff_table = PayoffTable.from_dill_file(dill_file_path=payoff_table_local_path)
    stats_out = {
        'payoff_table_key': [],
        'experiment_name': [],
        'num_policies': [],
        'exploitability': [],
        'total_steps': [],
        'total_episodes': [],
    }

    exploitability_per_generation = []
    total_steps_per_generation = []
    total_episodes_per_generation = []
    num_policies_per_generation = []

    for i, n_policies in enumerate(range(1,payoff_table.size() + 1)):
        metanash_probs = get_fp_metanash_for_payoff_table(payoff_table=payoff_table,
                                                                 fp_iters=40000,
                                                                 accepted_opponent_policy_class_names=[POLICY_CLASS_NAME],
                                                                 accepted_opponent_model_config_keys=[POKER_ENV_CONFIG],
                                                                 add_payoff_matrix_noise_std_dev=0.000,
                                                                 mix_with_uniform_dist_coeff=None,
                                                                 only_first_n_policies=n_policies,
                                                                 p_or_lower_rounds_to_zero=0.0)

        policy_weights_keys = payoff_table.get_ordered_keys_in_payoff_matrix()

        policy_dict = {key: prob for key, prob in zip(policy_weights_keys, metanash_probs)}

        exploitability_this_gen = measure_exploitability_nonlstm(rllib_policy=policy,
                                  poker_game_version=poker_game_version,
                                  policy_mixture_dict=policy_dict,
                                  set_policy_weights_fn=set_policy_weights)

        print(f"{n_policies} policies, {exploitability_this_gen} exploitability")

        policy_added_this_gen = payoff_table.get_policy_for_index(i)
        latest_policy_tags = policy_added_this_gen.tags
        steps_prefix = "timesteps: "
        latest_policy_steps = int([tag for tag in latest_policy_tags if steps_prefix in tag][0][len(steps_prefix):])
        episodes_prefix = "episodes: "
        latest_policy_episodes = int([tag for tag in latest_policy_tags if episodes_prefix in tag][0][len(episodes_prefix):])

        if i > 0:
            total_steps_this_generation = latest_policy_steps + total_steps_per_generation[i-1]
            total_episodes_this_generation = latest_policy_episodes + total_episodes_per_generation[i-1]
        else:
            total_steps_this_generation = latest_policy_steps
            total_episodes_this_generation = latest_policy_episodes

        exploitability_per_generation.append(exploitability_this_gen)
        total_steps_per_generation.append(total_steps_this_generation)
        total_episodes_per_generation.append(total_episodes_this_generation)
        num_policies_per_generation.append(n_policies)

        num_new_entries = len(exploitability_per_generation)
        stats_out['payoff_table_key'] = stats_out['payoff_table_key'] + [payoff_table_key] * num_new_entries
        stats_out['experiment_name'] = stats_out['experiment_name'] + [experiment_name] * num_new_entries
        stats_out['num_policies'] = stats_out['num_policies'] + num_policies_per_generation
        stats_out['exploitability'] = stats_out['exploitability'] + exploitability_per_generation
        stats_out['total_steps'] = stats_out['total_steps'] + total_steps_per_generation
        stats_out['total_episodes'] = stats_out['total_episodes'] + total_episodes_per_generation
    return stats_out

def get_exploitability_stats_over_time_for_payoff_table_all_same_poker_version(
        payoff_table_keys, exp_names, poker_game_version, model_config_key):

    with multiprocessing.Pool(processes=16) as pool:
        results = pool.starmap(func=get_stats_for_single_payoff_table,
                               iterable=zip(payoff_table_keys, exp_names, repeat(poker_game_version), repeat(model_config_key)))

    combined_stats = {}
    for result in results:
        for key, val in result.items():
            if key not in combined_stats:
                combined_stats[key] = val
            else:
                combined_stats[key] = [*combined_stats[key], *val]

    return pd.DataFrame(combined_stats)

### Graph the results of the original paper Leduc poker experiment

In [None]:
google_cloud_leduc_experiment_payoff_tables_and_names = [
    ("leduc_poker_pipe_3_workers_poker_ps/leduc_pipeline_psro/pipe-1-3-3-leduc-poker_pid_430_09_03_49AM_Jun-01-2020/payoff_tables/latest.dill","leduc_pipe_1"),
    ("leduc_poker_pipe_3_workers_poker_ps/leduc_pipeline_psro/pipe-2-3-3-leduc-poker_pid_430_09_04_00AM_Jun-01-2020/payoff_tables/latest.dill","leduc_pipe_2"),
    ("leduc_poker_pipe_3_workers_poker_ps/leduc_pipeline_psro/pipe-3-3-3-leduc-poker_pid_431_09_04_01AM_Jun-01-2020/payoff_tables/latest.dill","leduc_pipe_3"),
    ("leduc_poker_rect_3_workers_poker_ps/leduc_psro_rectified/rect-1-3-3-leduc-poker_pid_429_09_04_14AM_Jun-01-2020/payoff_tables/latest.dill","leduc_rect_1"),
    ("leduc_poker_rect_3_workers_poker_ps/leduc_psro_rectified/rect-2-3-3-leduc-poker_pid_430_09_04_24AM_Jun-01-2020/payoff_tables/latest.dill","leduc_rect_2"),
    ("leduc_poker_rect_3_workers_poker_ps/leduc_psro_rectified/rect-3-3-3-leduc-poker_pid_430_09_04_45AM_Jun-01-2020/payoff_tables/latest.dill","leduc_rect_3"),
    ("leduc_poker_naive_3_workers_poker_ps/leduc_psro_naive/naive-1-3-3-leduc-poker_pid_430_09_03_04AM_Jun-01-2020/payoff_tables/latest.dill","leduc_naive_1"),
    ("leduc_poker_naive_3_workers_poker_ps/leduc_psro_naive/naive-2-3-3-leduc-poker_pid_430_09_03_16AM_Jun-01-2020/payoff_tables/latest.dill","leduc_naive_2"),
    ("leduc_poker_naive_3_workers_poker_ps/leduc_psro_naive/naive-3-3-3-leduc-poker_pid_430_09_03_34AM_Jun-01-2020/payoff_tables/latest.dill","leduc_naive_3")
]
gc_leduc_table_keys, gc_leduc_exp_names = zip(*google_cloud_leduc_experiment_payoff_tables_and_names)
gc_leduc_perf_df = get_exploitability_stats_over_time_for_payoff_table_all_same_poker_version(
    payoff_table_keys=gc_leduc_table_keys,
    exp_names=gc_leduc_exp_names,
    poker_game_version="leduc_poker",
    model_config_key=POKER_ARCH1_MODEL_CONFIG_KEY
)

In [None]:
fig = px.line(gc_leduc_perf_df.drop_duplicates(), x="total_episodes", y="exploitability", title=f"Exploitability over Episodes Leduc 3 workers",
        render_mode="svg", color="experiment_name")
fig.show()

In [None]:
gc_leduc_perf_df.to_csv("gc_leduc_jun_1.csv")

### Example for graphing results of a Kuhn Poker Experiment

In [None]:
kuhn_experiment_payoff_tables_and_names = [
    ("kuhn_poker_pipe_3_workers_poker_ps/kuhn_pipeline_psro/goku_pid_143271_02_37_42AM_Oct-12-2020/payoff_tables/latest.dill","kuhn_pipe_1"),
]
kuhn_table_keys, kuhn_exp_names = zip(*kuhn_experiment_payoff_tables_and_names)
kuhn_perf_df = get_exploitability_stats_over_time_for_payoff_table_all_same_poker_version(
    payoff_table_keys=kuhn_table_keys,
    exp_names=kuhn_exp_names,
    poker_game_version="kuhn_poker",
    model_config_key=POKER_ARCH1_MODEL_CONFIG_KEY
)

In [None]:
fig = px.line(kuhn_perf_df.drop_duplicates(), x="total_episodes", y="exploitability", title=f"Exploitability over Episodes Kuhn 3 workers",
        render_mode="svg", color="experiment_name")
fig.show()

In [None]:
kuhn_perf_df.to_csv(f"kuhn_results_{datetime_str()}.csv")
