# Sparse GGNN low-bandwidth model for dense hardware

In this notebook we demonstrate the key parts of our implementation of the low-bandwidth graph message propagation algorithm for Graph Neural Networks. This is code accompanying the ICLR 2020 submission **Fast Training of Sparse Graph Neural Networks on Dense Hardware** by Anonymous Authors.

Licensed under the Apache License, Version 2.0'

## The model

This section shows the code that is executed on the dense hardware (TPU). We start by showing the code for the crucial inner loop that performs one step of graph message passing. The function, called `build_one_timestep`, corresponds to Algorithm 1 in the paper. Later in this section we show how the `build_one_timestep` function is invoked within the full model, and the next section shows how the input pipeline prepares the data in the format that the model expects.

The `build_one_timestep` function assumes that it receives input data as an `EinsumAdjacencyInformation` namedtuple containing Tensors of following shapes:

In [0]:
EinsumAdjacencyInformation = collections.namedtuple(
    "EinsumAdjacencyInformation",
    ["main_diagonal", "superior_diagonal", "inferior_diagonal"]
)
"""
- main_diagonal: [N'/S', S'*P, S] Tensor
- superior_diagonal: [N/S-1, S'*P, S] Tensor
- inferior_diagonal: [N/S-1, S'*P, S] Tensor
"""

This is the `build_one_timestep` function corresponding to Algorithm 1 in the paper:

In [0]:
def build_one_timestep(node_states, weights, adjacency_information, biases):
  """Returns new node states after a single message passing timestep.

  Args:
    node_states: [N, H] Tensor of previous node states.
    weights: Dictionary of weights, as returned by `build_weights`.
    adjacency_information: AdjacencyInformation namedtuple containing
        adjacency information required for one graph message propagation step.
    biases: [N, H] Tensor of precomputed biases to add for each node.
  Returns:
    new_node_states: [N, H] Tensor of new node states.
  """

  # Extract all relevant parameters from self.params.
  num_edge_types = params["data"]["num_edge_types"]
  num_nodes = params["data"]["num_nodes"]
  block_size = params["model"]["block_size"]
  assert num_nodes % block_size == 0
  num_blocks = num_nodes // block_size
  hidden_size = params["model"]["hidden_size"]

  # Reshape node states in preparation for the message passing einsums.
  node_states_for_diagonal = tf.reshape(
      node_states,
      [num_blocks, block_size, hidden_size]
  )  # [N'/S', S, H]

  # Message passing einsums.
  messages_from_main_diagonal = tf.einsum(
      "kzv,kvh->kzh",
      adjacency_information.main_diagonal,  # [N'/S, S'*P, S]
      node_states_for_diagonal  # [N'/S, S, H]
  )  # [N'/S, S'*P, H]
  messages_from_superior_diagonal = tf.einsum(
      "kzv,kvh->kzh",
      adjacency_information.superior_diagonal,  # [N/S-1, S'*P, S]
      node_states_for_diagonal[1:, :, :]  # [N/S-1, S, H]
  )  # [N'/S-1, S'*P, H]
  messages_from_inferior_diagonal = tf.einsum(
      "kzv,kvh->kzh",
      adjacency_information.inferior_diagonal,  # [N/S-1, S'*P, S]
      node_states_for_diagonal[:-1, :, :]  # [N/S-1, S, H]
  )  # [N'/S-1, S'*P, H]

  # Concat incoming messages with zeros to align the shapes before summing.
  zeros = tf.zeros(
      [1, block_size * num_edge_types, hidden_size],
      dtype=tf.float32
  )  # [1, S'*P, H]
  messages_from_superior_diagonal_aligned = tf.reshape(
      tf.concat([
          messages_from_superior_diagonal,
          zeros
      ], axis=0),  # [N'/S', S'*P, H]
      [num_blocks, block_size * num_edge_types, hidden_size]
  )  # [N'/S', S'*P, H]
  messages_from_inferior_diagonal_aligned = tf.reshape(
      tf.concat([
          zeros,
          messages_from_inferior_diagonal
      ], axis=0),  # [N'/S', S'*P, H]
      [num_blocks, block_size * num_edge_types, hidden_size]
  )  # [N'/S', S'*P, H]

  # Aggregate the incoming messages by summing pointwise.
  messages = (
      messages_from_main_diagonal
      + messages_from_superior_diagonal_aligned
      + messages_from_inferior_diagonal_aligned
  )  # [N'/S, S'*P, H]

  # Reshape into [N', P*H'] before returning.
  messages = tf.reshape(
      messages,
      [num_nodes, num_edge_types * hidden_size],
      name="messages_reshaped"
  )  # [N', P*H']

  # Pass incoming messages through a linear layer.
  messages_passed = tf.einsum(
      "ny,yh->nh",  # y=ph
      messages,  # [N', P*H]
      weights["edge_weights"],  # [P*H, H]
      name="messages_passed"
  )  # [N', H]

  # Add the edge biases.
  messages_passed += biases

  # Update `new_node_states` using the built recurrent unit.
  new_node_states = weights["rnn_cell"](
      incoming_information,
      node_states
  )[1]  # [N, H]
  return new_node_states

