# Batching

This notebook introduces you to one aspect of generating matrices that you will inevitably face when training a model: **batching**.

Before reading this notebook, **make sure you have read the [notebook on computing a matrix](<./Computing a matrix.ipynb>)**, which introduces all of the most basic concepts of `e3nn_matrix` that we are going to assume are already known. Also **we will use exactly the same setup**, with the only difference that we will compute **two matrices at the same time instead of just one**.

In [None]:
import numpy as np

# So that we can plot sisl geometries
import sisl.viz

from e3nn import o3

from e3nn_matrix.data import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from e3nn_matrix.torch import BasisMatrixDataset, BasisMatrixReadout

from e3nn_matrix.tools.viz import plot_basis_matrix

The matrix-computing function 
-----------------------------

As we have already seen in the notebook on computing a matrix, we need to define a **basis**, **a basis table**, **a data processor** an **the input shape**. With all this, we can **initialize the matrix-computing function**. We define everything exactly as in the other notebook:

In [None]:
# The basis
point_1 = PointBasis("A", "spherical", o3.Irreps("0e"), R=2)
point_2 = PointBasis("B", "spherical", o3.Irreps("2x0e + 1o"), R=5)

basis = [point_1, point_2]

# The basis table.
table = BasisTableWithEdges(basis)

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

# The input shape
input_irreps = o3.Irreps("0e + 1o")

# The matrix readout function
readout = BasisMatrixReadout(
    unique_basis=basis,
    irreps_in=input_irreps,
    symmetric=True,
)

Creating two configurations
---------------------------

Now, **we will create two configurations instead of one**. Both will have the same coordinates, the only difference will be that **we will swap the point types**. However, you could give different coordinates to each of them as well, or a different number of atoms.

We'll store both configurations in a `configs` list.

In [None]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

config1 = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

config2 = BasisConfiguration(
    point_types=["B", "A", "B"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

configs = [config1, config2]

As we did in the other notebook, we plot the configurations to see how they look like, and visualize the overlaps:

In [None]:
geom1 = config1.to_sisl_geometry()
geom1.plot(show_cell=False, atoms_style={"size": geom1.maxR(all=True)}).update_layout(
    title="Config 1"
).show()

geom2 = config2.to_sisl_geometry()
geom2.plot(show_cell=False, atoms_style={"size": geom2.maxR(all=True)}).update_layout(
    title="Config 2"
).show()

Build a dataset
---------------

With all our configurations, we can **create a dataset**. The specific class that does this is the `BasisMatrixDataset`, which **apart from the configurations needs the data processor** as usual. 

In [None]:
dataset = BasisMatrixDataset(configs, data_processor=processor)

This dataset contains all the configurations. We now just need some tool to create batches from it.

Batching with a DataLoader
-------------------

You don't need an `e3nn_matrix` specific tool to create batches. In fact, **we recommend that you use `torch_geometric`'s `DataLoader`**:

In [None]:
from torch_geometric.loader import DataLoader

Everything that you need to do is: **pass the dataset** and **specify some batch size**. 

In [None]:
loader = DataLoader(dataset, batch_size=2)

In this case we use a batch size of `2`, which is our total number of configurations. Therefore, **we will only have one batch**.

Let's loop through the batches (only 1) and print them:

In [None]:
for data in loader:
    print(data)

Calling the function
--------------------

We now have our batch object, `data`. It is a `Batch` object. In the previous notebook, we called the function from a `BasisMatrixTorchData` object. One might think that having batched data might make it more complicated to call the function.

However, it is **exactly the same code that you have to use to compute matrices in a batch**. First, of course, we need to get our inputs, which we generate artificially here (in the batch we have 6 nodes, each of them needs a scalar and a vector):

In [None]:
node_inputs = input_irreps.randn(6, -1, requires_grad=True)
node_inputs

And from them, we compute the matrices. We use the inputs as well as the preprocessed data in the batch, with exactly the same code that we have already seen:

In [None]:
node_labels, edge_labels = readout.forward(
    node_types=data["point_types"],
    edge_index=data["edge_index"],
    edge_types=data["edge_types"],
    edge_type_nlabels=data["edge_type_nlabels"],
    node_kwargs={"node_state": node_inputs},
)

In [None]:
node_labels

In [None]:
edge_labels

Disentangling the batch
----------------------

As simple as it is to run a batched calculation, **disentangling everything back into individual cases is harder**. It is even harder in our case, in which we have **batched sparse matrices**.

Not only you have to handle the indices of the sparsity pattern, but also the additional aggregation of the batches. This is the reason why in the `BasisMatrixData` objects you can see so many pointer arrays. They are needed to keep track of the organization.

Making use of those indices, **the data processor can disentangle the batch** and give you the individual cases. You'll be happy to see that you can call the `matrix_from_data` method of the processor, **just as you did with the single matrix case**, and it will return a `tuple` of matrices instead of just one:

In [None]:
matrices = processor.matrix_from_data(
    data,
    predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)
matrices

<div class="alert alert-info">

Note

`matrix_from_data` has automatically detected that the data passed was a `torch_geometric`'s `Batch` object. There's also the `is_batch` argument to explicitly indicate if it is a batch or not. Also, the processor has the `yield_from_batch` method, which is more explicit and will return a generator instead of a tuple, which is better for very big matrices if you want to process them individually.

</div>

As we already did in the previous notebook, we can plot the matrices:

In [None]:
for config, matrix in zip(configs, matrices):
    plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    ).show()

Try to relate the matrices to the systems we created and see if they make sense (they do) :)

Summary and next steps
-----------

In this notebook we learned **how to batch systems** and then **use the data processor to unbatch them**.

The **next steps** could be:

- Understanding how to **train the function** to produce the target matrix.
- Combining this function with other modules for a particular application.