# Functions

In [1]:
import numpy as np
import numpy.linalg as LA
import torch

from tblite.interface import Calculator
import dxtb
from dxtb import OutputHandler
from dxtb.config import ConfigCache
from dxtb.typing import DD

possible_elements = [1, 6, 7, 8, 9]


def _get_permutation_map_closedshell_convention(element_numbers_batch, calculator="tblite"):
    """
    Create permutation maps to reorder orbital matrices from calculator conventions
    to OrbNet-Equi's convention.

    Parameters:
    - element_numbers_batch: List or array of atomic numbers for each molecule in the batch.
    - calculator: str, either "tblite" or "dxtb".

    Returns:
    - List of numpy arrays containing permutation indices for each molecule.
    """
    perm_maps = []
    for element_numbers in element_numbers_batch:
        perm_map = []
        idx = 0
        for el in element_numbers:
            if el == 1:
                perm_map.extend([idx, idx + 1])
                idx += 2
            else:
                if calculator == "tblite":
                    # [s, px, py, pz] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 3, idx + 2, idx + 1])
                elif calculator == "dxtb":
                    # [s, pz, px, py] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 1, idx + 3, idx + 2])
                idx += 4
        perm_maps.append(perm_map)
    return np.array(perm_maps, dtype=np.int64)


def apply_perm_map(array, perm_maps, pos):
    """
    Apply a permutation map to a matrix or a batch of matrices.

    Parameters:
    - array: np.ndarray, shape (n, n), (batch_size, n, n), (batch_size, n, n, nat, 3).
    - perm_maps: List of np.ndarray, permutation indices for each element in the batch.
    - pos: np.ndarray, atomic positions (currently unused but could be used for future extensions).

    Returns:
    - Permuted array or batch of permuted array.
    """
    if len(perm_maps) == 1:  # Non-batched case
        p = perm_maps[0]
        if array.ndim == 2:  # Features (nao, nao)
            return array[p, :][:, p]
        elif array.ndim == 3 and array.shape[0] == 1:  # Non-batched with batch dimension (1, nao, nao)
            return array[:, p, :][:, :, p]
        elif array.ndim == 4:  # Gradients (nao, nao, nat, 3)
            return array[p, :, :, :][:, p, :, :]
        elif array.ndim == 5 and array.shape[0] == 1:  # Gradients with batch dimension (1, nao, nao, nat, 3)
            return array[:, p, :, :, :][:, :, p, :, :]
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
    else:  # Batched case with multiple permutation maps (same length)
        if array.ndim == 3:  # Features (batch_size, nao, nao)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :], axis=2)
        if array.ndim == 5:  # Gradients (batch_size, nao, nao, nat, 3)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None, None, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :, None, None], axis=2)
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
  

def generate_xtb_matrices_fpsh(
        calculator="tblite",
        element_numbers=None,
        coordinates=None,
        spin_pol=False,
        cutoff=None,
        *args, **kwargs):
    """
    Generate Fock (F), density (P), overlap (S), and Hamiltonian (H) matrices
    using either the 'tblite' or 'dxtb' calculator.

    Supports batch processing with the 'dxtb' calculator.

    Parameters:
    - calculator: str, either "tblite" or "dxtb" (default: "tblite").
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - cutoff: float, threshold below which matrix elements are set to zero (default: None).
    - *args, **kwargs: Additional arguments passed to the calculators.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", ...}
        Else:
            {"F", "P", "S", "H", ...}
    """
    element_numbers = np.asarray(element_numbers)
    coordinates = np.asarray(coordinates)
    batched = element_numbers.ndim == 2 and coordinates.ndim == 3

    if batched and calculator != "dxtb":
        raise NotImplementedError("Batch processing is only supported with 'dxtb' calculator.")

    if calculator == "tblite":
        res_dict = generate_xtb_matrices_tblite(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            **kwargs
        )
    elif calculator == "dxtb":
        batch_mode = 1 if batched else 0
        res_dict = generate_xtb_matrices_dxtb(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            batch_mode=batch_mode,
            **kwargs
        )

    # Generate permutation maps
    perm_maps = _get_permutation_map_closedshell_convention(
        element_numbers_batch=element_numbers if batched else [element_numbers],
        calculator=calculator
    )

    # Apply permutation maps and apply cutoff
    for key, mat in res_dict.items():
        if cutoff is not None and key not in ["energy", "force"]:
            mat = np.where(np.abs(mat) <= cutoff, 0, mat)
        if key not in ["energy", "force"]:
            res_dict[key] = apply_perm_map(
                array=mat,
                perm_maps=perm_maps,
                pos=coordinates
            )
    return res_dict


