# Functions

In [1]:
import numpy as np
import numpy.linalg as LA
import torch

from tblite.interface import Calculator
import dxtb
from dxtb import OutputHandler
from dxtb.config import ConfigCache
from dxtb.typing import DD

possible_elements = [1, 6, 7, 8, 9]


def _get_permutation_map_closedshell_convention(element_numbers_batch, calculator="tblite"):
    """
    Create permutation maps to reorder orbital matrices from calculator conventions
    to OrbNet-Equi's convention.

    Parameters:
    - element_numbers_batch: List or array of atomic numbers for each molecule in the batch.
    - calculator: str, either "tblite" or "dxtb".

    Returns:
    - List of numpy arrays containing permutation indices for each molecule.
    """
    perm_maps = []
    for element_numbers in element_numbers_batch:
        perm_map = []
        idx = 0
        for el in element_numbers:
            if el == 1:
                perm_map.extend([idx, idx + 1])
                idx += 2
            else:
                if calculator == "tblite":
                    # [s, px, py, pz] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 3, idx + 2, idx + 1])
                elif calculator == "dxtb":
                    # [s, pz, px, py] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 1, idx + 3, idx + 2])
                idx += 4
        perm_maps.append(perm_map)
    return np.array(perm_maps, dtype=np.int64)


def apply_perm_map(array, perm_maps, pos):
    """
    Apply a permutation map to a matrix or a batch of matrices.

    Parameters:
    - array: np.ndarray, shape (n, n), (batch_size, n, n), (batch_size, n, n, nat, 3).
    - perm_maps: List of np.ndarray, permutation indices for each element in the batch.
    - pos: np.ndarray, atomic positions (currently unused but could be used for future extensions).

    Returns:
    - Permuted array or batch of permuted array.
    """
    if len(perm_maps) == 1:  # Non-batched case
        p = perm_maps[0]
        if array.ndim == 2:  # Features (nao, nao)
            return array[p, :][:, p]
        elif array.ndim == 3 and array.shape[0] == 1:  # Non-batched with batch dimension (1, nao, nao)
            return array[:, p, :][:, :, p]
        elif array.ndim == 4:  # Gradients (nao, nao, nat, 3)
            return array[p, :, :, :][:, p, :, :]
        elif array.ndim == 5 and array.shape[0] == 1:  # Gradients with batch dimension (1, nao, nao, nat, 3)
            return array[:, p, :, :, :][:, :, p, :, :]
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
    else:  # Batched case with multiple permutation maps (same length)
        if array.ndim == 3:  # Features (batch_size, nao, nao)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :], axis=2)
        if array.ndim == 5:  # Gradients (batch_size, nao, nao, nat, 3)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None, None, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :, None, None], axis=2)
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
  

def generate_xtb_matrices_fpsh(
        calculator="tblite",
        element_numbers=None,
        coordinates=None,
        spin_pol=False,
        cutoff=None,
        *args, **kwargs):
    """
    Generate Fock (F), density (P), overlap (S), and Hamiltonian (H) matrices
    using either the 'tblite' or 'dxtb' calculator.

    Supports batch processing with the 'dxtb' calculator.

    Parameters:
    - calculator: str, either "tblite" or "dxtb" (default: "tblite").
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - cutoff: float, threshold below which matrix elements are set to zero (default: None).
    - *args, **kwargs: Additional arguments passed to the calculators.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", ...}
        Else:
            {"F", "P", "S", "H", ...}
    """
    element_numbers = np.asarray(element_numbers)
    coordinates = np.asarray(coordinates)
    batched = element_numbers.ndim == 2 and coordinates.ndim == 3

    if batched and calculator != "dxtb":
        raise NotImplementedError("Batch processing is only supported with 'dxtb' calculator.")

    if calculator == "tblite":
        res_dict = generate_xtb_matrices_tblite(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            **kwargs
        )
    elif calculator == "dxtb":
        batch_mode = 1 if batched else 0
        res_dict = generate_xtb_matrices_dxtb(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            batch_mode=batch_mode,
            **kwargs
        )

    # Generate permutation maps
    perm_maps = _get_permutation_map_closedshell_convention(
        element_numbers_batch=element_numbers if batched else [element_numbers],
        calculator=calculator
    )

    # Apply permutation maps and apply cutoff
    for key, mat in res_dict.items():
        if cutoff is not None and key not in ["energy", "force"]:
            mat = np.where(np.abs(mat) <= cutoff, 0, mat)
        if key not in ["energy", "force"]:
            res_dict[key] = apply_perm_map(
                array=mat,
                perm_maps=perm_maps,
                pos=coordinates
            )
    return res_dict


