Skip to content

Commit

Permalink
Reworked data loader and fixed bug in casting padded disjoint to batc…
Browse files Browse the repository at this point in the history
…hed attributes.
  • Loading branch information
PatReis committed Jan 5, 2024
1 parent d67703c commit 7d02fda
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 125 deletions.
18 changes: 13 additions & 5 deletions kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import os
from sklearn.model_selection import KFold
from kgcnn.io.loader import tf_disjoint_list_generator
from kgcnn.io.loader import tf_dataset_disjoint_generator
# import typing as t
from typing import Union, List, Callable, Dict, Optional
# from collections.abc import MutableSequence
Expand Down Expand Up @@ -329,10 +329,18 @@ def rename_property_on_graphs(self, old_property_name: str, new_property_name: s
set = assign_property
get = obtain_property

def tf_disjoint_data_generator(self, inputs, outputs, **kwargs):
assert isinstance(inputs, list), "Dictionary input is not yet implemented"
module_logger.info("Dataloader is experimental and does not have all features for in and output.")
return tf_disjoint_list_generator(self, inputs=inputs, outputs=outputs, **kwargs)
def tf_dataset_disjoint(self, inputs, **kwargs):
r"""Return generator via :obj:`tf.data.Dataset` from this list.
Uses :obj:`kgcnn.io.loader.tf_dataset_disjoint_generator`
Args:
inputs:
kwargs: Kwargs for :obj:`tf_dataset_disjoint_generator`
Returns:
tf.data.Dataset: Dataset from generator.
"""
return tf_dataset_disjoint_generator(self, inputs=inputs, **kwargs)


class MemoryGraphDataset(MemoryGraphList):
Expand Down
156 changes: 107 additions & 49 deletions kgcnn/io/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,125 @@ def pad_at_axis(x, pad_width, axis=0, **kwargs):
return np.pad(x, pad_width=pads, **kwargs)


def tf_disjoint_list_generator(
def tf_dataset_disjoint_generator(
graphs,
inputs: list,
outputs: list,
assignment_to_id: list = None,
assignment_of_indices: list = None,
pos_batch_id: list = None,
pos_subgraph_id: list = None,
pos_count: list = None,
inputs: Union[list, dict],
assignment_to_id: Union[list, dict] = None,
assignment_of_indices: Union[list, dict] = None,
pos_batch_id: Union[list, dict] = None,
pos_subgraph_id: Union[list, dict] = None,
pos_count: Union[list, dict] = None,
batch_size=32,
epochs=None,
padded_disjoint=False,
shuffle=True,
seed=42
):
r"""Make a tensorflow dataset for disjoint graph loading.
For the moment only IDs that have their values in inputs can be generated, as the value tensors of e.g. node
or edge are used to generate batch IDs.
Inputs is a list or dictionary of keras input layer configs. The names of the layers are linked to the properties
in `graph` .
With `assignment_to_id` and `assignment_of_indices` disjoint indices and attributes can be defined.
Their IDs are marked with `pos_batch_id` etc. One must use a name or index for each general split, since for
example edge IDs can be used for edge indices, edge attributes and edge relation tensors at the same time.
Therefore, one batch ID for edges is enough. One could however assign as many as IDs as there are disjoint
graph properties in `graph` .
Args:
graphs: List of dictionaries with named graph properties.
inputs: List or dict of keras input layer configs.
assignment_to_id: Assignment of if inputs to disjoint properties to IDs.
assignment_of_indices: Assignment of inputs (if they are indices) to their reference.
pos_batch_id: Position or name of batch IDs.
pos_subgraph_id: Position or name of batch IDs.
pos_count: Position or name of batch IDs.
batch_size: Batch size.
epochs: Expected number of epochs. Only required for padded disjoint.
padded_disjoint: If padded disjoint tensors should be generated.
shuffle: Whether to shuffle each epoch.
seed: Seed for shuffle.
Returns:
tf.data.Dataset: Tensorflow dataset to load disjoint graphs.
"""
# Stats on the required dataset.
dataset_size = len(graphs)
data_index = np.arange(dataset_size)
num_inputs = len(inputs)

# Check input information for outputspec.
is_single_input = False
is_list_input = False
if isinstance(inputs, list):
is_list_input = True
output_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in inputs])
elif isinstance(inputs, dict):
if "shape" in inputs and "dtype" in inputs:
output_spec = tf.TensorSpec(shape=tuple([None] + list(inputs["shape"])), dtype=inputs["dtype"])
inputs = {0: inputs}
is_single_input = True
num_inputs = 1
else:
output_spec = dict(
{i: tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for i, x in inputs.items()})
else:
raise ValueError("Inputs must be list or dict of keras input layer kwargs.")

