In [None]:
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import scipy.stats.distributions as distr
# from abc import ABC, abstractmethod
from collections import OrderedDict


# >> CyberSecurity <<
class Defence():
    
    def __init__(self, name:str, p:float=1.0, distribution:distr.rv_continuous=distr.uniform(loc=0.0, scale=0.0)) -> None:
        """
        Generic Cyber Security defence, has an associated probability for success and a certain amount of time required
        to compromise it. 
        The time spent attacking the defence is remembered, so multiple independent attempts are more likely to succeed.

        Args:
            name (str): Name of this defence
            p (float, optional): Probability of successful compromise, given infinite time. Defaults to 1.0.
            distribution (distr.rv_continuous, optional): Distribution of time needed to break this defence. Defaults to instantaneous.
        """
        super().__init__()
        self.name = name
        self.success_distr = distr.bernoulli(p=p)
        self.time_distr = distribution
        self.time_to_compromise = self.time_distr.rvs()
        # If 'can_be_successful' is False then no amount of effort can break this defence
        self.can_be_successful = self.success_distr.rvs()
        self.is_compromised = False
        self.time_spent = 0 # Total time spent attacking this defence

    def attack(self, budget:float) -> (bool, float):
        """
        Attack this defence with a certain time budget available. Note that time is not an exact measure.

        Args:
            budget (float): Amount of time available to attack the defence

        Returns:
            bool, float: Whether the attack is successful, how much time was spent
        """
        if (self.time_to_compromise - self.time_spent) > budget:
            # Not enough time available to break this defence
            self.time_spent -= budget
            return False, budget
        else:
            budget_used = self.time_to_compromise - self.time_spent
            self.time_spent = self.time_to_compromise
            self.is_compromised = True
            return self.can_be_successful, budget_used

class CommmonDefences():
    """
    Generators for common types of Communication Network cyber defences.
    """

    def EasyAndCertain():
        return Defence("EasyAndCertain", p=1.0, distribution=distr.expon(scale=1/1.0))
    def EasyAndUncertain():
        return Defence("EasyAndUncertain", p=0.5)
    
    def HardAndCertain():
        return Defence("HardAndCertain", p=1.0, distribution=distr.expon(scale=1/0.1))

    def HardAndUncertain():
        return Defence("HardAndUncertain", p=0.5, distribution=distr.expon(scale=1/0.1))

    def VeryHardAndUncertain():
        return Defence("VeryHardAndUncertain", p=0.5, distribution=distr.expon(scale=1/0.01))

class Vulnerability():

    """
    Modifier that effects the effectivness of a defence. That is if a Vulnerability 
    is present and is discovered, it is easier to compromise the asset.
    """

    def __init__(self, name:str) -> None:
        super().__init__()
        self.name = name

    def attack(self) -> (bool, float):
        pass

class CyberComponent():

    def __init__(self, is_accessible:bool, is_compromised:bool=False) -> None:
        """
        A Cyber Component is any electronic device that can be hacked. In order to
        attempt to hack it must be next to an open or already compromised connection.
        The time taken to compromise the component depends on its defences and vulnerabilities.

        Args:
            is_accessible (bool): Whether the component is an entry point to the Communication Network
            is_compromised (bool, optional): Whether the component has been compromised. Defaults to False.
        """
        super().__init__()
        self.is_accessible = is_accessible
        self.is_compromised = is_compromised
        self.total_time_spent = 0.0
        self.defences = OrderedDict()
        self.vulnerabilities = OrderedDict()

    def attack(self, budget:float) -> (bool, float):
        """
        Attack this component with a certain time budget available. Note that time is not an exact measure.
        To successfuly compromise the device, all defences must be broken.

        Args:
            budget (float): Amount of time available to attack the defence

        Returns:
            bool, float: Whether the attack is successful, how much time was spent
        """
        # TODO: Must we really break all defences to compromise an asset?
        time_spent = 0.0
        for _, defence in self.defences.items():
            if defence.is_compromised:
                is_successful, toc = True, 0
            else:
                is_successful, toc = defence.attack(budget)
            time_spent += toc
            budget -= toc
            # Could not get past the defence, or ran out of time
            if not is_successful or budget <= 0:
                break
        # If attack is successful mark this component as compromised
        if is_successful and budget >= 0:
            self.is_compromised = True
        self.total_time_spent += time_spent
        return is_successful, time_spent
        
    def add_defence(self, defence:Defence):
        self.defences[defence.name] = defence

    def remove_defence(self, defence:Defence|str) -> Defence:
        name = defence.name if isinstance(defence, Defence) else defence
        return self.defences.pop(name)
    
    def add_vulnerability(self, vulnerability:Vulnerability):
        self.vulnerabilities[vulnerability.name] = vulnerability

    def remove_vulnerability(self, vulnerability:Vulnerability|str) -> Vulnerability:
        name = vulnerability.name if isinstance(vulnerability, Vulnerability) else vulnerability
        return self.vulnerabilities.pop(name)