The `build_one_timestep` function referred to a dictionary of `weights`, which contains all the trainable parameters of the model. They are build in the following `build_weights` function:

In [0]:
def build_weights(self, hidden_size, num_edge_types, dropout_keep_prob):
  """Builds and returns weights used by the GGNN propagation model.

  Returns:
    A dictionary of weights mapping from weight names to lists of Tensors.
  """

  weights = {}

  # Build edge weights.
  weights["edge_weights"] = tf.get_variable(
        "gnn_edge_weights",
        shape=[num_edge_types * hidden_size, hidden_size],
        initializer=tf.contrib.layers.xavier_initializer()
  )  # [P*H, H]

  # Build edge biases.
  weights["edge_biases"] = tf.get_variable(
        "gnn_edge_biases", shape=[num_edge_types, hidden_size]
  )  # [P, H]

  # Build RNN cell.
  cell = tf.nn.rnn_cell.GRUCell(hidden_size, activation=activation_fun)
  if self.mode == tf.estimator.ModeKeys.TRAIN:
    cell = tf.nn.rnn_cell.DropoutWrapper(
        cell, state_keep_prob=dropout_keep_prob)
  weights["rnn_cells"] = cell

  # Build dense output layer weights.
  weights["output_weights"] = tf.get_variable(
      "output_weights",
      shape=[hidden_size, 1],
      initializer=tf.contrib.layers.xavier_initializer())

  return weights

The function that performs the loop across $T$ graph message passing propagation steps and thus produces the final node embeddings, is the following:

In [0]:
def compute_final_node_representations(
    weights, initial_node_representation, adjacency_information,
    num_incoming_edges_per_type):
  """Builds graph message passing computation.

  Args:
    weights: Dictionary of weights, as constructed by `build_weights`.
    initial_node_representation: Tensor of initial node states from the input
        pipeline, in a shape that can be reshaped to [N, H].
    adjacency_information: A namedtuple with the adjacency information
        required to compute incoming messages using `compute_messages`.
    num_incoming_edges_per_type: [P, N] Tensor of per-type inedge counts.
  Returns:
    [N, H] Tensor with node embeddings after graph propagation.
  """

  # Edge biases don't depend on the timestep, and so can be precomputed.
  biases = tf.einsum(
      "np,ph->nh",
      num_incoming_edges_per_type,  # [N', P]
      weights["edge_biases"],  # [P, H]
      name="biases"
  )  # [N', H]

  # Perform the number T of propagation steps specified.
  num_timesteps = params["model"]["num_timesteps"]
  node_states = initial_node_representation
  for step in range(num_timesteps):
    node_states = self.build_one_timestep(
        node_states, weights, adjacency_information, biases)

  return node_states

This concludes our implementation of a GGNN encoder with low-bandwidth graph message passing. The output of the GGNN encoder is then fed into a readout layer that computes the predictions and the loss:

