# Graph Failure Propagation using Deepmind's GCN Framework - Phase 2


This notebook models a network failure propagation scenario using a Graph Neural Network.

The framework used is [Google's Deepmind GraphNet Framework](https://github.com/deepmind/graph_nets/)


We need to add new boolean features (Latency, Error, Saturated, and Traffic) as below:
```
Latency = 1 -> there is a latency in the app
Latency = 0 -> there is no latency in the app
..
Error = 1 -> there are errors in the app
Error = 0 -> there are errors in the app
..
Saturated = 1 -> Resources are fully occupied/utilized for that app
Saturated = 0 -> Resources are not fully occupied/utilized for that app (there are available resources)
..
Traffic = 2 -> There is a high amount of traffic hitting that app
Traffic = 1 -> There is a med amount of traffic hitting that app
Traffic = 0 -> There is a low amount of traffic hitting that app

downstreams (Direct Child nodes)
upstreams (Direct Parent Nodes)
A -> B
A is Direct parent of B
B is direct child of A

Phase 2 Simulator rules:
1- High traffic causes Saturation on the same node in the next step.
2- Saturation causes immediate low traffic on downstreams in the same step.
3- Saturation causes Latency on the same node in the next step.
4- Latency causes Latency on the upstreams in the next step.
5- Latency causes Errors on the same node in the next step.
6- Latency causes Errors on the downstreams in the next step.
7- Low traffic causes Low traffic on the downstreams in the next step.
8- Errors cause Errors on the downstreams in the next step.


```

And then based on an equation like

```
health = med traffic && !error && !latency && !saturated
```

The node is healthy if:
- there is a medium amount of traffic
- no errors 
- no latency 
- there are available resources

The scenario can be;
- if there is a high traffic/low traffic (traffic = 2 / traffic = 0):
  1. it will lead to high usage of resources (saturated = 1) at the next step
  1. this saturation will cause latency (latency = 1) at the next step 
  1. eventually (error = 1).

we want to be able to predict some of them:
- i.e latency and errors ..  
- based on that prediction we decide if the node is healthy.


## Install Libraries

In [None]:
!pip install -q --upgrade ipython==5.5.0
!pip install -q --upgrade ipykernel==4.10

In [None]:
try:
  %pip install --quiet --user --upgrade watermark 
  %pip install --quiet --user --upgrade "graph_nets>=1.1" "dm-sonnet>=2.0.0b0"
  import sonnet
except ModuleNotFoundError:
  import os
  print('libraries installed. restarting ...')
  os.kill(os.getpid(), 9)

In [None]:
%load_ext watermark

In [None]:
%watermark -u -n -t -z -p numpy,tensorflow,sonnet,graph_nets -g

## Configure Environment



In [None]:
#@title Imports
import time
import copy
import itertools

import tensorflow as tf
import networkx as nx
import sonnet as snt

from matplotlib import pyplot as plt
try:
  import seaborn as sns
except ImportError:
  pass
else:
  sns.reset_orig()

from graph_nets import blocks
from graph_nets import utils_tf
from graph_nets import utils_np
from graph_nets.demos_tf2 import models

import numpy as np

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20,3)

In [None]:
#@title Set Random seeds
SEED = 42 #@param
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
NODE_HEALTH_THRESHOLD = 0.85 # @param
EDGE_HEALTH_THRESHOLD = 0.85 # @param

In [None]:
#@title Load Graph Visualization Functions

def plot_graphs_tuple(graphs_tuple):
  networkx_graphs = utils_np.graphs_tuple_to_networkxs(graphs_tuple)
  num_graphs = len(networkx_graphs)
  _, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5))
  if num_graphs == 1:
    axes = axes,
  for graph, ax in zip(networkx_graphs, axes):
    plot_graph_networkx(graph, ax)

import collections

NodeStatus = collections.namedtuple('NodeStatus', ['latency', 'errors', 'saturation', 'traffic'])

NodeStatus.is_healthy = lambda self: (self.latency == 0 and
                                      self.errors == 0 and
                                      self.saturation == 0 and
                                      self.traffic == 1)

EdgeStatus = collections.namedtuple('EdgeStatus', ['latency', 'traffic'])

EdgeStatus.is_healthy = lambda self: (self.latency == 0 and
                                      self.traffic == 1)
def node_status(d):
  L, E, S, T = np.argmax(d[0:2]), np.argmax(d[2:4]), np.argmax(d[4:6]), np.argmax(d[6:])
  return NodeStatus(L, E, S, T)

def edge_status(d):
  L, T = np.argmax(d[0:2]), np.argmax(d[2:])
  return EdgeStatus(L, T)

def nodefmt(node, data):
  d = data['features']
  L, E, S, T = np.argmax(d[0:2]), np.argmax(d[2:4]), np.argmax(d[4:6]), np.argmax(d[6:])
  if (L == 0 and T == 1 and S == 0 and E == 0):
    return ""
  fmt = []
  if L:
    fmt.append("L") #:{:.2g}".format(L))
  if E:
    fmt.append("E") #:{:.2g}".format(E))
  if S:
    fmt.append("S") # :{:.2g}".format(S))
  if (T == 0 or T == 2):
    fmt.append("T:{:.2g}".format(T))
  return "|".join(fmt)