# >> Nodes <<
class Node():

    id_iter = itertools.count()

    def __init__(self, parent:None=None) -> None:
        """
        Generic Node in a graph. Can have exactly 1 parent, or none at all.
        Can have 0 or more children. 
        Incoming and outgoing edges are stored seperately, which allows for a directional graph.
        Hashable, which means it can be used as a Node in Networkx. 

        Args:
            parent (None, optional): Node which has this node as its child. Defaults to None.
        """
        super().__init__()
        self.id = next(self.id_iter)
        self.parent = parent
        self.children = []
        self.outgoing_edges = []
        self.incoming_edges = []

    def set_parent(self, parent):
        self.parent = parent
        self.parent.add_child(self)

    def add_child(self, *children):
        self.children.extend(children)

    def add_outgoing_edge(self, other, edge):
        edge.source = self
        self.outgoing_edges.append(edge)
        if edge.target is None:
            other.add_incoming_edge(self, edge)

    def add_incoming_edge(self, other, edge):
        edge.target = self
        self.incoming_edges.append(edge)
        if edge.source is None:
            other.add_outgoing_edge(self, edge)

    def reset_edges(self):
        self.edges = []

    def reset_children(self):
        self.children = []
    
    def __str__(self):
        return f"{self.__name__}(id={self.id})"

    def __hash__(self):
        return hash(str(self))

    def __eq__(self, other):
        return self.id == other.id

class Aggregator(CyberComponent, Node):


    __name__ = "Aggregator"

    def __init__(self, *args, **kwargs) -> None:
        """
        Generic communication network component that aggregates data from 1 or more sources.
        The Aggregator can be hacked, which can also impact the reliability of all downstream data. 
        """
        super().__init__(*args, **kwargs)
    
    def aggregate(self):
        pass

    def __str__(self):
        return f"{self.__name__}(id={self.id}, is_accessible={self.is_accessible})"

class Device(CyberComponent, Node):

    __name__ = "Device"

    def __init__(self, is_controller:bool, is_sensor:bool, *args, **kwargs) -> None:
        """
        Generic communication network component that collects data and/or acts in the real world.
        The device can be hacked, which impacts the trustworthiness of the data the device emits.

        Args:
            is_controller (bool): Whether the device controls a real-world object, such as the power output of battery
            is_sensor (bool): Whether the device collects data about a real-world object such as the state of charge of a battery
        """
        super().__init__(*args, **kwargs)
        self.is_controller = is_controller
        self.is_sensor = is_sensor
    
    def collect(self):
        pass

    def act(self):
        pass

    def __str__(self):
        return f"{self.__name__}(id={self.id}, is_controller={self.is_controller}, is_sensor={self.is_sensor}, is_accessible={self.is_accessible})"

# >> Edges <<
class Link():

    def __init__(self, source:Node, target:Node) -> None:
        """
        Generic directed Edge in a Graph, has a source Node and a Target node. Can also have
        attributes associated with it.
        Not hashable, use to_edge() to convert it to a NetworkX format.

        Args:
            source (Node, optional): The Node where this edge starts. Defaults to None.
            target (Node, optional): The Node where this edge ends. Defaults to None.
        """
        super().__init__()
        self.source = source
        self.target = target
        self.attributes = {}

    def to_edge(self):
        return dict(u_of_edge=self.source, v_of_edge=self.target, attr=self)

class WiredLink(Link):

    def __init__(self, *args) -> None:
        """
        Physical communication connection between 2 nodes, such as through fibre optic cables.
        """
        super().__init__(*args)

class WirelessLink(Link):

    def __init__(self, *args) -> None:
        """
        Wireless communication connection between 2 nodes, such as through radio waves.
        """
        super().__init__(*args)

