In [18]:
import numpy as np
import torch
import yaml
import importlib

# So that we can plot sisl geometries
import sisl.viz

from e3nn import o3
from pathlib import Path

from graph2mat import (
    PointBasis,
    BasisTableWithEdges,
    BasisConfiguration,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixData, TorchBasisMatrixDataset
from graph2mat.bindings.e3nn import E3nnGraph2Mat

def load_config(path="../config.yaml"):
    with open(path, "r") as f:
        return yaml.safe_load(f)
    
def flatten(xss):
    return [x for xs in xss for x in xs]

from mace.modules import RadialEmbeddingBlock, EquivariantProductBasisBlock
from mace.modules.utils import get_edge_vectors_and_lengths

class EmbeddingBase(torch.nn.Module):
    def __init__(self, config, orbitals):
        super(EmbeddingBase, self).__init__()

        embeddings_config = config["model"]["embedding"]
        self.device = config["device"]

        # Define the irreducible representations for the node attributes and features.
        node_attr_irreps = o3.Irreps([(embeddings_config["num_elements"], (0, 1))]) # E.g. [(10, (0,1))]
        hidden_irreps = o3.Irreps(embeddings_config["hidden_irreps"]) # E.g. "8x0e+8x1o"
        node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) # Counting how many Irrep(0, 1) there are inside hidden_irreps.

        # Linear transformation from node attributes to node features.
        # / I think this is the same as torch.nn.Linear
        self.node_embedding = o3.Linear(
            node_attr_irreps,
            node_feats_irreps,
            shared_weights=True,
            internal_weights=True,
        )

        # Radial embedding block using Bessel functions and polynomial cutoffs.
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=embeddings_config["r_max"],
            num_bessel=embeddings_config["num_bessel"],
            num_polynomial_cutoff=embeddings_config["num_polynomial_cutoff"],
            radial_type=embeddings_config["radial_type"],
            distance_transform=embeddings_config["distance_transform"],
        )

        # Angular embedding using spherical harmonics.
        sh_irreps = o3.Irreps.spherical_harmonics(embeddings_config["max_ell"])
        self.angular_embedding = o3.SphericalHarmonics(sh_irreps, normalize=True, normalization="component")

        # Element encoding configuration
        self.orbitals = orbitals
        self.nr_bit = embeddings_config["nr_bits"]

    def forward(self, data):

        # Encode atomic numbers into binary orbital-based representation. 
        atom_types = data.metadata['atom_types'].unsqueeze(1) # The reshape is for format reasons
        one_hot_z = z_one_hot(atom_types, orbitals=self.orbitals, nr_bits=self.nr_bit).to(self.device)

        # Input node descriptors.
        node_feats = one_hot_z

        # Calculate edge vectors and their lengths (distances).
        vectors, lengths = get_edge_vectors_and_lengths(
            positions=data.positions,
            edge_index=data.edge_index,
            shifts=data.shifts,
        )

        # Apply node embedding.
        node_feats = self.node_embedding(node_feats)
        

        # Apply radial and angular embeddings for edges.
        radial_embedding = self.radial_embedding(
            lengths,
            node_feats,
            data.edge_index,
            atom_types
        )
        angular_embedding = self.angular_embedding(vectors)

        # Bundle the embeddings.
        embedding_collection = {
            "nodes": {
                "one_hot": one_hot_z,
                "node_features": node_feats,
            },
            "edges": {
                "radial_embedding": radial_embedding,
                "angular_embedding": angular_embedding,
            }
        }

        return embedding_collection
    

