# One structure

In [1]:
import numpy as np
import torch
import yaml
import importlib

# So that we can plot sisl geometries
import sisl.viz

from e3nn import o3
from pathlib import Path

from graph2mat import (
    PointBasis,
    BasisTableWithEdges,
    BasisConfiguration,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat

def load_config(path="../config.yaml"):
    with open(path, "r") as f:
        return yaml.safe_load(f)

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


In [2]:
path= Path('../dataset/SHARE_OUTPUTS_2_ATOMS')
config = load_config()
model_config = config["model"]
orbitals = config['orbitals']
device = torch.device(config["device"] if (torch.cuda.is_available() and config["device"]!="cpu") 
else 'cpu')
device = 'cpu' # I believe that all of these computations should be in the cpu. cuda is only aimed for training (?).

# Get all subdirectories
children_path = list(path.glob('*/'))

# The basis must be the same for all samples.
# basis = []
# unique_atom_types = []
# for atom in geometry.atoms:
#     if atom.Z not in unique_atom_types:
#         basis.append(PointBasis.from_sisl_atom(atom))
#         unique_atom_types.append(atom.Z)

basis = []
unique_atom_types = []
i_static=[0,1] # Just with these 2 samples, we can find all atom types.
for i in range(len(children_path)):
    file = sisl.get_sile(children_path[i] / "aiida.fdf")
    geometry = file.read_geometry()
    for atom in geometry.atoms:
        # Boron, Carbon, Nitrogen
        if (atom.Z == 5 or atom.Z == 6 or atom.Z == 7) and atom.Z not in unique_atom_types:
            basis.append(PointBasis.from_sisl_atom(atom))
            unique_atom_types.append(atom.Z)
        if len(unique_atom_types) == 3:
            break
    if len(unique_atom_types) == 3:
            print("Finished building the basis. Breaking...")
            break

basis.sort(key=lambda x: x.type)
unique_atom_types.sort()
print("Basis built!")

# * Change this
i=0
file = sisl.get_sile(children_path[i] / "aiida.fdf")
geometry = file.read_geometry()

[print(f"\nBasis for atom {i}.\n\tAtom type: {basis[i].type}\n\tBasis: {basis[i].basis}\n\tBasis convention: {basis[i].basis_convention}\n\tR: {basis[i].R}") for i in range(len(basis))]

# Now we need to define the shape of our inputs.
hidden_irreps = o3.Irreps(model_config["atomic_descriptors"]["hidden_irreps"])
num_interactions = model_config["atomic_descriptors"]["num_interactions"]
final_irreps = o3.Irreps([(mul, ir) for mul, ir in hidden_irreps] * num_interactions)

# Initialize the model
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=final_irreps, edge_feats_irreps=final_irreps),
    symmetric=True,# ! False, #THE PROBLEEEEEM. maybe it gets solved when passing symmetric=True through node_opertion.
    blocks_symmetry='ij=ji', # This is the symmetry of the matrix we want to compute
).to(device)

# Config
embeddings_config = BasisConfiguration(
    point_types=geometry.atoms.Z,
    positions=geometry.xyz,
    basis=basis,
    cell=geometry.cell,
    pbc=geometry.pbc,
    metadata={
        "device": device,
        "atom_types": torch.from_numpy(geometry.atoms.Z).to(device), # Unlike point_types, this is not rescaled.
    }
)

# geometry = config.to_sisl_geometry()

# Create the basis table.
table = BasisTableWithEdges(basis)

# Initialize the processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False, out_matrix=None
)

data = TorchBasisMatrixData.from_config(embeddings_config, processor).to(device)


Finished building the basis. Breaking...
Basis built!

Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 6
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.57112067 1.67068795 3.16682115 3.16682115 3.16682115 2.05819653
 2.05819653 2.05819653 3.16682115 3.16682115 3.16682115 3.16682115
 3.16682115]

Basis for atom 2.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]




In [3]:
from mace.modules import RadialEmbeddingBlock, EquivariantProductBasisBlock
from mace.modules.utils import get_edge_vectors_and_lengths

