In [None]:
import os

import gym
import numpy as np
import pandas as pd
import seaborn as sns

from ray import init, rllib, shutdown

In [None]:
from attack_simulator.agents import DEFENDERS
from attack_simulator.alphazero_env import AttackSimulationAlphaZeroEnv
from attack_simulator.env import AttackSimulationEnv
from attack_simulator.graph import AttackGraph, SIZES

In [None]:
class AgentPolicy(rllib.policy.Policy):
    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        agent_config = dict(
            input_dim=observation_space.shape[0],
            num_actions=action_space.n,
            random_seed=config["seed"],
            attack_graph=config["env_config"]["attack_graph"],
        )
        self._agent = DEFENDERS[config["agent_type"]](agent_config)

    def compute_actions(self, observations, *args, **kwargs):
        # FIXME: use a `numpy` array as a temporary workaround for
        #        https://github.com/ray-project/ray/issues/10100
        return np.array([self._agent.act(obs) for obs in observations]), [], {}

    def get_weights(self):
        return {}

    def set_weights(self, weights):
        pass

    def learn_on_batch(self, samples):
        return {}

    
def template_agent(agent_type):
    default_config = rllib.agents.trainer.with_common_config(dict(agent_type=agent_type))
    return rllib.agents.trainer_template.build_trainer(
        name=agent_type,
        default_policy=AgentPolicy,
        default_config=default_config,
    )

In [None]:
# 'contrib/AlphaZero' does NOT appear to work without its custom dense model
from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel

rllib.models.ModelCatalog.register_custom_model("alpha_zero_dense_model", DenseModel)

In [None]:
from ray.util.client import worker

worker.INITIAL_TIMEOUT_SEC = worker.MAX_TIMEOUT_SEC = 1


def ray_init():
    if os.path.isdir("/var/run/secrets/kubernetes.io") or os.path.exists(
        os.path.expanduser("~/ray_bootstrap_config.yaml")
    ):
        # inside k8s pod or a cluster managed by Ray's autoscaler
        context = init(address="auto")
    else:
        ray_client_server = "host.docker.internal" if os.path.exists("/.dockerenv") else "127.0.0.1"
        try:
            context = init(address=f"ray://{ray_client_server}:10001")
        except ConnectionError:
            # clean up after failed connection attempt
            shutdown()
            # listen on all interfaces inside a container for port-forwarding to work
            dashboard_host = "0.0.0.0" if os.path.exists("/.dockerenv") else "127.0.0.1"
            context = init(num_cpus=4, dashboard_host=dashboard_host)
    print("\x1b[33;1m", context, "\x1b[m")
    return context

In [None]:
from ray.tune.utils.trainable import TrainableUtil
from tqdm.auto import tqdm

rename = dict(
    agent_type="Agent",
    graph_size="Graph size",
    episode_length="Episode lengths",
    episode_reward="Returns",
)

agent_types = ["contrib/AlphaZero", "R2D2", "rule-based", "random"]
graphs = [AttackGraph(dict(graph_size=size)) for size in SIZES]
seeds = [0, 1, 2, 3, 6, 7, 11, 28, 42, 1337]
iterations = 10
rollouts = 10

train_config = dict(
    num_workers=4,
    rollout_fragment_length=32,
    train_batch_size=640,
    buffer_size=512,
    batch_mode="complete_episodes",
)
eval_config = dict(
    evaluation_interval=1,
    evaluation_num_workers=1,
    evaluation_config=dict(explore=False, replay_sequence_length=-1),
    evaluation_num_episodes=rollouts,
)


def generate(savename):
    ray_init()

    frames = []
    for graph in tqdm(graphs, "graphs"):
        for agent_type in tqdm(agent_types, f"└── {graph.graph_size}"):
            agent_name = agent_type.split("/")[-1]
            for seed in tqdm(seeds, f"\u00a0\u2001\u2001\u2001└── {agent_name}@{graph.graph_size}"):
                config = dict(
                    framework="torch",
                    env=AttackSimulationEnv,
                    env_config=dict(attack_graph=graph),
                    seed=seed,
                    log_level="ERROR",
                )
                if agent_type in DEFENDERS:
                    config.update(eval_config)
                    with keep_ipython_sane():
                        agent = template_agent(agent_type)(config=config)
                else:
                    if agent_type == "contrib/AlphaZero":
                        config["env_config"].update(env_class=AttackSimAlphaZeroEnv)
                        config.update(
                            env=AlphaZeroWrapper,
                            model=dict(custom_model="alpha_zero_dense_model"),
                        )
                    if agent_type == "R2D2":
                        config.update(model=dict(use_lstm=True))

                    name = f"{agent_name}_{graph.graph_size}_{seed}"
                    if not os.path.exists(name):
                        config.update(train_config)
                        with keep_ipython_sane():
                            agent = rllib.agents.registry.get_trainer_class(agent_type)(
                                config=config
                            )
                        for _ in tqdm(
                            range(iterations),
                            f"\u00a0\u2001\u2001\u2001\u2001\u2001\u2001└── {name}",
                        ):
                            agent.train()
                            # TODO: break based on results?
                            # results = agent.train()
                        agent.save(name)
                        del agent

                    config.update(eval_config, num_workers=0)
                    with keep_ipython_sane():
                        agent = rllib.agents.registry.get_trainer_class(agent_type)(config=config)
                        checkpoint_path = TrainableUtil.get_checkpoints_paths(name).chkpt_path[0]
                        agent.restore(checkpoint_path)

                stats = agent.evaluate()["evaluation"]["hist_stats"]
                frame = pd.DataFrame(
                    dict(agent_type=agent_type, graph_size=graph.num_attacks, **stats)
                )
                frames.append(frame)
    shutdown()
    results_df = pd.concat(frames, ignore_index=True).rename(columns=rename)
    results_df.to_csv(savename)
    return results_df

In [None]:
savename = "returns-agent-eval.csv"

df = generate(savename) if not os.path.exists(savename) else pd.read_csv(savename, index_col=0)

In [None]:
df

In [None]:
sns.set(style="darkgrid", rc={"figure.figsize": (12, 8)})

In [None]:
g = sns.lineplot(data=df, x="Graph size", y="Returns", hue="Agent", ci="sd")
g.legend(title="Agent", loc="lower left")
g.set_title("Returns vs Size (random attacker)")

In [None]:
g = sns.lineplot(data=df, x="Graph size", y="Episode lengths", hue="Agent", ci="sd")
g.legend(title="Agent", loc="upper left")
g.set_title("Episode lengths vs Size (random attacker)")

In [None]:
pd.set_option("display.max_columns", 32)
df.groupby("Agent").describe()