# Batch support

This notebook explains how to use the batch support in Haiku Geometric.

In [1]:
!pip install git+https://github.com/alexOarga/haiku-geometric.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/alexOarga/haiku-geometric.git
  Cloning https://github.com/alexOarga/haiku-geometric.git to /tmp/pip-req-build-l922idp0
  Running command git clone --filter=blob:none --quiet https://github.com/alexOarga/haiku-geometric.git /tmp/pip-req-build-l922idp0
  Resolved https://github.com/alexOarga/haiku-geometric.git to commit dbaa4355a068b75c08a181002789a2623d657f20
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dm-haiku (from haiku-geometric==0.0.3)
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jraph (from haiku-geometric==0.0.3)
  Downloading jraph-0.0.6.dev0-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00

## Batching graphs

The `haiku_geometric.utils.batch` function can be used to batch a list of
`haiku_geometric.data.DataGraphTuple` objects into a single `haiku_geometric.data.DataGraphTuple` object.

In [2]:
import jax.numpy as jnp
from haiku_geometric.utils import batch
from haiku_geometric.datasets.base import DataGraphTuple

graph1 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
    position=None,
    y=jnp.array([0, 0, 0]),
    train_mask=jnp.array([0, 0]),
)

graph2 = DataGraphTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=jnp.array([0, 0]),
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=jnp.array([0, 0]),
    position=jnp.array([0, 0]),
    y=jnp.array([0, 0]),
    train_mask=jnp.array([0, 0]),
)

batched_graph = batch([graph1, graph2])



## Unbatching graphs

To unbatch a `DataGraphTuple` object created with `batch` function,
we can use the `haiku_geometric.utils.unbatch` function. This function takes a `DataGraphTuple`
object and returns a list of `haiku_geometric.data.DataGraphTuple` objects.

In [3]:
from haiku_geometric.utils import unbatch

unbatched_graphs = unbatch(batched_graph)
graph1 = unbatched_graphs[0]
graph2 = unbatched_graphs[1]

## Dynamic batching

Unfortunately, Haiku Geometric does not currently support dynamic batching. If you are working with `jraph`,
you can create a `jraph.GraphsTuple` object and use the available function `jraph.dynamically_batch`.

In [6]:
import jax.numpy as jnp
import jraph

graph1 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.1, 0.2]),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([2, 2, 0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
    globals=None,
)

graph2 = jraph.GraphsTuple(
    nodes=jnp.array([0.0, 0.0]),
    senders=jnp.array([0, 1]),
    receivers=jnp.array([1, 0]),
    edges=None,
    n_node=jnp.array([2]),
    n_edge=jnp.array([2]),
    globals=None,
)

MAXIMUM_NUM_NODES = 2
MAXIMUM_NUM_EDGES = 3
MAXIMUM_NUM_GRAPHS = 2

batched_generator = jraph.dynamically_batch([graph1, graph2],
                                        MAXIMUM_NUM_NODES, # max number of nodes in a batch
                                        MAXIMUM_NUM_EDGES, # max number of edges in a batch
                                        MAXIMUM_NUM_GRAPHS)  # max number of graphs in a batch