class EmbeddingBase(torch.nn.Module):
    def __init__(self, config, orbitals):
        super(EmbeddingBase, self).__init__()

        embeddings_config = config["model"]["embedding"]
        self.device = config["device"]

        # Define the irreducible representations for the node attributes and features.
        node_attr_irreps = o3.Irreps([(embeddings_config["num_elements"], (0, 1))]) # E.g. [(10, (0,1))]
        hidden_irreps = o3.Irreps(embeddings_config["hidden_irreps"]) # E.g. "8x0e+8x1o"
        node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) # Counting how many Irrep(0, 1) there are inside hidden_irreps.

        # Linear transformation from node attributes to node features.
        # / I think this is the same as torch.nn.Linear
        self.node_embedding = o3.Linear(
            node_attr_irreps,
            node_feats_irreps,
            shared_weights=True,
            internal_weights=True,
        )

        # Radial embedding block using Bessel functions and polynomial cutoffs.
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=embeddings_config["r_max"],
            num_bessel=embeddings_config["num_bessel"],
            num_polynomial_cutoff=embeddings_config["num_polynomial_cutoff"],
            radial_type=embeddings_config["radial_type"],
            distance_transform=embeddings_config["distance_transform"],
        )

        # Angular embedding using spherical harmonics.
        sh_irreps = o3.Irreps.spherical_harmonics(embeddings_config["max_ell"])
        self.angular_embedding = o3.SphericalHarmonics(sh_irreps, normalize=True, normalization="component")

        # Element encoding configuration
        self.orbitals = orbitals
        self.nr_bit = embeddings_config["nr_bits"]

    def forward(self, data):

        # Encode atomic numbers into binary orbital-based representation. 
        atom_types = data.metadata['atom_types'].unsqueeze(1) # The reshape is for format reasons
        one_hot_z = z_one_hot(atom_types, orbitals=self.orbitals, nr_bits=self.nr_bit).to(self.device)

        # Input node descriptors.
        node_feats = one_hot_z

        # Calculate edge vectors and their lengths (distances).
        vectors, lengths = get_edge_vectors_and_lengths(
            positions=data.positions,
            edge_index=data.edge_index,
            shifts=data.shifts,
        )

        # Apply node embedding.
        node_feats = self.node_embedding(node_feats)
        

        # Apply radial and angular embeddings for edges.
        radial_embedding = self.radial_embedding(
            lengths,
            node_feats,
            data.edge_index,
            atom_types
        )
        angular_embedding = self.angular_embedding(vectors)

        # Bundle the embeddings.
        embedding_collection = {
            "nodes": {
                "one_hot": one_hot_z,
                "node_features": node_feats,
            },
            "edges": {
                "radial_embedding": radial_embedding,
                "angular_embedding": angular_embedding,
            }
        }

        # # (Optional) Track shapes of embedding tensors for debugging.
        # embedding_collection_shapes = {
        #     "nodes": {
        #         "one_hot": one_hot_z.shape,
        #         "node_features": node_feats.shape,
        #     },
        #     "edges": {
        #         "radial_embedding": radial_embedding.shape,
        #         "angular_embedding": angular_embedding.shape,
        #     }
        # }

        return embedding_collection
    

class MACEDescriptor(torch.nn.Module):
    def __init__(self, atomic_descriptors_config):
        super(MACEDescriptor, self).__init__()

        # --- Irreps definitions ---
        node_attr_irreps = o3.Irreps([(atomic_descriptors_config["num_elements"], (0, 1))])  # One-hot per element (scalar-even)

        # Extract number of scalar-even irreps from hidden_irreps
        hidden_irreps = o3.Irreps(atomic_descriptors_config["hidden_irreps"])
        num_scalar_irreps = hidden_irreps.count(o3.Irrep(0, 1))
        node_feats_irreps = o3.Irreps([(num_scalar_irreps, (0, 1))])

        sh_irreps = o3.Irreps.spherical_harmonics(atomic_descriptors_config["max_ell"])  # Angular features

        radial_out_dim = atomic_descriptors_config["radial_embedding.out_dim"]
        edge_feats_irreps = o3.Irreps(f"{radial_out_dim}x0e")  # Radial embeddings as scalar-even

        hidden_irreps_out = hidden_irreps  # Output IRs remain same

        # Determine output irreps of interaction (spherical harmonics ⊗ scalar features)
        interaction_irreps = (sh_irreps * num_scalar_irreps).sort()[0].simplify()

        # Support for correlation order per layer
        if isinstance(atomic_descriptors_config["correlation"], int):
            correlation = [atomic_descriptors_config["correlation"]] * atomic_descriptors_config["num_interactions"]

        # --- First Interaction Layer ---
        interaction_cls_first = get_object_from_module(atomic_descriptors_config["interaction_cls_first"], "mace.modules")
        first_interaction = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=hidden_irreps,
            avg_num_neighbors=atomic_descriptors_config["avg_num_neighbors"],
            radial_MLP=atomic_descriptors_config["radial_mlp"],
            cueq_config=None,
        )

        self.interactions = torch.nn.ModuleList([first_interaction])

        # Determine whether to use self-connection (important for residual-based models)
        use_sc_first = "Residual" in str(atomic_descriptors_config["interaction_cls_first"])

        first_product = EquivariantProductBasisBlock(
            node_feats_irreps=first_interaction.target_irreps,
            target_irreps=hidden_irreps,
            correlation=correlation[0],
            num_elements=atomic_descriptors_config["num_elements"],
            use_sc=use_sc_first,
            cueq_config=None,
        )

        self.products = torch.nn.ModuleList([first_product])

        # --- Remaining Interaction-Product Blocks ---
        for i in range(atomic_descriptors_config["num_interactions"] - 1):
            interaction_cls = get_object_from_module(atomic_descriptors_config["interaction_cls"], "mace.modules")
            interaction = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=hidden_irreps_out,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=interaction_irreps,
                hidden_irreps=hidden_irreps_out,
                avg_num_neighbors=atomic_descriptors_config["avg_num_neighbors"],
                radial_MLP=atomic_descriptors_config["radial_mlp"],
                cueq_config=None,
            )

            product = EquivariantProductBasisBlock(
                node_feats_irreps=interaction_irreps,
                target_irreps=hidden_irreps_out,
                correlation=correlation[i + 1],
                num_elements=atomic_descriptors_config["num_elements"],
                use_sc=True,
                cueq_config=None,
            )

            self.interactions.append(interaction)
            self.products.append(product)

    def forward(self, embeddings, edge_index):
        """_summary_

        Args:
            data (_type_): Already preprocessed data.

        Returns:
            _type_: _description_
        """

        node_feats = embeddings["nodes"]["node_features"]

        node_feats_list = []
        for interaction, product in zip(self.interactions, self.products):
            node_feats, sc = interaction(
                node_attrs=embeddings["nodes"]["one_hot"],
                node_feats=node_feats,
                edge_attrs=embeddings["edges"]["angular_embedding"],
                edge_feats=embeddings["edges"]["radial_embedding"],
                edge_index=edge_index,
            )

            node_feats = product(
                node_feats=node_feats,
                sc=sc,
                node_attrs=embeddings["nodes"]["one_hot"],
            )

            node_feats_list.append(node_feats)

        # Concatenate features from all interaction layers
        node_feats_out = torch.cat(node_feats_list, dim=-1)

        # Final descriptor
        descriptors = {
            "nodes": {
                "node_env": node_feats_out,
            }
        }

        return descriptors


    
