# Copyright 2019 Google LLC.¶
Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Let's disable all the warnings first
import warnings
warnings.simplefilter("ignore")

These are all the dependencies that will be used in this notebook.

In [9]:
import abc
import collections
import contextlib
import copy
import random
import time

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import six
import sonnet as snt
import tensorflow as tf

通过多层感知机MLP对结点和边信息进行编码

In [None]:
class GraphEncoder(snt.AbstractModult):
    def __init__(self, node_hidden_sizes=None, edge_hidden_sizes=None, 
                 name="graph-encoder"):
        '''
        node_hidden_sizes:如果提供的是一个int型列表，即节点编码器网络的隐藏大小，最后一个元素是节点输出的大小。如果没有提供，节点特性将按原样通过。
        Edge_hidden_sizes:如果提供的应该是一个int型的列表，边缘编码器网络的隐藏大小，最后一个元素是边缘输出的大小。如果没有提供，边缘特征将通过。
        Name:模块名。
        
        应该是在MLP的基础上进行编码，
        '''
        super(GraphEncoder, self).__init__(name=name)

        # this also handles the case of an empty list
        self._node_hidden_sizes = node_hidden_sizes if node_hidden_sizes else None
        self._edge_hidden_sizes = edge_hidden_sizes
    
    def _build(self, node_features, edge_features=None):
        '''
         对节点和边进行编码。
         
         参数:
             node_features：[n_nodes, node_feat_dim] float tensor.
             edge_features: if provided, should be [n_edges, edge_feat_dim] float tensor.

          Returns:
              node_outputs: [n_nodes, node_embedding_dim] float tensor, node embeddings.
              edge_outputs: if edge_features is not None and edge_hidden_sizes is not
                None, this is [n_edges, edge_embedding_dim] float tensor, edge
                embeddings; otherwise just the input edge_features.
        '''
   
        if self._node_hidden_sizes is None:
          node_outputs = node_features
        else: #默认输出层
          node_outputs = snt.nets.MLP(
              self._node_hidden_sizes, name='node-feature-mlp')(node_features)

        if edge_features is None or self._edge_hidden_sizes is None:
          edge_outputs = edge_features
        else:
          edge_outputs = snt.nets.MLP(
              self._edge_hidden_sizes, name='edge-feature-mlp')(edge_features)

        return node_outputs, edge_outputs


传播层

In [None]:
def graph_prop_once(node_states,
                    from_idx,
                    to_idx,
                    message_net,
                    aggregation_module=tf.unsorted_segment_sum,
                    edge_features=None):
  """One round of propagation (message passing) in a graph.
  图的一轮传播过程

  Args:
    node_states: [n_nodes, node_state_dim] float tensor, node state vectors, one
      row for each node.
    from_idx: [n_edges] int tensor, index of the from nodes.
    to_idx: [n_edges] int tensor, index of the to nodes.
    message_net: a network that maps concatenated edge inputs to message
      vectors.
    aggregation_module: a module that aggregates messages on edges to aggregated
      messages for each node.  Should be a callable and can be called like the
      following,
      `aggregated_messages = aggregation_module(messages, to_idx, n_nodes)`,
      where messages is [n_edges, edge_message_dim] tensor, to_idx is the index
      of the to nodes, i.e. where each message should go to, and n_nodes is an
      int which is the number of nodes to aggregate into.
    edge_features: if provided, should be a [n_edges, edge_feature_dim] float
      tensor, extra features for each edge.

  Returns:
    aggregated_messages: an [n_nodes, edge_message_dim] float tensor, the
      aggregated messages, one row for each node.
  """
  from_states = tf.gather(node_states, from_idx)
  to_states = tf.gather(node_states, to_idx)

  edge_inputs = [from_states, to_states]
  if edge_features is not None:
    edge_inputs.append(edge_features)

  edge_inputs = tf.concat(edge_inputs, axis=-1)
  messages = message_net(edge_inputs)

  return aggregation_module(messages, to_idx, tf.shape(node_states)[0])