def edgefmt(sender, receiver, data):
  d = data['features']
  L, T = np.argmax(d[0:2]), np.argmax(d[2:])
  if (L == 0 and T == 1):
    return ""
  fmt = [] 
  if L:
    fmt.append("L") #:{:.2g}".format(L))
  if T:
    fmt.append("T:{:.2g}".format(T))
  return "|".join(fmt)

# def is_nhealthy(node, data):
#   data = data['features']
#   return  np.mean((data[0], data[2], data[4], data[6])) > NODE_HEALTH_THRESHOLD

# def is_ehealthy(sender, receiver, data):
#   data = data['features']
#   return  np.mean((data[0], data[2])) > EDGE_HEALTH_THRESHOLD

ORANGE = '#FFD23F'
RED = '#EE4266'
GREEN = "g"

def edge_color(sender, receiver, data):
  status = edge_status(data['features'])
 
  if status.is_healthy():
    return GREEN
  
  color = GREEN
  
  if status.latency == 1:   # high latency
    color = ORANGE
  if status.traffic == 0:   # low traffic
    color = ORANGE
  elif status.traffic == 2: # high traffic
    color = RED
  return color

def node_color(node, data):
  status = node_status(data['features'])

  if status.is_healthy():
    return GREEN

  color = GREEN
  if status.latency == 1: # high latency
    color = ORANGE
  if status.errors == 1: # has error
    color = RED
    return color
  if status.saturation == 1: # saturated
    color = ORANGE
  if status.traffic == 0: # low traffic
    color = ORANGE
  elif status.traffic == 2: # high traffic
    color = RED
  return color

def plot_graph_networkx(graph, ax, pos=None):
  # node 
  node_labels = {node: nodefmt(node, data) 
                  for node, data in graph.nodes(data=True)
                    if data["features"] is not None}
  # edge
  edge_labels = {(sender, receiver): edgefmt(sender, receiver, data) 
                 for sender, receiver, data in graph.edges(data=True)
                 if data["features"] is not None}
  # unused
  global_label = ("{:.3g}".format(graph.graph["features"][0])
                  if graph.graph["features"] is not None else None)
  
  node_color_map = [node_color(node, data) 
                          for node, data in graph.nodes(data=True)]

  edge_color_map = [edge_color(sender, receiver, data) 
                 for sender, receiver, data in graph.edges(data=True)]
  
  
  if pos is None:
    random_pos = nx.random_layout(graph, seed=42)
    pos = nx.spring_layout(graph, pos=random_pos)

  nx.draw_networkx(graph, pos, ax=ax, 
                   labels=node_labels,
                   edge_color=edge_color_map,
                   node_color=node_color_map, node_size=700)

  if edge_labels:
    nx.draw_networkx_edge_labels(graph, pos, edge_labels, ax=ax)

  if global_label:
    plt.text(0.05, 0.95, global_label, transform=ax.transAxes)

  ax.yaxis.set_visible(False)
  ax.xaxis.set_visible(False)
  return pos


def plot_compare_graphs(graphs_tuples, labels):
  pos = None
  num_graphs = len(graphs_tuples)
  _, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5))
  if num_graphs == 1:
    axes = axes,
  pos = None
  for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes):
    graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0]
    pos = plot_graph_networkx(graph, ax, pos=pos)
    ax.set_title(name)

def plot_compare_graphs_custom(graphs_tuples, labels):
  pos = None
  num_graphs = len(graphs_tuples)
  _, axes = plt.subplots(1, num_graphs, figsize=(50*num_graphs, 50))
  if num_graphs == 1:
    axes = axes,
  pos = None
  for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes):
    graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0]
    pos = plot_graph_networkx(graph, ax, pos=pos)
    ax.set_title(name)

# Configure Graph

In [None]:
#@title Create Graph Dict { form-width: "30%" }
#@markdown number of features per edge
LATENCY = 2
ERROR = 2
SATURATION = 2
TRAFFIC = 3

NODE_FEATURES_DIM = LATENCY + ERROR + SATURATION + TRAFFIC 

EDGES_FEATURES_DIM = LATENCY + TRAFFIC 

def healthy_node():
  return [ 1., # no latency
           0.,
           1., # no errors
           0.,
           1., # no saturation
           0.,
           0., # medium traffic
           1.,
           0.,   
          ]

def healthy_edge():
  return [
          1., # no latency
          0.,
          0., # medium traffic
          1.,
          0.]

def set_traffic(data, value):
  if value == 1:
    data[6:] = [0.0, 1.0, 0.0]
  elif value == 2:
    data[6:] = [0.0, 0.0, 1.0]
  else:
    data[6:] = [1.0, 0.0, 0.0]
  return data

def set_saturation(data, value):
  if value:
    data[4:6] = [0.0, 1.0]
  else:
    data[4:6] = [1.0, 0.0]
  return data

def set_error(data, value):
  if value:
    data[2:4] = [0.0, 1.0]
  else:
    data[2:4] = [1.0, 0.0]
  return data

def set_latency(data, value):
  if value:
    data[0:2] = [0.0, 1.0]
  else:
    data[0:2] = [1.0, 0.0]
  return data

def set_edge_traffic(data, value):
  if value == 1:
    data[2:] = [0.0, 1.0, 0.0]
  elif value == 2:
    data[2:] = [0.0, 0.0, 1.0]
  else:
    data[2:] = [1.0, 0.0, 0.0]
  return data

from random import randrange
def create_graph_dict(nodes_n, edges_n, i):

  nodes = np.array([ np.array(healthy_node(), dtype=np.float32) for _ in range(nodes_n) ], dtype=np.float32)

  # first node is having high latency
  # nodes[0][0:2] = [0.0, 1.0]
  # first node is having errors
  # nodes[0][4:6] = [0.0, 1.0]
  
  # random=randrange(2) # traffic control
  # random2=randrange(4) # issue control

  # if random2 == 0:
  #   nodes[0][0:2] = [0.0, 1.0] # high latency
  # elif random2 == 1:
  #   nodes[0][2:4] = [0.0, 1.0] # high saturation
  # elif random2 == 2:
  #   nodes[0][4:6] = [0.0, 1.0] # high errors
  # else: # low/high traffic
  #   if random == 1:
  #     nodes[0][6:] = [0.0, 0.0, 1.0]
  #   else:
  #     nodes[0][6:] = [1.0, 0.0, 0.0]
  nodes[0][6:] = [1.0, 0.0, 0.0]
  # we don't have a global state (for now)
  globals_ft = [0.0, 0.0]
  
  # Edges.
  edges, senders, receivers = [], [], []

  perm_edges = list(itertools.combinations(np.arange(nodes_n), 2))
  idx = np.arange(len(perm_edges)) 
  np.random.shuffle(idx)

  for ix in idx[:edges_n]:
    # Left incoming edge.
    edges.append(healthy_edge())
    a,b = perm_edges[ix]
    senders.append(a)
    receivers.append(b)

  return {
      "globals": np.array(globals_ft, dtype=np.float64),
      "nodes": np.array(nodes, dtype=np.float64),
      "edges": np.array(edges, dtype=np.float64),
      "receivers": np.array(receivers, dtype=np.float64),
      "senders": np.array(senders, dtype=np.float64)
  }

In [None]:
#@title Flip Functions 
#@markdown These functions propagate the failure across the graph
def flip_edges(graph):
    nodes = graph.nodes.numpy()
    edges = graph.edges.numpy()
    senders = graph.senders.numpy()

    saturated_nodes = []
    bad_latency_nodes = []
    bad_traffic_nodes = []
    traffic_on_node = []
    for i in range(len(nodes)):
        ns = node_status(nodes[i])
        
        if ns.saturation:
            saturated_nodes.append(i)
            # 3- Saturation causes Latency on the same node in the next step.
            nodes[i] = set_latency(nodes[i], 1)
        if ns.latency: 
            bad_latency_nodes.append(i)
        if ns.traffic:
            bad_traffic_nodes.append(i)
            traffic_on_node.append(nodes[i][6:])
            set_traffic(nodes[i], ns.traffic) 

    for ix, sender in enumerate(senders):
        # 2 - Saturation causes immediate low traffic on downstreams in the same step.
        if sender in set(saturated_nodes):
            edges[ix] = set_edge_traffic(edges[ix], 0)
        
        if sender in set(low_traffic_nodes):
            edges[ix] = set_edge_traffic(edges[ix], 0)
            
        if sender in bad_latency_nodes:
            edges[ix] = set_latency(edges[ix], 1)

    return graph.replace(
        nodes=tf.convert_to_tensor(nodes),
        edges=tf.convert_to_tensor(edges)
    )

def flip_nodes(graph):
    nodes = graph.nodes.numpy()
    edges = graph.edges.numpy()
    receivers = graph.receivers.numpy()
    senders = graph.senders.numpy()

# 1- High traffic causes Saturation on the same node in the next step.
# 2- Saturation causes immediate low traffic on downstreams in the same step.
# 3- Saturation causes Latency on the same node in the next step.
# 4- Latency causes Latency on the upstreams in the next step.
# 5- Latency causes Errors on the same node in the next step.
# 6- Latency causes Errors on the downstreams in the next step.
# 7- Low traffic causes Low traffic on the downstreams in the next step.
# 8- Errors cause Errors on the downstreams in the next step.
    for ix in range(len(edges)):
        estatus = edge_status(edges[ix])

        node_idx = senders[ix]
        nstatus = node_status(nodes[node_idx])
        # 1- High traffic causes Saturation on the same node in the next step.
        if nstatus.traffic == 2:
          nodes[node_idx] = set_saturation(nodes[node_idx], 1)

        # # 2- Saturation causes immediate low traffic on downstreams in the same step.
        # if nstatus.saturation:
        #   nodes[node_idx] = set_traffic(nodes[node_idx], 0)

        # 3- Saturation causes Latency on the same node in the next step.
        if nstatus.saturation:
          nodes[node_idx] = set_latency(nodes[node_idx], 1)
        
        # 4- Latency causes Latency on the upstreams in the next step.
        if estatus.latency: 
            node_idx = senders[ix]
            nodes[node_idx] = set_latency(nodes[node_idx], 1)

        if nstatus.latency:
          # 5 - Latency causes Errors on the same node in the next step.
          node_idx = senders[ix]
          nodes[node_idx] = set_error(nodes[node_idx], 1)

          # 6 - Latency causes Errors on the downstreams in the next step.
          node_idx = receivers[ix]
          nodes[node_idx] = set_error(nodes[node_idx], 1)

        # 7 - Low traffic causes Low traffic on the downstreams in the next step.
        if estatus.traffic == 0: 
            node_idx = receivers[ix]
            nodes[node_idx] = set_traffic(nodes[node_idx], 0)

        # 8- Errors cause Errors on the downstreams in the next step.
        if nstatus.errors: 
            node_idx = receivers[ix]
            nodes[node_idx] = set_error(nodes[node_idx], 1)

    return graph.replace(
        nodes=tf.convert_to_tensor(nodes),
        edges=tf.convert_to_tensor(edges)
    ) 