class MACEDescriptor(torch.nn.Module):
    def __init__(self, atomic_descriptors_config):
        super(MACEDescriptor, self).__init__()

        # --- Irreps definitions ---
        node_attr_irreps = o3.Irreps([(atomic_descriptors_config["num_elements"], (0, 1))])  # One-hot per element (scalar-even)

        # Extract number of scalar-even irreps from hidden_irreps
        hidden_irreps = o3.Irreps(atomic_descriptors_config["hidden_irreps"])
        num_scalar_irreps = hidden_irreps.count(o3.Irrep(0, 1))
        node_feats_irreps = o3.Irreps([(num_scalar_irreps, (0, 1))])

        sh_irreps = o3.Irreps.spherical_harmonics(atomic_descriptors_config["max_ell"])  # Angular features

        radial_out_dim = atomic_descriptors_config["radial_embedding.out_dim"]
        edge_feats_irreps = o3.Irreps(f"{radial_out_dim}x0e")  # Radial embeddings as scalar-even

        hidden_irreps_out = hidden_irreps  # Output IRs remain same

        # Determine output irreps of interaction (spherical harmonics ⊗ scalar features)
        interaction_irreps = (sh_irreps * num_scalar_irreps).sort()[0].simplify()

        # Support for correlation order per layer
        if isinstance(atomic_descriptors_config["correlation"], int):
            correlation = [atomic_descriptors_config["correlation"]] * atomic_descriptors_config["num_interactions"]

        # --- First Interaction Layer ---
        interaction_cls_first = get_object_from_module(atomic_descriptors_config["interaction_cls_first"], "mace.modules")
        first_interaction = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=hidden_irreps,
            avg_num_neighbors=atomic_descriptors_config["avg_num_neighbors"],
            radial_MLP=atomic_descriptors_config["radial_mlp"],
            cueq_config=None,
        )

        self.interactions = torch.nn.ModuleList([first_interaction])

        # Determine whether to use self-connection (important for residual-based models)
        use_sc_first = "Residual" in str(atomic_descriptors_config["interaction_cls_first"])

        first_product = EquivariantProductBasisBlock(
            node_feats_irreps=first_interaction.target_irreps,
            target_irreps=hidden_irreps,
            correlation=correlation[0],
            num_elements=atomic_descriptors_config["num_elements"],
            use_sc=use_sc_first,
            cueq_config=None,
        )

        self.products = torch.nn.ModuleList([first_product])

        # --- Remaining Interaction-Product Blocks ---
        for i in range(atomic_descriptors_config["num_interactions"] - 1):
            interaction_cls = get_object_from_module(atomic_descriptors_config["interaction_cls"], "mace.modules")
            interaction = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=hidden_irreps_out,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=interaction_irreps,
                hidden_irreps=hidden_irreps_out,
                avg_num_neighbors=atomic_descriptors_config["avg_num_neighbors"],
                radial_MLP=atomic_descriptors_config["radial_mlp"],
                cueq_config=None,
            )

            product = EquivariantProductBasisBlock(
                node_feats_irreps=interaction_irreps,
                target_irreps=hidden_irreps_out,
                correlation=correlation[i + 1],
                num_elements=atomic_descriptors_config["num_elements"],
                use_sc=True,
                cueq_config=None,
            )

            self.interactions.append(interaction)
            self.products.append(product)

    def forward(self, embeddings, edge_index):
        """_summary_

        Args:
            data (_type_): Already preprocessed data.

        Returns:
            _type_: _description_
        """

        node_feats = embeddings["nodes"]["node_features"]

        node_feats_list = []
        for interaction, product in zip(self.interactions, self.products):
            node_feats, sc = interaction(
                node_attrs=embeddings["nodes"]["one_hot"],
                node_feats=node_feats,
                edge_attrs=embeddings["edges"]["angular_embedding"],
                edge_feats=embeddings["edges"]["radial_embedding"],
                edge_index=edge_index,
            )

            node_feats = product(
                node_feats=node_feats,
                sc=sc,
                node_attrs=embeddings["nodes"]["one_hot"],
            )

            node_feats_list.append(node_feats)

        # Concatenate features from all interaction layers
        node_feats_out = torch.cat(node_feats_list, dim=-1)

        # Final descriptor
        descriptors = {
            "nodes": {
                "node_env": node_feats_out,
            }
        }

        return descriptors


    