In [None]:
class CommNetwork(object):

    def __init__(self, family_size=3, max_family_size_deviation=2, n_devices=20, 
                 n_entrypoints=3,
                 controller_prob=0.3, sensor_prob=0.9):
        """
        Procedurally generated communication network composed of Defices and Aggregators.
        Each aggregator has a certain amount of children (can be devices, or other aggregators).
        Entry points are points in the network that can potentially be used by attackers to try
        and compromise the network, such as a Remote connection to a substation.

        Args:
            family_size (int, optional): Average number of children per aggregator. Defaults to 3.
            max_family_size_deviation (int, optional): Random variation in the no. of children per aggregator. Defaults to 2.
            n_devices (int, optional): Number of end-devices which collect data or execute commands. Defaults to 20.
            n_entrypoints (int, optional): Number of entrypoints for attackers. Defaults to 3.
            controller_prob (float, optional): Probability that a device is a controller. Defaults to 0.3.
            sensor_prob (float, optional): Probability that a device is a sensor. Defaults to 0.9.
        """
        self.family_size = family_size
        self.max_family_size_deviation = max_family_size_deviation
        self.n_devices = n_devices
        self.n_entrypoints = n_entrypoints
        self.controller_prob = controller_prob
        self.sensor_prob = sensor_prob

        # Generate Communication Network (Procedurally)
        self.n_components = 0
        self.root = self.build_network(components=[])
        self.entrypoints = []
        self.add_entrypoints()
        self.graph = nx.DiGraph()
        self.build_graph(self.root)

    def build_network(self, components:list[Device|Aggregator]=[]):
        """
        Recursively build Communication Network composed of Devices and Aggregators.

        Args:
            components (list[Device|Aggregator]): Collection of components (Device or Aggregator) at current level. Defaults to [].

        Returns:
            Aggregator: Root node of tree (represents control center)
        """
        if components == []:
            for _ in range(self.n_devices):
                is_sensor = np.random.choice([True, False], p=[self.sensor_prob, 1-self.sensor_prob])
                is_controller = np.random.choice([True, False], p=[self.controller_prob, 1-self.controller_prob])
                if not is_controller and not is_sensor:
                    is_sensor = True if self.sensor_prob >= self.controller_prob else False
                    is_controller = True if self.controller_prob > self.sensor_prob else False
                device = Device(is_controller=is_controller,
                                is_sensor=is_sensor,
                                is_accessible=False)
                device.add_defence(CommmonDefences.EasyAndUncertain())
                components.append(device)
            self.n_components += len(components)
        elif len(components) > self.family_size:
            aggregators = []
            for i, component in enumerate(components):
                children_per_aggregator = max(1, (self.family_size + np.random.randint(-self.max_family_size_deviation, self.max_family_size_deviation)))
                if i % children_per_aggregator == 0:
                    aggregator = Aggregator(is_accessible=False)
                    aggregator.add_defence(CommmonDefences.HardAndUncertain())
                    aggregators.append(aggregator)
                component.set_parent(aggregator)
                CommNetwork.connect_by_edges(aggregator, component)
            components = aggregators
            self.n_components += len(components)
        else:
            root = Aggregator(is_accessible=False)
            root.add_defence(CommmonDefences.VeryHardAndUncertain())
            for i, component in enumerate(components):
                component.set_parent(root)
                CommNetwork.connect_by_edges(root, component)
            self.n_components += 1
            return root
        
        return self.build_network(components=components)
    
    def add_entrypoints(self):
        """
        Randomly add entry points to aggregators or devices in the network. Excludes control center / root.
        """
        accessible_components = np.random.choice(np.arange(1, self.n_components), min(self.n_components - 1, self.n_entrypoints), replace=False)
        self.walk_and_set_entrypoints(self.root, idcs_to_match=accessible_components, idx=0)
    
    def build_graph(self, root:Aggregator):
        """
        Construct NetworkX Graph from connected Aggregators / Devices

        Args:
            root (Aggregator): Root node of the Communication Network (e.g. the Control Center)
        Returns:
            networkx.DiGraph: Directional NetworkX Graph, with added nodes/edges
        """
        self.graph.add_node(root)
        for child in root.children:
            self.build_graph(child)
            for edge in child.outgoing_edges:
                self.graph.add_edge(edge.source, edge.target)
            for edge in child.incoming_edges:
                self.graph.add_edge(edge.source, edge.target)
    
    def walk_and_set_entrypoints(self, root:Aggregator, idcs_to_match:np.ndarray, idx:int=0):
        """
        Walk the tree, modifying indices that are present in the 'idcs_to_match' array.

        Args:
            root (Aggregator): _description_
            attr_name (str): Name of attribute to modify
            set_value (object): Value to set the attribute to set
            idcs_to_match (np.ndarray): Idcs (in walking order) to modify
            idx (int, optional): Current index in walk. Defaults to 0.

        Returns:
            int: Last visited index
        """
        if idx in idcs_to_match:
            root.is_accessible = True
            self.entrypoints.append(root)
        for child in root.children:
            idx += 1
            idx = self.walk_and_set_entrypoints(child, idcs_to_match, idx=idx)
        return idx
    

    def connect_by_edges(parent:Device|Aggregator, child:Device|Aggregator):
        """
        Adds one-way or two-way communication edges depending on whether child is a sensor and/or controller or aggregator.

        Args:
            parent (Device|Aggregator): Component 1 level up in the communication hierarchy
            child (Device|Aggregator): Component below parent in communication hierarchy
        """
        if isinstance(child, Device):
            if child.is_sensor:
                parent.add_incoming_edge(child, WiredLink())
            if child.is_controller:
                parent.add_outgoing_edge(child, WiredLink())
        else:
            parent.add_incoming_edge(child, WiredLink())
            parent.add_outgoing_edge(child, WiredLink())


    def show_tree(root:Aggregator, s:str="", depth:int=0):
        """
        Recursively prints out structure of communication network using whitespace to denote deeper components.

        Args:
            root (Aggregator): _description_
            s (str, optional): _description_. Defaults to "".
            depth (int, optional): _description_. Defaults to 0.
            include_hash (bool): Whether to include the hash ID of each node

        Returns:
            str: String representing the network architecture
        """
        s += f"{depth*'   '}{root}\n"
        for child in root.children:
            s = CommNetwork.show_tree(child, s=s, depth=depth+1)
        return s
        