def generate_xtb_matrices_tblite(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=None,
        uhf=None,
        spin_pol=False,
        get_energy=False,
        get_forces=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'tblite' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int, molecular charge (default: None).
    - uhf: int, unrestricted Hartree-Fock value (default: None).
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the tblite calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force"}
        Else:
            {"F", "P", "S", "H", "energy", "force"}
    """
    calc = Calculator(
        method=option,
        numbers=element_numbers,
        positions=coordinates,
        charge=charge,
        uhf=uhf,
        **kwargs
    )
    if spin_pol:
        calc.add("spin-polarization", 1.0)
    calc.set("verbosity", int(verbosity))
    calc.set("save-integrals", 1)
    res = calc.singlepoint()

    S = res.get("overlap-matrix")  # (nao, nao)
    H = res.get("hamiltonian-matrix")  # (nao, nao)
    P = res.get("density-matrix")  # (nao, nao) or (2, nao, nao)
    E = res.get("orbital-energies")  # (nao) or (2, nao)
    C = res.get("orbital-coefficients")  # (nao, nao) or (2, nao, nao)

    if spin_pol:
        F_a = S @ C[0] @ np.diag(E[0]) @ LA.inv(C[0])
        F_b = S @ C[1] @ np.diag(E[1]) @ LA.inv(C[1])
        res_dict = {"F_a": F_a, "F_b": F_b, "P_a": P[0], "P_b": P[1], "S": S, "H": H}
    else:
        F = S @ C @ np.diag(E) @ LA.inv(C)
        res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = res.get("energy")
    if get_forces:
        res_dict["force"] = res.get("gradient")

    return res_dict


def generate_xtb_matrices_dxtb(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=0,
        spin=0,
        spin_pol=False,
        batch_mode=0,
        get_energy=False,
        get_forces=False,
        get_analytical_gradients=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'dxtb' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int or list, molecular charge(s) (default: 0).
    - spin: int or list, number of unpaired electrons (default: 0).
    - spin_pol: bool, whether to perform spin-polarized calculations (not supported).
    - batch_mode: int, 1 for batched, 0 for single (default: 0).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the dxtb calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True (not supported):
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force", ...}
        Else:
            {"F", "P", "S", "H", "energy", "force", ...}
    """
    assert not spin_pol, "Spin-polarized calculations are not supported for dxtb."
    par = getattr(dxtb, option.replace('-x', '_X'))
    dd: DD = {"dtype": torch.double, "device": torch.device("cpu")}

    numbers = torch.tensor(element_numbers, dtype=torch.int64, device=dd["device"]) # int64
    pos = torch.tensor(coordinates, dtype=dd["dtype"], device=dd["device"]).requires_grad_(get_forces)

    # Handle charge and spin
    if batch_mode == 1:
        if element_numbers.ndim != 2 or coordinates.ndim != 3:
            raise ValueError("For batch_mode=1, element_numbers must be (batch_size, natoms) and coordinates must be (batch_size, natoms, 3).")
        batch_size = element_numbers.shape[0]
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"]) if isinstance(charge, (list, np.ndarray)) else torch.full((batch_size,), charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"]) if isinstance(spin, (list, np.ndarray)) else torch.full((batch_size,), spin, dtype=torch.double, device=dd["device"])
    else:
        if isinstance(charge, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, charge must be a single value.")
        if isinstance(spin, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, spin must be a single value.")
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"])

    # Set options
    opts = {"verbosity": verbosity, "batch_mode": batch_mode}

    calc = dxtb.Calculator(numbers, par, opts=opts, **dd)
    calc.opts.cache = ConfigCache(enabled=True, density=True, fock=True, overlap=True, hcore=True)
    OutputHandler.verbosity = int(verbosity)

    P = calc.get_density(pos, chrg=charge, spin=spin)
    S = calc.integrals.build_overlap(pos)
    H = calc.integrals.build_hcore(pos)
    F = calc.cache["fock"]

    res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = calc.get_energy(pos, chrg=charge, spin=spin)
    if get_forces:
        assert get_energy, "Energy must be calculated to get forces."
        assert batch_mode == 0, "Forces are not supported in batch mode."
        res_dict["force"] = -torch.autograd.grad(res_dict["energy"], pos)[0]
        if get_analytical_gradients:
            # assert batch_mode == 0, "Forces are not supported in batch mode."
            keys_to_process = [key for key in res_dict if key not in ["energy", "force"]]
            for key in keys_to_process:
                res_dict[f"grad_{key}"] = get_jacobian(res_dict[key], pos)

    # Convert tensors to numpy arrays
    for k, v in res_dict.items():
        res_dict[k] = v.detach().cpu().numpy()

    return res_dict

def get_jacobian(matrix, pos):
    matrix_jac = torch.zeros(matrix.shape + pos.shape, dtype=matrix.dtype)
    for i, j in np.ndindex(matrix.shape):
        matrix_jac[i, j, :, :] = torch.autograd.grad(matrix[i, j], pos, create_graph=True, retain_graph=True)[0]
    return matrix_jac


In [2]:
# def get_2body_grads_new(element_numbers_batch, coordinates_batch):
#     """
#     Calculates 2body_grad (dT/dx) using the Five-point stencil finite difference approximation.
#     This function is batch-optimized and should produce results consistent with the non-batch version.
    
#     Coordinates should be given in the unit of bohr.

#     Returns:
#     dT/d{x} = [dF/d{x}, dP/d{x}, dS/d{x}, dH/d{x}],
#     a numpy array of shape (4, batch_size, norb, norb, nat, 3).
#     """
#     import numpy as np

#     h = 0.01  # Step size in bohr
#     batch_size, nat, _ = coordinates_batch.shape
#     matrices_list = ["F", "P", "S", "H"]
    
#     # Define shifts and calculate total perturbations per sample
#     shifts = np.array([2 * h, h, -h, -2 * h])  # Shape: (4,)
#     n_shifts = shifts.size  # Number of shifts per coordinate (4)
#     n_coords = 3            # Number of coordinates (x, y, z)
#     total_perts_per_sample = nat * n_coords * n_shifts  # Total perturbations per sample
#     total_perts = batch_size * total_perts_per_sample

#     # Generate indices for batches, atoms, coordinates, and shifts
#     batch_indices = np.repeat(np.arange(batch_size), total_perts_per_sample)  # Shape: (total_perts,)
#     atom_indices = np.tile(np.repeat(np.arange(nat), n_coords * n_shifts), batch_size)  # Shape: (total_perts,)
#     coord_indices = np.tile(np.tile(np.repeat(np.arange(n_coords), n_shifts), nat), batch_size)  # Shape: (total_perts,)
#     shift_indices = np.tile(np.tile(np.arange(n_shifts), nat * n_coords), batch_size)  # Shape: (total_perts,)
#     shifts_array = shifts[shift_indices]  # Shape: (total_perts,)

#     # Expand coordinates and element numbers to match the total number of perturbations
#     coordinates_batch_expanded = coordinates_batch[batch_indices, :, :]  # Shape: (total_perts, nat, 3)
#     element_numbers_batch_expanded = element_numbers_batch[batch_indices, :]  # Shape: (total_perts, nat)

#     # Initialize delta_coords and apply shifts
#     delta_coords = np.zeros_like(coordinates_batch_expanded)  # Shape: (total_perts, nat, 3)
#     delta_coords[np.arange(total_perts), atom_indices, coord_indices] = shifts_array

#     # Compute perturbed coordinates
#     perturbed_coords_batch = coordinates_batch_expanded + delta_coords  # Shape: (total_perts, nat, 3)

#     # Run the batch calculation on all perturbed coordinates
#     perturbed_matrices = generate_xtb_matrices_fpsh(
#         calculator="dxtb",
#         element_numbers=element_numbers_batch_expanded,
#         coordinates=perturbed_coords_batch,
#         spin_pol=False
#     )

#     # Number of orbitals (norb) is inferred from the matrix shape
#     num_orbitals = perturbed_matrices["F"].shape[-1]

#     # Reshape matrices to (batch_size, total_perts_per_sample, norb, norb)
#     for M in matrices_list:
#         perturbed_matrices[M] = perturbed_matrices[M].reshape(
#             batch_size, total_perts_per_sample, num_orbitals, num_orbitals
#         )

#     # Reshape matrices to (batch_size, nat, n_coords, n_shifts, norb, norb)
#     for M in matrices_list:
#         perturbed_matrices[M] = perturbed_matrices[M].reshape(
#             batch_size, nat, n_coords, n_shifts, num_orbitals, num_orbitals
#         )

#     # Apply the finite difference formula along the n_shifts dimension
#     gradients = {}
#     for M in matrices_list:
#         # M_values shape: (batch_size, nat, n_coords, n_shifts, norb, norb)
#         M_values = perturbed_matrices[M]
#         M1 = M_values[:, :, :, 0, :, :]  # Shift = +2h
#         M2 = M_values[:, :, :, 1, :, :]  # Shift = +h
#         M3 = M_values[:, :, :, 2, :, :]  # Shift = -h
#         M4 = M_values[:, :, :, 3, :, :]  # Shift = -2h

#         # Compute the derivative using the five-point stencil formula
#         dM_dx = (-M1 + 8 * M2 - 8 * M3 + M4) / (12 * h)
#         # Transpose axes to match the gradient array shape
#         dM_dx = np.transpose(dM_dx, (0, 3, 4, 1, 2))  # Shape: (batch_size, norb, norb, nat, n_coords)

#         gradients[M] = dM_dx

#     # Stack all gradients into a single tensor for the output
#     dT_dx_np = np.stack([gradients[M] for M in matrices_list], axis=0)  # Shape: (4, batch_size, norb, norb, nat, 3)

#     return dT_dx_np


In [3]:
def get_2body_grads_batched_shifts(element_numbers, coordinates):
    """
    Calculates 2body_grad (dT/dx) using the Five-point stencil finite
    difference approximation. This function batches the perturbations
    internally to utilize efficient batch processing.

    Coordinates should be given in the unit of bohr.

    Returns dT/d{x} = [dF/d{x}, dP/d{x}, dS/d{x}, dH/d{x}],
    a numpy array of shape (4, norb, norb, nat, 3), where
    {x} is a vector of atomic coordinates of atoms (length = nat).
    """
    import numpy as np

    h = 0.01  # Step size in bohr
    natoms = len(element_numbers)
    shifts = [2 * h, h, -h, -2 * h]  # Shift values for finite difference
    n_shifts = len(shifts)  # Number of shifts per coordinate (4)
    matrices_list = ["F", "P", "S", "H"]

    # List to hold all perturbed coordinates
    perturbed_coords_list = []
    mapping_info = []  # List of tuples (atom_idx, coord_idx, shift_idx)

    # Generate perturbed coordinates and mapping information
    for atom_idx in range(natoms):
        for coord_idx in range(3):  # x, y, z
            for shift_idx, shift in enumerate(shifts):
                perturbed_coords = coordinates.copy()
                perturbed_coords[atom_idx, coord_idx] += shift
                perturbed_coords_list.append(perturbed_coords)
                mapping_info.append((atom_idx, coord_idx, shift_idx))

    # Total number of perturbations
    Nperturbations = len(perturbed_coords_list)

    # Convert list of perturbed coordinates to a numpy array
    perturbed_coords_array = np.array(perturbed_coords_list)  # Shape: (Nperturbations, natoms, 3)

    # Create element_numbers_batch by repeating element_numbers
    element_numbers_batch = np.tile(element_numbers, (Nperturbations, 1))  # Shape: (Nperturbations, natoms)

    # Run batch calculation on all perturbed coordinates
    res_dict = generate_xtb_matrices_fpsh(
        calculator='dxtb',
        element_numbers=element_numbers_batch,
        coordinates=perturbed_coords_array,
        spin_pol=False
    )

    # Number of orbitals (norb) is inferred from the matrix shape
    norb = res_dict["F"].shape[-1]

    # Initialize arrays to hold the matrices per shift
    M_values = {M: np.zeros((natoms, 3, n_shifts, norb, norb)) for M in matrices_list}

    # Map the results back to the corresponding atom, coordinate, and shift
    for pert_idx in range(Nperturbations):
        atom_idx, coord_idx, shift_idx = mapping_info[pert_idx]
        for M in matrices_list:
            M_values[M][atom_idx, coord_idx, shift_idx] = res_dict[M][pert_idx]

    # Initialize dictionary to hold the gradients
    gradients = {M: np.zeros((norb, norb, natoms, 3)) for M in matrices_list}

    # Compute gradients using the five-point stencil formula
    for M in matrices_list:
        for atom_idx in range(natoms):
            for coord_idx in range(3):
                M1 = M_values[M][atom_idx, coord_idx, 0]  # Shift = +2h
                M2 = M_values[M][atom_idx, coord_idx, 1]  # Shift = +h
                M3 = M_values[M][atom_idx, coord_idx, 2]  # Shift = -h
                M4 = M_values[M][atom_idx, coord_idx, 3]  # Shift = -2h

                # Apply the finite difference stencil formula
                dM_dx = (-M1 + 8 * M2 - 8 * M3 + M4) / (12 * h)
                gradients[M][:, :, atom_idx, coord_idx] = dM_dx

    # Stack all gradients into a single tensor for the output
    # Final shape: (4, norb, norb, natoms, 3)
    dT_dx_np = np.stack([gradients[M] for M in matrices_list], axis=0)

    return dT_dx_np


# Testing

In [4]:

import numpy as np

# Define two methane molecules with different geometries
element_numbers_methane1 = [6, 1, 1, 1, 1]
coordinates_methane1 = [
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 1.089],
    [1.026719, 0.0, -0.363],
    [-0.51336, -0.889165, -0.363],
    [-0.51336, 0.889165, -0.363]
]

element_numbers_methane2 = [6, 1, 1, 1, 1]
coordinates_methane2 = [
    [0.0, 0.0, 0.0],
    [0.0, 1.089, 0.0],
    [0.0, -0.363, 1.026719],
    [-0.889165, -0.363, -0.51336],
    [0.889165, -0.363, -0.51336]
]

# Create batched inputs
element_numbers_batch = np.array([
    element_numbers_methane1, 
    element_numbers_methane2])

coordinates_batch = np.array([
    coordinates_methane1, 
    coordinates_methane2])


### Test batch vs single

In [5]:
# Test gradient calculations for batched and non-batched cases

# Compute gradients using batched processing
res_batched_gradients = get_2body_grads_new(element_numbers_batch, coordinates_batch)

# Compute gradients separately (non-batched)
res_separate_gradients = []
for en, coords in zip(element_numbers_batch, coordinates_batch):
    # Call the function for a single molecule (non-batched)
    res_single = get_2body_grads_new(
        np.array([en]),  # Wrap element numbers for a single molecule in a batch
        np.array([coords])  # Wrap coordinates for a single molecule in a batch
    )
    res_separate_gradients.append(res_single)

# Now compare the results from the batched vs. non-batched calculations
for i, res_sep in enumerate(res_separate_gradients):
    print(f"\nComparing gradients for molecule {i+1}:")
    
    for j, matrix_name in enumerate(["F", "P", "S", "H"]):  # Compare all matrices
        # Extract the gradients for matrix j (F, P, S, H) from both batched and non-batched results
        grad_batched = res_batched_gradients[j, i]  # Batch result for molecule i
        grad_separate = res_sep[j, 0]  # Separate result (single molecule)
        
        # Calculate the maximum difference
        max_diff = np.max(np.abs(grad_batched - grad_separate))
        
        print(f"  Gradient matrix {matrix_name} max difference: {max_diff:.2e}")
            
        assert np.allclose(grad_batched, grad_separate, atol=1e-6), f"Gradients for {matrix_name} do not match!"


NameError: name 'get_2body_grads_new' is not defined

### Speed test

In [6]:
import numpy as np
import time
from qcm_ml.features.xtb import get_2body_grads

# Load data
molecule = "benzene"
file = f"../../../data/rmd17/npz_data/rmd17_{molecule}.npz"
data = np.load(file)

numbers = data["nuclear_charges"]
coords = data["coords"]

# Repeat nuclear charges to match batch size
numbers_batch = np.tile(numbers, (len(coords), 1))

# Limit the number of molecules for testing
max_nb_mols = 2  # To avoid too long running times during testing
numbers_batch = numbers_batch[:max_nb_mols]
coords = coords[:max_nb_mols]
print(f"numbers_batch.shape: {numbers_batch.shape}") 
print(f"coords.shape: {coords.shape}")

# Batched version
start_batched = time.time()
res_batched_gradients = get_2body_grads_new(
    element_numbers_batch=numbers_batch,
    coordinates_batch=coords
    )   
end_batched = time.time()
print(f"Batched gradient processing time: {end_batched - start_batched:.4f} seconds")

# Non-batched version
start_non_batched = time.time()
res_separate_gradients_new = []
for i in range(max_nb_mols):
    res_grad = get_2body_grads_new(
        element_numbers_batch=np.array([numbers_batch[i]]),
        coordinates_batch=np.array([coords[i]])
    )
    res_separate_gradients_new.append(res_grad)
end_non_batched = time.time()
print(f"Non-batched gradient processing time: {end_non_batched - start_non_batched:.4f} seconds")

# Old non-batched 
start_old = time.time()
res_separate_gradients_old = []
for i in range(max_nb_mols):
    res_grad = get_2body_grads(
        element_numbers=numbers_batch[i],
        coordinates=coords[i]
    )
    res_separate_gradients_old.append(res_grad)
end_old = time.time()
print(f"Old non-batched gradient processing time: {end_old - start_old:.4f} seconds")

# Compare gradients for batched and non-batched calculations
def compare_gradients(res_batched, res_separate_new, res_separate_old, keys, max_nb_mols):
    for i in range(max_nb_mols):
        print(f"\nComparing gradients for molecule {i+1}:")
        for j, matrix_name in enumerate(keys):
            grad_batched = res_batched[j, i]
            grad_separate_new = res_separate_new[i][j, 0]  # Unpack batch of size 1
            grad_separate_old = res_separate_old[i][j, 0]  # Unpack batch of size 1

            max_diff_new = np.max(np.abs(grad_batched - grad_separate_new))
            max_diff_old = np.max(np.abs(grad_batched - grad_separate_old))
            
            print(f"  Gradient matrix {matrix_name}:")
            print(f"    Max difference (batched vs new non-batched): {max_diff_new:.2e}")
            print(f"    Max difference (batched vs old non-batched): {max_diff_old:.2e}")

# List of matrices to compare
keys_to_compare = ["F", "P", "S", "H"]

# Compare batched, new non-batched, and old non-batched results
compare_gradients(res_batched_gradients, res_separate_gradients_new, res_separate_gradients_old, keys_to_compare, max_nb_mols)

numbers_batch.shape: (2, 12)
coords.shape: (2, 12, 3)


NameError: name 'get_2body_grads_new' is not defined

# Only batched shifts

In [6]:
def get_2body_grads_batched_shifts(element_numbers, coordinates):
    """
    Calculates 2body_grad (dT/dx) using the Five-point stencil finite
    difference approximation. This function batches the perturbations
    internally to utilize efficient batch processing.

    Coordinates should be given in the unit of bohr.

    Returns dT/d{x} = [dF/d{x}, dP/d{x}, dS/d{x}, dH/d{x}],
    a numpy array of shape (4, norb, norb, nat, 3), where
    {x} is a vector of atomic coordinates of atoms (length = nat).
    """
    import numpy as np

    h = 0.01  # Step size in bohr
    natoms = len(element_numbers)
    shifts = [2 * h, h, -h, -2 * h]  # Shift values for finite difference
    n_shifts = len(shifts)  # Number of shifts per coordinate (4)
    matrices_list = ["F", "P", "S", "H"]

    # List to hold all perturbed coordinates
    perturbed_coords_list = []
    mapping_info = []  # List of tuples (atom_idx, coord_idx, shift_idx)

    # Generate perturbed coordinates and mapping information
    for atom_idx in range(natoms):
        for coord_idx in range(3):  # x, y, z
            for shift_idx, shift in enumerate(shifts):
                perturbed_coords = coordinates.copy()
                perturbed_coords[atom_idx, coord_idx] += shift
                perturbed_coords_list.append(perturbed_coords)
                mapping_info.append((atom_idx, coord_idx, shift_idx))

    # Total number of perturbations
    Nperturbations = len(perturbed_coords_list)

    # Convert list of perturbed coordinates to a numpy array
    perturbed_coords_array = np.array(perturbed_coords_list)  # Shape: (Nperturbations, natoms, 3)

    # Create element_numbers_batch by repeating element_numbers
    element_numbers_batch = np.tile(element_numbers, (Nperturbations, 1))  # Shape: (Nperturbations, natoms)

    # Run batch calculation on all perturbed coordinates
    res_dict = generate_xtb_matrices_fpsh(
        calculator='dxtb',
        element_numbers=element_numbers_batch,
        coordinates=perturbed_coords_array,
        spin_pol=False
    )

    # Number of orbitals (norb) is inferred from the matrix shape
    norb = res_dict["F"].shape[-1]

    # Initialize arrays to hold the matrices per shift
    M_values = {M: np.zeros((natoms, 3, n_shifts, norb, norb)) for M in matrices_list}

    # Map the results back to the corresponding atom, coordinate, and shift
    for pert_idx in range(Nperturbations):
        atom_idx, coord_idx, shift_idx = mapping_info[pert_idx]
        for M in matrices_list:
            M_values[M][atom_idx, coord_idx, shift_idx] = res_dict[M][pert_idx]

    # Initialize dictionary to hold the gradients
    gradients = {M: np.zeros((norb, norb, natoms, 3)) for M in matrices_list}

    # Compute gradients using the five-point stencil formula
    for M in matrices_list:
        for atom_idx in range(natoms):
            for coord_idx in range(3):
                M1 = M_values[M][atom_idx, coord_idx, 0]  # Shift = +2h
                M2 = M_values[M][atom_idx, coord_idx, 1]  # Shift = +h
                M3 = M_values[M][atom_idx, coord_idx, 2]  # Shift = -h
                M4 = M_values[M][atom_idx, coord_idx, 3]  # Shift = -2h

                # Apply the finite difference stencil formula
                dM_dx = (-M1 + 8 * M2 - 8 * M3 + M4) / (12 * h)
                gradients[M][:, :, atom_idx, coord_idx] = dM_dx

    # Stack all gradients into a single tensor for the output
    # Final shape: (4, norb, norb, natoms, 3)
    dT_dx_np = np.stack([gradients[M] for M in matrices_list], axis=0)

    return dT_dx_np


In [7]:
import numpy as np
import time
from qcm_ml.features.xtb import get_2body_grads

# Load data
molecule = "benzene"
file = f"../../../data/rmd17/npz_data/rmd17_{molecule}.npz"
data = np.load(file)

numbers = data["nuclear_charges"]
coords = data["coords"]

# For testing, limit the number of molecules
max_nb_mols = 5  # Adjust as needed for testing
coords = coords[:max_nb_mols]
numbers_batch = np.tile(numbers, (max_nb_mols, 1))

print(f"numbers_batch.shape: {numbers_batch.shape}")  # Should be (max_nb_mols, natoms)
print(f"coords.shape: {coords.shape}")  # Should be (max_nb_mols, natoms, 3)

# Run the new batched function
start_batched = time.time()
res_batched_gradients = []
for i in range(max_nb_mols):
    res_grad = get_2body_grads_batched_shifts(
        element_numbers=numbers_batch[i],
        coordinates=coords[i]
    )
    res_batched_gradients.append(res_grad)
end_batched = time.time()
print(f"New batched gradient processing time: {end_batched - start_batched:.4f} seconds")

# Run the old function
start_old = time.time()
res_old_gradients = []
for i in range(max_nb_mols):
    res_grad = get_2body_grads(
        element_numbers=numbers_batch[i],
        coordinates=coords[i]
    )
    res_old_gradients.append(res_grad)
end_old = time.time()
print(f"Old gradient processing time: {end_old - start_old:.4f} seconds")

# Compare gradients
def compare_gradients(res_batched, res_old, keys, max_nb_mols):
    for i in range(max_nb_mols):
        print(f"\nComparing gradients for molecule {i+1}:")
        for j, matrix_name in enumerate(keys):
            grad_batched = res_batched[i][j]
            grad_old = res_old[i][j]

            max_diff = np.max(np.abs(grad_batched - grad_old))
            print(f"  Gradient matrix {matrix_name}:")
            print(f"    Max difference (new batched vs old): {max_diff:.2e}")

# List of matrices to compare
keys_to_compare = ["F", "P", "S", "H"]

# Compare the results
compare_gradients(res_batched_gradients, res_old_gradients, keys_to_compare, max_nb_mols)


numbers_batch.shape: (5, 12)
coords.shape: (5, 12, 3)
New batched gradient processing time: 11.6674 seconds
Old gradient processing time: 61.4366 seconds

Comparing gradients for molecule 1:
  Gradient matrix F:
    Max difference (new batched vs old): 3.00e-03
  Gradient matrix P:
    Max difference (new batched vs old): 1.15e-01
  Gradient matrix S:
    Max difference (new batched vs old): 1.60e-12
  Gradient matrix H:
    Max difference (new batched vs old): 1.30e-12

Comparing gradients for molecule 2:
  Gradient matrix F:
    Max difference (new batched vs old): 1.98e-03
  Gradient matrix P:
    Max difference (new batched vs old): 7.31e-02
  Gradient matrix S:
    Max difference (new batched vs old): 6.51e-12
  Gradient matrix H:
    Max difference (new batched vs old): 1.80e-11

Comparing gradients for molecule 3:
  Gradient matrix F:
    Max difference (new batched vs old): 1.88e-03
  Gradient matrix P:
    Max difference (new batched vs old): 9.49e-02
  Gradient matrix S:
    

In [12]:
import numpy as np
import time
import cProfile
import pstats

# Load data
molecule = "benzene"
file = f"../../../data/rmd17/npz_data/rmd17_{molecule}.npz"
data = np.load(file)

numbers = data["nuclear_charges"]
coords = data["coords"]

# For testing, limit the number of molecules
max_nb_mols = 5  # Adjust as needed for testing
coords = coords[:max_nb_mols]
numbers_batch = np.tile(numbers, (max_nb_mols, 1))

print(f"numbers_batch.shape: {numbers_batch.shape}")  # Should be (max_nb_mols, natoms)
print(f"coords.shape: {coords.shape}")  # Should be (max_nb_mols, natoms, 3)

# ================================
# Run the new batched function with cProfile
# ================================
def run_batched_gradients():
    res_batched_gradients = []
    for i in range(max_nb_mols):
        res_grad = get_2body_grads_batched_shifts(
            element_numbers=numbers_batch[i],
            coordinates=coords[i]
        )
        res_batched_gradients.append(res_grad)
    return res_batched_gradients

print("\nRunning cProfile on new batched gradient calculation...")

# Create a Profile object
profiler = cProfile.Profile()
profiler.enable()
run_batched_gradients()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()


numbers_batch.shape: (5, 12)
coords.shape: (5, 12, 3)

Running cProfile on new batched gradient calculation...


         2260633 function calls (2236757 primitive calls) in 15.103 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000   15.115    7.558 /home/beom/anaconda3/envs/orbnet_tblite/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3514(run_code)
        2    0.000    0.000   15.115    7.558 {built-in method builtins.exec}
        1    0.000    0.000   15.115   15.115 /tmp/ipykernel_28224/2438092641.py:40(<module>)
        1    0.001    0.001   15.115   15.115 /tmp/ipykernel_28224/2438092641.py:25(run_batched_gradients)
        5    0.028    0.006   15.115    3.023 /tmp/ipykernel_28224/3855207295.py:1(get_2body_grads_batched_shifts)
        5    0.037    0.007   15.081    3.016 /tmp/ipykernel_28224/1769641353.py:81(generate_xtb_matrices_fpsh)
        5    0.000    0.000   14.939    2.988 /tmp/ipykernel_28224/1769641353.py:223(generate_xtb_matrices_dxtb)
        5    0.000    0.000   13.909  

<pstats.Stats at 0x7fb62ad681c0>