### Testing
create a simple graph

In [None]:
# 3 nodes and 3 edges
data_dict = create_graph_dict(3, 3,1)

In [None]:
data_dict

In [None]:
#@title convert it to GraphTuple
graph_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict])

In [None]:
#@title plot it
plot_graphs_tuple(graph_tuple)

In [None]:
stage1 = flip_edges(graph_tuple)

In [None]:
plot_graphs_tuple(stage1)

In [None]:
stage2 = flip_nodes(stage1)

In [None]:
plot_graphs_tuple(stage2)


Now we want to update the receiver node at the next step.


In [None]:
stage3 = flip_edges(stage2)
plot_graphs_tuple(stage3)

In [None]:
stage4 = flip_nodes(stage3)
plot_graphs_tuple(stage4)
stage5 = flip_edges(stage4)
plot_graphs_tuple(stage5)
stage5

In [None]:
stage6 = flip_nodes(stage5)
plot_graphs_tuple(stage6)
stage7 = flip_edges(stage6)
plot_graphs_tuple(stage7)
stage7

In [None]:
def encode_state(l=0, e=0, s=0, t=0):
  n3 = np.eye(3)
  stage3 = flip_edges(stage2)
  plot_graphs_tuple(stage3)
  n2 = np.eye(2)
  return np.concatenate((n2[l], n2[e], n2[s], n3[t]), axis=0)

In [None]:
encode_state(t=1, s=1, l=1)

In [None]:
encode_state(l=1, s=1, t=2)

In [None]:
import typing

class Scenario(object):
  def __init__(self, name):
    self._graph_nx = nx.OrderedMultiDiGraph()
    self._count: int = 0
    self._name = name
    self._map = {}

  def add_node(self, name, l=0, e=0, s=0, t=1):
    """
    l: Latency (low=0, high=1)
    e: Error (no errors=0, errors=1)
    s: Saturation (low=0, high=1)
    t: Traffic (low=0, medium=1, high=2)
    """
    f = encode_state(l=l, e=e, s=s, t=t)
    self._graph_nx.add_node(self._count, 
                                label=name, 
                                features=f)
    node = tuple((name, self._count, f))
    self._map[self._count] = name
    self._count += 1
    return node
  
  def set_traffic(self, node, t=0):
    idx = node[1]
    self._graph_nx.nodes[idx]['features'] = set_traffic(self._graph_nx.nodes[idx]['features'], t)
    return self

  def set_saturation(self, node, s=0):
    idx = node[1]
    self._graph_nx.nodes[idx]['features'] = set_saturation(self._graph_nx.nodes[idx]['features'], s)
    return self
  
  def set_latency(self, node, l=0):
    idx = node[1]
    self._graph_nx.nodes[idx]['features'] = set_latency(self._graph_nx.nodes[idx]['features'], l)
    return self
  
  def set_error(self, node, e=0):
    idx = node[1]
    self._graph_nx.nodes[idx]['features'] = set_error(self._graph_nx.nodes[idx]['features'], e)
    return self

  def connect(self, a, b, l=0, t=0):
    f = healthy_edge()
    self._graph_nx.add_edge(a[1], b[1], features=f)
    return

  def _repr_pretty_(self, p, cycle, pos=0):
    ax = plt.figure(figsize=(3, 3)).gca()
    #nx.draw(self._graph_nx, ax=ax, with_labels=True)
    #ax.set_title(self._name)
    #plot_graph_networkx(self._graph_nx, ax=ax)
    graph = self._graph_nx 
    
    node_color_map = [node_color(node, data) for node, data in graph.nodes(data=True)]
    edge_color_map = [edge_color(sender, receiver, data) for sender, receiver, data in graph.edges(data=True)]
    
    random_pos = nx.random_layout(graph, seed=42)
    pos = nx.spring_layout(graph, pos=random_pos)
  
    nx.draw_networkx(graph, 
                     pos, 
                     ax=ax, 
                     labels=self._map,
                     edge_color=edge_color_map,
                     node_color=node_color_map)
    
    edge_labels = None 
    if edge_labels:
      nx.draw_networkx_edge_labels(graph, pos, edge_labels, ax=ax)

    global_label = None  
    if global_label:
      plt.text(0.05, 0.95, global_label, transform=ax.transAxes)
  
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    return ax


