# Fitting a single matrix

**This notebook shows how you can fit your function to predict a single matrix, which we create synthetically.**

This is not of course a real life example of how you might train your models, but it is nice to:

- **Introduce the metrics** that can be used as loss functions.
- **Introduce the simplest training loop**.

It is **specially useful if you are quite new to machine learning**, because it goes step by step. It also serves as a minimal example from which you can expand to create training flows different from the ones we propose.

In [None]:
import numpy as np
import pandas as pd
import torch

# To load plotly templates for sisl visualization
import sisl.viz

from e3nn import o3

from e3nn_matrix.data import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from e3nn_matrix.torch import (
    BasisMatrixDataset,
    BasisMatrixReadout,
    BasisMatrixTorchData,
)

# from e3nn_matrix.data.batch_utils import batch_to_orbital_matrix_data
from e3nn_matrix.tools.viz import plot_basis_matrix

Setting up the model
--------------------

As usual, let's create our model:

In [None]:
# The basis
point_1 = PointBasis("A", "spherical", o3.Irreps("0e"), R=2)
point_2 = PointBasis("B", "spherical", o3.Irreps("2x0e + 1o"), R=5)

basis = [point_1, point_2]

# The basis table.
table = BasisTableWithEdges(basis)

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

# The input shape
input_irreps = o3.Irreps("0e + 1o")

# The matrix readout function
model = BasisMatrixReadout(
    unique_basis=basis,
    irreps_in=input_irreps,
    symmetric=True,
)

And some toy configuration to play around (this is the same as in [the notebook about computing a matrix](<./Computing a matrix.ipynb>)):

In [None]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

config = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

Which we preprocess to be digestible by the model.

In [None]:
data = BasisMatrixTorchData.from_config(config, data_processor=processor)

Get the target matrix
--------------------

The goal of training a model is to make it **approximate as good as possible another function** that is **likely much more expensive** to compute.

$$ model(x, \textbf{W}) \approx function(x) $$

Where $\textbf{W}$ are the *weights*, or tunable parameters of the $model$. Since $model$ and $function$ are different functions, the goal is to tune $\textbf{W}$ to make them as similar as possible. This is usually done by **comparing the outputs of both functions and using an optimizer to reduce the differences** until you are satisfied or they can no longer be reduced.

In this case, our target $function$ will simply be another `BasisMatrixReadout` with different weights. This of course makes no sense for a real life application, but it will serve nicely for our test case here.

In [None]:
# The matrix readout function
function = BasisMatrixReadout(
    unique_basis=basis,
    irreps_in=input_irreps,
    symmetric=True,
)

With **our function we will compute the target matrix**. As usual we will generate some random (but equivariant) inputs:

In [None]:
node_inputs = input_irreps.randn(3, -1)
# We make both A points have equivalent input values
node_inputs[-1, 0] = node_inputs[0, 0]
node_inputs[-1, 1:] = -node_inputs[0, 1:]
node_inputs

And then compute it:

In [None]:
with torch.no_grad():
    labels = function(
        node_types=data["point_types"],
        edge_index=data["edge_index"],
        edge_types=data["edge_types"],
        edge_type_nlabels=data["edge_type_nlabels"],
        node_kwargs={"node_state": node_inputs},
    )

The target values for the model are usually called *labels*, meaning that they label the data specifying its properties. 

Noticed also that we called `torch-no_grad` on the inputs, this is because we want to take the data generated from `function` as the truth, so the parameters inside function should not be modified and therefore we don't need to retain the computational graph from inputs to outputs (just a technical detail arising from using `BasisMatrixReadout` as the true function).

Comparing it to predictions
---------------------------

Let's now compare the target matrix to our predictions. Compute the predictions:

In [None]:
predictions = model(
    node_types=data["point_types"],
    edge_index=data["edge_index"],
    edge_types=data["edge_types"],
    edge_type_nlabels=data["edge_type_nlabels"],
    node_kwargs={"node_state": node_inputs},
)

