In [None]:
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# from abc import ABC, abstractmethod
from collections import OrderedDict


# >> CyberSecurity <<
class Defence():
    
    def __init__(self, name:str) -> None:
        super().__init__()
        self.name = name

class Vulnerability():

    def __init__(self, name:str) -> None:
        super().__init__()
        self.name = name

class CyberComponent():

    def __init__(self) -> None:
        super().__init__()
        self.defences = OrderedDict()
        self.vulnerabilities = OrderedDict()

    def add_defence(self, defence:Defence):
        self.defences[defence.name] = defence

    def remove_defence(self, defence:Defence|str) -> Defence:
        name = defence.name if isinstance(defence, Defence) else defence
        return self.defences.pop(name)
    
    def add_vulnerability(self, vulnerability:Vulnerability):
        self.vulnerabilities[vulnerability.name] = vulnerability

    def remove_vulnerability(self, vulnerability:Vulnerability|str) -> Vulnerability:
        name = vulnerability.name if isinstance(vulnerability, Vulnerability) else vulnerability
        return self.vulnerabilities.pop(name)

# >> Nodes <<
class Node():

    id_iter = itertools.count()

    def __init__(self, parent:None=None) -> None:
        super().__init__()
        self.id = next(self.id_iter)
        self.parent = parent
        self.children = []
        self.outgoing_edges = []
        self.incoming_edges = []

    def set_parent(self, parent):
        self.parent = parent
        self.parent.add_child(self)

    def add_child(self, *children):
        self.children.extend(children)

    def add_outgoing_edge(self, other, edge):
        edge.source = self
        self.outgoing_edges.append(edge)
        if edge.target is None:
            other.add_incoming_edge(self, edge)

    def add_incoming_edge(self, other, edge):
        edge.target = self
        self.incoming_edges.append(edge)
        if edge.source is None:
            other.add_outgoing_edge(self, edge)

    def reset_edges(self):
        self.edges = []

    def reset_children(self):
        self.children = []
    
    def __str__(self):
        return f"{self.__name__}(id={self.id})"

    def __hash__(self):
        return hash(str(self))

    def __eq__(self, other):
        return self.id == other.id

class Aggregator(CyberComponent, Node):

    __name__ = "Aggregator"

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def aggregate(self):
        pass

class Device(CyberComponent, Node):

    __name__ = "Device"

    def __init__(self, is_controller:bool, is_sensor:bool, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.is_controller = is_controller
        self.is_sensor = is_sensor
    
    def collect(self):
        pass

    def act(self):
        pass

# >> Edges <<
class Link():

    def __init__(self, source=None, target=None) -> None:
        super().__init__()
        self.source = source
        self.target = target
        self.attributes = {}

    def to_edge(self):
        return dict(u_of_edge=self.source, v_of_edge=self.target, attr=self)

class WiredLink(Link):

    def __init__(self, *args) -> None:
        super().__init__(*args)

class WirelessLink(Link):

    def __init__(self, *args) -> None:
        super().__init__(*args)

In [None]:
def connect_by_edges(parent, child):
    if isinstance(child, Device):
        if child.is_sensor:
            parent.add_incoming_edge(child, WiredLink())
        if child.is_controller:
            parent.add_outgoing_edge(child, WiredLink())
    else:
        parent.add_incoming_edge(child, WiredLink())
        parent.add_outgoing_edge(child, WiredLink())

def build_network(family_size=3, max_family_size_deviation=2, n_devices=20, components=[], controller_prob=0.3, sensor_prob=0.9):
    if components == []:
        for _ in range(n_devices):
            is_sensor = np.random.choice([True, False], p=[sensor_prob, 1-sensor_prob])
            is_controller = np.random.choice([True, False], p=[controller_prob, 1-controller_prob])
            if not is_controller and not is_sensor:
                is_sensor = True if sensor_prob >= controller_prob else False
                is_controller = True if controller_prob > sensor_prob else False
            device = Device(is_controller=is_controller,
                            is_sensor=is_sensor)
            components.append(device)
    elif len(components) > family_size:
        aggregators = []
        for i, component in enumerate(components):
            children_per_aggregator = max(1, (family_size + np.random.randint(-max_family_size_deviation, max_family_size_deviation)))
            if i % children_per_aggregator == 0:
                aggregator = Aggregator()
                aggregators.append(aggregator)
            component.set_parent(aggregator)
            connect_by_edges(aggregator, component)
        components = aggregators
    else:
        root = Aggregator()
        for i, component in enumerate(components):
            component.set_parent(root)
            connect_by_edges(root, component)
        return root
    return build_network(family_size=family_size, max_family_size_deviation=max_family_size_deviation, n_devices=n_devices, components=components)

root = build_network(n_devices=15)

def show_tree(root, s="", depth=0):
    s += f"{depth*'   '}{root}\n"
    for child in root.children:
        s = show_tree(child, s=s, depth=depth+1)
    return s

print(show_tree(root))

In [None]:
def build_graph(root, G):
    G.add_node(root)
    for child in root.children:
        G = build_graph(child, G)
        for edge in child.outgoing_edges:
            G.add_edge(edge.source, edge.target)
        for edge in child.incoming_edges:
            G.add_edge(edge.source, edge.target)
    return G

def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):

    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.  
    Licensed under Creative Commons Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G: the graph (must be a tree)
    
    root: the root node of current branch 
    - if the tree is directed and this is not given, 
      the root will be found and used
    - if the tree is directed and this is given, then 
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given, 
      then a random choice will be used.
    
    width: horizontal space allocated for this branch - avoids overlap with other branches
    
    vert_gap: gap between levels of hierarchy
    
    vert_loc: vertical location of root
    
    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = np.random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
tree = build_graph(root, nx.DiGraph())

node_color_mask = np.full(tree.number_of_nodes(), fill_value="#1f78b4", dtype=object)
node_shape_mask = np.full(tree.number_of_nodes(), fill_value="s", dtype=object)
root_idx = None
for i, node in enumerate(tree.nodes()):
    is_leaf = True if node.children == [] else False
    if is_leaf:
        node_color_mask[i] = "lightgreen"
    if node.parent is None:
        root_idx = i
        node_color_mask[i] = "red"

fig, axes = plt.subplots(nrows=1, ncols=2,  figsize=(24,6), width_ratios=[0.6, 0.4])
label_map = {node:node.id for node in tree.nodes()}

tree_pos = hierarchy_pos(nx.to_undirected(tree), root=root)
nx.draw(tree, pos=tree_pos, with_labels=True, labels=label_map, ax=axes[0], font_size=10, 
        node_shape="s", node_color=node_color_mask)

spring_pos = nx.layout.spring_layout(tree)
nx.draw(tree, pos=spring_pos, with_labels=True, labels=label_map, ax=axes[1], font_size=10, 
        node_shape="s", node_color=node_color_mask)
plt.tight_layout()
plt.show()