In [None]:
import itertools
import math
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import scipy.stats.distributions as distr

from cyber import Defence, CommmonDefences, CyberComponent
from tree import TreeNode, Link
from comm_network import Aggregator, Device, CommNetwork

In [None]:
from comm_network import LevelOfRedundancy
seed = np.random.randint(low=0, high=52600)
seed = 27194
print(f"Seed: {seed}")
np.random.seed(seed)
pcn = CommNetwork(n_devices=3, n_entrypoints=3, redundancy=2, redundancy_deviation=1, enable_sibling_to_sibling_comm=True)
root = pcn.root
print(CommNetwork.show_tree(root))
tree = pcn.graph
print(f"Number of Components: {pcn.n_components}")

In [None]:
from attackers import RandomAttacker
attacker = RandomAttacker(budget=52, verbose=True)
attacker.attack_network(pcn)

In [None]:
def hierarchy_pos(G:nx.DiGraph, root:TreeNode, width:float=1., vert_gap:float=0.2, vert_loc:float=0, xcenter:float=0.5):

    '''
    Credit: Joel (https://stackoverflow.com/a/29597209/2966723) 
    Licensed under CC Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G (networkx.DiGraph): Graph (must be a tree)
    root (Node): Root node of current graph
    width (float): Horizontal space allocated for this branch - avoids overlap with other branches. Defaults to 1.0
    vert_gap (float): Gap between levels of hierarchy. Defaults to 0.2
    vert_loc (float): Vertical location of root. Defaults to 0.0
    xcenter (float): Horizontal location of root. Defaults to 0.5
    '''
    # if not nx.is_tree(G):
    #     raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = np.random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = root.children # list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            for child in children:
                child.remove_parents(parent)  
        if len(children) !=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in sorted(children, key=lambda child:child.id):
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

node_color_mask = np.full(tree.number_of_nodes(), fill_value="#1f78b4", dtype=object)
node_edge_color_mask = np.full(tree.number_of_nodes(), fill_value="#000000", dtype=object)
edge_color_mask = np.full(tree.number_of_edges(), fill_value="#000000", dtype=object)
node_shape_mask = np.full(tree.number_of_nodes(), fill_value="s", dtype=object)
root_idx = None
for i, node in enumerate(tree.nodes()):
    if node.is_leaf:
        # Dark Green if the Leaf Node (Device) is an entry point
        node_color_mask[i] = "green" if node.is_accessible else "lightgreen"
    else:
        # Dark Blue if the Internal Node (Aggregator) is an entry point
        node_color_mask[i] = "#1f78b4" if node.is_accessible else "#1f98ff"
    # If the Node has no parent, it is the root of the Tree (the control center)
    if len(node.parents) == 0:
        root_idx = i
        node_color_mask[i] = "coral"
    if node.is_compromised:
        # Compromised/hacked nodes have a red outline around them
        node_edge_color_mask[i] = "#ff0000"

for j, (start_node, end_node) in enumerate(tree.edges()):
    # Edges / Communication Channels between 2 compromised nodes are compromised
    if start_node.is_compromised and end_node.is_compromised:
        edge_color_mask[j] = "#ff0000"

# >> Plotting <<
fig, axes = plt.subplots(nrows=1, ncols=2,  figsize=(24,6), width_ratios=[0.6, 0.4])
label_map = {node:node.id for node in tree.nodes()}

# Hierarchical / Tree Visualization of Communication Network
tree_pos = hierarchy_pos(nx.to_undirected(tree), root)
nx.draw_networkx_nodes(tree, pos=tree_pos, ax=axes[0],
                       node_size=400, node_shape="s", node_color=node_color_mask,
                       linewidths=1.0, edgecolors=node_edge_color_mask)
nx.draw_networkx_labels(tree, pos=tree_pos, labels=label_map, ax=axes[0], font_size=10)
nx.draw_networkx_edges(tree, pos=tree_pos, ax=axes[0], edge_color=edge_color_mask)