In [0]:
def build_model(self, features, labels):
  """Builds the part of the model that is to be run on a TPU."""

  num_nodes = params["data"]["num_nodes"]
  hidden_size = params["model"]["hidden_size"]

  weights = self.build_weights()
  adjacency_information, num_incoming_edges_per_type = (
      self.prepare_adjacency_information(features))
  final_node_representations = self.compute_final_node_representations(
      weights,
      tf.reshape(features["initial_node_states"], [num_nodes, hidden_size]),
      adjacency_information,
      num_incoming_edges_per_type
  )

  # Map from final node representations to predicted labels.
  node_logits = tf.squeeze(
      tf.einsum(
          "nh,hi->ni",
          final_node_representations,
          weights["output_weights"]),
      axis=-1,
      name="node_logits"
  )  # [N]
  loss = self.build_loss(
      node_logits,
      features["candidates"],
      labels
  )
  return loss

def build_loss(self,
               node_logits,
               candidates,
               labels):
  """Builds the output layer and returns the resulting loss Op.

  Args:
    node_logits: [N] Tensor of logits (pre-softmax probabilities) that give
        the model's predictions of which candidate is correct.
    candidates: [N] "batch hot" Tensor. A nonzero value indicates whether
        each node is a candidate. The value gives the index of the graph (with
        indices starting at 1) where the node is a candidate.
    labels: [N] binary Tensor indicating whether each node is a correct
        solution. Used as the target in computing the loss.
  Returns:
    A scalar Tensor `loss` containing the value of the loss.
  """

  # Extract relevant parameters.
  max_graphs_per_supergraph = self.params["data"]["max_graphs_per_supergraph"]
  label_smoothing = self.params["training"]["label_smoothing"]

  # Compute binary mask of candidate nodes.
  candidates_one_hot = tf.clip_by_value(candidates, 0.0, 1.0)

  # Apply softmax within each graph independently.
  groups = tf.cast(candidates[0, :], tf.int32)
  num_groups = max_graphs_per_supergraph + 1
  node_logprobs = log_softmax_by_group(
      node_logits[0, :],
      groups,
      num_groups,
  )
  node_logprobs = tf.expand_dims(node_logprobs, 0)

  # Apply label smoothing.
  labels_smoothed = smooth_by_group(
      labels[0, :], groups, num_groups, label_smoothing)
  # Remove label smoothing from group 0 (non-candidates).
  labels_smoothed = tf.multiply(candidates_one_hot, labels_smoothed)

  losses_per_node = tf.multiply(labels_smoothed, -node_logprobs)  # [N]
  loss_normalizer = tf.reduce_sum(labels)

  # Mask out losses of non-candidate nodes.
  losses_masked = tf.multiply(candidates_one_hot, losses_per_node)  # [N]

  # Compute average loss across both batch and superbatch.
  loss = tf.reduce_sum(losses_masked) / loss_normalizer
  return loss

The `build_model` function calls `prepare_adjacency_information`, which is a helper function that converts a dictionary of `features` passed to the dense hardware (TPU) by the input pipeline (see below).