In [None]:
#@title Build Graph Manually
scenario = Scenario(name="Short failing node due to high traffic")

A = scenario.add_node('A', l=0, e=0, s=0, t=1)
B = scenario.add_node('B', l=0, e=0, s=0, t=0)
C = scenario.add_node('C', l=0, e=0, s=0, t=0)
D = scenario.add_node('D', l=0, e=0, s=0, t=0)
scenario.connect(A, B)
scenario.connect(B, C)
scenario.connect(A, D)

In [None]:
scenario

In [None]:
scenario.set_traffic(B, 2)

In [None]:
graphs_tuple = utils_np.networkxs_to_graphs_tuple([scenario._graph_nx])
graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(utils_np.graphs_tuple_to_data_dicts(graphs_tuple))
plot_graphs_tuple(graphs_tuple)

In [None]:
steps = 8
state = graphs_tuple
for i in range(steps):
  state = flip_nodes(state)
  plot_graphs_tuple(state)
  state = flip_edges(state)
  plot_graphs_tuple(state)

# Simulator

With this logic we can create our desired simulator that updates the graph as a failure propagation across the nodes at each timestep.

In [None]:
def simulator(initial_state, steps):
    current_state = initial_state
    records = [ copy.copy(initial_state) ]
    for s in range(steps):
        if s % 2 == 1:
            current_state = flip_nodes(current_state)
        else:
            current_state = flip_edges(current_state)
        # propagate node value to outgoing edges
        records.append(copy.copy(current_state))
    return records

In [None]:
test_graph_dict = create_graph_dict(5, 6,1)
test_graph = utils_tf.data_dicts_to_graphs_tuple([test_graph_dict])
graph_at_timestep = simulator(test_graph, steps=10)
plot_compare_graphs(graph_at_timestep, [ f"stage{i}" for i in range(10+1) ])

In [None]:
stage0_dict = create_graph_dict(3, 3,1)
stage0 = utils_tf.data_dicts_to_graphs_tuple([stage0_dict])
stage1 = flip_nodes(stage0)
plot_graphs_tuple(stage1)

Now let's run the simulator for N steps and see how it updates the graph

In [None]:
TIMESTEPS=5
graph_at_timestep = simulator(stage0, steps=TIMESTEPS)

In [None]:
plot_compare_graphs(graph_at_timestep, [ f"stage{i}" for i in range(TIMESTEPS+1) ])


Let's try more complex graphs with more nodes

In [None]:
# #@title sample with 5 nodes and 5 edges
# sample_graph = create_graph_dict(5, 5,1)
# graph_tuple = utils_tf.data_dicts_to_graphs_tuple([sample_graph])
# ## seems we have to flip always the first 
# results = simulator(graph_tuple, steps=int(graph_tuple.n_node.numpy()))
# plot_compare_graphs(results, [ f"step{i}" for i in range(10) ])

In [None]:
# #@title sample with 6 nodes and 10 edges
# sample_graph = create_graph_dict(6, 10,1)
# graph_tuple = utils_tf.data_dicts_to_graphs_tuple([sample_graph])
# results = simulator(graph_tuple, steps=int(graph_tuple.n_node.numpy()))
# plot_compare_graphs(results, [ f"step{i}" for i in range(10) ])


In [None]:
graph_tuple.senders.shape, graph_tuple.receivers.shape, graph_tuple.edges.shape, graph_tuple.nodes.shape

In [None]:
graph_tuple.senders.shape, graph_tuple.receivers.shape, graph_tuple.edges.shape, graph_tuple.nodes.shape

# Create the Model

## Questions

- When does the model break?
- How many steps can we simulate?
- How many nodes can we generalize over? (range)



# Loss Function
The training loss is computed on the output of each processing step. 

The reason for this is to encourage the model to try to solve the problem in as
few steps as possible. 

It also helps make the output of intermediate steps more interpretable

The loss is computed only on the final processing step.

For each node and edge we want to compute the classification error using the crossentropy loss 

$y log(p)+(1-y)log(1-p)$ 

The loss for each step is the sum of the losses for each node, plus the sum of the losses for each edge.


In [None]:
#@title Loss Function

cross_entropy = tf.losses.categorical_crossentropy

def create_loss(target, predicted, **kwds):
    """
    sum of the crosspentropy loss for the edges and the nodes

    Args:
        target: a `graphs.GraphsTuple` which contains the target as a graph.
        predicted: a `list` of `graphs.GraphsTuple`s which contains the model
            outputs for each processing step as graphs.

    Returns:
        A `list` of ops which are the loss for each processing step. 
    """
    # this is the list of the mean losses for each timestep
    # + label smoothing to help the learning phase
    losses = [ 
            tf.math.reduce_sum(
                tf.concat( [
                    cross_entropy(target.nodes[:, 0:2], output.nodes[:, 0:2], **kwds), # latency 
                    cross_entropy(target.nodes[:, 2:4], output.nodes[:, 2:4], **kwds),  # error
                    cross_entropy(target.nodes[:, 4:6], output.nodes[:, 4:6], **kwds),  # saturation 
                    cross_entropy(target.nodes[:, 6:], output.nodes[:, 6:], **kwds)  # traffic
                ], axis=0)
              ) 
        +
            tf.math.reduce_sum(
              tf.concat([
                cross_entropy(target.edges[:, 0:2], output.edges[:, 0:2], **kwds), 
                cross_entropy(target.edges[:, 2:], output.edges[:, 2:], **kwds)
              ], axis=0))
    for output in predicted ]

    return losses