# Spring Visualization of Communication Network
spring_pos = nx.layout.spring_layout(tree)
nx.draw_networkx_nodes(tree, pos=spring_pos, ax=axes[1],
                       node_size=400, node_shape="s", node_color=node_color_mask, 
                       linewidths=1.0, edgecolors=node_edge_color_mask, )
nx.draw_networkx_labels(tree, pos=spring_pos, labels=label_map, ax=axes[1], font_size=10)
nx.draw_networkx_edges(tree, pos=spring_pos, ax=axes[1], edge_color=edge_color_mask)
plt.tight_layout()
plt.show()

## Static Analysis

In [None]:
A = nx.adjacency_matrix(pcn.graph, nodelist=sorted(pcn.graph.nodes(), key=lambda node: node.id)).todense()
print("A\n", A)
np.fill_diagonal(A, val=0)
A2 = A@A
np.fill_diagonal(A2, val=0)
print("A^2\n", A2)
A3 = A2@A
np.fill_diagonal(A3, val=0)
print("A^3\n", A3)

In [None]:
for node in pcn.graph.nodes():
    print(node, "\n\t", [n for n in pcn.graph.neighbors(node)])
    print("\t", [defence.p for defence in node.defences.values()])

In [None]:
sets = []
for node in pcn.graph.nodes():
    neighbors = set([n for n in pcn.graph.neighbors(node)])
    sets.append(neighbors)
print([n.id for n in sets[1]], [n.id for n in sets[2]], [n.id for n in sets[3]], [n.id for n in sets[4]])
print(sets[4].difference(sets[2]), [n.id for n in sets[4].difference(sets[2])])

In [None]:
set([3, 5, 0, 4, 1]).difference(set([5, 3]))

In [None]:
import copy

def iter_paths(graph, path=None, prob_to_compromise=1.0, reachable_nodes=set(), visited_nodes=set()):
    if not path:
        # Different starting locations
        for start_node in graph.nodes():
            yield from iter_paths(graph, [start_node], 1.0)
            break
    else:
        current_node = path[-1]
        visited_previously = current_node in visited_nodes
        if not visited_previously:
            visited_nodes.add(current_node)
        neighbouring_nodes = current_node.get_neighbours()
        reachable_nodes = reachable_nodes | neighbouring_nodes
        reachable_nodes = reachable_nodes - visited_nodes
        # reachable_nodes = reachable_nodes - visited_nodes
        if len(path) >= 1:
            # Probability to not compromise the remaining reachable nodes
            prob_not_compromised = np.prod([(1-node.get_prob_to_compromise()) for node in reachable_nodes])
            # Probability that we compromised everything along the current path (so far)
            prob_to_compromise *= current_node.get_prob_to_compromise()
            print("Reachable Nodes:", [n.id for n in reachable_nodes])
            print("Visited Nodes:", [n.id for n in visited_nodes])
            yield path, prob_to_compromise*prob_not_compromised
        if visited_previously:
            return
        for reachable_node in reachable_nodes:
            yield from iter_paths(graph,
                                  path+[reachable_node],
                                  prob_to_compromise,
                                  copy.copy(reachable_nodes), 
                                  copy.copy(visited_nodes))
sum_probs = 0.0
n_probs = {}
for path_no, (path, prob) in enumerate(iter_paths(pcn.graph)):
    print(f"Path {path_no} :: Prob {str(prob):<15} :: {'-'.join([str(node.id) for node in path])}")
    n_probs[len(path)] = prob if len(path) not in n_probs else n_probs[len(path)] + prob
    sum_probs += prob
print(f"No. of Paths: {path_no}. Sum of Probabilities: {sum_probs}")
print("\n".join(f"{k} devices: {v}" for k,v in n_probs.items()))

In [None]:
time_required = 0.0
nodes = pcn.graph.nodes()
node_probs = {}
for node in nodes:
    # print(node)
    probability_to_compromise = 1.0
    for defence_name, defence in node.defences.items():
        expected_effort = defence.effort_distribution.expect()
        time_required += expected_effort
        # print("\t", defence_name)
        probability_to_compromise *= defence.p
    node_probs[node] = probability_to_compromise

