# MACE+Graph2Mat

This notebook will show you how to integrate a `MACE` model with `Graph2Mat` through the python API. Note that you can also use `MACE+Graph2Mat` through the Command Line Interface (CLI).

Prerequisites
-------------
Before reading this notebook, **make sure you have read the [notebook on computing a matrix](<./Computing a matrix.ipynb>) and [the notebook on batching](./Batching.ipynb)**, which introduce the basic concepts of `graph2mat` that we are going to assume are already known. Also **we will use exactly the same setup as in the batching notebook**, with the only difference that we will compute add target matrices to each structure.

In [None]:
import numpy as np
import pandas as pd
import torch

# To load plotly templates for sisl visualization
import sisl.viz

from e3nn import o3

from graph2mat import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData

from graph2mat.bindings.e3nn import E3nnGraph2Mat

from graph2mat.tools.viz import plot_basis_matrix

Generating a dataset
--------------------

We generate a dataset here just as we have done in the other notebooks.

In [None]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")

# The basis table.
table = BasisTableWithEdges([point_1, point_2])

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

positions = np.array([[0, 0, 0], [6.0, 0, 0], [9, 0, 0]])

config1 = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=[point_1, point_2],
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

config2 = BasisConfiguration(
    point_types=["B", "A", "B"],
    positions=positions,
    basis=[point_1, point_2],
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

configs = [config1, config2]

dataset = TorchBasisMatrixDataset(configs, data_processor=processor)

from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=2)

data = next(iter(loader))

Initializing a MACE model
-------------------------

We will now initialize a normal MACE model.

Note that you must have MACE installed, which you can do with:

```
pip install mace_torch
```

In [None]:
from mace.modules import MACE, RealAgnosticResidualInteractionBlock

num_interactions = 3
hidden_irreps = o3.Irreps("1x0e + 1x1o")

mace_model = MACE(
    r_max=10,
    num_bessel=10,
    num_polynomial_cutoff=10,
    max_ell=2,  # 1,
    interaction_cls=RealAgnosticResidualInteractionBlock,
    interaction_cls_first=RealAgnosticResidualInteractionBlock,
    num_interactions=num_interactions,
    num_elements=2,
    hidden_irreps=hidden_irreps,
    MLP_irreps=o3.Irreps("2x0e"),
    atomic_energies=torch.tensor([0, 0]),
    avg_num_neighbors=2,
    atomic_numbers=[0, 1],
    correlation=2,
    gate=None,
)

Now, we can pass our data through the mace model. MACE outputs many things, but we are just interested in the node features, which we can get from the `"node_feats"` key.

In [None]:
mace_output = mace_model(data)
mace_output["node_feats"]

Our `Graph2Mat` model will take these node features and convert them to a matrix. Therefore we need to know what its irreps are, and then initialize the `Graph2Mat` module.

In [None]:
# MACE outputs as node features the hidden irreps for each interaction, except
# in the last interaction, where it computes just scalar features.
mace_out_irreps = hidden_irreps * (num_interactions - 1) + str(hidden_irreps[0])

# Initialize the matrix model with this information
matrix_model = E3nnGraph2Mat(
    unique_basis=table,
    irreps=dict(node_feats_irreps=mace_out_irreps),
    symmetric=True,
)

Now, we can use the matrix model, passing the node features computed by MACE:

In [None]:
node_labels, edge_labels = matrix_model(data=data, node_feats=mace_output["node_feats"])

And plot the obtained matrices:

In [None]:
matrices = processor.matrix_from_data(
    data,
    predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)

for config, matrix in zip(configs, matrices):
    plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    ).show()

Using MatrixMACE
----------------

If you don't want to handle the details of interacting `MACE` with `Graph2Mat`, you can also use `MatrixMACE`, which takes a mace model and wraps it to also output the `node_labels` and `edge_labels` corresponding to a matrix. 

Internally, it just initializes a `E3nnGraph2Mat` layer. However it can handle the interaction between `MACE` and `Graph2Mat` in more complex cases like having an extra preprocessing step for edges, which needs some extra inputs from MACE.

In [None]:
from graph2mat.models import MatrixMACE
from graph2mat.bindings.e3nn import E3nnEdgeMessageBlock

In [None]:
matrix_mace_model = MatrixMACE(
    mace_model,
    unique_basis=table,
    readout_per_interaction=True,
    edge_hidden_irreps=o3.Irreps("10x0e + 10x1o + 10x2e"),
    preprocessing_edges=E3nnEdgeMessageBlock,
    preprocessing_edges_reuse_nodes=False,
)

The output of this model is MACE's output plus the `node_labels` and `edge_labels` for the predicted matrix:

In [None]:
out = matrix_mace_model(data)

out

You can of course plot the predicted matrices:

In [None]:
matrices = processor.matrix_from_data(data, predictions=out)

for config, matrix in zip(configs, matrices):
    plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    ).show()

Summary and next steps
----------------------

In this notebook we learned **how to interface MACE with Graph2Mat**.

The **next steps** could be:

- **Train a MACE+Graph2Mat model** following the steps in [this notebook](<./Fitting matrices.ipynb>), replacing the model by the `MACE+Graph2Mat` model.