In [None]:
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import scipy.stats.distributions as distr

from cyber import Defence, CommmonDefences, CyberComponent
from tree import TreeNode, Link
from comm_network import Aggregator, Device, CommNetwork

In [None]:
pcn = CommNetwork(n_devices=15, n_entrypoints=3)
root = pcn.root
print(CommNetwork.show_tree(root))
tree = pcn.graph
print(f"Number of Components: {pcn.n_components}")

In [None]:
from math import isclose
from abc import abstractmethod
from collections import deque

class Attacker():
    
    def __init__(self, budget:float, verbose:bool=False):
        """
        Args:
            budget (float): Time available to compromise nodes, starting at the entry point.
            verbose (bool, optional): Whether to print out attack steps.
                Defaults to False
        """
        self.budget = budget
        self.path = []
        self.verbose = verbose
    
    @abstractmethod
    def attack_network(self, comm_network:CommNetwork) -> None:
        pass

class RandomAttacker(Attacker):

    @staticmethod
    def next_available_nodes(current_node:TreeNode):
        """
        Finds all direct neighbours of this node that we can 
        reach through an outgoing connection / edge.

        Args:
            current_node (Node): The active node

        Returns:
            set[Node]: Unique set of all outgoing Nodes
        """
        available_nodes = set()
        if current_node.is_compromised:
            # Can only attack through outgoing connections
            # I.e. need ability to write data to send payload
            for outgoing_edge in current_node.outgoing_edges:
                node = outgoing_edge.target
                available_nodes.add(node)
        return available_nodes

    def random_walk_with_budget(self, current_node:TreeNode, time_available:float, 
                                attack_path:deque=deque(), nodes_compromised:set[TreeNode]=set(), 
                                max_can_compromise:int=1):
        """
        Recursively walk a graph, trying to any compromise components / nodes we come across.
        Stopping criterion:
        * Reached a dead-end (no outgoing edges to follow)
        * Ran out of time to compromise devices
        * Compromised all devices in the network

        Args:
            current_node (Node): The active node
            time_available (float): Time available to try and compromise Nodes.
            nodes_compromised (set[Node], optional): Unique set of nodes in the communication network that 
                have been compromised so far. Defaults to empty set.
            max_can_compromise (int, optional): Maximum no. of components that can be compromised. Defaults to 1.

        Returns:
            set[Node]: Unique set of communication network components that have been compromised.
        """
        attack_path.append(current_node.id)
        # Try to Compromise current node
        if current_node.is_compromised:
            is_successful, time_spent = True, 0
        else:
            is_successful, time_spent = current_node.attack(time_available)
        # Lose time spent trying to break this node
        time_available -= time_spent
        if is_successful:
            nodes_compromised.add(current_node)
        # Still have time available, and haven't compromised entire network yet
        if time_available > 0 and len(nodes_compromised) < max_can_compromise:
            available_nodes = RandomAttacker.next_available_nodes(current_node)
            # Don't revisit nodes that were unsuccessfully attacked (not worth attacking)
            worth_visiting = lambda node: node.is_worth_attacking() or node.is_compromised
            next_nodes = [node for node in available_nodes if worth_visiting(node)]
            if len(next_nodes) > 0:
                next_node = np.random.choice(next_nodes)
                attack_path, new_nodes_compromised = \
                    self.random_walk_with_budget(next_node, time_available,
                                                 attack_path=attack_path,
                                                 nodes_compromised=nodes_compromised,
                                                 max_can_compromise=max_can_compromise)
                nodes_compromised = nodes_compromised.union(new_nodes_compromised)
            else:
                attack_path.append("Dead End")
    
        # If we've compromised all nodes, or have run out of time, stop.
        return attack_path, nodes_compromised
    
    def attack_network(self, comm_network:CommNetwork):
        """
        Randomly attack network from all entrypoints at the same time. 
        There is no coordinated strategy behind this attacker, it wanders through the
        communication network without regard for the position or importance of components.
        Each entry point starts with the same attack budget.

        Args:
            comm_network (CommNetwork): Procedurally generated Communication Network
        """
        n_components = comm_network.n_components
        nodes_compromised = set()
        for entrypoint in comm_network.entrypoints:
            time_available = self.budget
            attack_path, new_nodes_compromised = \
                self.random_walk_with_budget(entrypoint, time_available,
                                             attack_path = deque(),
                                             nodes_compromised=nodes_compromised,
                                             max_can_compromise=n_components)
            nodes_compromised = nodes_compromised.union(new_nodes_compromised)
            if self.verbose:
                print(f"Attack Path:\n{' --> '.join([str(elt) for elt in attack_path])}")

