In [None]:
import os

import numpy as np
import pandas as pd
import seaborn as sns

from ray import init, rllib, tune, shutdown

In [None]:
from attack_simulator.agents import ATTACKERS
from attack_simulator.env import AttackSimulationEnv
from attack_simulator.graph import AttackGraph, SIZES

In [None]:
class NoAction(rllib.policy.Policy):
    def compute_actions(self, observations, *args, **kwargs):
        # FIXME: use a `numpy` array as a temporary workaround for
        #        https://github.com/ray-project/ray/issues/10100
        return np.zeros(len(observations)), [], {}

    def get_weights(self):
        return {}

    def set_weights(self, weights):
        pass


no_action = rllib.agents.trainer_template.build_trainer(name="NoAction", default_policy=NoAction)

In [None]:
# work around: https://github.com/ray-project/ray/issues/17618

from IPython.core.interactiveshell import InteractiveShell


class keep_ipython_sane:
    def __enter__(self):
        self.instance = InteractiveShell.instance()
        
    def __exit__(self, *args, **kwargs):
        # feel free to improve with error handling, etc.
        InteractiveShell._instance = self.instance  

In [None]:
if os.path.isdir("/var/run/secrets/kubernetes.io"):  # inside k8s pod
    args = dict(address="auto")
else:
    # listen on all interfaces inside a container for port-forwarding to work
    dashboard_host = "0.0.0.0" if os.path.exists("/.dockerenv") else "127.0.0.1"
    args = dict(num_cpus=4, dashboard_host=dashboard_host)

# ALTERNATIVE: use the "Ray client" to connect to a remote cluster
# Unfortunately, JupyterNotebookReporter displays an object reference
# <IPython.core.display.HTML object> instead of content...
# --- --- ---
#
# from ray.util.client import worker
#
# worker.INITIAL_TIMEOUT_SEC = worker.MAX_TIMEOUT_SEC = 1
#
# ray_client_server = 'host.docker.internal' if os.path.exists("/.dockerenv") else '127.0.0.1'
# try:
#     init(address=f'ray://{ray_client_server}:10001')
# except Connection Error:
#     pass  # TODO: try something else...

In [None]:
from tqdm.auto import tqdm

rename = dict(
    attacker="Attacker",
    graph_size="Graph size",
    episode_reward="Returns",
    episode_lengths="Episode lengths",
)

num_episodes = 5  # run 5 episodes on the same environment
attackers = list(ATTACKERS)
graphs = [AttackGraph(dict(graph_size=size)) for size in SIZES]
seeds = [0, 1, 2, 3, 6, 7, 11, 28, 42, 1337]

def generate(savename):
    init(**args)
    
    frames = []
    for graph in tqdm(graphs, 'graphs'):
        for attacker in tqdm(attackers, f'└── {graph.graph_size}'):
            for seed in tqdm(seeds, f'\u00a0\u2001\u2001\u2001└── {attacker}@{graph.graph_size}'):
                config = dict(
                    log_level='ERROR',
                    framework="torch",
                    env=AttackSimulationEnv,
                    env_config=dict(attack_graph=graph, attacker=attacker),
                    seed=seed,
                    evaluation_interval=1,
                    evaluation_num_workers=1,
                    evaluation_config=dict(explore=False),
                    evaluation_num_episodes=num_episodes,
                )
                with keep_ipython_sane():
                    agent = no_action(config=config)
                stats = agent.evaluate()['evaluation']['hist_stats']
                frame = pd.DataFrame(dict(attacker=attacker, graph_size=graph.num_attacks, **stats))
                frames.append(frame)
    shutdown()
    results_df = pd.concat(frames, ignore_index=True).rename(columns=rename)
    results_df.to_csv(savename)
    return results_df

In [None]:
savename = "length-rollout.csv"

df = generate(savename) if not os.path.exists(savename) else pd.read_csv(savename, index_col=0)

In [None]:
sns.set(style="darkgrid", rc={"figure.figsize": (12, 8)})

In [None]:
g = sns.lineplot(data=df, x="Graph size", y="Returns", hue="Attacker", ci="sd")
g.legend(title="Attacker", loc="upper right")
g.set_title("Defender: no-action")

In [None]:
g = sns.lineplot(data=df, x="Graph size", y="Episode lengths", hue="Attacker", ci="sd")
g.legend(title="Attacker", loc="upper left")
g.set_title("Defender: no-action")

In [None]:
pd.set_option("display.max_columns", 32)
df.groupby("Attacker").describe()