In [0]:
def prepare_adjacency_information(features):
  """Builds Ops preprocessing features for graph propagation.

    Args:
      features: Dictionary of feature tensors from the input pipeline.
    Returns:
      adjacency_information: AdjacencyInformation namedtuple required for one
          graph message propagation step.
      num_incoming_edges_per_type: [P, N] Tensor of incoming edge counts per
          edge type (P) and node (N). This Tensor is used in computing the edge
          bias terms in graph message propagation.
    """

  # Extract all relevant parameters from self.params.
  num_edge_types = self.params["data"]["num_edge_types"]
  num_nodes = self.params["data"]["num_nodes"]
  block_size = self.params["model"]["block_size"]
  assert num_nodes % block_size == 0
  num_blocks = num_nodes // block_size
  transfer_dtype = self.params["input_pipeline"]["transfer_dtype"]

  # Extract relevant features.
  main_diagonal = features["main_diagonal"]
  superior_diagonal = features["superior_diagonal"]
  inferior_diagonal = features["inferior_diagonal"]
  num_inedges_per_type = features["num_incoming_edges_per_type"]

  # Unflatten the features received from the input pipeline.
  main_diagonal = tf.reshape(
      main_diagonal,
      [num_blocks, block_size * num_edge_types, block_size]
  )  # [N'/S', S'*T, S]
  superior_diagonal = tf.reshape(
      superior_diagonal,
      [num_blocks - 1, block_size * num_edge_types, block_size]
  )  # [N/S-1, S'*T, S]
  inferior_diagonal = tf.reshape(
      inferior_diagonal,
      [num_blocks - 1, block_size * num_edge_types, block_size]
  )  # [N/S-1, S'*T, S]
  num_inedges_per_type = tf.reshape(
      num_inedges_per_type,
      [num_nodes, num_edge_types]
  )  # [N, P]

  # Cast back to tf.float32 if necessary.
  if transfer_dtype != tf.float32:
    main_diagonal = tf.cast(main_diagonal, tf.float32)
    superior_diagonal = tf.cast(superior_diagonal, tf.float32)
    inferior_diagonal = tf.cast(inferior_diagonal, tf.float32)
    num_inedges_per_type = tf.cast(num_inedges_per_type, tf.float32)

  adjacency_information = EinsumAdjacencyInformation(
      main_diagonal=main_diagonal,
      superior_diagonal=superior_diagonal,
      inferior_diagonal=inferior_diagonal
  )
  return adjacency_information, num_inedges_per_type

The `build_loss` function made reference to two following two helper functions, which are used to respectively apply softmax and smoothing on a per-graph basis within a supergraph.

In [0]:
def log_softmax_by_group(node_scores, groups, num_groups):
  """Compute per-graph log-softmaxes without any per-graph batching.

  Args:
    node_scores: A num_nodes tensor of per-node scores.
    groups: A num_nodes tensor of normalizing group identities.
    num_groups: Upper bound on the max value in `groups` + 1.
  Returns:
    A tensor the same shape as node_scores with scores per group normalized so
    so that exponentiated results sum to 1 within each group.
  """

  # Subtract per-group maxes for numerical stability.
  group_maxes = tf.unsorted_segment_max(node_scores, groups, num_groups)
  stablizers = tf.gather(group_maxes, groups)
  stablized_scores = node_scores - stablizers

  # Compute per-group normalizing constants after subtracting stablizers.
  group_sum_exps = tf.unsorted_segment_sum(tf.exp(stablized_scores),
                                           groups,
                                           num_groups)
  normalizers = tf.gather(group_sum_exps, groups)

  # Normalize scores to get logprobs.
  node_logprobs = stablized_scores - tf.log(normalizers)
  return node_logprobs

def smooth_by_group(labels, groups, num_groups, smoothing):
  """Applies label smoothing to `labels`, per group.

  Args:
    labels: A binary tensor.
    groups: A non-negative integer tensor of the same shape as `labels`.
    num_groups: An upper bound on the integers in `groups`, plus one.
    smoothing: A float in [0, 1] parametrising the strength of smoothing.
  Returns:
    Tensor of floats in [0, 1] of the same shape as `labels`, obtained by
    applying label smoothing [Szegedy et al., 2016] to each group individually.
  """

  group_sizes_gathered = tf.gather(
      tf.unsorted_segment_sum(tf.ones_like(labels), groups, num_groups),
      groups
  )
  on_value = 1.0 - smoothing
  off_values = tf.divide(smoothing, tf.maximum(1.0, group_sizes_gathered - 1))
  labels_smoothed = tf.add(labels * (on_value - off_values), off_values)
  return labels_smoothed

## The input pipeline

The crucial part of the input pipeline that is specific to our low-bandwidth model is the following `_compute_einsum_representation` function, which takes as input a sparse representation of the non-zero indices in the adjacency matrices (see docstring in the code below), and computes the densified block diagonals to be passed to the dense hardware (TPU).