def z_one_hot(z, orbitals, nr_bits):
    """
    Generate one-hot encodings from a list of single-value tensors.

    Args:
        z (list of torch.Tensor): A list of single-value tensors, e.g., [[2], [3], [4], [2], [2], ...].
        orbitals (dict): A dictionary mapping numbers to their corresponding values.
        nr_bits (int): The number of bits for one-hot encoding.

    Returns:
        torch.Tensor: A tensor containing the one-hot encodings.
    """

    # Extract values from the list of single-value tensors
    node_map={}
    k=0
    for key in orbitals.keys():
        node_map[key]=k
        k+=1

    indices = [tensor.item() for tensor in z]

    # Create an empty tensor for one-hot encoding
    one_hot = torch.zeros(len(indices), nr_bits)

    # Fill in the one-hot encoding based on the indices
    for i, idx in enumerate(indices):
        if idx in orbitals:  # Ensure the index exists in orbitals
            one_hot[i, int(node_map[idx])] = 1  # Set the corresponding bit to 1
        else:
            raise ValueError(f"Index {idx} not found in orbitals.")

    return one_hot

def get_object_from_module(class_name, module="mace.modules"):
    try:
        return getattr(importlib.import_module(module), class_name)
    except AttributeError:
        return None  # Or raise an error if you prefer

In [19]:
# === List of paths to all structures ===
parent_path = Path('../dataset')
n_atoms_paths = list(parent_path.glob('*/'))
paths = []
for n_atoms_path in n_atoms_paths:
    structure_paths = list(n_atoms_path.glob('*/'))
    paths.append(structure_paths)
paths = flatten(paths)


# === Config init ===
config_path = Path("./training_loop_config.yaml")
config = load_config(config_path)
model_config = config["model"]
orbitals = config['orbitals']
device = torch.device(config["device"] if (torch.cuda.is_available() and config["device"]!="cpu") 
else 'cpu')
device = 'cpu'


# == Basis === 
basis = []
unique_atom_types = []

# We need a basis that contains all atom types in our dataset. To do so, we will examine just one sample of 64 atoms, because there already are atoms of all types.
for i, path in enumerate(paths):
    if 'SHARE_OUTPUTS_64_ATOMS' in str(path):
        file = sisl.get_sile(path / "aiida.fdf")
        geometry = file.read_geometry()
        for atom in geometry.atoms:
            # Boron, Carbon, Nitrogen
            if (atom.Z == 5 or atom.Z == 6 or atom.Z == 7) and atom.Z not in unique_atom_types:
                basis.append(PointBasis.from_sisl_atom(atom))
                unique_atom_types.append(atom.Z)
            if len(unique_atom_types) == 3:
                break
        if len(unique_atom_types) == 3:
                print("Finished building the basis. Breaking...")
                break

basis.sort(key=lambda x: x.type)
unique_atom_types.sort()

# Check that there are 3 atoms in the basis.
if len(basis) != 3:
    raise ValueError("There are not three elements in the basis")

print("Basis built!")
[print(f"\nBasis for atom {i}.\n\tAtom type: {basis[i].type}\n\tBasis: {basis[i].basis}\n\tBasis convention: {basis[i].basis_convention}\n\tR: {basis[i].R}") for i in range(len(basis))]


# === Basis table === 
table = BasisTableWithEdges(basis)


# === Data processor ===
# Initialize the processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False # ! In princpiple no out_matrix
)

# === Shape of our inputs ===
hidden_irreps = o3.Irreps(model_config["atomic_descriptors"]["hidden_irreps"])
num_interactions = model_config["atomic_descriptors"]["num_interactions"]
final_irreps = o3.Irreps([(mul, ir) for mul, ir in hidden_irreps] * num_interactions)

# === The matrix readout function ===
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=final_irreps, edge_feats_irreps=final_irreps),
    symmetric=True,
    blocks_symmetry='ij=ji', # This is the symmetry of the matrix we want to compute
)

# === Initialization of the enviroment representation ===
get_embeddings = EmbeddingBase(config, orbitals)

model_config["atomic_descriptors"]["radial_embedding.out_dim"] = get_embeddings.radial_embedding.out_dim
model_config["atomic_descriptors"]["angular_embedding.out_dim"] = sum(2 * l + 1 for l in range(model_config["embedding"]["max_ell"] + 1))
get_enviroment_representation = MACEDescriptor(model_config["atomic_descriptors"])

