# Fitting matrices

**This notebook shows how you can fit your function to predict matrices for configurations**. We create the target matrices synthetically.

<div class="alert alert-info">

Note

This tutorial shows the rawest possible training workflow. Notice that we provide tools to easily perform training with ``pytorch_lightning`` (including a CLI). Look for "Lightning" in the API documentation to understand the tools that we provide, or go to the CLI tutorials to understand how train with the CLI.

</div>

Prerequisites
-------------

Before reading this notebook, **make sure you have read the [notebook on computing a matrix](<./Computing a matrix.ipynb>) and [the notebook on batching](./Batching.ipynb)**, which introduce the basic concepts of `graph2mat` that we are going to assume are already known. Also **we will use exactly the same setup as in the batching notebook**, with the only difference that we will compute add target matrices to each structure.

In this notebook we will:

- Introduce the **addition of a target matrix** to a configuration.
- **Introduce the metrics** that can be used as loss functions.
- **Introduce the simplest training loop**.

It is **specially useful if you are quite new to machine learning**, because it goes step by step. It also serves as a minimal example from which you can expand to create training flows different from the ones we propose.

In [1]:
import numpy as np
import pandas as pd
import torch

# To load plotly templates for sisl visualization
import sisl.viz

from e3nn import o3

from graph2mat import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData

from graph2mat.bindings.e3nn import E3nnGraph2Mat

from graph2mat.tools.viz import plot_basis_matrix

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


Setting up the model
--------------------

As usual, let's create our model:

In [2]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")

basis = [point_1, point_2]

# The basis table.
table = BasisTableWithEdges(basis)

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

# The shape of the node features.
node_feats_irreps = o3.Irreps("0e + 1o")


# The fake environment representation function that we will use
# to compute node features.
def get_environment_representation(data, irreps):
    """Function that mocks a true calculation of an environment representation.

    Computes a random array and then ensures that the numbers obey our particular
    system's symmetries.
    """
    import torch

    torch.manual_seed(0)

    node_features = irreps.randn(data.num_nodes, -1)
    # The point in the middle sees the same in -X and +X directions
    # therefore its representation must be 0.
    # In principle the +/- YZ are also equivalent, but let's say that there
    # is something breaking the symmetry to make the numbers more interesting.
    # Note that the spherical harmonics convention is YZX.
    node_features[1, 3] = 0
    # We make both A points have equivalent features except in the X direction,
    # where the features are opposite
    node_features[2::3, :3] = node_features[0::3, :3]
    node_features[2::3, 3] = -node_features[0::3, 3]
    return node_features


# The matrix readout function
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=node_feats_irreps),
    symmetric=True,
)



Including target matrices in the data
-------------------------------------

We will now create our data. The difference between this notebook and the previous notebooks is that **each configuration will have an associated matrix**, which is what we will try to fit.

Usually, this matrix would be computed by the algorithm we are trying to substitute with ML (e.g. DFT for atomic systems) or experimental observations, but here we will just take random matrices.

We create a function to compute random symmetric matrices:

In [3]:
def true_matrix(size):
    """Mocks the algorithm that provides the training matrices.

    It just computes a random matrix
    """
    matrix = np.random.random((size, size)) * 2 - 1
    matrix += matrix.T
    return matrix

And then initialize the configurations as we have done in the previous notebooks, except that in this case we use the `matrix` argument to pass the matrix associated with the configuration:

In [4]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

config1 = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
    matrix=true_matrix(size=7),
)