In [None]:
pcn = CommNetwork(n_devices=15, n_entrypoints=3)
root = pcn.root
print(CommNetwork.show_tree(root))
tree = pcn.graph
print(f"Number of Components: {pcn.n_components}")

In [None]:
def next_available_nodes(current_node:Node):
    """
    Finds all direct neighbours of this node that we can 
    reach through an outgoing connection / edge.

    Args:
        current_node (Node): The active node

    Returns:
        set[Node]: Unique set of all outgoing Nodes
    """
    available_nodes = set()
    if current_node.is_compromised:
        # Can only attack through outgoing connections
        # I.e. need ability to write data to send payload
        for outgoing_edge in current_node.outgoing_edges:
            node = outgoing_edge.target
            available_nodes.add(node)
    return available_nodes

def random_walk_with_budget(current_node:Node, time_available:float, nodes_compromised:set[Node]=set(), max_can_compromise:int=1):
    """
    Recursively walk a graph, trying to any compromise components / nodes we come across.
    Stopping criterion:
    * Reached a dead-end (no outgoing edges to follow)
    * Ran out of time to compromise devices
    * Compromised all devices in the network

    Args:
        current_node (Node): The active node
        time_available (float): Time available to try and compromise Nodes.
        nodes_compromised (set[Node], optional): Unique set of nodes in the communication network that 
            have been compromised so far. Defaults to empty set.
        max_can_compromise (int, optional): Maximum no. of components that can be compromised. Defaults to 1.

    Returns:
        set[Node]: Unique set of communication network components that have been compromised.
    """
    # Try to Compromise current node
    if current_node.is_compromised:
        is_successful, time_spent = True, 0
    else:
        is_successful, time_spent = current_node.attack(time_available)
    
    # Lose time spent trying to break this node
    time_available -= time_spent
    if is_successful:
        nodes_compromised.add(current_node)
    
    # Still have time available, and haven't compromised entire network yet
    if time_available > 0 and len(nodes_compromised) < max_can_compromise:
        next_nodes = next_available_nodes(current_node)
        if len(next_nodes) > 0:
            next_node = np.random.choice(list(next_nodes))
            nodes_compromised = nodes_compromised.union(random_walk_with_budget(next_node, time_available, nodes_compromised=nodes_compromised, max_can_compromise=max_can_compromise))
    # If we've compromised all nodes, or have run out of time, stop.
    return nodes_compromised

def attack_network(comm_network:CommNetwork, budget:float=10.0):
    """
    Attack network from all entrypoints at the same time. Each entry point starts with the same attack budget. 

    Args:
        comm_network (CommNetwork): Procedurally generated Communication Network
        budget (float, optional): Time available to compromise nodes, starting at the entry point. Defaults to 10.0.
    """
    n_components = comm_network.n_components
    nodes_compromised = set()
    for entrypoint in comm_network.entrypoints:
        time_available = budget
        nodes_compromised = nodes_compromised.union(random_walk_with_budget(entrypoint, time_available, nodes_compromised=nodes_compromised, max_can_compromise=n_components))