prob_to_compromise_n_devices = {}
all_nodes = set(nodes)
cumulative = 0.0
for n_devices in range(pcn.n_components, 0, -1):
    prob_to_compromise_n_devices[n_devices] = cumulative
    for combination in itertools.combinations(nodes, n_devices):
        probability_to_compromise = 1.0
        combination = set(combination)
        missing_nodes = all_nodes.difference(combination)
        for node in combination:
            probability_to_compromise *= node_probs[node]
        for node in missing_nodes:
            probability_to_compromise *= (1 - node_probs[node])
        prob_to_compromise_n_devices[n_devices] += probability_to_compromise 
    print(f"{n_devices} Devices: {prob_to_compromise_n_devices[n_devices]}")
    cumulative += prob_to_compromise_n_devices[n_devices]
print(prob_to_compromise_n_devices)

In [None]:
import multiprocessing as mp
from multiprocessing import Process, Pool

def combinations(iterable, k):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if k > n:
        return
    indices = list(range(k))
    yield tuple(pool[i] for i in indices), tuple(pool[i] for i in range(k, n)), k
    while True:
        for i in reversed(range(k)):
            if indices[i] != i + n - k:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, k):
            indices[j] = indices[j-1] + 1
        neg_indices = tuple(pool[i] for i in range(n) if i not in indices)
        yield tuple(pool[i] for i in indices), neg_indices, k

def process_wrapper(successes, failures, k):
    prob = np.prod(successes)
    return prob + (np.prod(failures) if len(failures) >= 1 else 1.0), k

N = pcn.n_components
p = 0.5
np.random.seed(0)
ps = np.random.uniform(low=0, high=1, size=N)
ps = np.full(shape=N, fill_value=p)
idcs = np.arange(N)
all_idcs = set(idcs)
cumulative = 0.0
with Pool(processes=mp.cpu_count()-1) as pool:
    results = []
    for k in range(pcn.n_components, 0, -1):
        # k_prob = 0.0
        results.extend(pool.imap_unordered(process_wrapper, combinations(ps, k), chunksize=32))
print("Pool Active")
[result.wait() for result in results]
            # results = [pool.apply_async(process_wrapper, args=(successes, failures)) for successes, failures in combinations(ps, k)]
            # k_prob = sum([res.get(timeout=-1) for res in results])
                # # Succeses: p^k (where p is not fixed)
                # prob = np.prod(successes)
                # # Failures: (1-p)^(n-k) (where p is not fixed)
                # prob *= np.prod(failures) if len(failures) >= 1 else 1.0
            # k_prob += prob
        # print(f"{k}: {cumulative + k_prob}, {k_prob}")
        # cumulative += k_prob

In [None]:
# If the probability of compromising all components is the same,
# we can use the Binomial distribution function
# Takes: 12.6 µs
N = pcn.n_components
k = 2
p = 0.5
cumulative = 0.0
for k in range(pcn.n_components, 0, -1):
    prob = math.comb(N, k)*math.pow(p, k)*math.pow(1-p,N-k)
    print(f"{k} Devices: {cumulative + prob}")
    cumulative += prob

In [None]:
math.comb(5,3)*math.pow(0.5,3)*math.pow(0.5, 2)

In [None]:
import scipy.stats.distributions as distr
distr_lookup = {
    "TruncNorm": distr.truncnorm, # Continuous, loc=mean (float), scale=standard deviation (float)
    "Exponential": distr.expon, # Continuous, scale = 1 / lambda (float)
    "Gamma": distr.gamma, # Continuous, a = shape parameter (integer)
    "Bernoulli": distr.bernoulli, # Discrete
}
n_attacks = 20
is_successful = distr.bernoulli(0.5).rvs(size=n_attacks).astype(bool)
time_taken = distr.expon(scale=0.0).rvs(size=n_attacks)[is_successful]
print(f"Successful Attacks {sum(is_successful)}/{n_attacks}\nTime Taken per Successful Attack: {time_taken}")