In [1]:
import numpy as np

# So that we can plot sisl geometries
import sisl.viz

from e3nn import o3

from graph2mat import (
    PointBasis,
    BasisTableWithEdges,
    BasisConfiguration,
    MatrixDataProcessor,
)

from graph2mat.bindings.torch import TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


In [None]:
# Let's see how the MACE descriptor work. We are supposed to convert our atomic types and positions into an equivariant matrix.

# === Define a basis ===
# In our dataset we have Boron, Carbon and Nitrogen. Three types of atoms that need three types of "points".
# ! Maybe "spherical" is not correct?
# First DeepSeek input:
boron_basis = PointBasis("B", R=5, basis="1x0e + 1x1o", basis_convention="spherical")
nitrogen_basis = PointBasis("N", R=5, basis="2x0e + 1x1o", basis_convention="spherical")
carbon_basis = PointBasis("C", R=5, basis="1x0e + 1x1o + 1x2e", basis_convention="spherical")

basis = [boron_basis, nitrogen_basis, carbon_basis]

# The irreps of the node features that we will input into the model
node_feats_irreps = o3.Irreps("8x0e + 8x1o")

In [None]:
# Alternative: infere the basis from the sisl atom.
from pathlib import Path

path= Path('../dataset/SHARE_OUTPUTS_2_ATOMS')

# Get all subdirectories
children_path = list(path.glob('*/'))

i=4
file = sisl.get_sile(children_path[i] / "aiida.HSX")
geometry = file.read_geometry()
unique_atom_types = set(geometry.atoms.Z)
basis = [PointBasis.from_sisl_atom(atom) for atom in geometry.atoms]

[print(f"\nBasis for atom {i}.\n\tAtom type: {basis[i].type}\n\tBasis: {basis[i].basis}\n\tBasis convention: {basis[i].basis_convention}\n\tR: {basis[i].R}") for i in range(len(basis))]

# Now we need to define the shape of our inputs.
node_feats_irreps = o3.Irreps("0e + 1o + 2e") # Let's use this as a first approx.
edge_feats_irreps = o3.Irreps("0e + 1o + 2e") # Let's use this as a first approx.

# Initialize the model
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=node_feats_irreps, edge_feats_irreps=edge_feats_irreps),
    symmetric=False,
    blocks_symmetry='ij=ji', # This is the symmetry of the matrix we want to compute
)


Basis for atom 0.
	Atom type: 6
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]

Basis for atom 1.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]



The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1

So by plugging values i=0,2, we infer that:

Boron atoms have a basis of ``2x0e + 2x1o + 1x2e``.

Carbon atoms have a basis if ``2x0e + 2x1o + 1x2e``.

Nitrogen atoms have a basis if ``2x0e + 2x1o + 1x2e``.

All atoms use the same basis!

In [40]:
print(model.summary)

Preprocessing nodes: None
Preprocessing edges: None
Node operations:
 (6)  E3nnSimpleNodeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e
 (7)  E3nnSimpleNodeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e
Edge operations:
 (6, 6) [XY = YX.T] E3nnSimpleEdgeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e.
 (6, 7) E3nnSimpleEdgeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e.
 (7, 6) E3nnSimpleEdgeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e.
 (7, 7) [XY = YX.T] E3nnSimpleEdgeBlock: (2x0e+2x1o+1x2e) x (2x0e+2x1o+1x2e) -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e.


In [41]:
model

E3nnGraph2Mat(
  (self_interactions): ModuleList(
    (0-1): 2 x E3nnIrrepsMatrixBlock(
      (operation): E3nnSimpleNodeBlock(
        (tsq): TensorSquare(1x0e+1x1o+1x2e -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e | 56 paths | 56 weights)
      )
    )
  )
  (interactions): ModuleDict(
    ((0, 0, 0)): E3nnIrrepsMatrixBlock(
      (operation): E3nnSimpleEdgeBlock(
        (tensor_products): ModuleList(
          (0): FullyConnectedTensorProduct(1x0e+1x1o+1x2e x 1x0e+1x1o+1x2e -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e | 80 paths | 80 weights)
        )
      )
    )
    ((0, 1, 1)): E3nnIrrepsMatrixBlock(
      (operation): E3nnSimpleEdgeBlock(
        (tensor_products): ModuleList(
          (0): FullyConnectedTensorProduct(1x0e+1x1o+1x2e x 1x0e+1x1o+1x2e -> 7x0e+6x1o+1x1e+2x2o+6x2e+2x3o+1x4e | 80 paths | 80 weights)
        )
      )
    )
    ((1, 0, -1)): E3nnIrrepsMatrixBlock(
      (operation): E3nnSimpleEdgeBlock(
        (tensor_products): ModuleList(
          (0): FullyConnectedTensorP

In [52]:
# Config
config = BasisConfiguration(
    point_types=geometry.atoms.Z,
    positions=geometry.xyz,
    basis=basis,
    cell=geometry.cell,
    pbc=geometry.pbc,
)
config

BasisConfiguration(point_types=array([6, 7], dtype=int32), positions=array([[0.        , 0.        , 3.33146675],
       [0.        , 0.        , 2.39449735]]), basis=[PointBasis(type=6, R=array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical'), PointBasis(type=7, R=array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical')], cell=array([[3.59999, 0.     , 0.     ],
       [0.     , 3.59999, 0.     ],
       [0.     , 0.     , 3.59999]]), pbc=array([ True,  True,  True]), matrix=None, weight=1.0, config_type='Default', metadata=None)

In [53]:
geometry = config.to_sisl_geometry()
geometry.plot(show_cell=False, atoms_style={"size": geometry.maxR(all=True)})

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed