In [1]:
%load_ext autoreload
%autoreload 2

# Imports
## Common libs
import pandas as pd
import numpy as np
import pickle
from itertools import cycle, chain
from copy import deepcopy
from datetime import datetime
from collections import OrderedDict

## Plotting
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter

## Graphs
from networkx.drawing import nx_agraph
import networkx as nx
from graphviz import Source
import pygraphviz

## GPs
from GPy.kern import RBF
from GPy.models.gp_regression import GPRegression

## DCBO Utils
import sys
sys.path.append("../DCBO")
from src.utils.utilities import powerset
from src.utils.sem_utils.sem_estimate import build_sem_hat
from src.utils.sequential_intervention_functions import get_interventional_grids
from src.experimental.experiments import optimal_sequence_of_interventions, run_methods_replicates

## Optimisation Algos
from src.methods.cbo import CBO
from src.methods.dcbo import DCBO
from src.methods.bo import BO

## YT
import random
from gym import spaces
from numpy import repeat
from stable_baselines3.common.env_checker import check_env

sys.path.append("../YAWNING-TITAN_fork")
from yawning_titan.integrations.dcbo.dcbo_agent import DCBOAgent
from yawning_titan.envs.generic.core.blue_interface import BlueInterface
from yawning_titan.envs.generic.core.network_interface import NetworkInterface
from yawning_titan.envs.generic.core.red_interface import RedInterface
from yawning_titan.envs.generic.generic_env import GenericNetworkEnv
from yawning_titan.envs.generic.helpers import network_creator

In [3]:
def create_env(use_same_net=False):
    """
    Helper function to create an environment.

    Args:
        use_same_net: If true uses a saved network, otherwise creates a new network

    Returns: The env

    """
    settings_path = "dcbo_config.yaml"

    if use_same_net:
        matrix, node_positions = network_creator.load_network("dcbo_base_net.txt")
    else:
        matrix, node_positions = network_creator.create_mesh(size=10)

    network_interface = NetworkInterface(
        matrix, node_positions, settings_path=settings_path
    )

    red = RedInterface(network_interface)
    blue = BlueInterface(network_interface)

    number_of_actions = blue.get_number_of_actions()

    env = GenericNetworkEnv(
        red,
        blue,
        network_interface,
        number_of_actions,
        # print_metrics=True,
        show_metrics_every=10,
        collect_additional_per_ts_data=True,
    )

    check_env(env, warn=True)

    env.reset()

    return env

# Creates all 50 envs
all_envs = [create_env() for i in range(10)]
agent = DCBOAgent(all_envs[0].action_space, [0.5, 0.5])

TIMESTEPS = 50
COSTS = {"restore_node": 1, "isolate": 1, "do_nothing": 0, "compromise": 10}

centralities = nx.degree_centrality(all_envs[0].network_interface.current_graph)
slice_data = {k: [] for k in "PICSHAT"}

for counter, current_env in enumerate(all_envs):
    # Init data
    slice_p, slice_i, slice_s, slice_h, slice_c, slice_a, slice_t = (list() for i in range(7))
    
    # Init probabilities
    current_env.reset()
    agent.reset()

    w1, w2 = (
        random.gauss(0.5, 0.167),
        random.gauss(0.5, 0.167),
    )
    agent.update_probabilities([w1, w2])

    for i in range(TIMESTEPS):
        done = False
        current_step = 0

        action = agent.predict("", "", "", current_env)
        env_observation, reward, done, notes = current_env.step(action)

        c_cost = (COSTS["compromise"] * sum(notes["end_state"].values())) ** 1.5
        action_cost = COSTS.get(notes["blue_action"], 0)
        surface = sum(
            map(
                lambda x: not x[0] and not x[1],
                zip(
                    notes["end_state"].values(),
                    current_env.network_interface.get_all_isolation().values(),
                ),
            )
        )
        # Append P, I and S
        slice_p.append(agent.probabilities[0])
        slice_i.append(agent.probabilities[1])
        slice_s.append(surface) 
        
        # Calculate and append H
        h = 0
        node_states = current_env.network_interface.get_all_node_compromised_states()
        node_iso = current_env.network_interface.get_all_isolation()
        comp_nodes = [k for k in node_states if node_states[k] == 1]
        isolated_nodes = [k for k in node_iso if node_iso[k]]
        for node in comp_nodes:
            conn_nodes = current_env.network_interface.get_current_connected_nodes(node)
            for c in conn_nodes:
                if c in isolated_nodes:
                    continue
                h += current_env.network_interface.get_single_node_vulnerability(c)
        slice_h.append(h)
        # Append C, A and T
        slice_c.append(c_cost)
        slice_a.append(action_cost)
        slice_t.append(action_cost + c_cost)

    slice_data["P"].append(np.asarray(slice_p))
    slice_data["I"].append(np.asarray(slice_i))
    slice_data["S"].append(np.asarray(slice_s))
    slice_data["H"].append(np.asarray(slice_h))
    slice_data["C"].append(np.asarray(slice_c))
    slice_data["A"].append(np.asarray(slice_a))
    slice_data["T"].append(np.asarray(slice_t))

for k in slice_data.keys():
    slice_data[k] = np.asarray(slice_data[k])

--Game over--
Total number of Games Played:  10
Stats over the last 10 games:
Average episode length:  14 

Blue Won    Red Won
----------  ---------
0           10
0.0%        100.0%


Action          Avg Times Used  Percentage of Action Usage
------------  ----------------  ----------------------------
isolate                     10  71.43%
connect                      2  14.29%
restore_node                 2  14.29%



--Game over--
Total number of Games Played:  20
Stats over the last 10 games:
Average episode length:  14 

Blue Won    Red Won
----------  ---------
0           10
0.0%        100.0%


Action          Avg Times Used  Percentage of Action Usage
------------  ----------------  ----------------------------
isolate                     10  71.43%
connect                      2  14.29%
restore_node                 2  14.29%



--Game over--
Total number of Games Played:  30
Stats over the last 10 games:
Average episode length:  14 

Blue Won    Red Won
----------  --------

In [4]:
slice_data

{'P': array([[0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ,
         0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 , 0.6133631 ],
        [0.64221776, 0.64221776, 0.64221776, 0.64221776, 0.64221776,
         0.64221776, 0.64221776, 0.64221776, 0.64221776, 0.64221776,
         0.64221776, 0.64221776, 0.64221776, 0.64221776, 0.64221776,
         0.64221776, 0.64221776, 0.64221776, 0.64221776, 0.64221776,
         0.64221776, 0.64221