# Graph Failure Propagation using Deepmind's graph_nets library

This notebook models a network failure propagation scenario using a Graph Neural Network.

The framework used is [Google's Deepmind GraphNet Framework](https://github.com/deepmind/graph_nets/)

## Install Libraries

In [None]:
!df -h

In [None]:
!cat /proc/cpuinfo

In [None]:
!cat /proc/meminfo

In [None]:
%pip install --quiet watermark

In [None]:
%pip install --quiet "graph_nets>=1.1" "dm-sonnet>=2.0.0b0"

In [None]:
%load_ext watermark

In [None]:
%watermark -u -n -t -z -p numpy,tensorflow,sonnet,graph_nets -g

## Configure Environment



In [None]:
#@title Imports
import time
import copy
import itertools

import tensorflow as tf
import networkx as nx
import sonnet as snt

from matplotlib import pyplot as plt
try:
  import seaborn as sns
except ImportError:
  pass
else:
  sns.reset_orig()

from graph_nets import blocks
from graph_nets import utils_tf
from graph_nets import utils_np
from graph_nets.demos_tf2 import models

import numpy as np


In [None]:
#@title Set Random seeds
SEED = 42 #@param
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
#@title Load Graph Visualization Functions

def plot_graphs_tuple(graphs_tuple):
  networkx_graphs = utils_np.graphs_tuple_to_networkxs(graphs_tuple)
  num_graphs = len(networkx_graphs)
  _, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5))
  if num_graphs == 1:
    axes = axes,
  for graph, ax in zip(networkx_graphs, axes):
    plot_graph_networkx(graph, ax)


def plot_graph_networkx(graph, ax, pos=None):
  node_labels = {node: "{:.2g}|{:.2g}".format(*data["features"])
                 for node, data in graph.nodes(data=True)
                 if data["features"] is not None}
  edge_labels = {(sender, receiver): "{:.2g}|{:.2g}".format(*data["features"])
                 for sender, receiver, data in graph.edges(data=True)
                 if data["features"] is not None}
  global_label = ("{:.3g}".format(graph.graph["features"][0])
                  if graph.graph["features"] is not None else None)
  node_color_map = ["r" if np.argmax(data["features"]) == 1. else 'g'
                 for node, data in graph.nodes(data=True)]
  edge_color_map = ["r" if np.argmax(data["features"]) == 1. else 'g'
                 for sender, receiver, data in graph.edges(data=True)]
  
  
  if pos is None:
    random_pos = nx.random_layout(graph, seed=42)
    pos = nx.spring_layout(graph, pos=random_pos)

  nx.draw_networkx(graph, pos, ax=ax, 
                   labels=node_labels,
                   edge_color=edge_color_map,
                   node_color=node_color_map)

  if edge_labels:
    nx.draw_networkx_edge_labels(graph, pos, edge_labels, ax=ax)

  if global_label:
    plt.text(0.05, 0.95, global_label, transform=ax.transAxes)

  ax.yaxis.set_visible(False)
  ax.xaxis.set_visible(False)
  return pos


def plot_compare_graphs(graphs_tuples, labels):
  pos = None
  num_graphs = len(graphs_tuples)
  _, axes = plt.subplots(1, num_graphs, figsize=(50*num_graphs, 50))
  if num_graphs == 1:
    axes = axes,
  pos = None
  for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes):
    graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0]
    pos = plot_graph_networkx(graph, ax, pos=pos)
    ax.set_title(name)

In [None]:
#@title Create Graph Dict { form-width: "30%" }
#@markdown number of features per node
NODE_FEATURES_DIM = 2 #@param
#@markdown number of features per edge
EDGES_FEATURES_DIM = 2 #@param

def create_graph_dict(nodes_n, edges_n):
  nodes = np.ones((nodes_n, NODE_FEATURES_DIM), dtype=np.float32)
  n = nodes_n
  d = edges_n 

  # Nodes: set all the values to healthy 
  # - healthy:[1, 0]
  # - faulty:[0, 1][1, 0]
  nodes[:, 1] = 0. 
  nodes[0, :] = [0., 1.]  # flip first

  # we don't have a global state (for now)
  globals_ft = [0.0, 0.0]
  
  # Edges.
  edges, senders, receivers = [], [], []

  perm_edges = list(itertools.combinations(np.arange(nodes_n), 2))
  idx = np.arange(len(perm_edges)) 
  np.random.shuffle(idx)

  for ix in idx[:edges_n]:
    # Left incoming edge.
    edges.append([1., 0])
    a,b = perm_edges[ix]
    senders.append(a)
    receivers.append(b)

  return {
      "globals": globals_ft,
      "nodes": nodes,
      "edges": edges,
      "receivers": receivers,
      "senders": senders
  }