In [0]:
def _compute_einsum_representation(params, adjacency_matrix_indices):
  """Transforms adjacency matrix indices into representation for Einsum.
  
  Args:
    params: Dictionary of parameters.
    adjacency_matrix_indices: [B, E, 3] Tensor representing the non-zero indices
      of the adjacency matrices. Here B is the superbatch dimension (number of
      supergraphs) and E is the maximum number of edges in a supergraph. The
      three numbers in the last dimension give the edge type number, source
      index and target index of each edge, respectively.
  Returns:
    Dense Tensors to be passed to the dense hardware.
  """

  # Extract all relevant parameters from params.
  block_size = params["model"]["block_size"]
  num_edge_types = params["data"]["num_edge_types"]
  num_nodes = params["data"]["num_nodes"]
  max_num_edges = params["data"]["max_num_edges"]
  superbatch_size = params["batch_size"]
  transfer_dtype = params["input_pipeline"]["transfer_dtype"]
  
  # To construct dense adjacency matrices as a Tensor of shape [B, P, N, N'],
  # we first need to transform the [B, E, 3] `adjacency_matrix_indices` into
  # an `indices` Tensor of shape [B*E, 4], with the new first column (in the
  # last dimension) containing the superbatch index.
  indices_with_superbatch = prepend_first_coordinate_to_third_dimension(
      adjacency_matrix_indices, superbatch_size, max_num_edges)  # [B, E, 4]
  indices = tf.reshape(indices_with_superbatch, [-1, 4])  # [B*E, 4]

  # Filter out padding
  indices = tf.boolean_mask(
      indices,
      tf.greater_equal(indices[:, 1], 0)
  )  # [B*E, 4]

  target_block_indices = tf.floor_div(indices[:, 3], block_size)
  source_block_indices = tf.floor_div(indices[:, 2], block_size)

  # Distribute the edge lists according to whether each edge lies on the
  # main (block) diagonal, the superior (block) diagional, or the inferior
  # (block) diagonal in the *TRANSPOSED* adjacency matrix.
  main_diagonal_indices = tf.boolean_mask(
      indices,
      tf.equal(target_block_indices, source_block_indices)
  )  # [# edges on main block diagonal, 4]
  superior_diagonal_indices = tf.boolean_mask(
      indices,
      tf.equal(target_block_indices + 1, source_block_indices)
  )  # [# edges on superior block diagonal, 4]
  inferior_diagonal_indices = tf.boolean_mask(
      indices,
      tf.equal(target_block_indices, source_block_indices + 1)
  )  # [num_edges_on_inferior_diagonal, 4]

  # Compute incoming edge counts here, as it's simpler to do before the
  # splitting up into blocks that happens below.
  used_indices = tf.concat([main_diagonal_indices,
                            superior_diagonal_indices,
                            inferior_diagonal_indices], axis=0)

  used_indices_tensor = tf.SparseTensor(
      indices=used_indices,
      values=tf.ones_like(used_indices[:, 0], dtype=tf.float32),
      dense_shape=[superbatch_size, num_edge_types, num_nodes, num_nodes])

  # [B, P, N]
  num_incoming_edges_by_type = tf.sparse_reduce_sum(used_indices_tensor, 2)
  num_incoming_edges_by_type.set_shape([superbatch_size,
                                        num_edge_types,
                                        num_nodes])
  # [B, N, P]
  num_incoming_edges_by_type = tf.transpose(num_incoming_edges_by_type,
                                            [0, 2, 1])

  # Transform the indices such that they refer to the block index and then
  # the indices within that block.
  main_diagonal_indices = tf.stack([
      main_diagonal_indices[:, 0],
      tf.floor_div(main_diagonal_indices[:, 2], block_size),
      tf.mod(main_diagonal_indices[:, 3], block_size),
      main_diagonal_indices[:, 1],
      tf.mod(main_diagonal_indices[:, 2], block_size),
  ], axis=1)  # [num_edges_on_main_diagonal, 5]
  superior_diagonal_indices = tf.stack([
      superior_diagonal_indices[:, 0],
      tf.floor_div(superior_diagonal_indices[:, 3], block_size),
      tf.mod(superior_diagonal_indices[:, 3], block_size),
      superior_diagonal_indices[:, 1],
      tf.mod(superior_diagonal_indices[:, 2], block_size)
  ], axis=1)  # [num_edges_on_superior_diagonal, 5]
  inferior_diagonal_indices = tf.stack([
      inferior_diagonal_indices[:, 0],
      tf.floor_div(inferior_diagonal_indices[:, 2], block_size),
      tf.mod(inferior_diagonal_indices[:, 3], block_size),
      inferior_diagonal_indices[:, 1],
      tf.mod(inferior_diagonal_indices[:, 2], block_size)
  ], axis=1)  # [num_edges_on_inferior_diagonal, 5]

  # Compute dense representations of the three block diagonals.
  main_diagonal = tf.sparse_to_dense(
      main_diagonal_indices,
      [superbatch_size, num_nodes // block_size, block_size, num_edge_types,
        block_size],
      tf.ones_like(main_diagonal_indices[:, 0], dtype=transfer_dtype),
      validate_indices=False,
      name="main_diagonal"
  )  # [B, N/S, S', P, S]
  superior_diagonal = tf.sparse_to_dense(
      superior_diagonal_indices,
      [superbatch_size, num_nodes // block_size - 1, block_size,
        num_edge_types, block_size],
      tf.ones_like(superior_diagonal_indices[:, 0], dtype=transfer_dtype),
      validate_indices=False,
      name="superior_diagonal"
  )  # [B, N/S-1, S', P, S]
  inferior_diagonal = tf.sparse_to_dense(
      inferior_diagonal_indices,
      [superbatch_size, num_nodes // block_size - 1, block_size,
        num_edge_types, block_size],
      tf.ones_like(inferior_diagonal_indices[:, 0], dtype=transfer_dtype),
      validate_indices=False,
      name="inferior_diagonal"
  )  # [B, N/S-1, S', P, S]

  return (main_diagonal,
          superior_diagonal,
          inferior_diagonal,
          num_incoming_edges_by_type)