In [None]:
attack_network(pcn, budget=52)

In [None]:
def hierarchy_pos(G:nx.DiGraph, root:Node, width:float=1., vert_gap:float=0.2, vert_loc:float=0, xcenter:float=0.5):

    '''
    Credit: Joel (https://stackoverflow.com/a/29597209/2966723) 
    Licensed under CC Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G (networkx.DiGraph): Graph (must be a tree)
    root (Node): Root node of current graph
    width (float): Horizontal space allocated for this branch - avoids overlap with other branches. Defaults to 1.0
    vert_gap (float): Gap between levels of hierarchy. Defaults to 0.2
    vert_loc (float): Vertical location of root. Defaults to 0.0
    xcenter (float): Horizontal location of root. Defaults to 0.5
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = np.random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

node_color_mask = np.full(tree.number_of_nodes(), fill_value="#1f78b4", dtype=object)
node_edge_color_mask = np.full(tree.number_of_nodes(), fill_value="#000000", dtype=object)
edge_color_mask = np.full(tree.number_of_edges(), fill_value="#000000", dtype=object)
node_shape_mask = np.full(tree.number_of_nodes(), fill_value="s", dtype=object)
root_idx = None
for i, node in enumerate(tree.nodes()):
    is_leaf = True if node.children == [] else False
    if is_leaf:
        # Dark Green if the Leaf Node (Device) is an entry point
        node_color_mask[i] = "green" if node.is_accessible else "lightgreen"
    else:
        # Dark Blue if the Internal Node (Aggregator) is an entry point
        node_color_mask[i] = "#1f78b4" if node.is_accessible else "#1f98ff"
    # If the Node has no parent, it is the root of the Tree (the control center)
    if node.parent is None:
        root_idx = i
        node_color_mask[i] = "red"
    if node.is_compromised:
        # Compromised/hacked nodes have a red outline around them
        node_edge_color_mask[i] = "#ff0000"

for j, (start_node, end_node) in enumerate(tree.edges()):
    # Edges / Communication Channels between 2 compromised nodes are compromised
    if start_node.is_compromised and end_node.is_compromised:
        edge_color_mask[j] = "#ff0000"

# >> Plotting <<
fig, axes = plt.subplots(nrows=1, ncols=2,  figsize=(24,6), width_ratios=[0.6, 0.4])
label_map = {node:node.id for node in tree.nodes()}

# Hierarchical / Tree Visualization of Communication Network
tree_pos = hierarchy_pos(nx.to_undirected(tree), root)
nx.draw_networkx_nodes(tree, pos=tree_pos, ax=axes[0],
                       node_size=500, node_shape="s", node_color=node_color_mask,
                       linewidths=1.0, edgecolors=node_edge_color_mask)
nx.draw_networkx_labels(tree, pos=tree_pos, labels=label_map, ax=axes[0], font_size=10)
nx.draw_networkx_edges(tree, pos=tree_pos, ax=axes[0], edge_color=edge_color_mask)

# Spring Visualization of Communication Network
spring_pos = nx.layout.spring_layout(tree)
nx.draw_networkx_nodes(tree, pos=spring_pos, ax=axes[1],
                       node_size=500, node_shape="s", node_color=node_color_mask, 
                       linewidths=1.0, edgecolors=node_edge_color_mask, )
nx.draw_networkx_labels(tree, pos=spring_pos, labels=label_map, ax=axes[1], font_size=10)
nx.draw_networkx_edges(tree, pos=spring_pos, ax=axes[1], edge_color=edge_color_mask)
plt.tight_layout()
plt.show()

In [None]:
import scipy.stats.distributions as distr
distr_lookup = {
    "TruncNorm": distr.truncnorm, # Continuous, loc=mean (float), scale=standard deviation (float)
    "Exponential": distr.expon, # Continuous, scale = 1 / lambda (float)
    "Gamma": distr.gamma, # Continuous, a = shape parameter (integer)
    "Bernoulli": distr.bernoulli, # Discrete
}
n_attacks = 20
is_successful = distr.bernoulli(0.5).rvs(size=n_attacks).astype(bool)
time_taken = distr.expon(scale=0.0).rvs(size=n_attacks)[is_successful]
print(f"Successful Attacks {sum(is_successful)}/{n_attacks}\nTime Taken per Successful Attack: {time_taken}")