# We use a dict for both list and dict input.
def _convert_to_dict(container_to_check):
if container_to_check is None:
return {}
if isinstance(container_to_check, (list, tuple)):
return {i: x for i, x in enumerate(container_to_check)}
if not isinstance(container_to_check, dict):
raise ValueError("Must be dict or list for mapping and containers.")
return container_to_check

inputs = _convert_to_dict(inputs)
assignment_to_id = _convert_to_dict(assignment_to_id)
assignment_of_indices = _convert_to_dict(assignment_of_indices)
pos_batch_id = _convert_to_dict(pos_batch_id)
pos_subgraph_id = _convert_to_dict(pos_subgraph_id)
pos_count = _convert_to_dict(pos_count)

# Fill assignments with Nones if they are not used for input.
if len(assignment_to_id) < num_inputs:
assignment_to_id = assignment_to_id + [None for _ in range(num_inputs-len(assignment_to_id))]
for key, values in inputs.items():
if key not in assignment_to_id.keys():
assignment_to_id[key] = None
if len(assignment_of_indices) < num_inputs:
assignment_of_indices = assignment_of_indices + [None for _ in range(num_inputs-len(assignment_of_indices))]
for key, values in inputs.items():
if key not in assignment_of_indices.keys():
assignment_of_indices[key] = None

flag_batch_id = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_batch_id):
flag_batch_id = {i: None for i in inputs.keys()}
for i, x in pos_batch_id.items():
flag_batch_id[x] = i

flag_count = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_count):
flag_count = {i: None for i in inputs.keys()}
for i, x in pos_count.items():
flag_count[x] = i

flag_subgraph_id = [None for _ in range(num_inputs)]
for i, x in enumerate(pos_subgraph_id):
flag_subgraph_id = {i: None for i in inputs.keys()}
for i, x in pos_subgraph_id.items():
flag_subgraph_id[x] = i

all_flags = [flag_batch_id, flag_count, flag_subgraph_id]
is_attributes = [True if all([x[i] is None for x in all_flags]) else False for i in range(num_inputs)]
is_attributes = {i: True if all([x[i] is None for x in all_flags]) else False for i in inputs.keys()}

max_size = [[] if assignment_to_id[i] is not None else None for i in range(num_inputs)]
total_max = [[] if assignment_to_id[i] is not None else None for i in range(num_inputs)]
max_size = {i: [] if assignment_to_id[i] is not None else None for i in inputs.keys()}
total_max = {i: [] if assignment_to_id[i] is not None else None for i in inputs.keys()}

# We can check the maximum batch size at the beginning or just have a maximum batch size for each epoch.
if padded_disjoint:
if epochs is None:
raise ValueError("Requires number of epochs if `padded_disjoint=True` .")