In [None]:
#@title Accuracy 
def compute_accuracy(target, output):
  """Calculate model accuracy.

  Returns the number of correctly predicted nodes and the number
  of predicted edges

  Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.

  Returns:
    correct: A `float` fraction of correctly labeled nodes.
    solved: A `float` fraction of graphs that are completely correctly labeled.
  """
  tdds = utils_np.graphs_tuple_to_data_dicts(target)
  odds = utils_np.graphs_tuple_to_data_dicts(output)
  cs = []
  ss = []
  for td, od in zip(tdds, odds):
    latency_xn = np.argmax(td["nodes"][:, 0:2], axis=-1)
    latency_yn = np.argmax(od["nodes"][:, 0:2], axis=-1)
    
    errors_xn = np.argmax(td["nodes"][:, 2:4], axis=-1)
    errors_yn = np.argmax(od["nodes"][:, 2:4], axis=-1)
    
    saturation_xn = np.argmax(td["nodes"][:, 4:6], axis=-1)
    saturation_yn = np.argmax(od["nodes"][:, 4:6], axis=-1)
    
    traffic_xn = np.argmax(td["nodes"][:, 6:], axis=-1)
    traffic_yn = np.argmax(od["nodes"][:, 6:], axis=-1)
    
    latency_xe = np.argmax(td["edges"][:, 0:2], axis=-1)
    latency_ye = np.argmax(od["edges"][:, 0:2], axis=-1)
    
    traffic_xe = np.argmax(td["edges"][:, 2:], axis=-1)
    traffic_ye = np.argmax(od["edges"][:, 2:], axis=-1)


    # if the train node value is equal to the true node value
    # if the edge node value is equal to the true edge value
    # correct edges and nodes
    c = np.concatenate((latency_xn == latency_yn, 
                        errors_xn == errors_yn, 
                        saturation_xn == saturation_yn,
                        traffic_xn == traffic_yn,
                        traffic_xe == traffic_ye,
                        latency_xe == latency_ye), 
                       axis=0)
    # solved
    s = np.all(c)
    cs.append(c)
    ss.append(s)
  correct = np.mean(np.concatenate(cs, axis=0))
  solved = np.mean(np.stack(ss))
  return correct, solved



- **Ltr**: training loss
- **Lge**: test/generalization loss
      
- **Ctr**: training fraction nodes/edges labeled correctly

- **Str**: training fraction examples solved correctly

- **Cge**: test/generalization fraction nodes/edges labeled correctly 

- **Sge**: test/generalization fraction examples solved correctly

In [None]:
import gc

In [None]:
#@title # Reset State 
try:
    del model
except:
    pass
gc.collect()

NUM_PROCESSING_STEPS=10 #@param 

#@markdown # Data / training parameters
# num_training_iterations = 50000 #@param 
BATCH_SIZE_TRAIN = 32 #@param 
BATCH_SIZE_VAL = 100 #@param



In [None]:
import sonnet as snt
import tensorflow as tf

def make_mlp_model(latent_size, num_layers, final_output_size):
  """Instantiates a new MLP, followed by LayerNorm.
  The parameters of each new MLP are not shared with others generated by
  this function.
  Returns:
    A Sonnet module which contains the MLP and LayerNorm.
  """
  return snt.Sequential([
      snt.nets.MLP(output_sizes=[latent_size] * num_layers + [final_output_size], activate_final=True),
      snt.LayerNorm(axis=-1, create_offset=True, create_scale=True)
  ])

#@title Interaction Network module
# the reducer is sum as is the sum of all the traffic in input
from graph_nets.modules import InteractionNetwork
import functools

NUM_EDGE_FEATURES = 2 + 3 
NUM_NODE_FEATURES = 2 + 2 + 2 + 3

# edge_model_fn = functools.partial(make_mlp_model, 
#                                   latent_size=64, 
#                                   num_layers=10, 
#                                   final_output_size=NUM_EDGE_FEATURES) 

# node_model_fn = functools.partial(make_mlp_model, 
#                                   latent_size=64, 
#                                   num_layers=20, 
#                                   final_output_size=NODE_FEATURES_DIM)

# model = InteractionNetwork(edge_model_fn, 
#                            node_model_fn, 
#                            reducer=tf.math.unsorted_segment_sum)


# Model
model = models.EncodeProcessDecode(edge_output_size=NUM_EDGE_FEATURES, 
                                   node_output_size=NUM_NODE_FEATURES, 
                                   global_output_size=None)



In [None]:
#@title Setup Optimizer, GraphNetwork
learning_rate = 0.001 #@param
optimizer = snt.optimizers.Adam(learning_rate)

In [None]:
#@title Update Step, this function updates the weights of our model based on the loss

# def update_step(inputs, targets):
#   with tf.GradientTape() as tape:
#     output = model(inputs)
#     list_loss = create_loss(targets, [output], from_logits=True)
#     # sum the losses of each step and averages over the number of steps
#     # total_loss = tf.math.reduce_sum(list_loss) / NUM_PROCESSING_STEPS 
#     total_loss = tf.math.reduce_sum(list_loss)

#   gradients = tape.gradient(total_loss, model.trainable_variables)
#   optimizer.apply(gradients, model.trainable_variables)
#   return [output], total_loss, list_loss

def update_step(inputs, targets):
  with tf.GradientTape() as tape:
    output = model(inputs, NUM_PROCESSING_STEPS)
    list_loss = create_loss(targets, output, from_logits=True)
    # sum the losses of each step and averages over the number of steps
    # total_loss = tf.math.reduce_sum(list_loss) / NUM_PROCESSING_STEPS 
    total_loss = tf.math.reduce_sum(list_loss) / NUM_PROCESSING_STEPS

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply(gradients, model.trainable_variables)
  return output, total_loss, list_loss


# Test update step

In [None]:
NUM_NODES = 10
NUM_EDGES = 15
g = create_graph_dict(NUM_NODES, NUM_EDGES,1)
gt = utils_tf.data_dicts_to_graphs_tuple([g])
gt_list = simulator(gt, NUM_PROCESSING_STEPS)

# batch with only the last element
gt_truth = gt_list[-1]

input_batch = utils_tf.concat([gt], axis=0)
target_batch = utils_tf.concat([gt_truth], axis=0)

In [None]:
def create_example(num_nodes, num_edges, i,timesteps=NUM_PROCESSING_STEPS):
    g = create_graph_dict(num_nodes, num_edges,i)
    gt = utils_tf.data_dicts_to_graphs_tuple([g])
    
    # simulate the failure propagation
    g_list = simulator(gt, steps=timesteps)
    gt_final_state = g_list[-1]
    return gt, gt_final_state


def gen_batch(batch_size):
    # Compute a random graph. 
    # NOTE: We need a better heuristic to generate graphs that 
    #       represent interesting cases for our problem
    inputs_batch = []
    targets_batch = []

    num_nodes = NUM_NODES
    num_edges = NUM_EDGES
    for i in range(batch_size):
        g = create_graph_dict(num_nodes, num_edges,i)
        gt = utils_tf.data_dicts_to_graphs_tuple([g])
        
        # simulate the failure propagation
        gt_list = simulator(gt, steps=NUM_PROCESSING_STEPS)
        inputs_batch.append(gt)
        targets_batch.append(gt_list[-1])

    inputs_tf_batch = utils_tf.concat(inputs_batch, axis=0)
    targets_tf_batch = utils_tf.concat(targets_batch, axis=0)
    return inputs_tf_batch, targets_tf_batch

In [None]:
input_batch, target_batch = gen_batch(BATCH_SIZE_TRAIN)

In [None]:
len(input_batch[0]), len(target_batch[0])

In [None]:
# Get the input signature for that function by obtaining the specs
batch_input_signature = [
  utils_tf.specs_from_graphs_tuple(input_batch),
  utils_tf.specs_from_graphs_tuple(target_batch)
]

# Compile the update function using the input signature for speedy code.
compiled_batch_update_step = tf.function(update_step, input_signature=batch_input_signature)

Let's optimize a single graph update function for debugging

In [None]:
#@title Compile single graph update step
gt, gt_truth = create_example(NUM_NODES, NUM_EDGES,1)

single_input_signature = [
  utils_tf.specs_from_graphs_tuple(gt),
  utils_tf.specs_from_graphs_tuple(gt_truth)
]

single_compiled_update_step = tf.function(update_step, 
                                          input_signature=single_input_signature)


## Test Step by Step Single Value

In [None]:
g_in, g_out = create_example(NUM_NODES, NUM_EDGES,1)

In [None]:
plot_compare_graphs([g_in, g_out], ["input", "output"])

In [None]:
# train for a bit
for i in range(10):
    print(i)
    steps, loss, list_loss = single_compiled_update_step(g_in, g_out)
    labels = [ f"loss:{v.numpy()}" for v in list_loss]
    #plot_compare_graphs(steps, labels)
    # print(compute_accuracy(g_out, steps[-1]))

In [None]:
g_in, g_out = create_example(NUM_NODES, NUM_EDGES,1)
plot_compare_graphs([g_in, g_out], ["input", "output"])
results = model(g_in, NUM_PROCESSING_STEPS)
plot_compare_graphs([results[-1], g_out], ["result", "target"])

In [None]:
compute_accuracy(results[-1], g_out)

In [None]:
#@title sample with 5 nodes and 5 edges
sample_graph = create_graph_dict(5, 6, 0)
graph_tuple = utils_tf.data_dicts_to_graphs_tuple([sample_graph])
## seems we have to flip always the first 
results = simulator(graph_tuple, steps=9)
plot_compare_graphs(results, [ f"step{i}" for i in range(10) ])
# results = simulator(results[-1], steps=4)
# plot_compare_graphs(results, [ f"step{i+5}" for i in range(5) ])


# LOOP

In [None]:
#@title init state
last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

log_every_seconds =  10#@param
num_processing_steps_tr = 10 #@param