# === Batching ===
# from graph2mat.core.data.configuration import OrbitalConfiguration
from graph2mat import OrbitalConfiguration
embeddings_configs = []
for i, path in enumerate(paths):
    if i==2:
        break

    # Load the structure config
    file = sisl.get_sile(path / "aiida.fdf")
    file_h = sisl.get_sile(path / "aiida.HSX")
    geometry = file.read_geometry()

    # Load the true hamiltonian
    true_h = file_h.read_hamiltonian()

    # ! This only works for square matrices.
    # embeddings_config = BasisConfiguration(
    #     point_types=geometry.atoms.Z,
    #     positions=geometry.xyz,
    #     basis=basis,
    #     cell=geometry.cell,
    #     pbc=geometry.pbc,
    #     matrix=true_h,
        # metadata={
        #     "device": device,
        #     "atom_types": torch.from_numpy(geometry.atoms.Z), # Unlike point_types, this is not rescaled.
        # }
    # )

    # * This works in the last github version!
    embeddings_config = BasisConfiguration.from_matrix(
        matrix = true_h,
        geometry = geometry,
        labels = True,
        metadata={
            "device": device,
            "atom_types": torch.from_numpy(geometry.atoms.Z), # Unlike point_types, this is not rescaled.
        },
    )

    # ! This works but doesn't allow plotting (gives problems when plotting):
    # embeddings_config = OrbitalConfiguration.from_matrix(
    #     true_h,
    #     # geometry=geometry,
    #     # labels=true_h
    #     )

    embeddings_configs.append(embeddings_config)

# === Dataset creation ===
dataset = TorchBasisMatrixDataset(embeddings_configs, data_processor=processor)

Finished building the basis. Breaking...
Basis built!

Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 6
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.57112067 1.67068795 3.16682115 3.16682115 3.16682115 2.05819653
 2.05819653 2.05819653 3.16682115 3.16682115 3.16682115 3.16682115
 3.16682115]

Basis for atom 2.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]



The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1

In [34]:
from graph2mat.tools.viz import plot_basis_matrix

data_example = dataset[0]

def plot_matrices(data, predictions=None, title="", show=True):
    """Helper function to plot (possibly batched) matrices"""

    matrices = processor.matrix_from_data(data, predictions=predictions)

    if not isinstance(matrices, (tuple, list)):
        matrices = (matrices,)

    for i, (config, matrix) in enumerate(zip(embeddings_configs, matrices)):
        # if i == 0:
        #     continue
        if show is True or show == i:
            plot_basis_matrix(
                matrix,
                config,
                point_lines={"color": "black"},
                basis_lines={"color": "blue"},
                colorscale="temps",
                text=".2f",
                basis_labels=True,
            ).update_layout(title=f"{title} [{i}]").show()

        # break


plot_matrices(data_example, title="Labels")

In [21]:
# The simplest training loop

# Create the data loader
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=2)

# Number of training steps
n_steps = 3
# Initialize an optimizer
optimizer = torch.optim.Adam(model.parameters())

# Initialize arrays to store errors
losses = np.zeros(n_steps)
node_rmse = np.zeros(n_steps)
edge_rmse = np.zeros(n_steps)

# The loss function, which we get from graph2mat's metrics functions
from graph2mat import metrics
from tqdm import tqdm

loss_fn = metrics.elementwise_mse # This computes the elementwise squared diff and computes the mean between all the elements. Then it computes another mean for all nodes and edges.
# ! I don't unierstand why this last mean.

# Loop
for i in tqdm(range(n_steps)):
    for data in loader:
        # Reset gradients
        optimizer.zero_grad()
        data = data.to(device)
        get_embeddings = get_embeddings.to(device)

        # Get the node feats. Since this function is not learnable, it could be outside the loop, but we keep it here to show how things could work with a learnable environment representation.
        # Get enviroment representation.
        embeddings = get_embeddings(data)
        node_inputs = get_enviroment_representation(embeddings, data.edge_index)["nodes"]["node_env"]

        # Make predictions for this batch
        step_predictions = model(data, node_feats=node_inputs)

        # Compute the loss
        loss, info = loss_fn(
            nodes_pred=step_predictions[0],
            nodes_ref=data.point_labels,
            edges_pred=step_predictions[1],
            edges_ref=data.edge_labels,
        )

        # Store errors
        losses[i] = loss
        node_rmse[i] = info["node_rmse"]
        edge_rmse[i] = info["edge_rmse"]

        # Compute gradients
        loss.backward()

        # Update weights
        optimizer.step()