In [None]:
#@title Flip Functions
#@markdown These functions propagate the failure across the Graph
def flip_edges(graph):
    nodes = graph.nodes.numpy()
    edges = graph.edges.numpy()
    senders = graph.senders.numpy()

    bad_nodes = []
    for i in range(len(nodes)):
        if nodes[i][0] == 0: # is dead
            bad_nodes.append(i)

    for ix, sender in enumerate(senders):
        if sender in bad_nodes:
            edges[ix][:] = [0.0, 1.0] 

    return graph.replace(
        nodes=tf.convert_to_tensor(nodes),
        edges=tf.convert_to_tensor(edges)
    )
def flip_nodes(graph):
    nodes = graph.nodes.numpy()
    edges = graph.edges.numpy()
    receivers = graph.receivers.numpy()

    for ix in range(len(edges)):
        if edges[ix][0] == 0: # is dead
            node_idx = receivers[ix]
            nodes[node_idx][:] = [0., 1.0] # dead

    return graph.replace(
        nodes=tf.convert_to_tensor(nodes),
    ) 
    pass

# Simulator

With this logic we can create our desired simulator that updates the graph as a failure propagation across the nodes at each timestep.

In [None]:
def simulator(initial_state, steps):
    current_state = initial_state
    records = [ copy.copy(initial_state) ]
    for s in range(steps):
        if s % 2 == 1:
            current_state = flip_nodes(current_state)
        else:
            current_state = flip_edges(current_state)
        # propagate node value to outgoing edges
        records.append(copy.copy(current_state))
    return records

# Create the Model


## Encode Process Decode

We will use the EncodeProcessDecode module.

How many iterations we want to do on the graph is determined by the number of iterations.

Parameters are shared between the different processing steps.

Encoder:

A "Core" graph net, which performs `N` rounds of processing (message-passing)steps. 

The input to the Core is the concatenation of the Encoder's output and the previous output of the Core (labeled "Hidden(t)" below, where "t" is the processing step).



In [None]:
#@title EncodeProcessDecode module
model = models.EncodeProcessDecode(edge_output_size=EDGES_FEATURES_DIM, 
                                   node_output_size=NODE_FEATURES_DIM, 
                                   global_output_size=None)

# Loss Function
The training loss is computed on the output of each processing step. 

The reason for this is to encourage the model to try to solve the problem in as
few steps as possible. 

It also helps make the output of intermediate steps more interpretable

The loss is computed only on the final processing step.

For each node and edge we want to compute the classification error using the crossentropy loss 

$y log(p)+(1-y)log(1-p)$ 

The loss for each step is the sum of the losses for each node, plus the sum of the losses for each edge.


In [None]:
#@title Loss Function

cross_entropy = tf.losses.categorical_crossentropy

def create_loss(target, predicted, **kwds):
    """
    sum of the crosspentropy loss for the edges and the nodes

    Args:
        target: a `graphs.GraphsTuple` which contains the target as a graph.
        predicted: a `list` of `graphs.GraphsTuple`s which contains the model
            outputs for each processing step as graphs.

    Returns:
        A `list` of ops which are the loss for each processing step. 
    """
    # this is the list of the mean losses for each timestep
    # + label smoothing to help the learning phase
    losses = [ 
            tf.math.reduce_sum(cross_entropy(target.nodes, output.nodes, **kwds)) + 
            tf.math.reduce_sum(cross_entropy(target.edges, output.edges, **kwds))
    for output in predicted]

    return losses


In [None]:
#@title Accuracy 
def compute_accuracy(target, output):
  """Calculate model accuracy.

  Returns the number of correctly predicted nodes and the number
  of predicted edges

  Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.

  Returns:
    correct: A `float` fraction of correctly labeled nodes.
    solved: A `float` fraction of graphs that are completely correctly labeled.
  """
  tdds = utils_np.graphs_tuple_to_data_dicts(target)
  odds = utils_np.graphs_tuple_to_data_dicts(output)
  cs = []
  ss = []
  for td, od in zip(tdds, odds):
    xn = np.argmax(td["nodes"], axis=-1)
    yn = np.argmax(od["nodes"], axis=-1)
    
    xe = np.argmax(td["edges"], axis=-1)
    ye = np.argmax(od["edges"], axis=-1)
    # if the train node value is equal to the true node value
    # if the edge node value is equal to the true edge value
    # correct edges and nodes
    c = np.concatenate((xn == yn, xe == ye), axis=0)
    # solved
    s = np.all(c)
    cs.append(c)
    ss.append(s)
  correct = np.mean(np.concatenate(cs, axis=0))
  solved = np.mean(np.stack(ss))
  return correct, solved