In [None]:
from tqdm import auto as tqdm

In [None]:
#@title Train Loop
num_training_iterations = 3000
start_time = time.time()
last_log_time = start_time
for iteration in tqdm.tqdm(range(last_iteration, num_training_iterations)):
  last_iteration = iteration
  n_nodes, n_edges = NUM_NODES , NUM_EDGES

  # outputs_tr, loss_tr = compiled_update_step(inputs_tr, targets_tr)
  g_in, g_out = create_example(n_nodes, n_edges,iteration)

  outputs_tr, total_loss_tr, _ = single_compiled_update_step(g_in, g_out)

  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time

  # keep looping 
  if elapsed_since_last_log <= log_every_seconds: continue

  # LOG 
  num_processing_steps_ge = num_processing_steps_tr
  g_val_in, g_val_out = create_example(n_nodes+1, n_edges+2,iteration)

  last_log_time = the_time
  outputs_val = model(g_val_in, NUM_PROCESSING_STEPS)
  loss_list = create_loss(g_val_out, outputs_val)
  loss_ge = loss_list[-1]

  # outputs_ge is a list of N * batch_size graphs
  
  # plot_compare_graphs( [outputs_val, g_val_out], 
  #                     [ "result", "ground truth"])
  
  # correct_tr, solved_tr = compute_accuracy(g_out, outputs_tr[-1])
  correct_tr, solved_tr = compute_accuracy(g_out, outputs_tr[-1])
  correct_ge, solved_ge = compute_accuracy(g_val_out, outputs_val[-1])
  
  elapsed = time.time() - start_time
  # train
  losses_tr.append(total_loss_tr.numpy())
  corrects_tr.append(correct_tr)
  solveds_tr.append(solved_tr)
  # val
  losses_ge.append(loss_ge.numpy())
  corrects_ge.append(correct_ge)
  solveds_ge.append(solved_ge)

  logged_iterations.append(iteration)
  print("# {:05d}, T {:05.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, "
          "Str {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, total_loss_tr.numpy(), loss_ge.numpy(),
              correct_tr, solved_tr, correct_ge, solved_ge))

In [None]:
print("Training overall:")
print("Number of evaluated models: "+str(len(corrects_tr)))
print(corrects_tr)
print(solveds_tr)
average = sum(corrects_tr) / len(corrects_tr)
print("The average of correct classified nodes/edges (Training) " + str(round(average, 2)))

print("Generalization overall:")
print("Number of evaluated models: "+str(len(corrects_ge)))
print(corrects_ge)
print(solveds_ge)
average = sum(corrects_ge) / len(corrects_ge)
print("The average of correct classified nodes/edges (Generaliztion) " + str(round(average, 2)))


In [None]:
# print("# {:05d}, T {:05.1f} loss:train {:.4f} loss:gen {:.4f} correct:train {:.4f} "
#           "solved:train {:.4f} correct:gen {:.4f} solved:gen {:.4f}".format(
#               iteration, elapsed, total_loss_tr.numpy(), loss_ge.numpy(),
#               correct_tr, solved_tr, correct_ge, solved_ge))

## Manual Inspection

In [None]:
# results

In [None]:
# MAX_ITERATIONS = 20
# plot_compare_graphs([results], [ str(i) for i in range(MAX_ITERATIONS)])

In [None]:
MAX_ITERATIONS = 10
for node_variance in range(1, 10):
    for edge_variance in range(1, 10):
        g_in, g_out = create_example(NUM_NODES + node_variance, NUM_EDGES + edge_variance,1)
        for timesteps in range(1, MAX_ITERATIONS+1):
            results = model(g_in, timesteps)
            correct, solved = compute_accuracy(results[-1], g_out)
            if correct and solved:
                print('solved in %d' % timesteps)
                plot_compare_graphs([g_in, results[-1], g_out], [ "input", f"Correct:{correct:.2f} Solved:{solved:.2f}", "target"])
                break
            if timesteps == MAX_ITERATIONS:
                print("not solved")
                plot_compare_graphs([g_in, results[-1], g_out], [ "input", f"Correct:{correct:.2f} Solved:{solved:.2f}", "target"])


In [None]:
accuracy_list=[]
# the edge count shoud increase with the sqaure of the node size

for i in range(1,  2):
    g_in, g_out = create_example(100, 625,1)
    results = model(g_in, NUM_PROCESSING_STEPS)
    correct, solved = compute_accuracy(results[-1], g_out)
    accuracy_list.append(correct)
    plot_compare_graphs_custom([g_in, results[-1], g_out], ["input","Prediction", "Actual"])
    # if i > 990:
    #   plot_compare_graphs([g_in, results[-1], g_out], [ "input", f"Prediction Correct:{correct:.2f} Solved:{solved:.2f}", "Actual"])

print("1k nodes generalization:")
print(accuracy_list)
average = sum(accuracy_list) / len(accuracy_list)
print("The average of correct classified nodes/edges " + str(round(average, 2)))         

In [None]:
plt.plot(losses_tr)
# plt.plot(accuracy_list)

In [None]:
plt.figure(figsize=(50, 7))
plt.plot(corrects_tr)

In [None]:
%watermark -d -t -n