In [None]:
import os

import seaborn as sns

from ray import init, rllib, tune, shutdown

In [None]:
from attack_simulator.agents import DEFENDERS
from attack_simulator.env import AttackSimulationEnv
from attack_simulator.graph import AttackGraph, SIZES

In [None]:
import gym
import numpy as np

class BooleanVectorPreprocessor(rllib.models.preprocessors.Preprocessor):
    def _init_shape(self, observation_space, options=None):
        return (len(observation_space.spaces),)

    def transform(self, observation):
        return np.array(observation)
    
    @property
    def observation_space(self):
        space = gym.spaces.Box(0, 1, self.shape, dtype='int8')
        space.original_space = self._obs_space
        return space

rllib.models.ModelCatalog.register_custom_preprocessor('boolean_vector', BooleanVectorPreprocessor)


class AgentPolicy(rllib.policy.Policy):
    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        agent_config = dict(
            input_dim=observation_space.shape[0], # same as len(observation_space.original_space.spaces)
            num_actions=action_space.n,
            random_seed=config['seed'],
            attack_graph=config['env_config']['attack_graph'],
        )
        self._agent = DEFENDERS[config['agent_type']](agent_config)

    def compute_actions(self, observations, *args, **kwargs):
        return [self._agent.act(obs) for obs in observations], [], {}

    def get_weights(self):
        return {}

    def set_weights(self, weights):
        pass
    
    
def instantiate_agent(agent_type, config):
    default_config = rllib.agents.trainer.with_common_config(dict(config, agent_type=agent_type, model=dict(custom_preprocessor='boolean_vector'), env_class=config['env']))
    return rllib.agents.trainer_template.build_trainer(
        name=agent_type,
        default_policy=AgentPolicy,
        default_config=default_config,
    )(config=config)

In [None]:
from copy import deepcopy

class AlphaZeroWrapper(gym.Env):
    def __init__(self, config):
        self.env = config['env_class'](config)
        self.action_space = self.env.action_space
        assert isinstance(self.action_space, gym.spaces.Discrete), 'AlphaZero requires a Discrete action space'
        shape = (self.action_space.n,)
        self.observation_space = gym.spaces.Dict(dict(obs=self.env.observation_space, action_mask=gym.spaces.Box(0, 1, shape)))
        self.reward = 0
        self.mask = np.full(shape, 1, dtype='int8')
        
    def reset(self):
        self.reward = 0
        observation = self.env.reset()
        return dict(obs=observation, action_mask=self.mask)

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        self.reward += reward
        reward = self.reward if done else 0
        return dict(obs=observation, action_mask=self.mask), reward, done, info

    def set_state(self, state):
        env, self.reward = state
        self.env = deepcopy(env)
        return dict(obs=self.env.observation, action_mask=self.mask)

    def get_state(self):
        return deepcopy(self.env), self.reward
    
    def close(self):
        self.env.close()
        
    def render(self, mode=None):
        self.env.render(mode)
        
    def seed(self, seed=None):
        self.env.seed(seed)

In [None]:
import pandas as pd

class RolloutAggregator:
    def __init__(self, **kwargs):
        self._kwargs = kwargs
        self._episodes = []

    def begin_rollout(self):
        self._rewards = []

    def append_step(self, obs, action, next_obs, reward, done, info):
        self._rewards.append(reward)

    def end_rollout(self):
        self._episodes.append(
            dict(self._kwargs, episode_length=len(self._rewards), episode_reward=sum(self._rewards))
        )
    
    def to_df(self):
        return pd.DataFrame(self._episodes)

In [None]:
if os.path.isdir('/var/run/secrets/kubernetes.io'):  # inside k8s pod
    args = dict(address='auto')
else:  # listen on all interfaces inside a container for port-forwarding to work
    dashboard_host = "0.0.0.0" if os.path.exists("/.dockerenv") else "127.0.0.1"
    args = dict(num_cpus=4, dashboard_host=dashboard_host)

init(**args)

In [None]:
from ray.rllib.rollout import rollout
from tqdm import tqdm

agent_types = ['R2D2', 'contrib/AlphaZero', 'rule-based', 'random']
seeds = [0, 1, 2, 3, 6, 7, 11, 28, 42, 1337]
training_iterations = 10
rollouts = 10

def generate(savename):
    frames = []

    for graph_size in SIZES:
        graph = AttackGraph(dict(graph_size=graph_size))

        for seed in seeds:
            config = dict(
                framework='torch',
                model=dict(use_lstm=True),
                env=AttackSimulationEnv,
                env_config=dict(attack_graph=graph),
                seed=seed,
                create_env_on_driver=True,  # apparently, assumed by `rollout`
                num_workers=8, # for auto-scaling  # use 0 to run on driver for debugging
                batch_mode='complete_episodes',
            )
            for agent_type in agent_types:
                if agent_type in DEFENDERS:
                    agent = instantiate_agent(agent_type, config)
                else:
                    if agent_type == 'contrib/AlphaZero':
                        config['env_config'].update(env_class=config['env'])
                        config.update(env=AlphaZeroWrapper)
                    agent = rllib.agents.registry.get_trainer_class(agent_type)(config=config)
                    name = f'{agent_type.split("/")[-1]}_{graph_size}_{seed}'
                    if os.path.exists(name):
                        checkpoint_path = tune.utils.trainable.TrainableUtil.get_checkpoints_paths(name).chkpt_path[0]
                        agent.restore(checkpoint_path)
                    else:
                        pbar = tqdm(range(training_iterations), f'{graph_size:13.13s} [{seed: 6d}] {agent_type:11.11s}')
                        for _ in pbar:
                            results = agent.train()
                            # TODO: break based on results?
                        agent.save(name)

                aggregator = RolloutAggregator(agent_type=agent_type, graph_size=graph.num_attacks)
                rollout(agent, 'AttackSimulator', num_steps=0, num_episodes=rollouts, saver=aggregator)
                frames.append(aggregator.to_df())

    df = pd.concat(frames, ignore_index=True).rename(columns=dict(agent_type='Agent', graph_size='Graph size', episode_length='Episode lengths', episode_reward='Returns'))
    df.to_csv(savename)
    return df

In [None]:
%%capture noise --no-stderr

savename = 'data.csv'

df = generate(savename) if not os.path.exists(savename) else pd.read_csv(savename, index_col=0)
df

In [None]:
shutdown()

In [None]:
sns.set(style='darkgrid', rc={'figure.figsize': (12, 8)})

In [None]:
g = sns.lineplot(data=df, x='Graph size', y='Returns', hue='Agent', ci='sd')
g.legend(title='Agent', loc='upper left')
g.set_title('Returns vs Size (random attacker)')

In [None]:
g = sns.lineplot(data=df, x='Graph size', y='Episode lengths', hue='Agent', ci='sd')
g.legend(title='Agent', loc='upper left')
g.set_title('Episode lengths vs Size (random attacker)')

In [None]:
pd.set_option('display.max_columns', 32)
df.groupby('Agent').describe()