config2 = BasisConfiguration(
    point_types=["B", "A", "B"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
    matrix=true_matrix(size=11),
)

configs = [config1, config2]

# Create the dataset
dataset = TorchBasisMatrixDataset(configs, data_processor=processor)

We can take one example from the dataset and check that it now has `point_labels` and `edge_labels`, which contain the values of the matrix organized in the same way that are returned by `Graph2Mat`:

In [5]:
data_example = dataset[0]
data_example.point_labels, data_example.edge_labels

(tensor([ 0.8538, -0.1838, -0.3273, -0.9057, -1.4387, -0.2657, -0.3273, -0.1767,
         -1.8107, -0.6650, -0.2220, -0.9057, -1.8107,  1.6950,  1.8348, -0.5548,
         -1.4387, -0.6650,  1.8348,  0.0188, -0.6430, -0.2657, -0.2220, -0.5548,
         -0.6430,  1.3328,  1.5008]),
 tensor([-3.5343e-01, -7.6467e-04, -8.2642e-01,  5.0351e-01, -6.1751e-02,
         -1.6126e+00, -6.3493e-01, -1.6887e-01, -1.2717e+00,  4.0559e-01]))

During training, we will compare these to the output of `Graph2Mat`.

We can also plot the target matrices from the data example:

In [6]:
def plot_matrices(data, predictions=None, title="", show=True):
    """Helper function to plot (possibly batched) matrices"""

    matrices = processor.matrix_from_data(data, predictions=predictions)

    if not isinstance(matrices, (tuple, list)):
        matrices = (matrices,)

    for i, (config, matrix) in enumerate(zip(configs, matrices)):
        if show is True or show == i:
            plot_basis_matrix(
                matrix,
                config,
                point_lines={"color": "black"},
                basis_lines={"color": "blue"},
                colorscale="temps",
                text=".2f",
                basis_labels=True,
            ).update_layout(title=f"{title} [{i}]").show()


plot_matrices(data_example, title="Labels")

The simplest training loop
--------------------------

Below we just create a simple `pytorch` training loop that:

1. Uses the model to **compute predictions** for the matrix
2. **Computes the loss** (error).
3. Computes the gradients and **updates the model parameters**.
4. **Goes back** to 1.

While doing so we store the errors at each step so that we can plot their evolution later.

There is just one last thing that we need to introduce: `graph2mat`'s metrics. The `metrics` module contains several functions that compare matrices in different ways. They can be used as loss functions. In this case, we will use `elementwise_mse`, which just computes the [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error) of all the matrix elements.

In [13]:
# Create the data loader
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=2)

# Number of training steps
n_steps = 4000
# Initialize an optimizer
optimizer = torch.optim.Adam(model.parameters())

# Initialize arrays to store errors
losses = np.zeros(n_steps)
node_rmse = np.zeros(n_steps)
edge_rmse = np.zeros(n_steps)

# The loss function, which we get from graph2mat's metrics functions
from graph2mat import metrics

loss_fn = metrics.elementwise_mse

# Loop
for i in range(n_steps):
    for data in loader:
        # Reset gradients
        optimizer.zero_grad()

        # Get the node feats. Since this function is not learnable, it could be
        # outside the loop, but we keep it here to show how things could work
        # with a learnable environment representation.
        node_feats = get_environment_representation(data, node_feats_irreps)

        # Make predictions for this batch
        step_predictions = model(data, node_feats=node_feats)

        # Compute the loss
        loss, info = loss_fn(
            nodes_pred=step_predictions[0],
            nodes_ref=data.point_labels,
            edges_pred=step_predictions[1],
            edges_ref=data.edge_labels,
        )
        

        # Store errors
        losses[i] = loss
        node_rmse[i] = info["node_rmse"]
        edge_rmse[i] = info["edge_rmse"]

        # Compute gradients
        loss.backward()

        # Update weights
        optimizer.step()
        
    

node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
edge_loss=  torch.Size([])
node_loss=  torch.Size([])
e

KeyboardInterrupt: 

Checking results
----------------

After training, we store all the errors in a dataframe:

In [8]:
df = pd.DataFrame(
    np.array([losses, node_rmse, edge_rmse]).T,
    columns=["loss", "node_rmse", "edge_rmse"],
)

And plot them:

In [9]:
df.plot(backend="plotly").update_layout(
    yaxis_type="log", yaxis_showgrid=True, xaxis_showgrid=True
).update_layout(
    yaxis_title="Value",
    xaxis_title="Training step",
    title="Error evolution during training",
)

The model has learned something, but still the errors are quite high.

We can plot the first target matrix and the corresponding prediction:

In [10]:
plot_matrices(data, title=f"Target matrix", show=0)
plot_matrices(
    data,
    predictions={
        "node_labels": step_predictions[0],
        "edge_labels": step_predictions[1],
    },
    title=f"Prediction after {n_steps} training steps",
    show=0,
)

As you can see, the matrices are very different. That is, the model has no idea how to predict the matrices!

This could be shocking considering that it has only been tasked with fitting 2 matrices, a super simple problem that any model would overfit without any trouble. Well, you must take into account two things:

