In [19]:
from typing import *
import tensorflow as tf
import functools
import jax.numpy as jnp
import jax
import jraph
import numpy as np
import os
import chex

import sys
sys.path.append('..')
import datatypes
import dynamic_batcher

In [2]:
class GraphsTupleSize(NamedTuple):
  """Helper class to represent padding and graph sizes."""
  n_node: int
  n_edge: int
  n_graph: int


def get_graphs_tuple_size(graph: jraph.GraphsTuple):
  """Returns the number of nodes, edges and graphs in a GraphsTuple."""
  return GraphsTupleSize(
      n_node=np.sum(graph.n_node),
      n_edge=np.sum(graph.n_edge),
      n_graph=np.shape(graph.n_node)[0])


def specs_from_graphs_tuple(graph: jraph.GraphsTuple):
  """Returns a tf.TensorSpec corresponding to this graph."""

  def get_tensor_spec(array: np.ndarray):
    shape = list(array.shape)
    dtype = array.dtype
    return tf.TensorSpec(shape=shape, dtype=dtype)

  specs = {}
  for field in [
      'nodes', 'edges', 'senders', 'receivers', 'globals', 'n_node', 'n_edge'
  ]:
    field_sample = getattr(graph, field)
    specs[field] = get_tensor_spec(field_sample)
  return jraph.GraphsTuple(**specs)


def get_dummy_raw_datasets(dataset_length) -> Dict[str, tf.data.Dataset]:
  """Returns dummy datasets, mocking tfds.DatasetBuilder.as_dataset()."""

  # The dummy graph.
  num_nodes = 3
  num_edges = 4
  dummy_graph = {
      'edge_feat': tf.zeros((num_edges, 3), dtype=tf.float32),
      'edge_index': tf.zeros((num_edges, 2), dtype=tf.int64),
      'labels': tf.ones((128,), dtype=tf.float32),
      'node_feat': tf.zeros((num_nodes, 9), dtype=tf.float32),
      'num_edges': tf.expand_dims(num_edges, axis=0),
      'num_nodes': tf.expand_dims(num_nodes, axis=0),
  }
  dummy_graph_spec = {
      'edge_feat': tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
      'edge_index': tf.TensorSpec(shape=(None, 2), dtype=tf.int64),
      'labels': tf.TensorSpec(shape=(128,), dtype=tf.float32),
      'node_feat': tf.TensorSpec(shape=(None, 9), dtype=tf.float32),
      'num_edges': tf.TensorSpec(shape=(None,), dtype=tf.int64),
      'num_nodes': tf.TensorSpec(shape=(None,), dtype=tf.int64),
  }

  def get_dummy_graphs():
    for _ in range(dataset_length):
      yield dummy_graph

  datasets = {}
  for split in ['train', 'validation', 'test']:
    datasets[split] = tf.data.Dataset.from_generator(
        get_dummy_graphs, output_signature=dummy_graph_spec)
  return datasets



def get_dummy_datasets(
    dataset_length: int,
    batch_size: Optional[int] = None) -> Dict[str, tf.data.Dataset]:
  """Returns dummy datasets, mocking input_pipeline.get_datasets()."""

  datasets = get_dummy_raw_datasets(dataset_length)

  # Process each split separately.
  for split_name in datasets:

    # Convert to GraphsTuple.
    datasets[split_name] = datasets[split_name].map(
        convert_to_graphs_tuple,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

  # If batch size is None, do not batch.
  if batch_size is not None:
    budget = estimate_padding_budget_for_batch_size(
        datasets['train'], batch_size, num_estimation_graphs=1)

    # Pad an example graph to see what the output shapes will be.
    # We will use this shape information when creating the tf.data.Dataset.
    example_graph = next(datasets['train'].as_numpy_iterator())
    example_padded_graph = jraph.pad_with_graphs(example_graph, *budget)
    padded_graphs_spec = specs_from_graphs_tuple(
        example_padded_graph)

    # Batch and pad each split separately.
    for split, dataset_split in datasets.items():
      batching_fn = functools.partial(
          jraph.dynamically_batch,
          graphs_tuple_iterator=iter(dataset_split),
          n_node=budget.n_node,
          n_edge=budget.n_edge,
          n_graph=budget.n_graph)
      datasets[split] = tf.data.Dataset.from_generator(
          batching_fn,
          output_signature=padded_graphs_spec)
  return datasets


def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor]) -> jraph.GraphsTuple:
  """Converts a dictionary of tf.Tensors to a GraphsTuple."""
  num_nodes = tf.squeeze(graph['num_nodes'])
  num_edges = tf.squeeze(graph['num_edges'])
  nodes = graph['node_feat']
  edges = graph['edge_feat']
  edge_feature_dim = edges.shape[-1]
  labels = graph['labels']
  senders = graph['edge_index'][:, 0]
  receivers = graph['edge_index'][:, 1]

  return jraph.GraphsTuple(
      n_node=tf.expand_dims(num_nodes, 0),
      n_edge=tf.expand_dims(num_edges, 0),
      nodes=tf.reshape(nodes, (num_nodes, -1)),
      edges=tf.reshape(edges, (num_edges, -1)),
      senders=senders,
      receivers=receivers,
      globals=tf.expand_dims(labels, axis=0),
  )


