In [16]:
import random
import sisl
import torch
import torch.optim as optim
import numpy as np
import yaml

from e3nn import o3
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_geometric.loader import DataLoader
from mace.modules import MACE, RealAgnosticResidualInteractionBlock
from graph2mat.bindings.e3nn import E3nnGraph2Mat
from graph2mat.models import MatrixMACE
from graph2mat.bindings.torch import TorchBasisMatrixDataset
from graph2mat import (
    PointBasis,
    BasisTableWithEdges,
    BasisConfiguration,
    MatrixDataProcessor,
)

def load_config(path="../config.yaml"):
    with open(path, "r") as f:
        return yaml.safe_load(f)
    
def flatten(xss):
    return [x for xs in xss for x in xs]

In [8]:
def get_basis_from_structures_paths(paths, verbose=False, num_unique_z=None):
    """_summary_

    Args:
        paths (list[str]): _description_
        verbose (bool, optional): _description_. Defaults to False.
        num_unique_z (int, optional): If known, input the total number of different atom types here for faster performance.

    Returns:
        list[PointBasis]: _description_
    """
    if verbose:
        print("="*60)
        print("Basis computation.")
        print(f"Number of structures to look on: {len(paths)}")
        print("Looking for unique atoms in each structure...")

    unique_atom_types = []
    unique_atom_types_path_idx = []

    # Look for all atom types in your list of structures
    iterator = tqdm(enumerate(paths)) if verbose else enumerate(paths)
    for i, path in iterator:
        if num_unique_z is not None and len(unique_atom_types) == num_unique_z:
            print("Found enough basis points. Breaking the search...")
            break
        geometry = sisl.get_sile(path / "aiida.fdf").read_geometry()
        for z in geometry.atoms.Z:
            if z not in unique_atom_types:
                unique_atom_types.append(z)
                unique_atom_types_path_idx.append(i)
            if num_unique_z is not None and len(unique_atom_types) == num_unique_z:
                print("Found enough basis points. Breaking the search...")
                break
        # print("n_atoms= ", len(geometry.atoms.Z))

    if verbose:
        print(f"Found the following atomic numbers: {unique_atom_types}")
        print(f"Corresponding path indices: {unique_atom_types_path_idx}")

    # Build the basis
    basis = []
    unique_atom_types_basis = []
    for path_idx in unique_atom_types_path_idx:
        geometry = sisl.get_sile(paths[path_idx] / "aiida.fdf").read_geometry()
        for atom in geometry.atoms:
            if atom.Z not in unique_atom_types_basis:
                basis.append(PointBasis.from_sisl_atom(atom))
                unique_atom_types_basis.append(atom.Z)

    basis.sort(key=lambda x: x.type)
    unique_atom_types_basis.sort()

    if verbose:
        print(f"Basis with {len(basis)} elements built!")
        [print(f"\nBasis for atom {i}.\n\tAtom type: {basis[i].type}\n\tBasis: {basis[i].basis}\n\tBasis convention: {basis[i].basis_convention}\n\tR: {basis[i].R}") for i in range(len(basis))]

    return basis

In [None]:
# === Configuration load ===
directory = Path("../results/hamiltonian") # * Write here the directory where the model is stored
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device=torch.device("cpu")
filename = "model_epoch_4500.tar"
print(f"Loading model {directory / filename}...")


config = load_config(directory / "config.yaml")

# === List of paths to all structures ===
parent_path = Path('../dataset')
n_atoms_paths = list(parent_path.glob('*/'))
paths = []
for n_atoms_path in n_atoms_paths:
    structure_paths = list(n_atoms_path.glob('*/'))
    paths.append(structure_paths)
paths = flatten(paths)

random.seed(config["dataset"]["seed"])
random.shuffle(paths)



# == Basis creation === 
basis = get_basis_from_structures_paths(paths, verbose=False, num_unique_z=config["dataset"].get("num_unique_z", None))
table = BasisTableWithEdges(basis)



# === Dataset creation ===
processor = MatrixDataProcessor(basis_table=table, symmetric_matrix=True, sub_point_matrix=False)
embeddings_configs = []
for i, path in enumerate(paths):

    # We need to keep track of the training/val splits, so we can't plot more than used for training (at least for training dataset)
    if i==20:
        break
    
    print(f"Processing structure {i+1} of {len(paths)}...")

    # Load the structure config
    file = sisl.get_sile(path / "aiida.fdf")
    file_h = sisl.get_sile(path / "aiida.HSX")
    geometry = file.read_geometry()

    # Load the true hamiltonian
    true_h = file_h.read_hamiltonian()

    embeddings_config = BasisConfiguration.from_matrix(
        matrix = true_h,
        geometry = geometry,
        labels = True,
        metadata={
            "atom_types": torch.from_numpy(geometry.atoms.Z), # Unlike point_types, this is not rescaled.,
            "path": path
        },
    )

    embeddings_configs.append(embeddings_config)

dataset = TorchBasisMatrixDataset(embeddings_configs, data_processor=processor)

# Split and stratify
# n_atoms_list = [dataset[i].num_nodes for i in range(len(dataset))] if config["dataset"]["stratify"] == True else None
n_atoms_list = None
train_dataset, val_dataset = train_test_split(
    dataset, 
    train_size=config["dataset"]["train_split_ratio"],
    stratify=n_atoms_list,
    random_state=None # Dataset already shuffled (paths)
)

Loading model ../results/hamiltonian/model_epoch_4500.tar...
Found enough basis points. Breaking the search...
Found enough basis points. Breaking the search...
Processing structure 1 of 1311...
Processing structure 2 of 1311...
Processing structure 3 of 1311...


  warn(
  warn(
  warn(
  warn(
  warn(


Processing structure 4 of 1311...
Processing structure 5 of 1311...
Processing structure 6 of 1311...
Processing structure 7 of 1311...


  warn(
  warn(


Processing structure 8 of 1311...
Processing structure 9 of 1311...
Processing structure 10 of 1311...


  warn(
  warn(
  warn(
  warn(


Processing structure 11 of 1311...
Processing structure 12 of 1311...


  warn(
  warn(


Processing structure 13 of 1311...
Processing structure 14 of 1311...
Processing structure 15 of 1311...
Processing structure 16 of 1311...
Processing structure 17 of 1311...


  warn(
  warn(
  warn(
  warn(


Processing structure 18 of 1311...
Processing structure 19 of 1311...
Processing structure 20 of 1311...
../dataset/SHARE_OUTPUTS_8_ATOMS/907c-7487-46bd-8ad3-65b031f7f12b


  warn(
  warn(
  warn(


In [40]:
Path(dataset[0].metadata["path"])

PosixPath('../dataset/SHARE_OUTPUTS_8_ATOMS/907c-7487-46bd-8ad3-65b031f7f12b')