def generate_xtb_matrices_tblite(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=None,
        uhf=None,
        spin_pol=False,
        get_energy=False,
        get_forces=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'tblite' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int, molecular charge (default: None).
    - uhf: int, unrestricted Hartree-Fock value (default: None).
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the tblite calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force"}
        Else:
            {"F", "P", "S", "H", "energy", "force"}
    """
    calc = Calculator(
        method=option,
        numbers=element_numbers,
        positions=coordinates,
        charge=charge,
        uhf=uhf,
        **kwargs
    )
    if spin_pol:
        calc.add("spin-polarization", 1.0)
    calc.set("verbosity", int(verbosity))
    calc.set("save-integrals", 1)
    res = calc.singlepoint()

    S = res.get("overlap-matrix")  # (nao, nao)
    H = res.get("hamiltonian-matrix")  # (nao, nao)
    P = res.get("density-matrix")  # (nao, nao) or (2, nao, nao)
    E = res.get("orbital-energies")  # (nao) or (2, nao)
    C = res.get("orbital-coefficients")  # (nao, nao) or (2, nao, nao)

    if spin_pol:
        F_a = S @ C[0] @ np.diag(E[0]) @ LA.inv(C[0])
        F_b = S @ C[1] @ np.diag(E[1]) @ LA.inv(C[1])
        res_dict = {"F_a": F_a, "F_b": F_b, "P_a": P[0], "P_b": P[1], "S": S, "H": H}
    else:
        F = S @ C @ np.diag(E) @ LA.inv(C)
        res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = res.get("energy")
    if get_forces:
        res_dict["force"] = res.get("gradient")

    return res_dict


def generate_xtb_matrices_dxtb(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=0,
        spin=0,
        spin_pol=False,
        batch_mode=0,
        get_energy=False,
        get_forces=False,
        get_analytical_gradients=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'dxtb' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int or list, molecular charge(s) (default: 0).
    - spin: int or list, number of unpaired electrons (default: 0).
    - spin_pol: bool, whether to perform spin-polarized calculations (not supported).
    - batch_mode: int, 1 for batched, 0 for single (default: 0).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the dxtb calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True (not supported):
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force", ...}
        Else:
            {"F", "P", "S", "H", "energy", "force", ...}
    """
    assert not spin_pol, "Spin-polarized calculations are not supported for dxtb."
    par = getattr(dxtb, option.replace('-x', '_X'))
    dd: DD = {"dtype": torch.double, "device": torch.device("cpu")}

    numbers = torch.tensor(element_numbers, dtype=torch.int64, device=dd["device"]) # int64
    pos = torch.tensor(coordinates, dtype=dd["dtype"], device=dd["device"]).requires_grad_(get_forces)

    # Handle charge and spin
    if batch_mode == 1:
        if element_numbers.ndim != 2 or coordinates.ndim != 3:
            raise ValueError("For batch_mode=1, element_numbers must be (batch_size, natoms) and coordinates must be (batch_size, natoms, 3).")
        batch_size = element_numbers.shape[0]
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"]) if isinstance(charge, (list, np.ndarray)) else torch.full((batch_size,), charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"]) if isinstance(spin, (list, np.ndarray)) else torch.full((batch_size,), spin, dtype=torch.double, device=dd["device"])
    else:
        if isinstance(charge, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, charge must be a single value.")
        if isinstance(spin, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, spin must be a single value.")
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"])

    # Set options
    opts = {"verbosity": verbosity, "batch_mode": batch_mode}

    calc = dxtb.Calculator(numbers, par, opts=opts, **dd)
    calc.opts.cache = ConfigCache(enabled=True, density=True, fock=True, overlap=True, hcore=True)
    OutputHandler.verbosity = int(verbosity)

    P = calc.get_density(pos, chrg=charge, spin=spin)
    S = calc.integrals.build_overlap(pos)
    H = calc.integrals.build_hcore(pos)
    F = calc.cache["fock"]

    res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = calc.get_energy(pos, chrg=charge, spin=spin)
    if get_forces:
        assert get_energy, "Energy must be calculated to get forces."
        assert batch_mode == 0, "Forces are not supported in batch mode."
        res_dict["force"] = -torch.autograd.grad(res_dict["energy"], pos)[0]
        if get_analytical_gradients:
            # assert batch_mode == 0, "Forces are not supported in batch mode."
            keys_to_process = [key for key in res_dict if key not in ["energy", "force"]]
            for key in keys_to_process:
                res_dict[f"grad_{key}"] = get_jacobian(res_dict[key], pos)

    # Convert tensors to numpy arrays
    for k, v in res_dict.items():
        res_dict[k] = v.detach().cpu().numpy()

    return res_dict

def get_jacobian(matrix, pos):
    matrix_jac = torch.zeros(matrix.shape + pos.shape, dtype=matrix.dtype)
    for i, j in np.ndindex(matrix.shape):
        matrix_jac[i, j, :, :] = torch.autograd.grad(matrix[i, j], pos, create_graph=True, retain_graph=True)[0]
    return matrix_jac



def get_2body_grads_batched_shifts(element_numbers, coordinates):
    """
    Calculates 2body_grad (dT/dx) using the Five-point stencil finite
    difference approximation. This function batches the perturbations
    internally to utilize efficient batch processing.

    Coordinates should be given in the unit of bohr.

    Returns dT/d{x} = [dF/d{x}, dP/d{x}, dS/d{x}, dH/d{x}],
    a numpy array of shape (4, norb, norb, nat, 3), where
    {x} is a vector of atomic coordinates of atoms (length = nat).
    """
    import numpy as np

    h = 0.01  # Step size in bohr
    natoms = len(element_numbers)
    shifts = [2 * h, h, -h, -2 * h]  # Shift values for finite difference
    n_shifts = len(shifts)  # Number of shifts per coordinate (4)
    matrices_list = ["F", "P", "S", "H"]

    # List to hold all perturbed coordinates
    perturbed_coords_list = []
    mapping_info = []  # List of tuples (atom_idx, coord_idx, shift_idx)

    # Generate perturbed coordinates and mapping information
    for atom_idx in range(natoms):
        for coord_idx in range(3):  # x, y, z
            for shift_idx, shift in enumerate(shifts):
                perturbed_coords = coordinates.copy()
                perturbed_coords[atom_idx, coord_idx] += shift
                perturbed_coords_list.append(perturbed_coords)
                mapping_info.append((atom_idx, coord_idx, shift_idx))

    # Total number of perturbations
    Nperturbations = len(perturbed_coords_list)

    # Convert list of perturbed coordinates to a numpy array
    perturbed_coords_array = np.array(perturbed_coords_list)  # Shape: (Nperturbations, natoms, 3)

    # Create element_numbers_batch by repeating element_numbers
    element_numbers_batch = np.tile(element_numbers, (Nperturbations, 1))  # Shape: (Nperturbations, natoms)

    # Run batch calculation on all perturbed coordinates
    res_dict = generate_xtb_matrices_fpsh(
        calculator='dxtb',
        element_numbers=element_numbers_batch,
        coordinates=perturbed_coords_array,
        spin_pol=False
    )

    # Number of orbitals (norb) is inferred from the matrix shape
    norb = res_dict["F"].shape[-1]

    # Initialize arrays to hold the matrices per shift
    M_values = {M: np.zeros((natoms, 3, n_shifts, norb, norb)) for M in matrices_list}

    # Map the results back to the corresponding atom, coordinate, and shift
    for pert_idx in range(Nperturbations):
        atom_idx, coord_idx, shift_idx = mapping_info[pert_idx]
        for M in matrices_list:
            M_values[M][atom_idx, coord_idx, shift_idx] = res_dict[M][pert_idx]

    # Initialize dictionary to hold the gradients
    gradients = {M: np.zeros((norb, norb, natoms, 3)) for M in matrices_list}

    # Compute gradients using the five-point stencil formula
    for M in matrices_list:
        for atom_idx in range(natoms):
            for coord_idx in range(3):
                M1 = M_values[M][atom_idx, coord_idx, 0]  # Shift = +2h
                M2 = M_values[M][atom_idx, coord_idx, 1]  # Shift = +h
                M3 = M_values[M][atom_idx, coord_idx, 2]  # Shift = -h
                M4 = M_values[M][atom_idx, coord_idx, 3]  # Shift = -2h

                # Apply the finite difference stencil formula
                dM_dx = (-M1 + 8 * M2 - 8 * M3 + M4) / (12 * h)
                gradients[M][:, :, atom_idx, coord_idx] = dM_dx

    # Stack all gradients into a single tensor for the output
    # Final shape: (4, norb, norb, natoms, 3)
    # dT_dx_np = np.stack([gradients[M] for M in matrices_list], axis=0)
    # return dT_dx_np
    return gradients



# h5py funcs

In [2]:
from qcm_ml.orbnet_equi.util.utilities import get_unit_conversion

def save_data(data_group, data_idx, data):
    conformation_group = data_group.create_group(str(data_idx))
    result_dict = {
        "atomic_numbers": data["nuclear_charges"],
        "geometry_bohr": data["coords"][data_idx] * get_unit_conversion("angstrom","bohr"),
        "Etot_PBEdef2-SVP_Ha": data["energies"][data_idx] * get_unit_conversion("kcal/mol", "hartree"),
        "forces_PBEdef2-SVP_Ha_per_bohr": data["forces"][data_idx] * get_unit_conversion("kcal/mol/angstrom", "hartree/bohr"),
    }
    T = generate_xtb_matrices_fpsh(
        calculator="tblite",
        element_numbers=result_dict["atomic_numbers"],
        coordinates=result_dict["geometry_bohr"],
        spin_pol=False,
        get_energy=True,
        get_forces=True,
    )
    result_dict["Etot_xtb_Ha"] = T["energy"]
    result_dict["energy_delta_Ha"] = result_dict["Etot_PBEdef2-SVP_Ha"] - result_dict["Etot_xtb_Ha"]
    result_dict["force_xtb_Ha_per_bohr"] = T["force"]
    result_dict["forces_delta_Ha_per_bohr"] = result_dict["forces_PBEdef2-SVP_Ha_per_bohr"] - result_dict["force_xtb_Ha_per_bohr"]

    for key in result_dict:
        conformation_group.create_dataset(key, data=result_dict[key])
    
    feat_group = conformation_group.create_group("2body")
    for feat in T:
        if feat not in ["energy", "force"]:
            feat_group.create_dataset(feat, data=T[feat])

    dT = get_2body_grads_batched_shifts(
        element_numbers=result_dict["atomic_numbers"],
        coordinates=result_dict["geometry_bohr"]
    )

    feat_grad_group = conformation_group.create_group("2body_grad")
    for feat_grad in dT:
        feat_grad_group.create_dataset(feat_grad, data=dT[feat_grad])
    

In [3]:
import numpy as np

split = 1 # 1-5
test_indices = np.genfromtxt(f"../../../data/rmd17/splits/index_test_0{split}.csv", delimiter=",", dtype=int)
train_indices = np.genfromtxt(f"../../../data/rmd17/splits/index_train_0{split}.csv", delimiter=",", dtype=int)
print(f"len test_indices = {len(test_indices)}, len train_indices = {len(train_indices)}")

len test_indices = 1000, len train_indices = 1000


In [4]:
import os

import h5py
from tqdm import tqdm

# Path to the npz_data folder
npz_folder = "../../../data/rmd17/npz_data"
npz_files = [f for f in os.listdir(npz_folder) if f.endswith('.npz')]

max_idx = 2  # Define max index for training/validation split

# Create the HDF5 file

hdf5_file_name = "../../../data/rmd17/rmd17_molecules_combined_forces.hdf5"
with h5py.File(hdf5_file_name, "w") as f:
    for npz_file in npz_files:
        molecule_name = npz_file.split('_')[1].split('.')[0]  # Extract molecule name from file
        file_path = os.path.join(npz_folder, npz_file)
        data = np.load(file_path)

        # Create groups for train and val data
        train_group = f.create_group(f"train/{molecule_name}")
        val_group = f.create_group(f"val/{molecule_name}")

        # Save training data
        for data_idx in tqdm(train_indices[:max_idx], desc=f"Saving train data for {molecule_name}"):
            save_data(train_group, data_idx, data)

        # Save validation data
        for val_idx in tqdm(test_indices[:max_idx], desc=f"Saving val data for {molecule_name}"):
            save_data(val_group, val_idx, data)

    print(f"Saved samples to {hdf5_file_name}")

Saving train data for azobenzene:   0%|          | 0/2 [00:00<?, ?it/s]

Saving train data for azobenzene: 100%|██████████| 2/2 [00:12<00:00,  6.01s/it]
Saving val data for azobenzene: 100%|██████████| 2/2 [00:10<00:00,  5.22s/it]
Saving train data for paracetamol: 100%|██████████| 2/2 [00:07<00:00,  3.91s/it]
Saving val data for paracetamol: 100%|██████████| 2/2 [00:07<00:00,  3.56s/it]
Saving train data for ethanol: 100%|██████████| 2/2 [00:01<00:00,  1.14it/s]
Saving val data for ethanol: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s]
Saving train data for malonaldehyde: 100%|██████████| 2/2 [00:01<00:00,  1.03it/s]
Saving val data for malonaldehyde: 100%|██████████| 2/2 [00:01<00:00,  1.19it/s]
Saving train data for benzene: 100%|██████████| 2/2 [00:02<00:00,  1.32s/it]
Saving val data for benzene: 100%|██████████| 2/2 [00:02<00:00,  1.27s/it]
Saving train data for naphthalene: 100%|██████████| 2/2 [00:06<00:00,  3.18s/it]
Saving val data for naphthalene: 100%|██████████| 2/2 [00:05<00:00,  2.86s/it]
Saving train data for aspirin: 100%|██████████| 2/2 [0

Saved samples to ../../../data/rmd17/rmd17_molecules_combined_forces.hdf5



