# Neural Networks on Graphs

The current notebook describes the process of applying classification tasks on Graphs. The models that are deployed rely on the "Graph Networks" framework, as defined in "Relational Inductive Biases, Deep Learning and Graph Networks" - [Battaglia et al. (2018)](https://arxiv.org/abs/1806.01261).

The notebook is divided in four parts:

#### 1. [Loading and Preprocessing](#h1)

This Part includes helper functions for reading a .txt dataset file and converting it to an array of GraphDicts objects as specified in the **graph_nets** documentation. The GraphDicts format is a Python dictionary that contains all required information for the representation of the topology of a graph and its corresponding attributes on its elements.

#### 2. [Defining Architectures](#h2)

The used architectures take advantage of implementations that were included in the demo files of the **graph_nets** library. Most of the details of the edge, node and global update functions are reimplemented for letting in more modularity that is needed for testing different architectures. Specifically, features added are:

- Customable Batch Size
- Customable MLPs for inner update functions
- Dropout and Skip Connections
- Deep Architectures

#### 3. [Building TensorFlow Graph](#h3)

The necessary functions for building a complete training session including:
- defining placeholders
- loss function and optimizer
- calculating metrics
- utility functions such as plotting, logging, making checkpoints 

#### 4. [Running Classification Tasks](#h4)

Classification tasks can be run either with either a nested Cross Validation, or with a Holdout set.

- **Cross Validation**: Cross validation is implemented as a single function. The nested CV is comprised of an outer loop where a dataset is split into Test/TrainVal sets, and an inner loop where the TrainVal set is further split into Train and Validation sets. The Validation set is used for making checkpoints of the model. During training, information about the process is returned as a log. Each fold's training process is plotted in graphs that depict the Losses and Accuracy Score of Train and Validation sets. At the end a ".txt" file with the confusion matrices of all folds is returned.

- **Holdout Set**: A dataset is split into Train/Val/Test sets which then are used for training, optimizing and testing a model. Training with a Holdout set was used mainly for testing new architectures and implementations. 



### Datasets

Classification is carried on the following benchmark datasets from the field of Bioinformatics

- **MUTAG**,
- **ENZYMES**,
- **PTC**,

but can be applied seamlessly to any of the datasets that are included in the corresponding github repository.


In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools
import time
import re
import os
import shutil

from operator import itemgetter

from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets import modules
from graph_nets.demos import models

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import tensorflow as tf
import sonnet as snt

SEED = 1
np.random.seed(SEED)
tf.set_random_seed(SEED)

## Downloading Datasets

Original source of used datasets' version was the work **"Deep Graph Convolutional Neural Network (DGCNN)"** of [Zhang et al. (2018)](https://github.com/muhanzhang/dgcnn). These versions of graph datasets were preferred over those included in the [TU Dortmund repository](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) as were deemed more consistent to each other without the need for repeating basic steps of preprocessing for each different dataset. in order to leverage common graph dataset formulation for easy code reuse. 

#### Downloading MUTAG, ENZYMES, PTC Datasets from repository

- **MUTAG**:
https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/MUTAG/MUTAG.txt

- **ENZYMES**:
https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/ENZYMES/ENZYMES.txt

- **PTC**:
https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/PTC/PTC.txt

In [1]:
# If an [SSL error] occurs uncomment following 4 lines
# import os, ssl
# if (not os.environ.get('PYTHONHTTPSVERIFY', '') and
#     getattr(ssl, '_create_unverified_context', None)): 
#     ssl._create_default_https_context = ssl._create_unverified_context

import urllib
url = "https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/MUTAG/MUTAG.txt"
#https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/ENZYMES/ENZYMES.txt
#https://raw.githubusercontent.com/ChNousias/graph-classification-thesis/master/datasets/PTC/PTC.txt
file_name = url.split('/')[-1]

urllib.request.urlretrieve(url, file_name)

('MUTAG.txt', <http.client.HTTPMessage at 0x7fcf2b78a278>)

# <a name="h1"></a>1. Loading and preprocessing

The following set of functions include an import and preprocess action on the dataset:

    - read_graph_dataset_file(f)
    - graph_dataset_statistics(X, Y)
    - base_graph(nodes, edges, senders, receivers, glob)
    - one_hot(label, total)
    - convert_targets_to_one_hot(targets, no_targets)
    - graph_dataset_preprocessing(X, convert_attr_to_numpy)
    
[^](#Neural-Networks-on-Graphs)

## Loading dataset, get basic statistics and plotting

All bioinformatics datasets that are used in this notebook have discrete node labels, but the functions used for loading a dataset also consider node attributes that are numeric, as long as the dataset file is in the specified format as given by the [README](https://github.com/ChNousias/graph-classification-thesis/tree/master/datasets) file in the dataset section of the corresponding repository.

The preprocessing function **graph_dataset_preprocessing** converts these discrete labels in one hot encodings and not further preprocessing is carried.

Note that we can consider a case where a graph's attributes involve vectors that represent different objects and accompany different tags, say a vector that accompanies a node of type 1 and a vector (of the same size) that accompanies a node of type 2. Then we can convert the node tag to a one-hot vector that will replace each node tag with the rest of the vector remaining the same. 

That is if $v_1$ is the first node with attributes $a_1 = [1 | \vec{x_1}]$ and $v_2$ the second node with attributes $a_2 = [2 | \vec{x_2}]$, then we can transform the tags to the corresponding one-hot vectors leaving unchanged the accompanying numerical values, like $a_1 = [1, 0 | \vec{x_1}]$ and $a_2 = [0, 1 | \vec{x_2}]$.



In [None]:
def read_graph_dataset_file(f):
    '''
    Input: f is an opened txt file
    
    - first line is a header specifying number of graphs specified in txt file
    - if line is of the form (n c) then the line specifies the start of a new graph 
        n: number of nodes
        c: class label of new graph
    - if line is different it specifies a new node and is of the form (t m d)
        t: tag of node
        m: number of neighbors followed by m values indicating neighbors indices
        d: following d numbers indicating node's attributes
    '''
    X = []
    Y = []
    
    total = int(next(f))
    
    readlines_generator = iter(f.readlines())
    
    build_flag = False
    
    for line in readlines_generator:
        
        parsedLine = tuple(map(float, re.findall('[0-9]+', line)))
        
        if len(parsedLine) == 2:
            
            if build_flag is True:
                            
                X.append(base_graph(nodes, edges, senders, receivers, glob))
                
                Y.append(clss)
                
                build_flag = False

            count = 0
            
            graph_size, clss = tuple(map(int, parsedLine))
            
            nodes, edges, senders, receivers = [], [], [], []
            
            glob = [0]
            
            build_flag = True
            
            for k in range(graph_size):
                
                parsedLine = re.findall('[0-9]+', next(readlines_generator))
            
                tag, neighbors = tuple(map(int, parsedLine[0:2]))

                edges.extend([[1.0]]*neighbors)

                senders.extend([count]*neighbors)

                receivers.extend(tuple(map(int, parsedLine[2: 2 + neighbors])))
                
                labels = []

                labels.append(tag)
                
                labels.extend(tuple(map(float, parsedLine[2+neighbors:])))
                
                nodes.append(labels)
            
                count += 1
                        
    if build_flag is True:

        X.append(base_graph(nodes, edges, receivers, senders, glob))

        Y.append(clss)

        build_flag = False
    
    print ('Total graphs specified: {}. Total graphs found: {}.'.format(total, len(X)))
    
    try:
        assert total==len(X)
    except AssertionError:
        print ('Final total number of graphs is different from specified in the dataset file')
    
    Y = convert_targets_to_one_hot(Y)
    
    return X,Y

In [None]:
def base_graph(nodes, edges, receivers, senders, glob):
    """Define a basic graph structure to represent a graph
    
    Args:
        nodes: a list of lists that correspond to different nodes 
        edges: a numpy array of edge attributes (of at least rank-2)
        receivers: a list of indices that indicate receiver nodes
        senders: a list of indices that indicate sender nodes
        glob: a vector/tensor of global attributes

    Returns:
        data_dict: dictionary with globals, nodes, edges, receivers and senders 
                   to represent a structure like the one above.
    """
    return {
      "globals": glob,
      "nodes": nodes,
      "edges": edges,
      "receivers": receivers,
      "senders": senders
    }

def one_hot(label, total):
    '''
    Create one-hot vector.
    
    Args:
        label: 
            if label == scalar:
                a numpy array of shape (total, ) is returned
            elif label == numpy array of discrete values and len(label)==n:
                a numpy array of shape (n, total) is returned
        total: number of possible discrete values
    
    Returns:
        one_hot: a binary vector with the value of 1 where the target class is equal to index 
    '''
    if isinstance(label, (int, float)):
        one_hot = np.zeros(total, dtype = np.float32)
        one_hot[label] = 1
    elif isinstance(label, (list, np.ndarray)):
        n = len(label)
        one_hot = np.zeros((len(label), total), dtype = np.float32)
        one_hot[np.arange(n), label] = 1
        
    return one_hot

def convert_targets_to_one_hot(targets, no_targets = None):
    ''' Encodes target labels to one_hot.
    
    Args:
        targets: a list of target values
        no_targets: number of unique elements in target list
    
    Returns:
        one_hot_targets: targets encoded to one_hot vectors
    '''
    unique = list(np.unique(targets))
    
    if no_targets is None:
        no_targets = len(unique)
    
    return one_hot([unique.index(i) for i in targets], no_targets)

In [None]:
def graph_dataset_preprocessing(X_in, convert_attr_to_numpy = True, make_copy=True):
    '''  
    Convert dataset Node-Tag attributes to one-hot encodings:
    
                        t -> [0, 0, 1,.. 0]
                        
    In case there are extra attributes besides the tag of a node like [t | d]
    the transformation will append node attributes at the begining of the array:
    
                    [t | d] -> [0, 0, 1, ... 0 | d]
                    
    In the case where the n available tags do not correspond to an ordered list from (0, n-1)
    a mapping takes place which is then returned as a tokenizer for "unseen" examples
    
    Args:
        X: graph Dataset in the GraphDicts format of graph_nets
    Returns:
        X_onehot: dataset with each graph's node tag attributes converted to one-hot.
        tokenizer: for mapping specific tags to specific numbers
    '''
    if make_copy is True:
        X = copy.deepcopy(X_in)
    else:
        X = X_in
    
    f = itemgetter(0)
    tags = set()
    for graph_item in X:
        tags.update(f(n) for n in graph_item['nodes'])
    
    total_tags = len(tags)
    
    tokenizer = {k:v for k,v in zip(tags,range(total_tags))}
        
    for graph_item in X:
        for i, node in enumerate(graph_item['nodes']):
            label = tokenizer[f(node)]
            graph_item['nodes'][i] = np.concatenate((one_hot(label, total_tags), node[1:]))
            
        if convert_attr_to_numpy is True:
            graph_item['nodes'] = np.array(graph_item['nodes'], dtype = np.float32)
            graph_item['edges'] = np.array(graph_item['edges'], dtype = np.float32)
            graph_item['globals'] = np.array(graph_item['globals'], dtype = np.float32)
        
    return X, tokenizer

In [None]:
def get_dataset_statistics(X, Y):
    """Get graph-dataset statistics:
    Total # of Graphs
    Total # of Classes
    Total # of nodes
    Total # of edges
    Average # of Nodes per example
    Average # of Edges per example
    Edge density: Fraction of all possible edges in a graph (# edges/(# nodes)^2)*
    * We allow directed graphs with self-loops and multiedges. Therefore we can get 
      densities bigger than 1. An "undirected" edge between node v and w consists of 
      two directed ones.
    """
    n = len(X)
    total_classes = Y.shape[-1]
    total_nodes  = sum((len(X[i]['nodes']) for i in range(n)))
    total_edges  = sum((len(X[i]['edges']) for i in range(n)))
    max_min_nodes = (max((len(X[i]['nodes']) for i in range(n))), 
                    min((len(X[i]['nodes']) for i in range(n))))
    max_min_edges = (max((len(X[i]['edges']) for i in range(n))), 
                     min((len(X[i]['edges']) for i in range(n))))
    average_nodes = total_nodes/n
    average_edges = total_edges/n
    edge_density = (sum(len(X[i]['edges'])/(len(X[i]['nodes']))**2 for i in range(n)))/n
    return {"total_graphs":n,
            "total_classes":total_classes,
            "total_nodes":total_nodes,
            "total_edges":total_edges,
            "average_nodes":average_nodes,
            "average_edges":average_edges,
            "max_min_nodes":max_min_nodes,
            "max_min_edges":max_min_edges, 
            "average_edge_density":edge_density}

In [None]:
# Specify base directory where the dataset files are stored
base_directory = 'datasets_txt_files/'

datasetFile = 'MUTAG.txt'
with open((base_directory + datasetFile), 'r') as f:
    X,Y = read_graph_dataset_file(f)

# If discrete node tags are involved convert them to one-hot encodings
# tokenizer is a dict that stores a mapping needed for preprocessing new examples 
X, tokenizer = graph_dataset_preprocessing(X, make_copy=False)
no_tags = len(tokenizer)

In [None]:
get_dataset_statistics(X,Y)

## Graph Visualization

Graphs are visualized using the spring layout method. The color of a node depends on the corresponding tag that indicates an element of the compound.

In [None]:
graph_plot_set = [utils_np.data_dict_to_networkx(X[i]) for i in range(len(X))]

In [None]:
colors = ['darkred', 'orangered', 'slategrey', 'blue', 'darkslategrey', 'midnightblue',
          'orchid', 'darkcyan', 'grey', 'dodgerblue', 'turquoise','darkviolet', 'crimson',
          'darkorange', 'khaki', 'ivory', 'palegreen', 'limegreen', 'darkgreen', 
          'mediumseagreen', 'mediumaquamarine','teal']

COLOR_DICTS = {k:colors[k] for k in range(no_tags)}

if datasetFile=="MUTAG.txt":
    lbl = {0:'Non-Mutagenic', 1:'Mutagenic'}
    plot_title = 'Nitro Compounds'
    
elif datasetFile == "ENZYMES.txt":
    lbl = {0:'EC1: Oxidoreductase', 
           1:'EC2: Transferase',
           2:'EC3: Hydrolase', 
           3:'EC4: Lyase', 
           4:'EC5: Isomerase', 
           5:'EC6: Ligase'}
    plot_title = 'Enzymes'
    
elif datasetFile == "PTC.txt":
    lbl = {0: 'Non - Carcinogenic', 1:'Carcinogenic'}
    plot_title = 'Organic Molecules'
    
elif datasetFile == "PROTEINS.txt":
    lbl = {0: 'Enzymatic Function', 1: 'Non-Enzymatic Function'}
    plot_title = "Protein Structures"

def get_color_map(G, graph, no_tags):
    color_map = []
    Identity_Matrix = np.eye(no_tags)
    
    for i,node in enumerate(G):
        color = COLOR_DICTS[np.argmax(np.all(Identity_Matrix == graph['nodes'][i], axis = 1))]
        color_map.append(color)
    return color_map

plt.clf()
f = plt.figure(figsize=(16,6))
    
for i,ind in enumerate(np.random.randint(0, len(Y), 6)):
    ax = f.add_subplot(int("23"+str(i+1)))
    plt.title( (plot_title + ', i = {}, \nL = {}').format(ind, lbl[np.argmax(Y[ind])]))
    ax.set_aspect('equal', 'box')
    nx.draw(graph_plot_set[ind], 
            node_color = get_color_map(graph_plot_set[ind], X[ind], no_tags), 
            node_size = 140, 
            with_labels = False)

plt.show()

# <a name="h2"></a>2. Defining Architectures

The models that are used are:

1. an **"Encoder - Core - Decoder"** architecture which consists of:
    - an **Encoder** and a **Decode** part that are represented by an Independent Block where all graph's elements, (edges, nodes, globals) are calculated without their relations taken into consideration, 
    - a **Core** model which consists of a Graph Neural Network model and a number of specified processing steps. Each processing step performs a "message-passing" action, that is, diffuses information across the graph. The output of each pass is concatenated with the initial Core input and run again 
    - The Output includes a **'Softmax'** block or said differently an Independent Block with only a global update function with target size equal to the number of target classes and a softmax as an activation function.  
  
  
2. A **"Multiple Graph Model"** which consists of Graph Network Blocks stacked in deep architectures. At the end, a **'Softmax'** block (as degined above) is used for carrying predictions.
  
  
3. An **"Interaction Model"** that is constrained in carrying computations on the nodes and edges of the graph omitting the use of the global block. At the end, a full GN block is used for correctly gathering the acquired information for the global block. The global block has a Softmax activation function for carrying predictions.

[^](#Neural-Networks-on-Graphs)

In [None]:
# MODELS FOR GN BLOCKS

def MLP_model_factory(hidden_layers, output_layer,
                      keep_rate, use_dropout = True,
                      hidden_activation = tf.nn.relu,
                      output_activation = tf.nn.softmax):
    """
    A function factory for returning build_mlp_model functions. 
    The returned functions is evaluated in the corresponding scope of the GN:
        EdgeBlock, NodeBlock, GlobalBlock. 
    
    The returned function when evaluated, returns a custom sonnet MLPmodel.
    """
    
    def build_mlp_model():
        return MLPmodel(hidden_layers, keep_rate,
                        use_dropout, output_layer,
                        hidden_activation,output_activation)
    
    return build_mlp_model


class MLPmodel(snt.AbstractModule):
    """
    Instantiates a sonnet Module that describes a vanilla MLP model
    """
    def __init__(self, hidden_layers, keep_rate, 
                 use_dropout = True, output_layer = None, 
                 hidden_l_non_linearity = tf.nn.relu, 
                 output_l_non_linearity = tf.nn.softmax,
                 name = "MLPmodel"):
        super(MLPmodel, self).__init__(name = name)
        
        self._use_dropout = use_dropout
        self._hidden_layers = hidden_layers
        self._output_layer = output_layer
        
        with self._enter_variable_scope():
            
            if self._hidden_layers:
                self._hidden = [OneLayerPerceptron(layer_size, 
                                                   keep_rate,
                                                   hidden_l_non_linearity,
                                                   use_dropout) 
                                for layer_size in hidden_layers]
                
                if self._use_dropout is not True:
                    self._layernorm = [snt.LayerNorm() 
                                        for i in range(len(self._hidden_layers))]
                else:
                    self._layernorm = [None for i in range(len(self._hidden_layers))]
            
            if self._output_layer is not None:
                self._output = OneLayerPerceptron(self._output_layer, 
                                                  keep_rate,
                                                  output_l_non_linearity,
                                                  use_dropout = False)
                
            
    def _build(self, inputs):
        """Build method for defining a full MLP layer"""
        if self._hidden_layers:
            f = itemgetter(0)
            
            latent = f(self._hidden)(inputs)
            
            if self._use_dropout is not True:
                latent = f(self._layernorm)(latent)
            
            for hid, layernorm in zip(self._hidden[1:], self._layernorm[1:]):
                latent = hid(latent)

                if self._use_dropout is not True:
                    latent = layernorm(latent)
            
        else:
            latent = inputs
        if self._output_layer is not None:
            return self._output(latent)
        else:
            return latent


class OneLayerPerceptron(snt.AbstractModule):
    """Instantiate a single_layer Perceptron with/without Dropout Layer"""
    def __init__(self, layer_size, keep_rate,
                 non_linearity = tf.nn.relu,
                 use_dropout = True,
                 name = "OneLayerPerceptron"):
        super(OneLayerPerceptron, self).__init__(name = name)
        
        self._layer_size = layer_size
        self._use_dropout = use_dropout
        self._non_linearity = non_linearity
        
        with self._enter_variable_scope():
            if self._use_dropout is True:
                self._dropout = Dropout(keep_rate)
            self._linear = snt.Linear(self._layer_size)
        
    def _build(self, inputs):
        if self._use_dropout is True:
            return self._non_linearity(self._linear(self._dropout(inputs)))
        else:
            return self._non_linearity(self._linear(inputs))


class Dropout(snt.AbstractModule):
    """Apply Dropout as a snt.Module"""
    def __init__(self, keep_rate, name = "Dropout"):
        super(Dropout, self).__init__(name = name)
        self._keep_rate = keep_rate
    
    def _build(self, inputs):
        return tf.nn.dropout(inputs, self._keep_rate)


class Softmax(snt.AbstractModule):
    """Softmax non-linearity"""
    def __init__(self, name = "Softmax"):
        super(Softmax, self).__init__(name=name)
    
    def _build(self, inputs):
        return tf.nn.softmax(inputs)

class MLPGraphIndependent(snt.AbstractModule):
    """GraphIndependent with MLP edge, node, and global models."""

    def __init__(self, 
                 edge_mlp = (32, 32), node_mlp = (32, 32), global_mlp = (32, 32),
                 edge_dropout = False, node_dropout = False, global_dropout = False,
                 name="MLPGraphIndependent"):
        super(MLPGraphIndependent, self).__init__(name=name)
        with self._enter_variable_scope():
            self._network = modules.GraphIndependent(
                  edge_model_fn=MLP_model_factory(edge_mlp, None, keep_rate, edge_dropout),
                  node_model_fn=MLP_model_factory(node_mlp, None, keep_rate, node_dropout),
                  global_model_fn=MLP_model_factory(global_mlp, None, keep_rate, global_dropout))

    def _build(self, inputs):
        return self._network(inputs)


class MLPGraphNetwork(snt.AbstractModule):
    """GraphNetwork with MLP edge, node, and global models."""

    def __init__(self, 
                 edge_mlp = (32, 32), node_mlp = (32, 32), global_mlp = (32, 32),
                 edge_dropout = False, node_dropout = False , global_dropout= False,
                 name="MLPGraphNetwork"):
        super(MLPGraphNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._network = modules.GraphNetwork(MLP_model_factory(edge_mlp, None, keep_rate, edge_dropout), 
                                                 MLP_model_factory(node_mlp, None, keep_rate, node_dropout),
                                                 MLP_model_factory(global_mlp, None, keep_rate, global_dropout),
                                                 tf.unsorted_segment_sum)

    def _build(self, inputs):
        return self._network(inputs)
    
class MLPInteraction(snt.AbstractModule):
    """Interaction Networks with MLP edge and node models."""
    
    def __init__(self, 
                 edge_mlp = (32, 32), node_mlp = (32, 32),
                 edge_dropout = False, node_dropout = False,
                 name="MLPInteraction"):
        super(MLPInteraction, self).__init__(name=name)
        with self._enter_variable_scope():
            self._network = modules.InteractionNetwork(MLP_model_factory(edge_mlp, None, keep_rate, edge_dropout),
                                                       MLP_model_factory(node_mlp, None, keep_rate, node_dropout),
                                                       reducer = tf.unsorted_segment_sum)
            
    def _build(self, inputs):
        return self._network(inputs)
    

class MultiGraphNetwork(snt.AbstractModule):
    """Multi-Graph Network is a model for stacking multiple Graph Network
    Blocks in forward fashion with Multi-layer Perceptrons for the update
    functions. The number of Graph Network Blocks that will be used is passed as
    a hyper-parameter at the constructor of the model.
    
                *---------*     *---------*
                |         |     |         |
      Input --->|  #1 GN  | --->|  #2 GN  |---> ... ---> Output
                |         |     |         |
                *---------*     *---------*   
    
    """

    def __init__(self,
                graph_block_kwargs,
                no_graph_network_blocks = 2,
                edge_output_size=None,
                node_output_size=None,
                global_output_size=None,
                use_skip_connections = True,
                name="MultiGraphNetwork"):
        super(MultiGraphNetwork, self).__init__(name=name)
        self._blocks = []
        for i in range(no_graph_network_blocks):
            self._blocks.append(MLPGraphNetwork(**graph_block_kwargs, name = "MLPGraphNetwork"+"_"+str(i)))
        
        self._use_skip_connections = use_skip_connections
        
        # Transforms the outputs into the appropriate shapes.
        if edge_output_size is None:
            edge_fn = None
        else:
            edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
        if node_output_size is None:
            node_fn = None
        else:
            node_fn = lambda: snt.Linear(node_output_size, name="node_output")
        if global_output_size is None:
            global_fn = None
        else:
            global_fn = MLP_model_factory((), global_output_size, keep_rate, False)
        with self._enter_variable_scope():
            self._output_transform = modules.GraphIndependent(edge_fn, node_fn,
                                                        global_fn) 

    def _build(self, input_op):
        '''Basic build method for stacking successive 
        graph network blocks'''
        
        latent = self._blocks[0](input_op)
        
        latent_skip = latent
                
        for i in range(1,len(self._blocks)):
            
            if i%2==0 and self._use_skip_connections is True:
                                    
#                 latent = residual_connect([self._blocks[i](latent), latent_skip])
                
                latent = utils_tf.concat([self._blocks[i](latent), latent_skip], axis=1)
                
                latent_skip = latent
            
            else:
                latent = self._blocks[i](latent)
        
        output_ops = self._output_transform(latent)
        
        return output_ops
    
    
class EncodeProcessDecode(snt.AbstractModule):
    """Full encode-process-decode model as used in graph_networks demos.
    The model includes three components:
    - An "Encoder" graph net, which independently encodes the edge, node, and
        global attributes (does not compute relations etc.).
    - A "Core" graph net, which performs N rounds of processing (message-passing)
        steps. The input to the Core is the concatenation of the Encoder's output
        and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
    the processing step).
    - A "Decoder" graph net, which independently decodes the edge, node, and
        global attributes (does not compute relations etc.), on each message-passing
        step.
                          Hidden(t)   Hidden(t+1)
                             |            ^
                *---------*  |  *------*  |  *---------*
                |         |  |  |      |  |  |         |
      Input --->| Encoder |  *->| Core |--*->| Decoder |---> Output(t)
                |         |---->|      |     |         |
                *---------*     *------*     *---------*
    """

    def __init__(self,
                graph_block_kwargs,
                edge_output_size=None,
                node_output_size=None,
                global_output_size=None,
                name="EncodeProcessDecode"):
        super(EncodeProcessDecode, self).__init__(name=name)
        self._encoder = MLPGraphIndependent(**graph_block_kwargs)
        self._core = MLPGraphNetwork(**graph_block_kwargs)
        self._decoder = MLPGraphIndependent(**graph_block_kwargs)
        # Transforms the outputs into the appropriate shapes.
        if edge_output_size is None:
            edge_fn = None
        else:
            edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
        if node_output_size is None:
            node_fn = None
        else:
            node_fn = lambda: snt.Linear(node_output_size, name="node_output")
        if global_output_size is None:
            global_fn = None
        else:
            global_fn = lambda: snt.Sequential([snt.Linear(global_output_size, name="global_output"), Softmax()])
        with self._enter_variable_scope():
            self._output_transform = modules.GraphIndependent(edge_fn, node_fn,
                                                        global_fn) 
            
    def _build(self, input_op, num_processing_steps):
        latent = self._encoder(input_op)
        latent0 = latent
        output_ops = []
        for _ in range(num_processing_steps):
            core_input = utils_tf.concat([latent0, latent], axis=1)
            latent = self._core(core_input)
            decoded_op = self._decoder(latent)
            output_ops.append(self._output_transform(decoded_op))
        return output_ops


class InteractionModel(snt.AbstractModule):
    """Graph Network Interaction Model."""
    
    def __init__(self, graph_block_kwargs,
                 no_interaction_network_blocks = 2,
                 edge_output_size=None,
                 node_output_size=None,
                 global_output_size=None,
                 name="InteractionModel"):
        super(InteractionModel, self).__init__(name=name)
                
        self._blocks = []
        for i in range(no_interaction_network_blocks):
            self._blocks.append(MLPInteraction(**graph_block_kwargs, name = "MLPInteraction"+"_"+str(i)))
        
        edge_fn = MLP_model_factory(graph_block_kwargs["edge_mlp"], None, keep_rate, False)
        node_fn = MLP_model_factory(graph_block_kwargs["node_mlp"], None, keep_rate, False)
        global_fn = MLP_model_factory(graph_block_kwargs["node_mlp"], global_output_size, keep_rate, False)
        with self._enter_variable_scope():
            self._output_transform = modules.GraphNetwork(edge_fn, 
                                                          node_fn,
                                                          global_fn) 

    def _build(self, input_op):
        n = len(self._blocks)
        latent = self._blocks[0](input_op)
        for i in range(1,n):
            latent = self._blocks[i](latent)
        output_ops = self._output_transform(latent)
        return output_ops
    
class MLPCommNet(snt.AbstractModule):
    """Communication Network with MLP edge, node_decoder and node models."""
    
    def __init__(self,
                 edge_mlp = (32,32), node_mlp = (32, 32),
                 edge_dropout = False, node_dropout = False,
                 name="MLPCommNet"):
        super(MLPCommNet, self).__init__(name=name)
        with self._enter_variable_scope():
            self._network = modules.CommNet(MLP_model_factory(edge_mlp, None, keep_rate, edge_dropout),
                                            MLP_model_factory(edge_mlp, None, keep_rate, edge_dropout),
                                            MLP_model_factory(node_mlp, None, keep_rate, node_dropout))
        
        def _build(self, inputs):
            return self._network(inputs)

In [None]:
def residual_connect(input_graphs, name="graph_residual"):
    """Returns an operator that adds the attributes of two input graph
    tensors along their dimension axis (axis=1).
    
    The purpose of this operator is to work for graph addition between two
    same graphs in topology dimension and attribute dimension, so that if x
    is a graph and M(x) an operator acting on a graph, we are able to retrieve:
    
                            Res(M(x)) = M(x) + x
    
    In all cases, the NODES, EDGES and GLOBALS dimensions are added
    element-wise (if a fields is `None`, the concatenation is just a `None`), 
    therefore the the RECEIVERS, SENDERS, N_NODE and N_EDGE fields of the 
    graphs should all match.
    
    The graphs in `input_graphs` should also have the same set of keys for which 
    the corresponding fields is not `None`.
    
    Args:
        input_graphs: A list of `graphs.GraphsTuple` objects containing `Tensor`s
        and satisfying the constraints outlined above.
        name: (string, optional) A name for the operation.
    Returns: 
        An op that returns an operator that performs addition along their attribute dimension
        for the graphs.
    Raises:
        ValueError if `values` is an empty list, or if the fields which are `None`
        in `input_graphs` are not the same for all the graphs.
    """
    
    if not input_graphs:
        raise ValueError("List argument `input_graphs` is empty")
        
    utils_np._check_valid_sets_of_keys([gr._asdict() for gr in input_graphs])  # pylint: disable=protected-access
  
    if len(input_graphs) == 1:
        return input_graphs[0]

    nodes = [gr.nodes for gr in input_graphs if gr.nodes is not None]
    edges = [gr.edges for gr in input_graphs if gr.edges is not None]
    globals_ = [gr.globals for gr in input_graphs if gr.globals is not None]

    with tf.name_scope(name):
        
        nodes = tf.add_n(nodes, name="add_nodes") if nodes else None
        edges = tf.add_n(edges, name="add_edges") if edges else None
        
        if globals_:
            globals_ = tf.add_n(globals_, name="add_globals")
        else:
            globals_ = None
        
        output = input_graphs[0].replace(nodes=nodes, edges=edges, globals=globals_)
        
        return output

# 3. <a name="h3"></a>Building TensorFlow Graph

This section is about running classification tasks on graphs, either with a k-fold Cross Validation either with a Holdout set. It involves defining beforehand several utility functions, that facilitate splitting the dataset, building the Tensorflow Graph, plotting and logging, and of course defining the main training sessions.

1. **[Utility Functions for Holdout Training and Cross Validation](#Utility-Functions-for-Holdout-Training-and-Cross-Validation)**  
2. **[Utility functions for building the TensorFlow Graph](#Utility-functions-for-building-the-TensorFlow-Graph)**
3. **[Utility functions for Plotting and Logging](#Utility-functions-for-Plotting-and-Logging)**
4. **[Main k-fold Cross Validation for graph classification](#Main-k-fold-Cross-Validation-for-graph-classification)**

[^](#Neural-Networks-on-Graphs)

## Utility Functions for Holdout Training and Cross Validation

Some basic functions for spiting a dataset into train/val/test sets, as well as in k-folds. A stratification flag is given to ensure same distribution of classes among different folds.

    - train_val_test(X, Y, splits, shuffle, stratify, seed)
    - k_cross_validation(X, Y, folds, stratify, seed)

In [None]:
def train_val_test(X, Y, splits, shuffle = True, stratify = True, seed = 123):
    '''Splits and shuffles into training, validation and test sets.
    
    Args:
        X, Y: Training examples / Targets (One-hot vectors)
        splits: tuple of train/val/test percentage splits
        shuffle: boolean flag for shuffling dataset
        seed: seed for random generator
    
    Returns:
    X_train,Y_train, X_val, Y_val, X_test, Y_test
    '''
    def shuffle_data(X, Y):
        '''Shuffle data'''
        indices = np.random.permutation(np.arange(len(X)))
        return ([X[i] for i in indices], np.array([Y[i] for i in indices]))
        
    np.random.seed(seed)
    
    try:
        assert len(splits) == 3
    except AssertionError:
        print ("splits list length is not equal to 3")
        
    if (sum(splits) - 1.0) > 1e-06:
        splits = np.array(splits)

        splits = splits/sum(splits)
    
    if stratify is True:
        strata = []
        
        _indices = [Y[:,i]==1 for i in range(Y.shape[-1])]
        
        for stratum in _indices:
            strata.append([
                    [X[i] for i,val in enumerate(stratum) if val==True],
                    [Y[i] for i,val in enumerate(stratum) if val==True]
                ])
    else:
        strata = [(X,Y)]
    
    X_train, Y_train, X_val, Y_val, X_test, Y_test = [], [], [], [], [], []
    
    for stratum in strata:
        X_, Y_ = stratum
        
        n = len(X_)

        if shuffle is True:
            X_tmp, Y_tmp = shuffle_data(X_, Y_)
        else:
            X_tmp, Y_tmp = X_, Y_
        
        tr_split = int(splits[0] * n)

        val_split = int((splits[0] + splits[1]) * n)

        X_train.extend(X_tmp[:tr_split])
        Y_train.extend(Y_tmp[:tr_split])

        X_val.extend(X_tmp[tr_split:val_split])
        Y_val.extend(Y_tmp[tr_split:val_split])

        X_test.extend(X_tmp[val_split:])
        Y_test.extend(Y_tmp[val_split:])
    
    if shuffle is True:
        X_train, Y_train = shuffle_data(X_train, Y_train)
        X_val, Y_val = shuffle_data(X_val, Y_val)
        X_test, Y_test = shuffle_data(X_test, Y_test)

    return X_train, Y_train, X_val, Y_val, X_test, Y_test

In [None]:
def k_cross_validation(X, Y, folds, stratify = True, seed = 123):
    """Divide dataset in k folds
    
    Returns a list of length the number of folds with each element a 
    tuple of numpy arrays of indices corresponding to train/validation sets
    """
    np.random.seed(seed)
    
    def shuffle_array(X):
        '''Shuffle array'''
        indices = np.random.permutation(np.arange(len(X)))
        return X[indices]
            
    if stratify is True:
        strata = [np.flatnonzero(Y[:,i]==1) for i in range(Y.shape[-1])]
    else:
        strata = [np.arange(len(Y))]
    
    k_folds = [None for i in range(folds)] 

    for stratum in strata:
        indices = shuffle_array(stratum)
        for i, split in enumerate(np.array_split(indices, folds)):
            if k_folds[i] is None:
                k_folds[i] = np.stack(split)
            else:
                k_folds[i] = np.concatenate((k_folds[i], split))        
    
    return [shuffle_array(fold) for fold in k_folds]

## Utility functions for building the TensorFlow Graph

The following set of functions are facilitating the TF graph building process. They involve defining **placeholders**, calculating **loss**, **accuracy**, creating **eed_dict**, defining a **batch_size** generator:

    - create_placeholders(X, Y, batch_size)
    - create_loss_ops(targets, output_ops)
    - compute_accuracy(targets, outputs)
    - make_all_runnable_in_session(*args)
    - create_feed_dict(placeholders, input_objs)
    - graph_batch(iterable, n)
    - check_saver(epoch, test_values, save_variables)

In [None]:
def create_placeholders(X, Y, batch_size):
    """Creates placeholders for the model training and evaluation.

    Args:
    X: training examples
    Y: labels
    batch_size: Total number of graphs per batch.

    Returns:
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.
    """

    input_ph = utils_tf.placeholders_from_data_dicts(X[0:batch_size], force_dynamic_num_graphs=True)
    
    target_ph = tf.placeholder(tf.float32, shape =[None, Y.shape[-1]], name='labels')
    
    return input_ph, target_ph

def create_loss_ops(targets, output_ops):
    """
    Define losses based on the global output or on intermediary steps
    """
    if isinstance(output_ops, list):
        loss_ops = [tf.losses.softmax_cross_entropy(targets, output_op.globals)
          for output_op in output_ops
        ]        
    else:
        loss_ops = tf.losses.softmax_cross_entropy(targets, output_ops.globals)
    return loss_ops

def compute_accuracy(targets, outputs):
    """Calculate model accuracy.

    Returns the number of correctly classified graphs

    Args:
    target: A numpy array that contains the target one_hot vector
    output: A `graphs.GraphsTuple` that contains the output graph.

    Returns:
    correct: A `float` fraction of correctly labeled graphs
    """
    if isinstance(outputs, list):
        outputs_np = utils_np.graphs_tuple_to_data_dicts(outputs[-1])
    else:
        outputs_np = utils_np.graphs_tuple_to_data_dicts(outputs)
    total = []
    for target, out in zip(targets, outputs_np):
        xn = np.argmax(target)
        yn = np.argmax(out["globals"])

        total.append(xn == yn)

    accuracy = np.mean(total)
    
    return accuracy

def make_all_runnable_in_session(*args):
    """Lets an iterable of TF graphs be output from a session as NP graphs."""
    return [utils_tf.make_runnable_in_session(a) for a in args]

def create_feed_dict(placeholders, input_objs):
    input_objs[0] = utils_np.data_dicts_to_graphs_tuple(input_objs[0])
    """Create feed_dict from a tuple/list of placeholders and input_objs"""
    return {k:v for k,v in zip(placeholders, input_objs)}

def graph_batch(iterable, n = 1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx+n, l)]
        
def check_saver(epoch, test_values, save_variables):
    '''Function to check validation score
    
    Args:
    iteration: current iteration
    test_values: results from current 
    
    Returns:
    save_variables: Updated Saved Variables
    '''
    improve_flag = False
    
    if test_values['loss'] < save_variables['losses_best_val']:
        
        save_variables['losses_best_val'] = test_values['loss']
        save_variables['improved_at'] = epoch
        improve_flag = True
    
    return save_variables, improve_flag

## Utility functions for Plotting and Logging

Functions for facilitating logging, plotting loss/accuracy curves and Confusion Matrices:
    
    - compute_log(epoch, elapsed, train_values, test_values, log_variables)
    - predictions(targets, outputs)
    - plot_result_curves(log_variables, fold, save, save_title)
    - plot_confusion_matrix(cm, classes, normalize, save_title, save, title, cmap)

In [None]:
def compute_log(epoch, elapsed, train_values, test_values, log_variables):
    '''Function to update logs'''

    acc_tr = compute_accuracy(train_values["target"], train_values["outputs"])
    acc_val = compute_accuracy(test_values["target"], test_values["outputs"])
    log_variables['last_epoch'] = epoch
    log_variables['losses_tr'].append(train_values["loss"])
    log_variables['losses_val'].append(test_values["loss"])
    log_variables['acc_cum_tr'].append(acc_tr)
    log_variables['acc_cum_val'].append(acc_val)
    log_variables['logged_iterations'].append(epoch)

    print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lval {:.4f}, AccTr {:.4f}, AccVal {:.4f}".format(
            epoch, elapsed, train_values["loss"], test_values["loss"],acc_tr, acc_val)) 
    
    return log_variables

        
def predictions(targets, outputs):
    """Run predictions on test/val set"""
    xn, yn = [], []
    if isinstance(outputs, list):
        outputs_np = utils_np.graphs_tuple_to_data_dicts(outputs[-1])
    else:
        outputs_np = utils_np.graphs_tuple_to_data_dicts(outputs)
        
    for target, out in zip(targets, outputs_np):
        xn.append(np.argmax(target))
        yn.append(np.argmax(out["globals"]))
    return np.array(xn), np.array(yn)

In [None]:
def plot_result_curves(log_variables, fold = 0, save = True, save_title = 'GN Model'):
    """Function to plot loss/accuracy curves across training"""
        
    fig = plt.figure(1, figsize=(12, 3))
    fig.clf()
    x = np.array(log_variables['logged_iterations'])
    # Loss.
    y_tr = log_variables['losses_tr']
    y_val = log_variables['losses_val']
    ax = fig.add_subplot(1, 2, 1)
    ax.plot(x, y_tr, "k", label="Training")
    ax.plot(x, y_val, "k--", label="Test/generalization")
    ax.set_title("Loss across training")
    ax.set_xlabel("No of Epochs")
    ax.set_ylabel("Loss (binary cross-entropy)")
    ax.legend()
    # Accuracy.
    y_tr = log_variables['acc_cum_tr']
    y_val = log_variables['acc_cum_val']
    ax = fig.add_subplot(1, 2, 2)
    ax.plot(x, y_tr, "k", label="Training")
    ax.plot(x, y_val, "k--", label="Validation")
    ax.set_title("Accuracy correct across training")
    ax.set_xlabel("No of Epochs")
    ax.set_ylabel("Fraction of Graphs correctly classified")
    if save is True:
        fig.savefig('plots/' + save_title + ' - Loss & Acc - fold {}.png'.format(fold), 
                    bbox_inches='tight')
        fig.savefig('plots/' + save_title + ' - Loss & Acc - fold {}.pdf'.format(fold), 
                    bbox_inches='tight')
        
        
def plot_confusion_matrix(cm, classes, normalize=False, save_title = 'GN model', save = True,
                          title='Confusion matrix', cmap=plt.cm.Blues): #viridis
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else '.0f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    if save is True:
        plt.savefig('plots/' + save_title + '.png', 
                    bbox_inches='tight')
        plt.savefig('plots/' + save_title + '.pdf', 
                    bbox_inches='tight')

## Main k-fold Cross Validation for graph classification

Putting it all together, a function is defined for running multiple models with k-fold Cross Validation:

    - get_data_from_indices(X, Y, indices)
    - run_k_cross_validation(X_dataset, Y_dataset, outer_no_folds, inner_no_folds, 
                             num_epochs, GN_model, batch_size, log_every_epoch,
                             NN_size, NN_dropout, dropout_keep_rate,
                             use_skip_connections, GN_model_steps, input_optimizer, 
                             learning_rate, seed)

In [None]:
def get_data_from_indices(X, Y, indices):
    return [X[k] for k in indices], np.array([Y[k] for k in indices])

def run_k_cross_validation(X_dataset, Y_dataset, outer_no_folds = 5, inner_no_folds = 5, 
                           num_epochs = 600, GN_model = "Multiple_GN", 
                           batch_size = 32, log_every_epoch = 5,
                           NN_size = {"edge_size":(32, 32),
                                      "node_size":(32, 32),
                                      "global_size":(32, 32)},
                           NN_dropout = {'edge_dropout':False,
                                         'node_dropout':False,
                                         'global_dropout':False},
                           dropout_keep_rate = 1.0,
                           use_skip_connections = False,
                           GN_model_steps = 2,
                           input_optimizer = tf.train.AdamOptimizer,
                           learning_rate = 1e-3, seed = 123):
    
    outer_folds = k_cross_validation(X_dataset, Y_dataset, outer_no_folds, seed = seed + 1)
    
    cm_total = []
    
    for o in range(len(outer_folds)):
        
        indices_inner_train = np.concatenate([fold for j,fold in enumerate(outer_folds) if j != o])
        indices_test = outer_folds[o]
    
        X, Y = get_data_from_indices(X_dataset, Y_dataset, indices_inner_train)
        X_test, Y_test = get_data_from_indices(X_dataset, Y_dataset, indices_test)
    
        folds = k_cross_validation(X, Y, inner_no_folds, seed = seed + 2)


        for i in range(len(folds)):
            indices_train = np.concatenate([fold for j,fold in enumerate(folds) if j != i])
            indices_val = folds[i]

            X_tr, Y_tr = get_data_from_indices(X, Y, indices_train)
            X_val, Y_val = get_data_from_indices(X, Y, indices_val)

            no_clss = Y_tr.shape[-1]
            
            GN_graph = tf.Graph()
            # Reset Tensorflow Graph
            tf.reset_default_graph()
            
            with GN_graph.as_default():
                # Input and target placeholders
                input_ph, target_ph = create_placeholders(X_tr, Y_tr, batch_size)
                global keep_rate
                keep_rate = tf.placeholder(tf.float32, shape=(), name = "keep_rate_ph")

                if GN_model == "Multiple_GN":
                    # FORWARD GRAPH-NETWORK BLOCKS
                    # 1. Instantiate Model

                    graph_block_kwargs = {'edge_mlp': NN_size["edge_size"], 
                                          'node_mlp': NN_size["node_size"], 
                                          'global_mlp': NN_size["global_size"],
                                          'edge_dropout': NN_dropout['edge_dropout'],
                                          'node_dropout': NN_dropout['node_dropout'],
                                          'global_dropout': NN_dropout['global_dropout']
                                          }

                    GRAPH_NETWORK_OPTIONS = {'graph_block_kwargs': graph_block_kwargs,
                                             'no_graph_network_blocks': GN_model_steps,
                                             'use_skip_connections':use_skip_connections,
                                             'global_output_size': no_clss}

                    model =  MultiGraphNetwork(**GRAPH_NETWORK_OPTIONS)

                    # 2. Build Graph - Calculate training/test loss.
                    output_ops_tr = model(input_ph)
                    output_ops_val = model(input_ph)
                    loss_op_tr = create_loss_ops(target_ph, output_ops_tr)
                    loss_op_val = create_loss_ops(target_ph, output_ops_val)

                    # 3. Optimizer.
                    optimizer = input_optimizer(learning_rate)
                    step_op = optimizer.minimize(loss_op_tr)
                    
                elif GN_model == "EncodeDecode":
                    # ENCODE-PROCESS-DECODE MODEL
                    # 1. Instantiate Model
                    graph_block_kwargs = {'edge_mlp': NN_size["edge_size"], 
                                          'node_mlp': NN_size["node_size"], 
                                          'global_mlp': NN_size["global_size"],
                                          'edge_dropout': NN_dropout['edge_dropout'],
                                          'node_dropout': NN_dropout['node_dropout'],
                                          'global_dropout': NN_dropout['global_dropout']
                                         }

                    model = EncodeProcessDecode(graph_block_kwargs, global_output_size = no_clss)

                    # Number of Processing Steps in Training/Validation
                    num_processing_steps = GN_model_steps

                    # 2. Build Graph - Calculate training/test loss.
                    output_ops_tr = model(input_ph, num_processing_steps)
                    output_ops_val = model(input_ph, num_processing_steps)

                    loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
                    loss_op_tr = sum(loss_ops_tr) / num_processing_steps

                    loss_ops_val = create_loss_ops(target_ph, output_ops_val)
                    loss_op_val = loss_ops_val[-1]

                    # 3. Optimizer.
                    optimizer = input_optimizer(learning_rate)
                    step_op = optimizer.minimize(loss_op_tr)

                elif GN_model == "InteractionModel":
                    # INTERACTION NETWORK BLOCKS
                    # 1. Instantiate Model

                    graph_block_kwargs = {'edge_mlp': NN_size["edge_size"], 
                                          'node_mlp': NN_size["node_size"], 
                                          'edge_dropout': NN_dropout['edge_dropout'],
                                          'node_dropout': NN_dropout['node_dropout']
                                         }

                    GRAPH_NETWORK_OPTIONS = {'graph_block_kwargs': graph_block_kwargs,
                                             'no_interaction_network_blocks': GN_model_steps,
                                             'global_output_size': no_clss}

                    model =  InteractionModel(**GRAPH_NETWORK_OPTIONS)

                    # 2. Build Graph - Calculate training/test loss.
                    output_ops_tr = model(input_ph)
                    output_ops_val = model(input_ph)
                    loss_op_tr = create_loss_ops(target_ph, output_ops_tr)
                    loss_op_val = create_loss_ops(target_ph, output_ops_val)

                    # 3. Optimizer.
                    optimizer = input_optimizer(learning_rate)
                    step_op = optimizer.minimize(loss_op_tr)
            
            # Set current Tensorflow Session - tf.Session - tf.train.Saver
            try:
                sess.close()
            except NameError:
                pass

            # Set Tensorflow to use all available cores
            with tf.Session(graph=GN_graph ,config = tf.ConfigProto(
                                            intra_op_parallelism_threads = 4,
                                            inter_op_parallelism_threads = 4,
                                            log_device_placement = False)) as sess:

                sess.run(tf.global_variables_initializer())

                log_variables = {'last_iteration':0,
                                 'last_epoch':0,
                                 'logged_iterations':[],
                                 'losses_tr':  [],
                                 'losses_val': [],
                                 'acc_cum_tr': [],
                                 'acc_cum_val': []}

                save_variables = {'losses_best_val': 100,
                                  'improved_at': 0,
                                  'no_improve_limit': 100
                }

                saver = tf.train.Saver()
                save_dir = 'checkpoints/'
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                save_path = os.path.join(save_dir, 'best_validation')

                print("# (Epoch number), T (elapsed seconds), "
                      "Ltr (training loss), Lval (validation loss), "
                      "AccTr (accuracy of Training set), "
                      "AccVal (accuracy of Validation set)")

                start_time = time.time()

                # Initialize Training Session

                iteration = 0
                train_session = {"step": step_op,
                                 "target": target_ph,
                                 "loss": loss_op_tr,
                                 "outputs": output_ops_tr}
                val_session = {"target": target_ph,
                               "loss": loss_op_val,
                               "outputs": output_ops_val}

                try:
                    for epoch in range(log_variables['last_epoch'], num_epochs):

                        batch_generator = graph_batch(np.random.permutation(np.arange(len(X_tr))), batch_size)

                        for batch in batch_generator:
                            log_variables['last_iteration'] = iteration
                            feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [
                                                        [X_tr[ind] for ind in batch],
                                                        [Y_tr[ind] for ind in batch],
                                                        dropout_keep_rate])

                            train_values = sess.run(train_session,feed_dict=feed_dict)

                            iteration += 1

                        # Calculate epoch Training and Validation Loss 
                        feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_tr, Y_tr, 1.00])
                        train_values = sess.run(val_session,feed_dict=feed_dict)

                        feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_val, Y_val, 1.00])
                        test_values = sess.run(val_session, feed_dict=feed_dict)

                        # Check log 
                        if epoch % log_every_epoch==0:
                            elapsed = time.time() - start_time
                            log_variables = compute_log(epoch, elapsed,train_values, 
                                                        test_values, log_variables)

                        save_variables, improve_flag = check_saver(epoch, test_values, save_variables)

                        if improve_flag is True:
                            saver.save(sess=sess, save_path=save_path)

                        if log_variables['last_epoch'] - save_variables['improved_at'] >= save_variables['no_improve_limit']:
                            print ("No improvement found for #{} epochs".format(save_variables['no_improve_limit']))
                            print ("Exiting training session at epoch: #{}".format(epoch))
                            break

                except KeyboardInterrupt:
                    print("\n")
                    print("Exiting training session at epoch: #{}".format(log_variables['last_epoch']))

                print ("Restoring session with lowest validation Loss: ", save_variables['losses_best_val'])
                saver.restore(sess=sess, save_path=save_path)

                # Calculate Confusion Matrices by using Best model on Train/Test 
                feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_tr, Y_tr, 1.00])
                train_values = sess.run(val_session, feed_dict=feed_dict)

                feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_test, Y_test, 1.00])
                test_values = sess.run(val_session, feed_dict=feed_dict)

                if GN_model == "Multiple_GN":
                    save_title = "Multiple_GN" + '_{}'.format(GN_model_steps)
                elif GN_model == "EncodeDecode":
                    save_title = "EncodeDecode" + '_{}'.format(GN_model_steps)
                else:
                    save_title = "InteractionModel" + '_{}'.format(GN_model_steps)

                plot_result_curves(log_variables, fold = (o, i), save = True, save_title = save_title)

                x_tr, pred_tr = predictions(train_values["target"], train_values["outputs"])
                x_test, pred_test = predictions(test_values["target"], test_values["outputs"])

                cm_tr = np.zeros((no_clss, no_clss))
                for a, p in zip(x_tr, pred_tr):
                    cm_tr[a][p] += 1

                cm_test = np.zeros((no_clss, no_clss))
                for a, p in zip(x_test, pred_test):
                    cm_test[a][p] += 1

                cm_total.append((cm_tr, cm_test))
    
    return cm_total