The function made reference to a standalone helper function `prepend_first_coordinate_to_third_dimension`:

In [0]:
def prepend_first_coordinate_to_third_dimension(
    tensor, first_dimension_size, second_dimension_size):
  """Prepends 1st coordinate index as a new column in the 3rd dimension.

  Args:
    tensor: A 3D Tensor of shape [A, B, C].
    first_dimension_size: Value of `A` (size of first dimension of `tensor`).
    second_dimension_size: Value of `B` (size of second dimension of `tensor`).
  Returns:
    Tensor of shape [A, B, C+1], where the (a, b, c) entry equals a if c = 0
    and tensor[a, b, c-1] otherwise.
  """
  first_coordinate_range = tf.range(
      first_dimension_size,
      dtype=tensor.dtype
  )  # [A]
  first_coordinate_range_reshaped = tf.reshape(
      first_coordinate_range,
      [first_dimension_size, 1, 1]
  )  # [A, 1, 1]
  first_coordinates = tf.broadcast_to(
      first_coordinate_range_reshaped,
      [first_dimension_size, second_dimension_size, 1]
  )  # [A, B, 1]
  result = tf.concat(
      [first_coordinates, tensor],
      axis=2
  )  # [A, B, C+1]
  return result

To show how `_compute_einsum_representation` is invoked, we assume access to a function `_load_tpu_batch_data` that implements the sparse batching logic and returns a padded supergraph. The following `build_input_pipeline` function assembles the `features` dictionary passed over to the dense hardware (TPU).