for i in range(num_inputs):
for i in inputs.keys():
if assignment_to_id[i] is None:
continue
len_list = [len(x[inputs[i]["name"]]) for x in graphs]
Expand All @@ -74,26 +145,26 @@ def tf_disjoint_list_generator(
rng = Generator(PCG64(seed=seed))

for epoch in range(epochs):
max_size_epoch = [[] if assignment_to_id[i] is not None else None for i in range(num_inputs)]
max_size_epoch = {i: [] if assignment_to_id[i] is not None else None for i in inputs.keys()}
if shuffle:
rng.shuffle(data_index)
for batch_index in range(0, dataset_size, batch_size):
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]
for i in range(num_inputs):
for i in inputs.keys():
if assignment_to_id[i] is None:
continue
len_list = [len(x[inputs[i]["name"]]) for x in graphs_batch]
max_length = sum(len_list)
max_size_epoch[i].append(max_length)
for i, x in enumerate(max_size_epoch):
for i, x in max_size_epoch.items():
if x is not None:
max_size[i].append(max(x))
max_size = [max(x) if x is not None else None for x in max_size]
max_size = {i: max(x) if x is not None else None for i, x in max_size.items()}

module_logger.info("Max of graph: %s." % total_max)
module_logger.info("Padded max of disjoint: %s." % [
x/batch_size if x is not None else None for x in max_size])
x/batch_size if x is not None else None for x in max_size.values()])

data_index = np.arange(dataset_size)
rng = Generator(PCG64(seed=seed))
Expand All @@ -107,10 +178,10 @@ def generator():
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]

out = [None for _ in range(num_inputs)]
out_counts = [None for _ in range(num_inputs)]
out = {i: None for i in inputs.keys()}
out_counts = {i: None for i in inputs.keys()}

for i in range(num_inputs):
for i in inputs.keys():
if not is_attributes[i]:
continue

Expand Down Expand Up @@ -146,7 +217,7 @@ def generator():
[np.arange(x, dtype="int64") for x in counts], axis=0)

# Indices
for i in range(num_inputs):
for i in inputs.keys():
if assignment_of_indices[i] is not None:
edge_indices_flatten = out[i]
count_nodes = out_counts[assignment_of_indices[i]]
Expand All @@ -157,30 +228,17 @@ def generator():
disjoint_indices = np.transpose(disjoint_indices)
out[i] = disjoint_indices

if isinstance(outputs, list):
out_y = []
for k in range(len(outputs)):
array_list = [x[outputs[k]["name"]] for x in graphs_batch]
out_y.append(np.array(array_list, dtype=outputs[k]["dtype"]))
else:
out_y = np.array(
[x[outputs["name"]] for x in graphs_batch], dtype=outputs["dtype"])

yield tuple(out), out_y
# Match output container
if is_list_input:
out = tuple([out[i] for i in range(num_inputs)])
if is_single_input:
out = out[0]

input_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in inputs])

if isinstance(outputs, list):
output_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in outputs])
else:
output_spec = tf.TensorSpec(shape=tuple([None] + list(outputs["shape"])), dtype=outputs["dtype"])
yield out

data_loader = tf.data.Dataset.from_generator(
generator,
output_signature=(
input_spec,
output_spec
)
output_signature=output_spec
)

return data_loader
45 changes: 23 additions & 22 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
padded_disjoint: bool = False, uses_mask: bool = False,
static_batched_node_output_shape: tuple = None,
static_batched_edge_output_shape: tuple = None,
remove_padded_disjoint_from_batched_output: bool = True,
**kwargs):
r"""Initialize layer.
Expand All @@ -32,6 +33,8 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
non-padded nodes from index 0. Default is False.
static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None.
static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None.
remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output
in case of padding.
"""
super(_CastBatchedDisjointBase, self).__init__(**kwargs)
self.reverse_indices = reverse_indices
Expand All @@ -42,6 +45,7 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
self.supports_jit = padded_disjoint
self.static_batched_node_output_shape = static_batched_node_output_shape
self.static_batched_edge_output_shape = static_batched_edge_output_shape
self.remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output

def get_config(self):
"""Get config dictionary for this layer."""
Expand All @@ -50,7 +54,9 @@ def get_config(self):
"dtype_index": self.dtype_index, "padded_disjoint": self.padded_disjoint,
"uses_mask": self.uses_mask,
"static_batched_node_output_shape": self.static_batched_node_output_shape,
"static_batched_edge_output_shape": self.static_batched_edge_output_shape})
"static_batched_edge_output_shape": self.static_batched_edge_output_shape,
"remove_padded_disjoint_from_batched_output": self.remove_padded_disjoint_from_batched_output
})
return config