In [None]:
# should be correct and solved ( 1.0, 1.0)
graph_tuple = utils_tf.data_dicts_to_graphs_tuple([create_graph_dict(5, 5)])

input_graph_tuple = graph_tuple
(input_graph_tuple.nodes.shape, input_graph_tuple.edges.shape)
compute_accuracy(input_graph_tuple, input_graph_tuple)

# Train Loop


- **Ltr**: training loss
- **Lge**: test/generalization loss
      
- **Ctr**: training fraction nodes/edges labeled correctly

- **Str**: training fraction examples solved correctly

- **Cge**: test/generalization fraction nodes/edges labeled correctly 

- **Sge**: test/generalization fraction examples solved correctly

In [None]:
import gc

In [None]:
#@title # Reset State 
try:
    del model
except:
    pass
gc.collect()

NUM_PROCESSING_STEPS=10 #@param 

#@markdown # Data / training parameters
num_training_iterations = 2000 #@param 
BATCH_SIZE_TRAIN = 32 #@param 
BATCH_SIZE_VAL = 100 #@param

EDGES_OUTPUT_DIM = EDGES_FEATURES_DIM
NODES_OUTPUT_DIM = NODE_FEATURES_DIM

# Model
model = models.EncodeProcessDecode(edge_output_size=EDGES_OUTPUT_DIM, 
                                   node_output_size=NODES_OUTPUT_DIM, 
                                   global_output_size=None)



In [None]:
#@title Setup Optimizer, GraphNetwork
learning_rate = 1e-3 #@param
optimizer = snt.optimizers.Adam(learning_rate)

In [None]:
#@title Update Step, this function updates the weights of our model based on the loss

def update_step(inputs, targets):
  with tf.GradientTape() as tape:
    outputs = model(inputs, NUM_PROCESSING_STEPS)
    list_loss = create_loss(targets, outputs, from_logits=True)
    # sum the losses of each step and averages over the number of steps
    total_loss = tf.math.reduce_sum(list_loss) / NUM_PROCESSING_STEPS 

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply(gradients, model.trainable_variables)
  return outputs, total_loss, list_loss


In [None]:
NUM_NODES = 5
NUM_EDGES = 6

def create_example(num_nodes, num_edges, timesteps=NUM_PROCESSING_STEPS):
    g = create_graph_dict(num_nodes, num_edges)
    gt = utils_tf.data_dicts_to_graphs_tuple([g])
    
    # simulate the failure propagation
    g_list = simulator(gt, steps=timesteps)
    gt_final_state = g_list[-1]
    return gt, gt_final_state


def gen_batch(batch_size):
    # Compute a random graph. 
    # NOTE: We need a better heuristic to generate graphs that 
    #       represent interesting cases for our problem
    inputs_batch = []
    targets_batch = []

    num_nodes = NUM_NODES
    num_edges = NUM_EDGES
    for i in range(batch_size):
        g = create_graph_dict(num_nodes, num_edges)
        gt = utils_tf.data_dicts_to_graphs_tuple([g])
        
        # simulate the failure propagation
        gt_list = simulator(gt, steps=NUM_PROCESSING_STEPS)
        inputs_batch.append(gt)
        targets_batch.append(gt_list[-1])

    inputs_tf_batch = utils_tf.concat(inputs_batch, axis=0)
    targets_tf_batch = utils_tf.concat(targets_batch, axis=0)
    return inputs_tf_batch, targets_tf_batch

In [None]:
input_batch, target_batch = gen_batch(BATCH_SIZE_TRAIN)

In [None]:
len(input_batch[0]), len(target_batch[0])


# LOOP

In [None]:
#@title init state
last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

log_every_seconds =  5#@param
num_processing_steps_tr = 10 #@param
#@title Compile single graph update step
gt, gt_truth = create_example(NUM_NODES, NUM_EDGES)

single_input_signature = [
  utils_tf.specs_from_graphs_tuple(gt),
  utils_tf.specs_from_graphs_tuple(gt_truth)
]

single_compiled_update_step = tf.function(update_step, 
                                          input_signature=single_input_signature)
g_in, g_out = create_example(NUM_NODES, NUM_EDGES)

In [None]:
# # train for a bit
# for i in range(10):
#     print(i)
#     steps, loss, list_loss = single_compiled_update_step(g_in, g_out)
#     labels = [ f"loss:{v.numpy()}" for v in list_loss]
#     plot_compare_graphs(steps, labels)
#     print(compute_accuracy(g_out, steps[-1]))