In [0]:
def build_input_pipeline(params, split):
  """Builds the input pipeline as a TensorFlow computation.

  This function can be called in the `input_fn` of a (TPU)Estimator, and the
  returned `features` and `labels` (see below) are suitable for being in turn
  returned by the `input_fn` itself. The (TPU)Estimator then ensures that these
  values are passed to the `model_fn` of the (TPU)Estimator.

  Args:
    params: Dictionary of parameters from the (TPU)Estimator API's `input_fn`.
        Contains four keys:
        - "batch_size", mapping to size of the superbatch that this input
          pipeline is expected to produce. This is a keyword reserved by
          TPUEstimator, and it is populated automatically.
        - The three keys "data", "model", and "training", all of which map to
          dictionaries themselves, with each dictionary mapping from parameter
          names to values.
    split: A DataSplit enum value.
  Returns:
    `features` and `labels`, the two values to be returned by the `input_fn`.
  """

  # Retrieve relevant parameters from params.
  problem = params["data"]["problem_name"]
  data_dir = params["data"]["data_dir"]
  max_bandwidth = params["data"]["max_bandwidth"]
  num_nodes = params["data"]["num_nodes"]
  max_num_edges = params["data"]["max_num_edges"]
  seed = params["input_pipeline"]["seed"]
  data_loading_config = params["input_pipeline"]["data_loading"]

  # The `batch_size` entry in `params` is auto-populated by TPUEstimator,
  # which ensures that the value is per-host or per-core, as appropriate. See
  # https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimator
  superbatch_size = params["batch_size"]

  # Retrieve information about this input pipeline invocation.
  if "context" in params:
    _, invocation_index, total_invocations, _ = (
        params["context"].current_input_fn_deployment())
  else:
    invocation_index = 0
    total_invocations = 1

  # Retrieve padded TPU data.
  holes, candidates, candidate_indices, solutions, edges = (
        _load_tpu_batch_data(
            problem, data_dir, split, num_nodes, max_bandwidth,
            max_num_edges, superbatch_size, invocation_index, total_invocations
        )
    )

  # Ensure static shapes of Tensors constructed from TensorArrays.
  holes = tf.reshape(
      holes, [superbatch_size, num_nodes], name="holes")
  candidates = tf.reshape(
      candidates, [superbatch_size, num_nodes], name="candidates")
  adjacency_matrix_indices = tf.reshape(
      edges, [superbatch_size, max_num_edges, 3],
      name="adjacency_indices")
  labels = tf.reshape(
      solutions, [superbatch_size, num_nodes], name="labels")

  # Concatenate node features and pad to hidden size for initial node states.
  candidates_one_hot = tf.clip_by_value(candidates, 0.0, 1.0)
  holes_one_hot = tf.clip_by_value(holes, 0.0, 1.0)
  node_annotations_list = [
      tf.expand_dims(holes_one_hot, axis=2),  # [superbatch_size, N, 1]
      tf.expand_dims(candidates_one_hot, axis=2),  # [superbatch_size, N, 1]
  ]
  node_annotations = tf.concat(node_annotations_list, axis=2)
  annotation_size = node_annotations.shape[2]
  assert hidden_size >= annotation_size
  initial_node_states = tf.pad(
      node_annotations,  # [superbatch_size, N, annotation_size]
      [[0, 0], [0, 0], [0, hidden_size - annotation_size]],
      name="initial_node_states"
  )  # [superbatch_size, N, H]

  # Convert adjacency information to format suitable for the einsum.
  (main_diagonal,
    superior_diagonal,
    inferior_diagonal,
    num_incoming_edges_per_type) = _compute_einsum_representation(
        params, adjacency_matrix_indices)

  # Explicitly flatten the large Tensors being sent to TPU.
  # If superbatch_size > 1, this dimension should be preserved for sharding.
  superbatch_size = params["batch_size"]
  target_shape = [-1] if superbatch_size == 1 else [superbatch_size, -1]
  initial_node_states = tf.reshape(
      initial_node_states, target_shape)
  main_diagonal = tf.reshape(main_diagonal, target_shape)
  superior_diagonal = tf.reshape(superior_diagonal, target_shape)
  inferior_diagonal = tf.reshape(inferior_diagonal, target_shape)

  # Return `features` and `labels`.
  features = {
      "initial_node_states": initial_node_states,
      "main_diagonal": main_diagonal,
      "superior_diagonal": superior_diagonal,
      "inferior_diagonal": inferior_diagonal,
      "num_incoming_edges_per_type": num_incoming_edges_per_type,
      "candidates": candidates,
  }
  return features, labels