class GraphPropLayer(snt.AbstractModule):
  """Implementation of a graph propagation (message passing) layer."""

  def __init__(self,
               node_state_dim,
               edge_hidden_sizes,
               node_hidden_sizes,
               edge_net_init_scale=0.1,
               node_update_type='residual',
               use_reverse_direction=True,
               reverse_dir_param_different=True,
               layer_norm=False,
               name='graph-net'):
    """Constructor.

    Args:
      node_state_dim: int, dimensionality of node states.
      edge_hidden_sizes: list of ints, hidden sizes for the edge message
        net, the last element in the list is the size of the message vectors.
      node_hidden_sizes: list of ints, hidden sizes for the node update
        net.
      edge_net_init_scale: initialization scale for the edge networks.  This
        is typically set to a small value such that the gradient does not blow
        up.
      node_update_type: type of node updates, one of {mlp, gru, residual}.
      use_reverse_direction: set to True to also propagate messages in the
        reverse direction.
      reverse_dir_param_different: set to True to have the messages computed
        using a different set of parameters than for the forward direction.
      layer_norm: set to True to use layer normalization in a few places.
      name: name of this module.
    """
    super(GraphPropLayer, self).__init__(name=name)

    self._node_state_dim = node_state_dim
    self._edge_hidden_sizes = edge_hidden_sizes[:]

    # output size is node_state_dim
    self._node_hidden_sizes = node_hidden_sizes[:] + [node_state_dim]
    self._edge_net_init_scale = edge_net_init_scale
    self._node_update_type = node_update_type

    self._use_reverse_direction = use_reverse_direction
    self._reverse_dir_param_different = reverse_dir_param_different

    self._layer_norm = layer_norm

  def _compute_aggregated_messages(
      self, node_states, from_idx, to_idx, edge_features=None):
    """Compute aggregated messages for each node.

    Args:
      node_states: [n_nodes, input_node_state_dim] float tensor, node states.
      from_idx: [n_edges] int tensor, from node indices for each edge.
      to_idx: [n_edges] int tensor, to node indices for each edge.
      edge_features: if not None, should be [n_edges, edge_embedding_dim]
        tensor, edge features.

    Returns:
      aggregated_messages: [n_nodes, aggregated_message_dim] float tensor, the
        aggregated messages for each node.
    """
    self._message_net = snt.nets.MLP(
        self._edge_hidden_sizes,
        initializers={
            'w': tf.variance_scaling_initializer(
                scale=self._edge_net_init_scale),
            'b': tf.zeros_initializer()},
        name='message-mlp')

    aggregated_messages = graph_prop_once(
        node_states,
        from_idx,
        to_idx,
        self._message_net,
        aggregation_module=tf.unsorted_segment_sum,
        edge_features=edge_features)

    # optionally compute message vectors in the reverse direction
    if self._use_reverse_direction:
      if self._reverse_dir_param_different:
        self._reverse_message_net = snt.nets.MLP(
            self._edge_hidden_sizes,
            initializers={
                'w': tf.variance_scaling_initializer(
                    scale=self._edge_net_init_scale),
                'b': tf.zeros_initializer()},
            name='reverse-message-mlp')
      else:
        self._reverse_message_net = self._message_net

      reverse_aggregated_messages = graph_prop_once(
          node_states,
          to_idx,
          from_idx,
          self._reverse_message_net,
          aggregation_module=tf.unsorted_segment_sum,
          edge_features=edge_features)

      aggregated_messages += reverse_aggregated_messages

    if self._layer_norm:
      aggregated_messages = snt.LayerNorm()(aggregated_messages)

    return aggregated_messages

  def _compute_node_update(self,
                           node_states,
                           node_state_inputs,
                           node_features=None):
    """Compute node updates.

    Args:
      node_states: [n_nodes, node_state_dim] float tensor, the input node
        states.
      node_state_inputs: a list of tensors used to compute node updates.  Each
        element tensor should have shape [n_nodes, feat_dim], where feat_dim can
        be different.  These tensors will be concatenated along the feature
        dimension.
      node_features: extra node features if provided, should be of size
        [n_nodes, extra_node_feat_dim] float tensor, can be used to implement
        different types of skip connections.

    Returns:
      new_node_states: [n_nodes, node_state_dim] float tensor, the new node
        state tensor.

    Raises:
      ValueError: if node update type is not supported.
    """
    if self._node_update_type in ('mlp', 'residual'):
      node_state_inputs.append(node_states)
    if node_features is not None:
      node_state_inputs.append(node_features)

    if len(node_state_inputs) == 1:
      node_state_inputs = node_state_inputs[0]
    else:
      node_state_inputs = tf.concat(node_state_inputs, axis=-1)

    if self._node_update_type == 'gru':
      _, new_node_states = snt.GRU(self._node_state_dim)(
          node_state_inputs, node_states)
      return new_node_states
    else:
      mlp_output = snt.nets.MLP(
          self._node_hidden_sizes, name='node-mlp')(node_state_inputs)
      if self._layer_norm:
        mlp_output = snt.LayerNorm()(mlp_output)
      if self._node_update_type == 'mlp':
        return mlp_output
      elif self._node_update_type == 'residual':
        return node_states + mlp_output
      else:
        raise ValueError('Unknown node update type %s' % self._node_update_type)

  def _build(self,
             node_states,
             from_idx,
             to_idx,
             edge_features=None,
             node_features=None):
    """Run one propagation step.

    Args:
      node_states: [n_nodes, input_node_state_dim] float tensor, node states.
      from_idx: [n_edges] int tensor, from node indices for each edge.
      to_idx: [n_edges] int tensor, to node indices for each edge.
      edge_features: if not None, should be [n_edges, edge_embedding_dim]
        tensor, edge features.
      node_features: extra node features if provided, should be of size
        [n_nodes, extra_node_feat_dim] float tensor, can be used to implement
        different types of skip connections.

    Returns:
      node_states: [n_nodes, node_state_dim] float tensor, new node states.
    """
    aggregated_messages = self._compute_aggregated_messages(
        node_states, from_idx, to_idx, edge_features=edge_features)

    return self._compute_node_update(node_states,
                                     [aggregated_messages],
                                     node_features=node_features)