def estimate_padding_budget_for_batch_size(
    dataset: tf.data.Dataset,
    batch_size: int,
    num_estimation_graphs: int) -> GraphsTupleSize:
  """Estimates the padding budget for a dataset of unbatched GraphsTuples.
  Args:
    dataset: A dataset of unbatched GraphsTuples.
    batch_size: The intended batch size. Note that no batching is performed by
      this function.
    num_estimation_graphs: How many graphs to take from the dataset to estimate
      the distribution of number of nodes and edges per graph.
  Returns:
    padding_budget: The padding budget for batching and padding the graphs
    in this dataset to the given batch size.
  """

  def next_multiple_of_64(val: float):
    """Returns the next multiple of 64 after val."""
    return 64 * (1 + int(val // 64))

  if batch_size <= 1:
    raise ValueError('Batch size must be > 1 to account for padding graphs.')

  total_num_nodes = 0
  total_num_edges = 0
  for graph in dataset.take(num_estimation_graphs).as_numpy_iterator():
    graph_size = get_graphs_tuple_size(graph)
    if graph_size.n_graph != 1:
      raise ValueError('Dataset contains batched GraphTuples.')

    total_num_nodes += graph_size.n_node
    total_num_edges += graph_size.n_edge

  num_nodes_per_graph_estimate = total_num_nodes / num_estimation_graphs
  num_edges_per_graph_estimate = total_num_edges / num_estimation_graphs

  padding_budget = GraphsTupleSize(
      n_node=next_multiple_of_64(num_nodes_per_graph_estimate * batch_size),
      n_edge=next_multiple_of_64(num_edges_per_graph_estimate * batch_size),
      n_graph=batch_size)
  return padding_budget


datasets = get_dummy_datasets(dataset_length=500, batch_size=100)
for graphs in datasets['train'].take(100):
  print(graphs)


2023-02-23 14:34:29.585892: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


GraphsTuple(nodes=<tf.Tensor: shape=(320, 9), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, edges=<tf.Tensor: shape=(448, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       ...,
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>, receivers=<tf.Tensor: shape=(448,), dtype=int64, numpy=
array([  0,   0,   0,   0,   3,   3,   3,   3,   6,   6,   6,   6,   9,
         9,   9,   9,  12,  12,  12,  12,  15,  15,  15,  15,  18,  18,
        18,  18,  21,  21,  21,  21,  24,  24,  24,  24,  27,  27,  27,
        27,  30,  30,  30,  30,  33,  33,  33,  33,  36,  36,  36,  36,
        39,  39,  39,  39,  42,  42,  42,  42,  45,  45,  45,  45,  48,
        48,  48,  48,  51,  51,  51,  51,  54,

In [18]:

def get_raw_qm9_datasets(rng: chex.PRNGKey) -> Dict[str, tf.data.Dataset]:
  """Loads the raw QM9 dataset as tf.data.Datasets for each split."""
  # Root directory of the dataset.
  root_dir = '/Users/ameyad/Documents/qm9_data_tf/'
  filenames = os.listdir(root_dir)
  filenames = [os.path.join(root_dir, f) for f in filenames if 'dataset_tf' in f]

  # Shuffle the filenames.
  shuffled_indices = jax.random.permutation(rng, len(filenames))
  shuffled_filenames = [filenames[i] for i in shuffled_indices]

  # Partition the filenames into train, val, and test.
  num_train_files, num_val_files, num_test_files = 1, 1, 1
  num_files_cumsum = np.cumsum(
    [num_train_files, num_val_files, num_test_files]
  )
  files_by_split = {
    "train": shuffled_filenames[: num_files_cumsum[0]],
    "val": shuffled_filenames[num_files_cumsum[0] : num_files_cumsum[1]],
    "test": shuffled_filenames[num_files_cumsum[1] : num_files_cumsum[2]],
  }

  element_spec = tf.data.Dataset.load(filenames[0]).element_spec
  datasets = {}
  for split, files_split in files_by_split.items():
    dataset_split = tf.data.Dataset.from_tensor_slices(files_split)
    dataset_split = dataset_split.interleave(lambda x: tf.data.Dataset.load(x, element_spec=element_spec),
        cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

    datasets[split] = dataset_split
  return datasets



def specs_from_graphs_tuple(graph: jraph.GraphsTuple):
  """Returns a tf.TensorSpec corresponding to this graph."""

  def get_tensor_spec(array: np.ndarray):
    shape = list(array.shape)
    dtype = array.dtype
    return tf.TensorSpec(shape=shape, dtype=dtype)

  return jraph.GraphsTuple(
      nodes=datatypes.FragmentNodes(
          positions=get_tensor_spec(graph.nodes.positions),
          species=get_tensor_spec(graph.nodes.species),
          focus_probability=get_tensor_spec(graph.nodes.focus_probability),
      ),
      globals=datatypes.FragmentGlobals(
          stop=get_tensor_spec(graph.globals.stop),
          target_positions=get_tensor_spec(graph.globals.target_positions),
          target_species=get_tensor_spec(graph.globals.target_species),
          target_species_probability=get_tensor_spec(graph.globals.target_species_probability),
      ),
      edges=get_tensor_spec(graph.edges),
      receivers=get_tensor_spec(graph.receivers),
      senders=get_tensor_spec(graph.senders),
      n_node=get_tensor_spec(graph.n_node),
      n_edge=get_tensor_spec(graph.n_edge),
  )


def convert_to_graphstuple(graph: Dict[str, tf.Tensor]) -> jraph.GraphsTuple:
  """Converts a dictionary of tf.Tensors to a GraphsTuple."""
  positions = graph["positions"]
  species = graph["species"]
  focus_probability = graph["focus_probability"]
  receivers = graph["receivers"]
  senders = graph["senders"]
  n_node = graph["n_node"]
  n_edge = graph["n_edge"]
  edges = tf.ones((tf.shape(senders)[0], 1))
  stop = graph["stop"]
  target_positions = graph["target_positions"]
  target_species = graph["target_species"]
  target_species_probability = graph["target_species_probability"]

  return jraph.GraphsTuple(
      nodes=datatypes.FragmentNodes(positions=positions, species=species, focus_probability=focus_probability),
      edges=edges,
      receivers=receivers,
      senders=senders,
      globals=datatypes.FragmentGlobals(stop=stop, target_positions=target_positions, target_species=target_species, target_species_probability=target_species_probability),
      n_node=n_node,
      n_edge=n_edge,
  )


def get_qm9_datasets(max_n_node: int, max_n_edge: int, max_n_graph: int) -> Dict[str, tf.data.Dataset]:
  """Loads and preprocesses the QM9 dataset as tf.data.Datasets for each split."""

  # Get the raw datasets.
  datasets = get_raw_qm9_datasets()

  # Process each split separately.
  for split, dataset_split in datasets.items():
    # Convert to datatypes.Fragment.
    datasets[split] = dataset_split.map(
        convert_to_graphstuple,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True)

  # Pad an example graph to see what the output shapes will be.
  # We will use this shape information when creating the tf.data.Dataset.
  budget = GraphsTupleSize(n_node=max_n_node, n_edge=max_n_edge, n_graph=max_n_graph)
  example_graph = next(datasets['train'].as_numpy_iterator())
  example_padded_graph = jraph.pad_with_graphs(example_graph, *budget)
  padded_graphs_spec = specs_from_graphs_tuple(
      example_padded_graph)

  # Batch and pad each split separately.
  for split, dataset_split in datasets.items():
    batching_fn = functools.partial(
        jraph.dynamically_batch,
        graphs_tuple_iterator=iter(dataset_split),
        n_node=budget.n_node,
        n_edge=budget.n_edge,
        n_graph=budget.n_graph)
    datasets[split] = tf.data.Dataset.from_generator(
        batching_fn,
        output_signature=padded_graphs_spec)

  return datasets


datasets = get_qm9_datasets(max_n_node=100, max_n_edge=1000, max_n_graph=10)
num = 0
for graphs in datasets['train'].as_numpy_iterator():
  graphs = datatypes.Fragment.from_graphstuple(graphs)
  num += 1
  if num == 100:
    break

{'positions': TensorSpec(shape=(None, 3), dtype=tf.float32, name=None), 'target_species_probability': TensorSpec(shape=(1, 5), dtype=tf.float32, name=None), 'target_species': TensorSpec(shape=(1,), dtype=tf.int32, name=None), 'n_node': TensorSpec(shape=(1,), dtype=tf.int32, name=None), 'focus_probability': TensorSpec(shape=(None,), dtype=tf.float32, name=None), 'n_edge': TensorSpec(shape=(1,), dtype=tf.int32, name=None), 'stop': TensorSpec(shape=(1,), dtype=tf.bool, name=None), 'species': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'senders': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'receivers': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'target_positions': TensorSpec(shape=(1, 3), dtype=tf.float32, name=None)}
train ['/Users/ameyad/Documents/qm9_data_tf/dataset_tf_0']
val ['/Users/ameyad/Documents/qm9_data_tf/dataset_tf_2']
test ['/Users/ameyad/Documents/qm9_data_tf/dataset_tf_1']