And print both:

In [None]:
print("LABELS (target matrix)\n-------------------")
print(labels)
print("\nPREDICTIONS\n-------------------")
print(predictions)

You can see that they are **anywhere close to resemble one another**. Our goal will be to make this as similar as possible.

We can also plot the actual matrices to see it more clearly:

In [None]:
def plot_labels(labels):
    matrix = processor.matrix_from_data(
        data,
        predictions={"node_labels": labels[0], "edge_labels": labels[1]},
    )

    return plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    )


plot_labels(labels).update_layout(title="Labels").show()
plot_labels(predictions).update_layout(title="Predictions")

Loss functions
---------------

We can clearly see that these two matrices are different, but we need some quantitative way of saying how different they are. 

This seems like a simple task. However, there is no unique way of quantifying the error. Furthermore, if the optimizer needs to tune the model parameters to reduce this quantity, **it is generally useful to retain the dependence of the error with respect to each parameter**. This is in fact the main point of `pytorch`, which can automatically compute gradients $\frac{\partial error}{\partial \textbf{W}_i}$.

In `e3nn_matrix.data.metrics` you will find some useful metrics that can be used as loss (error) functions. The simplest one is `elementwise_mse`, which just computes the [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error) of all the matrix elements. Let's use it:

In [None]:
from e3nn_matrix.data.metrics import elementwise_mse

In [None]:
loss_fn = elementwise_mse

loss, info = loss_fn(
    nodes_pred=predictions[0],
    nodes_ref=labels[0],
    edges_pred=predictions[1],
    edges_ref=labels[1],
)

The simplest training loop
--------------------------

Below we just create a simple `pytorch` training loop that:

1. Uses the model to **compute predictions** for the matrix
2. **Computes the loss** (error).
3. Computes the gradients and **updates the model parameters**.
4. **Goes back** to 1.

While doing so we store the errors at each step so that we can plot their evolution later.

In [None]:
# Number of training steps
n_steps = 5000
# Initialize an optimizer
optimizer = torch.optim.Adam(model.parameters())

# Initialize arrays to store errors
losses = np.zeros(n_steps)
node_rmse = np.zeros(n_steps)
edge_rmse = np.zeros(n_steps)

# Loop
for i in range(n_steps):
    # Reset gradients
    optimizer.zero_grad()

    # Make predictions for this batch
    step_predictions = model(
        node_types=data["point_types"],
        edge_index=data["edge_index"],
        edge_types=data["edge_types"],
        edge_type_nlabels=data["edge_type_nlabels"],
        node_kwargs={"node_state": node_inputs},
    )

    # Compute the loss
    loss, info = loss_fn(
        nodes_pred=step_predictions[0],
        nodes_ref=labels[0],
        edges_pred=step_predictions[1],
        edges_ref=labels[1],
    )

    # Store errors
    losses[i] = loss
    node_rmse[i] = info["node_rmse"]
    edge_rmse[i] = info["edge_rmse"]

    # Compute gradients
    loss.backward()

    # Update weights
    optimizer.step()

Checking results
----------------

After training, we store all the errors in a dataframe:

In [None]:
df = pd.DataFrame(
    np.array([losses, node_rmse, edge_rmse]).T,
    columns=["loss", "node_rmse", "edge_rmse"],
)

And plot them:

In [None]:
df.plot(backend="plotly").update_layout(
    yaxis_type="log", yaxis_showgrid=True, xaxis_showgrid=True
)

We can also plot the target matrix along with the predictions of the model before and after training. You will see that **the model has learned to predict the matrix**. 

Of course it has, **this was an extremely easy task!** :)

In [None]:
plot_labels(labels).update_layout(title=f"Target matrix").show()
plot_labels(predictions).update_layout(title=f"Predictions without training").show()
plot_labels(step_predictions).update_layout(
    title=f"Predictions after {n_steps} training steps"
)

Summary and next steps
----------------------