# <a name="h4"></a>4. Running Classification Tasks

The main parts of the notebook where a training session can be defined are:

#### 4A. [Training using Cross Validation](#Training-using-Cross-Validation)

This part includes performing a grid search on different model parameters, training them and storing their results, that is loss/accuracy curves for each model, cumulative confusion matrices, as well as a .txt file including a confusion matrix for each fold. 

#### 4B. [Training Using a Holdout Set](#Training-Using-a-Holdout-Set)

This part is about using a holdout set with a train/val/test split for testing new implementations.

[^](#Neural-Networks-on-Graphs)

The model parameters than can be defined are:

- Graph Network Architecture choosing between:
     - **"Multiple GN"**, 
     - **"Encode - Core - Decode"** and a 
     - **"Interaction Model"**
- A model's inner update function characteristics, that is its number of layers and number of units. For each model the same number of units are used for edge, node and global update function.
- A model's depth for the "Multiple_GN" and "Interaction Model" architecture, and recurrent steps for the "Encode-Core-Decode" model.
- On which model's units (edge, block, global) dropout will be applied and with what **keep rate**.
- Skp connections parameter.
- Training parameters, like:
    - **inner/outer number of folds** 
    - **learning rate**
    - **batch size**
    - **number of epochs**
    
The skip connection parameter was tested in several variants:

- It was tested as a Residual connection on a graph level, where if $M(X)$ a function (a step of the learning model) that is applied on a graph, then the residual block ensures the identity function is learned, $Res(M)(x) = M(x) + x$ , that is performs an element-wise addition between the calculated block and the initial input.
- It was tested as a skip connection of concatenation between graph's parameters. That is if $M(X)$ a function that is applied on a graph, then the skip block ensures that a part of the output skips the $M$ functions and is concatenated on the result, that is $concat([M(x), x], axis=1)$. As this could lead to a blow-up, each block's models are of the form $(2*n , n)$, where $n$ is the number of units that ensures that the model's size does not increase with more depth.

## Training using Cross Validation

Training with Cross Validation includes the following steps:

- Instantiate model
- Build Graph - Calculate training/test loss
- Set Optimizer
- Initialize tf.Session
- Initialize log variables
- Initialize tf.Saver
- Log of training session, including **Iterations, Time, Losses, Accuracy**
- Termination based on validation score by checking if improves over session
- Saver to create checkpoints and restore model with best validation loss after termination
- Calculate Confusion Matrices for training/testing sets and store them
- Visualize Loss, Accuracy curves along training (save them to folder?)

After finishing training and evaluation for each fold, function returns a list of tuples of training/testing Confusion Matrices as numpy arrays.

In [None]:
# Some Default Settings
# Multiple_GN, InteractionModel
# =============================
# GN_Sizes_cv = [(16,16), (32,32), (64,64), (96,96)]
# GN_model_steps = [2]
# If use skip_connections choose any number of GN_model_steps (e.g. 6, 8 etc.)
# but then must choose sizes of the form (64,32) or (48,24)
# 
#
# EncodeDecode
# =============================
# GN_Sizes_cv = [(32,32), (64,64)]
# GN_model_steps = [5]

In [None]:
GN_models_cv =  ["Multiple_GN", "InteractionModel"] #["Multiple_GN", "EncodeDecode", "InteractionModel"]
GN_Sizes_cv = [(32,32), (64,64)] 
GN_model_steps = [2]
GN_learning_rate_cv = [1e-5]

# The cartesian product for grid search or multiple training sessions
session_cv = itertools.product(GN_models_cv, GN_Sizes_cv, GN_model_steps, GN_learning_rate_cv)

for gn_model, gn_size, gn_model_step, gn_learning_rate in session_cv:
    CV_SETTINGS = {"outer_no_folds":5,
                   "inner_no_folds":5,
                   "num_epochs": 600,
                   "GN_model": gn_model,
                   "batch_size": 16,
                   "log_every_epoch" : 5,
                   "NN_size": {"edge_size":gn_size,
                               "node_size":gn_size,
                               "global_size":gn_size},
                   "NN_dropout":{"edge_dropout": False,
                                 "node_dropout": False,
                                 "global_dropout": False},
                   "dropout_keep_rate":1.00,
                   "use_skip_connections": False, # Applies only to Multiple_GN network
                   "GN_model_steps": gn_model_step,
                   "input_optimizer": tf.train.AdamOptimizer,
                   "learning_rate": gn_learning_rate,
                   "seed": 123}

    cm_total = run_k_cross_validation(X, Y, **CV_SETTINGS)
    
    # Extract some information for plots' titles and saved files
    title_cm = CV_SETTINGS["GN_model"] + '_{}'.format(CV_SETTINGS["GN_model_steps"])
    title_folds = '[{}x{}]'.format(str(CV_SETTINGS["outer_no_folds"]),\
                                   str(CV_SETTINGS["inner_no_folds"]))
    title_NN = '[{}N]'.format(CV_SETTINGS["NN_size"]['node_size'])
    title_LR = '[{}L]'.format(CV_SETTINGS["learning_rate"])
    title_BS = '[{}BS]'.format(CV_SETTINGS["batch_size"])

    title_report = '-'.join([title_cm, title_folds, title_NN, title_LR, title_BS])
    plots_cm_path = 'plots/Confusion_Matrices-'
    path = plots_cm_path + title_report + '.txt'
    with open(path, 'w') as f:
        f.write('np.' + np.array(cm_total).__repr__())

    cm_total_cumulative = np.sum(np.array(cm_total), axis = 0)

    plt.figure()
    plot_confusion_matrix(cm_total_cumulative[0], classes=[0,1], save_title = title_report+'-CMTrainingSet', 
                          save = True, title='Cumulative Confusion matrix of 5x5-fold Training Sets')

    plt.figure()
    plot_confusion_matrix(cm_total_cumulative[1], classes=[0,1], save_title = title_report+'-CMTestingSet', 
                          save = True, title='Cumulative Confusion matrix of 5x5-fold Testing Sets')

    # Transfer Output Files to Training Output Files Directory
    current_dir = os.path.join(os.getcwd(), 'plots')
    target_dir = os.path.join(current_dir, title_report)

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    session_files = os.listdir(current_dir)

    for f in session_files:
        if f.endswith(('.txt', '.pdf', '.png')):
            shutil.move(os.path.join(current_dir, f), target_dir)

## Training Using a Holdout Set

Similar to Cross Validation training with a Holdout Set includes the following steps:

- Instantiate model
- Build Graph - Calculate training/test loss
- Set Optimizer
- Initialize tf.Session
- Initialize log variables
- Initialize tf.Saver
- Log of training session, including **Iterations, Time, Losses, Accuracy**
- Termination based on validation score by checking if improves over session
- Saver to create checkpoints and restore model with best validation loss after termination
- Calculate Confusion Matrices for training/testing sets and store them
- Visualize Loss, Accuracy curves along training

After finishing training and evaluation, training/testing Confusion Matrices as numpy arrays are calculated for visualization of the model.

### Defining training Architecture

After splitting a dataset in some target sets for defining a training architecture, use the corresponding cell.

1. #### Encode - Process - Decode Model
2. #### Multiple Graph Network Model
3. #### Interaction Model


In [None]:
X_train, Y_train, X_val, Y_val, X_test, Y_test = train_val_test(X, Y, splits = (0.7,0.15,0.15), stratify=True)

In [None]:
# Reset Tensorflow Graph
tf.reset_default_graph()

# Data / training parameters.
num_epochs = 600
batch_size_tr = 8
batch_size_val = Y_val.shape[0]

# Data.
# Input and target placeholders
input_ph, target_ph = create_placeholders(X_train, Y_train, batch_size_tr)
keep_rate = tf.placeholder(tf.float32, shape=(), name = "keep_rate_ph")

### 1. Encode - Process - Decode Model

Parameters to change:

- num_processing_steps on training and validation graph (can be different)
- edge/node/global MLP parameters
- edge/node/global dropout enable
- learning_rate

In [None]:
# -1- ENCODE-PROCESS-DECODE MODEL

# Model parameters - Number of processing (message-passing) steps.
num_processing_steps_tr = 5
num_processing_steps_val = 5

# 1. Instantiate Model
graph_block_kwargs = {'edge_mlp': (32, 32), 
                      'node_mlp': (32, 32), 
                      'global_mlp': (32, 32),
                      'edge_dropout': False,
                      'node_dropout': False,
                      'global_dropout': False
                     }

model = EncodeProcessDecode(graph_block_kwargs, global_output_size = Y_train.shape[-1])

# To add Dropout add is_training flag to distinguish between training_validation

# 2. Build Graph - Calculate training/test loss.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_val = model(input_ph, num_processing_steps_val)

loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr

loss_ops_val = create_loss_ops(target_ph, output_ops_val)
loss_op_val = loss_ops_val[-1]

# 3. Optimizer.
learning_rate = 1e-5
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

### 2. Forward Graph Network Blocks Model

Parameters to change:

- no of graph network blocks
- edge/node/global MLP parameters
- edge/node/global dropout enable
- enable skip connections between Multiple GN blocks
- learning_rate

In [None]:
# -2- FORWARD GRAPH-NETWORK BLOCKS

# 1. Instantiate Model

graph_block_kwargs = {'edge_mlp': (32, 32), 
                      'node_mlp': (32, 32), 
                      'global_mlp': (32, 32),
                      'edge_dropout': False,
                      'node_dropout': False,
                      'global_dropout': False
                     }

GRAPH_NETWORK_OPTIONS = {'graph_block_kwargs': graph_block_kwargs,
                         'no_graph_network_blocks': 4,
                         'use_skip_connections':True,
                         'global_output_size': Y_train.shape[-1]}

model =  MultiGraphNetwork(**GRAPH_NETWORK_OPTIONS)

# 2. Build Graph - Calculate training/test loss.
output_ops_tr = model(input_ph)
output_ops_val = model(input_ph)
loss_op_tr = create_loss_ops(target_ph, output_ops_tr)
loss_op_val = create_loss_ops(target_ph, output_ops_val)

# 3. Optimizer.
learning_rate = 3e-5
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

In [None]:
# total_parameters = 0
# for variable in tf.trainable_variables():
#     # shape is an array of tf.Dimension
#     shape = variable.get_shape()
# #     print(shape)
# #     print(len(shape))
#     variable_parameters = 1
#     for dim in shape:
# #         print(dim)
#         variable_parameters *= dim.value
#     print(variable_parameters)
#     total_parameters += variable_parameters
# print(total_parameters)

## 3. Interaction Model

- no of interaction network blocks
- edge/node MLP parameters
- edge/node dropout enable
- learning_rate

In [None]:
# -3- INTERACTION NETWORK BLOCKS
# 1. Instantiate Model

graph_block_kwargs = {'edge_mlp': (32, 32), 
                      'node_mlp': (32, 32), 
                      'edge_dropout': False,
                      'node_dropout': False
                     }

GRAPH_NETWORK_OPTIONS = {'graph_block_kwargs': graph_block_kwargs,
                         'no_interaction_network_blocks': 2,
                         'global_output_size': Y_train.shape[-1]}

model =  InteractionModel(**GRAPH_NETWORK_OPTIONS)

# 2. Build Graph - Calculate training/test loss.
output_ops_tr = model(input_ph)
output_ops_val = model(input_ph)
loss_op_tr = create_loss_ops(target_ph, output_ops_tr)
loss_op_val = create_loss_ops(target_ph, output_ops_val)

# 3. Optimizer.
learning_rate = 3e-5
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

### Tensorflow Graph initialization

- Initializie tensorflow graph
- Specify threads to be used
- Initialize Global Variables
- Initialize Log Variables
- Initialize Saver and Checkpoint Variables

In [None]:
# This cell resets the Tensorflow session, but keeps the same computational graph

try:
    sess.close()
except NameError:
    pass

sess = tf.Session(config = tf.ConfigProto(intra_op_parallelism_threads = 4,
                                          inter_op_parallelism_threads = 4,
                                          log_device_placement = False))

sess.run(tf.global_variables_initializer())

log_variables = {'last_iteration':0,
                 'last_epoch':0,
                 'logged_iterations':[],
                 'losses_tr':  [],
                 'losses_val': [],
                 'acc_cum_tr': [],
                 'acc_cum_val': []}

saver = tf.train.Saver()
save_dir = 'checkpoints/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'best_validation')

save_variables = {'losses_best_val': 100,
                  'improved_at': 0,
                  'no_improve_limit': 100
}

## Graph Network training session

Training session includes:
- log of training session, including **Iterations, Time, Losses, Accuracy**
- termination based on validation score by checking if improves over session
- saver to create checkpoints and restore model with best val loss after termination

In [None]:
print("# (Epoch number), T (elapsed seconds), "
      "Ltr (training loss), Lval (test/generalization loss), "
      "AccTr (accuracy of training set), "
      "AccVal (accuracy of test set)")

start_time = time.time()
iteration = 0

try:
    for epoch in range(log_variables['last_epoch'], num_epochs):
        
        batch_generator = graph_batch(np.random.permutation(np.arange(len(X_train))), batch_size_tr)
        
        for batch in batch_generator:
            
            log_variables['last_iteration'] = iteration
            feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [
                                        [X_train[ind] for ind in batch],
                                        [Y_train[ind] for ind in batch],
                                        1.00])
            train_values = sess.run({
              "step": step_op,
              "target": target_ph,
              "loss": loss_op_tr,
              "outputs": output_ops_tr
            },
                feed_dict=feed_dict)
            
            iteration += 1

        feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_train, Y_train, 1.00])
        test_values = sess.run({
            "target": target_ph,
            "loss": loss_op_val,
            "outputs": output_ops_val
        },
            feed_dict=feed_dict)
            
        feed_dict=create_feed_dict([input_ph, target_ph, keep_rate], [X_val, Y_val, 1.00])
        test_values = sess.run({
            "target": target_ph,
            "loss": loss_op_val,
            "outputs": output_ops_val
        },
            feed_dict=feed_dict)

        if epoch%5 == 0:
            elapsed = time.time() - start_time
            log_variables = compute_log(epoch, elapsed, train_values, test_values, log_variables)

        save_variables, improve_flag = check_saver(epoch, test_values, save_variables)

        if improve_flag is True:
            saver.save(sess=sess, save_path=save_path)

        if log_variables['last_epoch'] - save_variables['improved_at'] >= save_variables['no_improve_limit']:
            print ("No improvement found for #{} iterations".format(save_variables['no_improve_limit']))
            print ("Restoring session with lowest validation Loss: ", save_variables['losses_best_val'])
            saver.restore(sess=sess, save_path=save_path)
            print ("Exiting training session at iteration number: #{}".format(iteration))
            break
            
except KeyboardInterrupt:
    print("\n")
    print("Exiting training session at epoch number: #{}".format(log_variables['last_epoch']))
    print ("Restoring session with lowest validation Loss: ", save_variables['losses_best_val'])
    saver.restore(sess=sess, save_path=save_path)
    
print ("Restoring session with lowest validation Loss: ", save_variables['losses_best_val'])
saver.restore(sess=sess, save_path=save_path)

In [None]:
plot_result_curves(log_variables)

In [None]:
x_tr, pred_tr = predictions(train_values["target"], train_values["outputs"])
x_val, pred_val = predictions(test_values["target"], test_values["outputs"])

In [None]:
no_of_labels = Y_train.shape[-1]

cm_tr = np.zeros((no_of_labels,no_of_labels))
for a, p in zip(x_tr, pred_tr):
    cm_tr[a][p] += 1
    
cm_val = np.zeros((no_of_labels,no_of_labels))
for a, p in zip(x_val, pred_val):
    cm_val[a][p] += 1

In [None]:
plt.figure()
plot_confusion_matrix(cm_tr, classes=list(range(no_of_labels)), save_title = 'Multiple_GN_2', save = False,
                      title='Cumulative Confusion matrix of 10-fold Training Sets')