100%|██████████| 3/3 [00:00<00:00,  3.65it/s]


In [22]:
import pandas as pd

df = pd.DataFrame(
    np.array([losses, node_rmse, edge_rmse]).T,
    columns=["loss", "node_rmse", "edge_rmse"],
)

df.plot(backend="plotly").update_layout(
    yaxis_type="log", yaxis_showgrid=True, xaxis_showgrid=True
).update_layout(
    yaxis_title="Value",
    xaxis_title="Training step",
    title="Error evolution during training",
)

In [33]:
plot_matrices(data, title=f"Target matrix", show=0)
predictions={
    "node_labels": step_predictions[0],
    "edge_labels": step_predictions[1],
}
plot_matrices(
    data,
    predictions=predictions,
    title=f"Prediction after {n_steps} training steps",
    show=0,
)
matrices = processor.matrix_from_data(data, predictions=predictions)

# Testing

In [26]:
from plotly.subplots import make_subplots

def plot_error_matrices_interactive(true_matrix, predicted_matrix, config_matrix, matrix_label=None, figure_title=None, predicted_matrix_text=None, filepath=None, n_atoms=None, absolute_error_cbar_limit=None):
    """Interactive Plotly visualization of error matrices."""

    # === Error matrices computation ===
    absolute_error_matrix = true_matrix - predicted_matrix
    epsilon = 0.001
    relative_error_matrix = absolute_error_matrix / (true_matrix + epsilon)*100

    # === Colorbar limits ===
    vmin = np.min([np.min(true_matrix), np.min(predicted_matrix)])
    vmax = np.max([np.max(true_matrix), np.max(predicted_matrix)])
    lim_data = max(abs(vmin), abs(vmax))

    if absolute_error_cbar_limit is None:
        lim_abs = np.max(np.abs(absolute_error_matrix))
    else:
        lim_abs = absolute_error_cbar_limit

    lim_rel = 100.0  # %

    cbar_limits = [lim_data, lim_data, lim_abs, lim_rel]

    # === Titles ===
    if matrix_label is None:
        matrix_label = ''
    titles = [
        "True " + matrix_label,
        "Predicted " + matrix_label,
        "Absolute error (A-B)",
        f"Relative error (A-B)/(A+{epsilon})"
    ]
    cbar_titles = ["eV", "eV", "eV", "%"]

    # === Figure ===
    # cbar_positions = [0.44, 1, 0.44, 1]
    matrices = [true_matrix, predicted_matrix, absolute_error_matrix, relative_error_matrix]

    fig = make_subplots(
        rows=4, cols=1,
        # subplot_titles=titles,
        # horizontal_spacing=0.15,
        vertical_spacing=0.17
    )

    for i, matrix in enumerate(matrices):
        # row = i // 2 + 1
        # col = i % 2 + 1

        heatmap = plot_basis_matrix(
            matrix,
            config_matrix,
            point_lines={"color": "black"},
            basis_lines={"color": "blue"},
            colorscale="temps",
            text=".2f",
            basis_labels=True,
        )
        fig.add_trace(heatmap, row=row, col=1)

    # === Subplot titles ===
    fig.update_layout(
        xaxis1=dict(side="top", title_text=titles[0]), yaxis1=dict(autorange="reversed"),
        xaxis2=dict(side="top", title_text=titles[1]), yaxis2=dict(autorange="reversed"),
        xaxis3=dict(side="top", title_text=titles[2]), yaxis3=dict(autorange="reversed"),
        xaxis4=dict(side="top", title_text=titles[3]), yaxis4=dict(autorange="reversed"),
        margin={"l":0,
                "r":0,
                "t":0,
                "b":0}
    )

    # # === Atomic orbitals blocks grid ===
    # if n_atoms is not None:
    #     n_orbitals = 13
    #     minor_ticks = np.arange(-0.5, n_orbitals * n_atoms, n_orbitals)

    #     for i, matrix in enumerate(matrices):
    #         row = i // 2 + 1
    #         col = i % 2 + 1  # Ensure shapes are added to the correct subplot

    #         grid_lines = [
    #             # Vertical grid lines
    #             dict(type="line", x0=x, x1=x, y0=-0.5, y1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
    #             for x in minor_ticks
    #         ] + [
    #             # Horizontal grid lines
    #             dict(type="line", y0=y, y1=y, x0=-0.5, x1=n_orbitals * n_atoms - 0.5, line=dict(color="black", width=1))
    #             for y in minor_ticks
    #         ]

    #         # Add each grid line to the corresponding subplot
    #         for line in grid_lines:
    #             fig.add_shape(line, row=row, col=col)

    # === Text annotations ===

    # Text under predicted matrix
    if predicted_matrix_text:
        fig.add_annotation(
            text=predicted_matrix_text,
            xref='x2 domain', yref='y2 domain',
            x=1.1, y=-0.15,
            showarrow=False,
            font=dict(size=12),
            align='right'
        )

    # Absolute error stats
    max_absolute_error = np.max(absolute_error_matrix)
    min_absolute_error = np.min(absolute_error_matrix)
    fig.add_annotation(
        text=f"max = {max_absolute_error:.2f} eV, min = {min_absolute_error:.2f} eV",
        xref='x3 domain', yref='y3 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # Relative error stats
    max_relative_error = np.max(relative_error_matrix)
    min_relative_error = np.min(relative_error_matrix)
    fig.add_annotation(
        text=f"max = {max_relative_error:.2f}%, min = {min_relative_error:.2f}%",
        xref='x4 domain', yref='y4 domain',
        x=0.5, y=-0.07,
        showarrow=False,
        font=dict(size=12),
        align='center'
    )

    # === Layout of the whole figure ===
    fig.update_layout(
        height=1700,
        width=800,
        title_text=figure_title if figure_title else "Matrix Comparison and Errors",
        title_x=0.5,
        title_y=0.99,
        margin=dict(t=100, b=20)
    )

    # === Output ===
    if filepath:
        fig.write_html(filepath)
    else:
        fig.show()



predictions={
        "node_labels": step_predictions[0],
        "edge_labels": step_predictions[1],
    },

data_example = dataset[0]
matrices = processor.matrix_from_data(data_example, predictions)
if not isinstance(matrices, (tuple, list)):
    matrices = (matrices,)


plot_error_matrices_interactive(data_example,)

TypeError: tuple indices must be integers or slices, not str

In [27]:
data_example

TorchBasisMatrixData(
  edge_index=[2, 98],
  num_nodes=2,
  neigh_isc=[98],
  n_edges=98,
  positions=[2, 3],
  shifts=[98, 3],
  cell=[3, 3],
  nsc=[1, 3],
  node_attrs=[2, 3],
  point_labels=[338],
  edge_labels=[8281],
  point_types=[2],
  edge_types=[98],
  metadata={
    device='cpu',
    atom_types=[2],
    data_processor=MatrixDataProcessor(basis_table=BasisTableWithEdges(siesta_spherical, basis=[PointBasis(type=5, R=array([3.02420918, 2.02341372, 3.73961942, 3.73961942, 3.73961942,
       2.51253945, 2.51253945, 2.51253945, 3.73961942, 3.73961942,
       3.73961942, 3.73961942, 3.73961942]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical'), PointBasis(type=6, R=array([2.57112067, 1.67068795, 3.16682115, 3.16682115, 3.16682115,
       2.05819653, 2.05819653, 2.05819653, 3.16682115, 3.16682115,
       3.16682115, 3.16682115, 3.16682115]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical'), PointBasis(type=7, R=array([2.25

In [24]:
from sklearn.model_selection import train_test_split

train_dataset, val_dataset = train_test_split(
    np.arange(len(dataset)),  # Indices
    test_size=0.2,           # Validation ratio
    stratify=n_atoms_list,         # Ensures stratification
    random_state=42          # For reproducibility
)

NameError: name 'n_atoms_list' is not defined

In [None]:
n_atoms_list = [dataset[i].num_nodes for i in range(len(dataset))]
n_atoms_list

[2, 2]

In [None]:
train_dataset

array([1])