# One structure

In [123]:
import numpy as np

# So that we can plot sisl geometries
import sisl.viz

from e3nn import o3

from graph2mat import (
    PointBasis,
    BasisTableWithEdges,
    BasisConfiguration,
    MatrixDataProcessor,
)

from graph2mat.bindings.torch import TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat

# Alternative: infere the basis from the sisl atom.
from pathlib import Path

In [None]:
path= Path('../dataset/SHARE_OUTPUTS_3_ATOMS')

# Get all subdirectories
children_path = list(path.glob('*/'))

i=1
file = sisl.get_sile(children_path[i] / "aiida.fdf")
geometry = file.read_geometry()

basis = []
unique_atom_types = []
for atom in geometry.atoms:
    if atom.Z not in unique_atom_types:
        basis.append(PointBasis.from_sisl_atom(atom))
        unique_atom_types.append(atom.Z)

# atom_symbols = []
# for atom in geometry.atoms:
#     atom_symbols.append(atom.symbol)

[print(f"\nBasis for atom {i}.\n\tAtom type: {basis[i].type}\n\tBasis: {basis[i].basis}\n\tBasis convention: {basis[i].basis_convention}\n\tR: {basis[i].R}") for i in range(len(basis))]

# Now we need to define the shape of our inputs.
node_feats_irreps = o3.Irreps("0e + 1o + 2e") # Let's use this as a first approx.
edge_feats_irreps = o3.Irreps("0e + 1o + 2e") # Let's use this as a first approx.

# Initialize the model
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=node_feats_irreps, edge_feats_irreps=edge_feats_irreps),
    symmetric=False,
    blocks_symmetry='ij=ji', # This is the symmetry of the matrix we want to compute
)

# Config
config = BasisConfiguration(
    point_types=geometry.atoms.Z,
    positions=geometry.xyz,
    basis=basis,
    cell=geometry.cell,
    pbc=geometry.pbc,
)

# geometry = config.to_sisl_geometry()

# Create the basis table.
table = BasisTableWithEdges(basis)

# Initialize the processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=False, sub_point_matrix=False
)

data = TorchBasisMatrixData.from_config(config, processor)



Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 6
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.57112067 1.67068795 3.16682115 3.16682115 3.16682115 2.05819653
 2.05819653 2.05819653 3.16682115 3.16682115 3.16682115 3.16682115
 3.16682115]

Basis for atom 2.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]



The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.


The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1

TorchBasisMatrixData(
  edge_index=[2, 224],
  num_nodes=3,
  neigh_isc=[224],
  n_edges=224,
  positions=[3, 3],
  shifts=[224, 3],
  cell=[3, 3],
  nsc=[1, 3],
  node_attrs=[3, 3],
  point_types=[3],
  edge_types=[224],
  metadata={ data_processor=MatrixDataProcessor(basis_table=BasisTableWithEdges(siesta_spherical, basis=[PointBasis(type=5, R=array([3.02420918, 2.02341372, 3.73961942, 3.73961942, 3.73961942,
       2.51253945, 2.51253945, 2.51253945, 3.73961942, 3.73961942,
       3.73961942, 3.73961942, 3.73961942]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical'), PointBasis(type=6, R=array([2.57112067, 1.67068795, 3.16682115, 3.16682115, 3.16682115,
       2.05819653, 2.05819653, 2.05819653, 3.16682115, 3.16682115,
       3.16682115, 3.16682115, 3.16682115]), basis=((2, 0, 1), (2, 1, -1), (1, 2, 1)), basis_convention='siesta_spherical'), PointBasis(type=7, R=array([2.25704422, 1.4271749 , 2.78012609, 2.78012609, 2.78012609,
       1.75309697, 1.7530

## GO WITH THE ENVIROMENT REPRESENTATION!!!

# TESTS

In [135]:
for atom in geometry.atoms:
    print(atom)

Atom{N, Z: 7, mass(au): 14.00700, maxR: 2.78013,
 AtomicOrbital{2sZ1, q0: 2.0, SphericalOrbital{l: 0, R: 2.2570442171902836, q0: 2.0}},
 AtomicOrbital{2sZ2, q0: 0.0, SphericalOrbital{l: 0, R: 1.4271748989309714, q0: 0.0}},
 AtomicOrbital{2pyZ1, q0: 1.0, SphericalOrbital{l: 1, R: 2.780126091859161, q0: 3.0}},
 AtomicOrbital{2pzZ1, q0: 1.0, SphericalOrbital{l: 1, R: 2.780126091859161, q0: 3.0}},
 AtomicOrbital{2pxZ1, q0: 1.0, SphericalOrbital{l: 1, R: 2.780126091859161, q0: 3.0}},
 AtomicOrbital{2pyZ2, q0: 0.0, SphericalOrbital{l: 1, R: 1.7530969713169737, q0: 0.0}},
 AtomicOrbital{2pzZ2, q0: 0.0, SphericalOrbital{l: 1, R: 1.7530969713169737, q0: 0.0}},
 AtomicOrbital{2pxZ2, q0: 0.0, SphericalOrbital{l: 1, R: 1.7530969713169737, q0: 0.0}},
 AtomicOrbital{3dxyZ1P, q0: 0.0, SphericalOrbital{l: 2, R: 2.780126091859161, q0: 0.0}},
 AtomicOrbital{3dyzZ1P, q0: 0.0, SphericalOrbital{l: 2, R: 2.780126091859161, q0: 0.0}},
 AtomicOrbital{3dz2Z1P, q0: 0.0, SphericalOrbital{l: 2, R: 2.7801260918591

In [None]:
from tqdm import tqdm

# TEMPORAL: Find any edge
path= Path('../dataset')

# Get all subdirectories (atom types)
n_atoms_paths = list(path.glob('*/'))

# Then all structures
structure_paths = [list(n_atoms_paths[i].glob('*/')) for i in range(len(n_atoms_paths))]
structure_paths = [item for sublist in structure_paths for item in sublist]

nonzero_edges = 0
for structrue_path in tqdm(structure_paths):
    file = sisl.get_sile(structrue_path / "aiida.HSX")
    geometry = file.read_geometry()

    basis = []
    unique_atom_types = []
    for atom in geometry.atoms:
        if atom.Z not in unique_atom_types:
            basis.append(PointBasis.from_sisl_atom(atom))
            unique_atom_types.append(atom.Z)

    config = BasisConfiguration(
        point_types=geometry.atoms.Z,
        positions=geometry.xyz,
        basis=basis,
        cell=geometry.cell,
        pbc=geometry.pbc,
    )

    table = BasisTableWithEdges(basis)

    processor = MatrixDataProcessor(
        basis_table=table, symmetric_matrix=False, sub_point_matrix=False
    )

    data = TorchBasisMatrixData.from_config(config, processor)

    if data.num_edges != 0:
        nonzero_edges += 1
        break

print(len(structure_paths))
print(nonzero_edges)

  0%|          | 0/1311 [00:00<?, ?it/s]

100%|██████████| 1311/1311 [01:41<00:00, 12.88it/s]

1311
0





In [125]:
table

Index,Type,Irreps,Max R
0,6,2x0e + 2x1o + 1x2e,-1.0


In [142]:
print(geometry.atoms.Z)
geometry.plot(show_cell=False, atoms_style={"size": geometry.maxR(all=True)})

[5 6 7]
