# Batch support

This notebook explains how to use the batch support in Haiku Geometric.


[![Open in Colab](https://img.shields.io/static/v1.svg?logo=google-colab&label=Quickstart&message=Open%20In%20Colab&color=blue)](https://colab.research.google.com/github/alexOarga/haiku-geometric/blob/main/docs/source/notebooks/batch_support.ipynb)

In [None]:
!pip install git+https://github.com/alexOarga/haiku-geometric.git

## Batching graphs

The `haiku_geometric.utils.batch` function can be used to batch a list of
`haiku_geometric.data.DataGraphTuple` objects into a single `haiku_geometric.data.DataGraphTuple` object.

The `batch` function returns:
  - A single `haiku_geometric.data.DataGraphTuple` with the batched graphs.
  - A `jax.numpy.Array` with indices indicating to which graph each node belongs to.

In [2]:
import jax.numpy as jnp
from haiku_geometric.utils import batch
from haiku_geometric.datasets.base import DataGraphTuple

graph1 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
    position=None,
    y=jnp.array([0, 0, 0]),
    train_mask=None,
)

graph2 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=None,
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=None,
    position=None,
    y=jnp.array([0, 0]),
    train_mask=None,
)

batched_graph, batch_index = batch([graph1, graph2])



## Unbatching graphs

To unbatch a `DataGraphTuple` object created with `batch` function,
we can use the `haiku_geometric.utils.unbatch` function. This function takes a `DataGraphTuple`
object and returns a list of `haiku_geometric.data.DataGraphTuple` objects.

In [3]:
from haiku_geometric.utils import unbatch

unbatched_graphs = unbatch(batched_graph)
graph1 = unbatched_graphs[0]
graph2 = unbatched_graphs[1]

## Dynamic batching

Unfortunately, Haiku Geometric does not currently support dynamic batching. If you are working with `jraph`,
you can create a `jraph.GraphsTuple` object and use the available function `jraph.dynamically_batch`.

In [6]:
import jax.numpy as jnp
import jraph

graph1 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
)

graph2 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=None,
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=None,
)

MAXIMUM_NUM_NODES = 2
MAXIMUM_NUM_EDGES = 3
MAXIMUM_NUM_GRAPHS = 2

batched_generator = jraph.dynamically_batch([graph1, graph2],
                                        MAXIMUM_NUM_NODES, # max number of nodes in a batch
                                        MAXIMUM_NUM_EDGES, # max number of edges in a batch
                                        MAXIMUM_NUM_GRAPHS)  # max number of graphs in a batch