- **The target matrix is random**, while **the model is designed to learn equivariant matrices!**. All operations are equivariant and therefore result into an equivariant predicted matrix. For example, symmetry determines that the scalar element for node blocks for points 0 and 2 (at the top-left and bottom-right corner of the matrix) must be exactly the same because the point are equivalent. The random matrix does not satisfy this condition so it is impossible to fit.

- **The model is limited by the input node features**, which only contain one scalar and one vector. The combination possibilities are very small. If you increase the node feats irreps to `0e + 2x1o` (i.e. add one extra vector) and modify the `get_environment_representation` to still satisfy symmetries you should see some elements that have no symmetry problems (e.g. the 4 scalar elements at the top-left corner of node block for point 1) get very close to the target matrix.

We could work very hard to make our fake environment and true matrix computing functions equivariant to see the model fit perfectly, but you will see this in other real-life examples in the tutorials. Also **it is nice to see how a random matrix can't be fitted by an equivariant model to understand the power of equivariant design**!

# Testing

In [11]:
from plotly.subplots import make_subplots

def plot_error_matrices_interactive(true_matrix, predicted_matrix, config_matrix, matrix_label=None, figure_title=None, predicted_matrix_text=None, filepath=None, n_atoms=None, absolute_error_cbar_limit=None):
    """Interactive Plotly visualization of error matrices."""

    # === Error matrices computation ===
    absolute_error_matrix = true_matrix - predicted_matrix
    epsilon = 0.001
    # relative_error_matrix = absolute_error_matrix / (true_matrix + epsilon)*100
    relative_error_matrix = absolute_error_matrix

    # === Colorbar limits ===
    vmin = np.min([np.min(true_matrix), np.min(predicted_matrix)])
    vmax = np.max([np.max(true_matrix), np.max(predicted_matrix)])
    lim_data = max(abs(vmin), abs(vmax))

    if absolute_error_cbar_limit is None:
        lim_abs = np.max(np.abs(absolute_error_matrix))
    else:
        lim_abs = absolute_error_cbar_limit

    lim_rel = 100.0  # %

    cbar_limits = [lim_data, lim_data, lim_abs, lim_rel]

    # === Titles ===
    if matrix_label is None:
        matrix_label = ''
    titles = [
        "True " + matrix_label,
        "Predicted " + matrix_label,
        "Absolute error (A-B)",
        f"Relative error (A-B)/(A+{epsilon})"
    ]
    cbar_titles = ["eV", "eV", "eV", "%"]

    # === Figure ===
    # cbar_positions = [0.44, 1, 0.44, 1]
    matrices = [true_matrix, predicted_matrix, absolute_error_matrix, relative_error_matrix]

    fig = make_subplots(
        rows=4, cols=1,
        # subplot_titles=titles,
        # horizontal_spacing=0.15,
        vertical_spacing=0.17
    )

    for i, matrix in enumerate(matrices):
        # row = i // 2 + 1
        # col = i % 2 + 1

        heatmap = plot_basis_matrix(
            matrix,
            config_matrix,
            point_lines={"color": "black"},
            basis_lines={"color": "blue"},
            colorscale="temps",
            text=".2f",
            basis_labels=True,
        )
        fig.add_trace(heatmap, row=i, col=1)

    # === Subplot titles ===
    fig.update_layout(
        xaxis1=dict(side="top", title_text=titles[0]), yaxis1=dict(autorange="reversed"),
        xaxis2=dict(side="top", title_text=titles[1]), yaxis2=dict(autorange="reversed"),
        xaxis3=dict(side="top", title_text=titles[2]), yaxis3=dict(autorange="reversed"),
        xaxis4=dict(side="top", title_text=titles[3]), yaxis4=dict(autorange="reversed"),
        margin={"l":0,
                "r":0,
                "t":0,
                "b":0}
    )

    # # === Atomic orbitals blocks grid ===
    # if n_atoms is not None:
    #     n_orbitals = 13
    #     minor_ticks = np.arange(-0.5, n_orbitals * n_atoms, n_orbitals)

    #     for i, matrix in enumerate(matrices):
    #         row = i // 2 + 1
    #         col = i % 2 + 1  # Ensure shapes are added to the correct subplot

    #         grid_lines = [
    #             # Vertical grid lines
    #             dict(type="line", x0=x, x1=x, y0=-0.5, y1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
    #             for x in minor_ticks
    #         ] + [
    #             # Horizontal grid lines
    #             dict(type="line", y0=y, y1=y, x0=-0.5, x1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
    #             for y in minor_ticks
    #         ]

    #         # Add each grid line to the corresponding subplot
    #         for line in grid_lines:
    #             fig.add_shape(line, row=row, col=col)

    # === Text annotations ===

    # Text under predicted matrix
    if predicted_matrix_text:
        fig.add_annotation(
            text=predicted_matrix_text,
            xref='x2 domain', yref='y2 domain',
            x=1.1, y=-0.15,
            showarrow=False,
            font=dict(size=12),
            align='right'
        )

    # Absolute error stats
    max_absolute_error = np.max(absolute_error_matrix)
    min_absolute_error = np.min(absolute_error_matrix)
    fig.add_annotation(
        text=f"max = {max_absolute_error:.2f} eV, min = {min_absolute_error:.2f} eV",
        xref='x3 domain', yref='y3 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # Relative error stats
    max_relative_error = np.max(relative_error_matrix)
    min_relative_error = np.min(relative_error_matrix)
    fig.add_annotation(
        text=f"max = {max_relative_error:.2f}%, min = {min_relative_error:.2f}%",
        xref='x4 domain', yref='y4 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # === Layout of the whole figure ===
    fig.update_layout(
        height=1700,
        width=800,
        title_text=figure_title if figure_title else "Matrix Comparison and Errors",
        title_x=0.5,
        title_y=0.99,
        margin=dict(t=100, b=20)
    )

    # === Output ===
    if filepath:
        fig.write_html(filepath)
    else:
        fig.show()



predictions={
    "node_labels": step_predictions[0],
    "edge_labels": step_predictions[1],
}

data_example = dataset[0]
matrices = processor.matrix_from_data(data, predictions=predictions)
if not isinstance(matrices, (tuple, list)):
    matrices = (matrices,)

true_matrix = processor.matrix_from_data(data)[0]
predicted_matrix = matrices[0]
plot_error_matrices_interactive(true_matrix, predicted_matrix, configs[0])

ValueError: 
    Invalid element(s) received for the 'data' property of 
        Invalid elements include: [Figure({
    'data': [{'coloraxis': 'coloraxis',
              'hovertemplate': 'x: %{x}<br>y: %{y}<br>color: %{z}<extra></extra>',
              'name': '0',
              'texttemplate': '%{z:.2f}',
              'type': 'heatmap',
              'xaxis': 'x',
              'yaxis': 'y',
              'z': {'bdata': ('JJVaP2D0tL5cdEi6QZBTv9TlAD/t7n' ... '+ZiiK/RO0svnbFor/gqc8+0hjAPw=='),
                    'dtype': 'f4',
                    'shape': '7, 7'}}],
    'layout': {'coloraxis': {'cmid': 0,
                             'colorscale': [[0.0, 'rgb(0, 147, 146)'],
                                            [0.16666666666666666, 'rgb(57, 177,
                                            133)'], [0.3333333333333333, 'rgb(156,
                                            203, 134)'], [0.5, 'rgb(233, 226,
                                            156)'], [0.6666666666666666, 'rgb(238,
                                            180, 121)'], [0.8333333333333334,
                                            'rgb(232, 132, 113)'], [1.0, 'rgb(207,
                                            89, 126)']]},
               'margin': {'t': 60},
               'shapes': [{'line': {'color': 'black'},
                           'type': 'line',
                           'x0': 0,
                           'x1': 1,
                           'xref': 'x domain',
                           'y0': 0.5,
                           'y1': 0.5,
                           'yref': 'y'},
                          {'line': {'color': 'black'},
                           'type': 'line',
                           'x0': 0.5,
                           'x1': 0.5,
                           'xref': 'x',
                           'y0': 0,
                           'y1': 1,
                           'yref': 'y domain'},
                          {'line': {'color': 'black'},
                           'type': 'line',
                           'x0': 0,
                           'x1': 1,
                           'xref': 'x domain',
                           'y0': 5.5,
                           'y1': 5.5,
                           'yref': 'y'},
                          {'line': {'color': 'black'},
                           'type': 'line',
                           'x0': 5.5,
                           'x1': 5.5,
                           'xref': 'x',
                           'y0': 0,
                           'y1': 1,
                           'yref': 'y domain'},
                          {'line': {'color': 'blue', 'dash': 'dot'},
                           'type': 'line',
                           'x0': 0,
                           'x1': 1,
                           'xref': 'x domain',
                           'y0': 1.5,
                           'y1': 1.5,
                           'yref': 'y'},
                          {'line': {'color': 'blue', 'dash': 'dot'},
                           'type': 'line',
                           'x0': 1.5,
                           'x1': 1.5,
                           'xref': 'x',
                           'y0': 0,
                           'y1': 1,
                           'yref': 'y domain'},
                          {'line': {'color': 'blue', 'dash': 'dot'},
                           'type': 'line',
                           'x0': 0,
                           'x1': 1,
                           'xref': 'x domain',
                           'y0': 2.5,
                           'y1': 2.5,
                           'yref': 'y'},
                          {'line': {'color': 'blue', 'dash': 'dot'},
                           'type': 'line',
                           'x0': 2.5,
                           'x1': 2.5,
                           'xref': 'x',
                           'y0': 0,
                           'y1': 1,
                           'yref': 'y domain'}],
               'template': '...',
               'xaxis': {'anchor': 'y',
                         'constrain': 'domain',
                         'domain': [0.0, 1.0],
                         'scaleanchor': 'y',
                         'ticktext': [0: (0, 0), 1: (0, 0), 1: (0, 0), 1: (1, -1),
                                      1: (1, 0), 1: (1, 1), 2: (0, 0)],
                         'tickvals': {'bdata': 'AAECAwQFBg==', 'dtype': 'i1'}},
               'yaxis': {'anchor': 'x',
                         'autorange': 'reversed',
                         'constrain': 'domain',
                         'domain': [0.0, 1.0],
                         'ticktext': [0: (0, 0), 1: (0, 0), 1: (0, 0), 1: (1, -1),
                                      1: (1, 0), 1: (1, 1), 2: (0, 0)],
                         'tickvals': {'bdata': 'AAECAwQFBg==', 'dtype': 'i1'}}}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['bar', 'barpolar', 'box', 'candlestick',
                     'carpet', 'choropleth', 'choroplethmap',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymap',
                     'densitymapbox', 'funnel', 'funnelarea',
                     'heatmap', 'histogram', 'histogram2d',
                     'histogram2dcontour', 'icicle', 'image',
                     'indicator', 'isosurface', 'mesh3d', 'ohlc',
                     'parcats', 'parcoords', 'pie', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermap',
                     'scattermapbox', 'scatterpolar',
                     'scatterpolargl', 'scattersmith',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])

In [None]:
true_matrix

(<Compressed Sparse Row sparse array of dtype 'float32'
 	with 47 stored elements and shape (7, 7)>,
 <Compressed Sparse Row sparse array of dtype 'float32'
 	with 71 stored elements and shape (11, 11)>)

In [None]:
import scipy.sparse as sp
import numpy as np
from graph2mat.tools.viz.sparse_plot import plot_basis_matrix
from plotly.subplots import make_subplots

# Create dummy sparse matrices
matrices = [sp.random(10, 10, density=0.1, format='csr', random_state=i) for i in range(4)]

# Create a 4x1 subplot figure
fig = make_subplots(rows=4, cols=1, subplot_titles=[f"Matrix {i+1}" for i in range(4)])

# Plot each sparse matrix using plot_basis_matrix with Plotly
for i, mat in enumerate(matrices):
    fig_i = plot_basis_matrix(mat, return_plotly=True)  # Get plotly figure
    for trace in fig_i.data:
        fig.add_trace(trace, row=i+1, col=1)

fig.update_layout(height=1200, showlegend=False)
fig.show()



TypeError: plot_basis_matrix() got an unexpected keyword argument 'return_plotly'

In [None]:
matrix = processor.matrix_from_data(data)[0].todense()

# plot_basis_matrix(
#     matrix,
#     config1,
#     point_lines={"color": "black"},
#     basis_lines={"color": "blue"},
#     colorscale="temps",
#     text=".2f",
#     basis_labels=True,
# )

plot_error_matrices_interactive(matrix, matrix)

In [None]:
import plotly.graph_objects as go

def plot_error_matrices_interactive(true_matrix, predicted_matrix, matrix_label=None, figure_title=None, predicted_matrix_text=None, filepath=None, n_atoms=None, absolute_error_cbar_limit=None):
    """Interactive Plotly visualization of error matrices."""

    # === Error matrices computation ===
    absolute_error_matrix = true_matrix - predicted_matrix
    epsilon = 0.001
    relative_error_matrix = absolute_error_matrix / (true_matrix + epsilon)*100

    # === Colorbar limits ===
    vmin = np.min([np.min(true_matrix), np.min(predicted_matrix)])
    vmax = np.max([np.max(true_matrix), np.max(predicted_matrix)])
    lim_data = max(abs(vmin), abs(vmax))

    if absolute_error_cbar_limit is None:
        lim_abs = np.max(np.abs(absolute_error_matrix))
    else:
        lim_abs = absolute_error_cbar_limit

    lim_rel = 100.0  # %

    cbar_limits = [lim_data, lim_data, lim_abs, lim_rel]

    # === Titles ===
    if matrix_label is None:
        matrix_label = ''
    titles = [
        "True " + matrix_label,
        "Predicted " + matrix_label,
        "Absolute error (A-B)",
        f"Relative error (A-B)/(A+{epsilon})"
    ]
    cbar_titles = ["eV", "eV", "eV", "%"]

    # === Figure ===
    cbar_positions = [1, 1, 1, 1]
    matrices = [true_matrix, predicted_matrix, absolute_error_matrix, relative_error_matrix]

    fig = make_subplots(
        rows=4, cols=1,
        # subplot_titles=titles,
        # horizontal_spacing=0.15,
        vertical_spacing=0.17
    )

    for i, matrix in enumerate(matrices):
        row = i + 1
        col = 1

        heatmap = go.Heatmap(
            z=matrix,
            colorscale='RdYlBu',
            zmin=-cbar_limits[i],
            zmax=cbar_limits[i],
            colorbar=dict(title=cbar_titles[i], len=0.475, yanchor="middle", y=0.807 - 0.585*(row-1)),
            colorbar_x = cbar_positions[i]
        )
        fig.add_trace(heatmap, row=row, col=col)

    # === Subplot titles ===
    fig.update_layout(
        xaxis1=dict(side="top", title_text=titles[0]), yaxis1=dict(autorange="reversed"),
        xaxis2=dict(side="top", title_text=titles[1]), yaxis2=dict(autorange="reversed"),
        xaxis3=dict(side="top", title_text=titles[2]), yaxis3=dict(autorange="reversed"),
        xaxis4=dict(side="top", title_text=titles[3]), yaxis4=dict(autorange="reversed"),
        margin={"l":0,
                "r":0,
                "t":0,
                "b":0}
    )

    # === Atomic orbitals blocks grid ===
    if n_atoms is not None:
        n_orbitals = 13
        minor_ticks = np.arange(-0.5, n_orbitals * n_atoms, n_orbitals)

        for i, matrix in enumerate(matrices):
            row = i // 2 + 1
            col = i % 2 + 1  # Ensure shapes are added to the correct subplot

            grid_lines = [
                # Vertical grid lines
                dict(type="line", x0=x, x1=x, y0=-0.5, y1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
                for x in minor_ticks
            ] + [
                # Horizontal grid lines
                dict(type="line", y0=y, y1=y, x0=-0.5, x1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
                for y in minor_ticks
            ]

            # Add each grid line to the corresponding subplot
            for line in grid_lines:
                fig.add_shape(line, row=row, col=col)

    # === Text annotations ===

    # Text under predicted matrix
    if predicted_matrix_text:
        fig.add_annotation(
            text=predicted_matrix_text,
            xref='x2 domain', yref='y2 domain',
            x=1.1, y=-0.15,
            showarrow=False,
            font=dict(size=12),
            align='right'
        )

    # Absolute error stats
    max_absolute_error = np.max(absolute_error_matrix)
    min_absolute_error = np.min(absolute_error_matrix)
    fig.add_annotation(
        text=f"max = {max_absolute_error:.2f} eV, min = {min_absolute_error:.2f} eV",
        xref='x3 domain', yref='y3 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # Relative error stats
    max_relative_error = np.max(relative_error_matrix)
    min_relative_error = np.min(relative_error_matrix)
    fig.add_annotation(
        text=f"max = {max_relative_error:.2f}%, min = {min_relative_error:.2f}%",
        xref='x4 domain', yref='y4 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # === Layout of the whole figure ===
    fig.update_layout(
        height=850,
        width=800,
        title_text=figure_title if figure_title else "Matrix Comparison and Errors",
        title_x=0.5,
        title_y=0.99,
        margin=dict(t=100, b=20)
    )

    # === Output ===
    if filepath:
        fig.write_html(filepath)
    else:
        fig.show()