def z_one_hot(z, orbitals, nr_bits):
    """
    Generate one-hot encodings from a list of single-value tensors.

    Args:
        z (list of torch.Tensor): A list of single-value tensors, e.g., [[2], [3], [4], [2], [2], ...].
        orbitals (dict): A dictionary mapping numbers to their corresponding values.
        nr_bits (int): The number of bits for one-hot encoding.

    Returns:
        torch.Tensor: A tensor containing the one-hot encodings.
    """

    # Extract values from the list of single-value tensors
    node_map={}
    k=0
    for key in orbitals.keys():
        node_map[key]=k
        k+=1

    indices = [tensor.item() for tensor in z]

    # Create an empty tensor for one-hot encoding
    one_hot = torch.zeros(len(indices), nr_bits)

    # Fill in the one-hot encoding based on the indices
    for i, idx in enumerate(indices):
        if idx in orbitals:  # Ensure the index exists in orbitals
            one_hot[i, int(node_map[idx])] = 1  # Set the corresponding bit to 1
        else:
            raise ValueError(f"Index {idx} not found in orbitals.")

    return one_hot

def get_object_from_module(class_name, module="mace.modules"):
    try:
        return getattr(importlib.import_module(module), class_name)
    except AttributeError:
        return None  # Or raise an error if you prefer

# === G2M's tutorial random enviroment descriptor. Just to Troubleshooting. ===
def get_random_environment_representation(data, irreps):
    """Function that mocks a true calculation of an environment representation.

    Computes a random array and then ensures that the numbers obey our particular
    system's symmetries.
    """

    node_features = irreps.randn(data.num_nodes, -1)
    # The point in the middle sees the same in -X and +X directions
    # therefore its representation must be 0.
    # In principle the +/- YZ are also equivalent, but let's say that there
    # is something breaking the symmetry to make the numbers more interesting.
    # Note that the spherical harmonics convention is YZX.
    node_features[1, 3] = 0
    # We make both A points have equivalent features except in the X direction,
    # where the features are opposite
    node_features[-1, :3] = node_features[0, :3]
    node_features[-1, 3] = -node_features[0, 3]
    return node_features

cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [4]:
# Initialization of the enviroment representation
get_embeddings = EmbeddingBase(config, orbitals).to(device)

model_config["atomic_descriptors"]["radial_embedding.out_dim"] = get_embeddings.radial_embedding.out_dim
model_config["atomic_descriptors"]["angular_embedding.out_dim"] = sum(2 * l + 1 for l in range(model_config["embedding"]["max_ell"] + 1))
get_enviroment_representation = MACEDescriptor(model_config["atomic_descriptors"]).to(device)



In [5]:
# Get enviroment representation.
embeddings = get_embeddings(data)
node_inputs = get_enviroment_representation(embeddings, data.edge_index)["nodes"]["node_env"]

# Forward pass through the model
node_labels, edge_labels = model(data, node_feats=node_inputs)

# Post processing
matrix = processor.matrix_from_data(
    data,
    predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)

# TESTS

In [6]:
print("NODE LABELS: ", node_labels.shape)
print("EDGE LABELS: ", edge_labels.shape)

NODE LABELS:  torch.Size([338])
EDGE LABELS:  torch.Size([8281])


In [7]:
from graph2mat.tools.viz import plot_basis_matrix

plot_basis_matrix(
    matrix,
    embeddings_config,
    point_lines={"color": "black"},
    basis_lines={"color": "blue"},
    colorscale="rdylbu",
    text=".3f",
    basis_labels=True,
)