attacker = RandomAttacker(budget=52, verbose=True)

In [None]:
attacker.attack_network(pcn)

In [None]:
def hierarchy_pos(G:nx.DiGraph, root:TreeNode, width:float=1., vert_gap:float=0.2, vert_loc:float=0, xcenter:float=0.5):

    '''
    Credit: Joel (https://stackoverflow.com/a/29597209/2966723) 
    Licensed under CC Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G (networkx.DiGraph): Graph (must be a tree)
    root (Node): Root node of current graph
    width (float): Horizontal space allocated for this branch - avoids overlap with other branches. Defaults to 1.0
    vert_gap (float): Gap between levels of hierarchy. Defaults to 0.2
    vert_loc (float): Vertical location of root. Defaults to 0.0
    xcenter (float): Horizontal location of root. Defaults to 0.5
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = np.random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

node_color_mask = np.full(tree.number_of_nodes(), fill_value="#1f78b4", dtype=object)
node_edge_color_mask = np.full(tree.number_of_nodes(), fill_value="#000000", dtype=object)
edge_color_mask = np.full(tree.number_of_edges(), fill_value="#000000", dtype=object)
node_shape_mask = np.full(tree.number_of_nodes(), fill_value="s", dtype=object)
root_idx = None
for i, node in enumerate(tree.nodes()):
    is_leaf = True if node.children == [] else False
    if is_leaf:
        # Dark Green if the Leaf Node (Device) is an entry point
        node_color_mask[i] = "green" if node.is_accessible else "lightgreen"
    else:
        # Dark Blue if the Internal Node (Aggregator) is an entry point
        node_color_mask[i] = "#1f78b4" if node.is_accessible else "#1f98ff"
    # If the Node has no parent, it is the root of the Tree (the control center)
    if node.parent is None:
        root_idx = i
        node_color_mask[i] = "red"
    if node.is_compromised:
        # Compromised/hacked nodes have a red outline around them
        node_edge_color_mask[i] = "#ff0000"

for j, (start_node, end_node) in enumerate(tree.edges()):
    # Edges / Communication Channels between 2 compromised nodes are compromised
    if start_node.is_compromised and end_node.is_compromised:
        edge_color_mask[j] = "#ff0000"

# >> Plotting <<
fig, axes = plt.subplots(nrows=1, ncols=2,  figsize=(24,6), width_ratios=[0.6, 0.4])
label_map = {node:node.id for node in tree.nodes()}

# Hierarchical / Tree Visualization of Communication Network
tree_pos = hierarchy_pos(nx.to_undirected(tree), root)
nx.draw_networkx_nodes(tree, pos=tree_pos, ax=axes[0],
                       node_size=500, node_shape="s", node_color=node_color_mask,
                       linewidths=1.0, edgecolors=node_edge_color_mask)
nx.draw_networkx_labels(tree, pos=tree_pos, labels=label_map, ax=axes[0], font_size=10)
nx.draw_networkx_edges(tree, pos=tree_pos, ax=axes[0], edge_color=edge_color_mask)

# Spring Visualization of Communication Network
spring_pos = nx.layout.spring_layout(tree)
nx.draw_networkx_nodes(tree, pos=spring_pos, ax=axes[1],
                       node_size=500, node_shape="s", node_color=node_color_mask, 
                       linewidths=1.0, edgecolors=node_edge_color_mask, )
nx.draw_networkx_labels(tree, pos=spring_pos, labels=label_map, ax=axes[1], font_size=10)
nx.draw_networkx_edges(tree, pos=spring_pos, ax=axes[1], edge_color=edge_color_mask)
plt.tight_layout()
plt.show()

In [None]:
import scipy.stats.distributions as distr
distr_lookup = {
    "TruncNorm": distr.truncnorm, # Continuous, loc=mean (float), scale=standard deviation (float)
    "Exponential": distr.expon, # Continuous, scale = 1 / lambda (float)
    "Gamma": distr.gamma, # Continuous, a = shape parameter (integer)
    "Bernoulli": distr.bernoulli, # Discrete
}
n_attacks = 20
is_successful = distr.bernoulli(0.5).rvs(size=n_attacks).astype(bool)
time_taken = distr.expon(scale=0.0).rvs(size=n_attacks)[is_successful]
print(f"Successful Attacks {sum(is_successful)}/{n_attacks}\nTime Taken per Successful Attack: {time_taken}")