In [None]:
#@title sample with 5 nodes and 5 edges
sample_graph = create_graph_dict(5, 6)
graph_tuple = utils_tf.data_dicts_to_graphs_tuple([sample_graph])
## seems we have to flip always the first 
results = simulator(graph_tuple, steps=9)
plot_compare_graphs(results, [ f"step{i}" for i in range(10) ])
# results = simulator(results[-1], steps=4)
# plot_compare_graphs(results, [ f"step{i+5}" for i in range(5) ])

In [None]:
#@title Train Loop
start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
  last_iteration = iteration
  n_nodes, n_edges = NUM_NODES + 30, NUM_EDGES + 40

  # outputs_tr, loss_tr = compiled_update_step(inputs_tr, targets_tr)
  g_in, g_out = create_example(n_nodes, n_edges)

  outputs_tr, total_loss_tr, _ = single_compiled_update_step(g_in, g_out)

  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time

  # keep looping 
  if elapsed_since_last_log <= log_every_seconds: continue
  
  # LOG 
  num_processing_steps_ge = num_processing_steps_tr
  g_val_in, g_val_out = create_example(n_nodes+1, n_edges+2)
  
  last_log_time = the_time
  outputs_val = model(g_val_in, num_processing_steps_ge)
  loss_list = create_loss(g_val_out, outputs_val)
  loss_ge = loss_list[-1]

  # outputs_ge is a list of N * batch_size graphs
  
  if elapsed_since_last_log <= log_every_seconds:
    plot_compare_graphs( [outputs_val[-1], g_val_out], 
                      [ "Predicted", "Actual"])
  
  # plot_compare_graphs( [outputs_val[-1], g_val_in], 
  #                     [ "result", "ground truth"])
  
  correct_tr, solved_tr = compute_accuracy(g_out, outputs_tr[-1])
  correct_ge, solved_ge = compute_accuracy(g_val_out, outputs_val[-1])
  
  elapsed = time.time() - start_time
  # train
  losses_tr.append(total_loss_tr.numpy())
  corrects_tr.append(correct_tr)
  solveds_tr.append(solved_tr)
  # val
  losses_ge.append(loss_ge.numpy())
  corrects_ge.append(correct_ge)
  solveds_ge.append(solved_ge)

  logged_iterations.append(iteration)
  print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, "
          "Str {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, total_loss_tr.numpy(), loss_ge.numpy(),
              correct_tr, solved_tr, correct_ge, solved_ge))

In [None]:
print("Training overall:")
print("Number of evaluated models: "+str(len(corrects_tr)))
print(corrects_tr)
print(solveds_tr)
average = sum(corrects_tr) / len(corrects_tr)
print("The average of correct classified nodes (Training) " + str(round(average, 2)))

print("Generalization overall:")
print("Number of evaluated models: "+str(len(corrects_ge)))
print(corrects_ge)
print(solveds_ge)
average = sum(corrects_ge) / len(corrects_ge)
print("The average of correct classified nodes (Generaliztion) " + str(round(average, 2)))

## Manual Inspection

In [None]:
# MAX_ITERATIONS = 10
# for node_variance in range(1, 10):
#     for edge_variance in range(1, 10):
#         g_in, g_out = create_example(NUM_NODES + node_variance, NUM_EDGES + edge_variance)
#         for timesteps in range(1, MAX_ITERATIONS+1):
#             results = model(g_in, timesteps)
#             correct, solved = compute_accuracy(results[-1], g_out)
#             if correct and solved:
#                 plot_compare_graphs([g_in, results[-1], g_out], [ "input", f"Correct:{correct:.2f} Solved:{solved:.2f}", "target"])
#                 break
#             if timesteps == MAX_ITERATIONS:
#                 print("not solved")
#                 plot_compare_graphs([g_in, results[-1], g_out], [ "input", f"Correct:{correct:.2f} Solved:{solved:.2f}", "target"])


In [None]:
accuracy_list=[]
for i in range(1, 2):
    g_in, g_out = create_example(500, 15625)
    results = model(g_in, 5)
    correct, solved = compute_accuracy(results[-1], g_out)
    accuracy_list.append(correct)
    plot_compare_graphs([g_in, results[-1], g_out], [ "Prediction", "Actual"])
    if i > 900:
      plot_compare_graphs([g_in, results[-1], g_out], [ "Prediction", "Actual"])

print("1k nodes generalization:")
print(accuracy_list)
average = sum(accuracy_list) / len(accuracy_list)
print("The average of correct classified nodes/edges " + str(round(average, 2)))         

In [None]:
fig= plt.figure(figsize=(18,3))
plt.plot(losses_tr)

In [None]:
%watermark -d -t -n