Expand Down Expand Up @@ -417,30 +423,25 @@ def call(self, inputs: list, **kwargs):
attr_id = ops.arange(0, ops.shape(graph_id_attr)[0], dtype=graph_id_attr.dtype)
attr_splits = ops.pad(ops.cumsum(attr_len), [[1, 0]])
attr_id = attr_id - repeat_static_length(attr_splits[:1], attr_len, ops.shape(graph_id_attr)[0])
output_shape = [target_shape[0]*target_shape[1]] + list(ops.shape(attr)[1:])
indices = graph_id_attr*ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast(
attr_id, dtype=graph_id_attr.dtype)
out = scatter_reduce_sum(indices, attr, output_shape)
out = ops.reshape(out, list(target_shape[:2]) + list(ops.shape(attr)[1:]))
if self.return_mask:
output_mask_shape = output_shape[:1]
out_mask = scatter_reduce_sum(indices, ops.ones(ops.shape(attr)[0], dtype="bool"), output_mask_shape)
out_mask = ops.reshape(out_mask, list(target_shape[:2]))
else:
if attr_id is None:
# Required because padded graphs in the general case can have padded nodes inbetween batches.
raise ValueError("Require sub-graph IDs in addition to batch IDs for padded disjoint graphs.")
output_shape = [(target_shape[0]+1)*target_shape[1]] + list(ops.shape(attr)[1:])
indices = graph_id_attr * ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast(
attr_id, dtype=graph_id_attr.dtype)
out = scatter_reduce_sum(indices, attr, output_shape)
out = out[target_shape[1]:] # Because first actual graph is 1*size shifted.
out = ops.reshape(out, list(target_shape[:2]) + list(ops.shape(attr)[1:]))

output_shape = tuple([target_shape[0] * target_shape[1]] + list(ops.shape(attr)[1:]))
indices = graph_id_attr * ops.convert_to_tensor(target_shape[1], dtype=graph_id_attr.dtype) + ops.cast(
attr_id, dtype=graph_id_attr.dtype)
out = scatter_reduce_sum(indices, attr, output_shape)
out = ops.reshape(out, list(target_shape[:2]) + list(ops.shape(attr)[1:]))
if self.return_mask:
output_mask_shape = output_shape[:1]
out_mask = scatter_reduce_sum(indices, ops.ones(ops.shape(attr)[0], dtype="bool"), output_mask_shape)
out_mask = ops.reshape(out_mask, list(target_shape[:2]))

if self.padded_disjoint and self.remove_padded_disjoint_from_batched_output:
out = out[1:]
if self.return_mask:
output_mask_shape = output_shape[:1]
out_mask = scatter_reduce_sum(indices, ops.ones(ops.shape(attr)[0], dtype="bool"), output_mask_shape)
out_mask = out_mask[target_shape[1]:]
out_mask = ops.reshape(out_mask, list(target_shape[:2]))
out_mask = out_mask[1:]

if self.return_mask:
return out, out_mask
Expand Down Expand Up @@ -471,7 +472,7 @@ def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
if self.padded_disjoint:
if self.padded_disjoint and self.remove_padded_disjoint_from_batched_output:
if input_shape[0] is not None:
return tuple([input_shape[0] - 1] + list(input_shape[1:]))
return input_shape
Expand All @@ -487,7 +488,7 @@ def call(self, inputs: list, **kwargs):
Tensor: Graph labels of shape `(batch, ...)` .
"""
# Simply remove first graph.
if self.padded_disjoint:
if self.padded_disjoint and self.remove_padded_disjoint_from_batched_output:
return inputs[1:]
return inputs

Expand Down
Loading

0 comments on commit 7d02fda

Please sign in to comment.