In [1]:
"""
Running a simple batched calculation.
"""
import torch

import dxtb
from dxtb.typing import DD

dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}


numbers = torch.tensor(
    [
        [3, 1, 0],
        [8, 1, 1],
    ],
    device=dd["device"],
)
positions = torch.tensor(
    [
        [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 1.0],
            [0.0, 0.0, 0.0],
        ],
        [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 1.0],
            [0.0, 0.0, 2.0],
        ],
    ],
    **dd
).requires_grad_(True)
charge = torch.tensor([0, 0], **dd)


# no conformers -> batched mode 1
opts = {"verbosity": 0, "batch_mode": 1}

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
energy = calc.get_energy(positions, chrg=charge)
# forces = calc.get_forces(positions, chrg=charge) # Does not work for batched calculations

print(f"energy: {energy}")

energy: tensor([ 0.0111, -4.8440], dtype=torch.float64, grad_fn=<SumBackward1>)


# Changed functions

In [2]:
import numpy as np

# Example matrix
array = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

batch_matrices = np.array([[[1, 2, 3],
                            [4, 5, 6],
                            [7, 8, 9]],
                           [[9, 8, 7],
                            [6, 5, 4],
                            [3, 2, 1]]])

# Example permutation
p = [2, 0, 1]

# mat[p, p] gives only diagonal elements (permuted)
print("mat[p, p]:")
print(array[p, p])  # Output: [9 1 5]

# mat[p, :][:, p] gives the fully permuted matrix
print("mat[p, :][:, p]:")
print(array[p, :][:, p])  # Output: full permuted matrix

print("mat[np.ix_(p, p)]:")
print(array[np.ix_(p, p)])  # Output: full permuted matrix

print("batch_matrices[:, np.ix_(p, p)[0], np.ix_(p, p)[1]]:")
print(batch_matrices[:, np.ix_(p, p)[0], np.ix_(p, p)[1]])  # Output: full permuted batched matrices

print("batch_matrices[:, p, :][:, :, p]:")
print(batch_matrices[:, p, :][:, :, p])  # Output: full permuted batched matrices



mat[p, p]:
[9 1 5]
mat[p, :][:, p]:
[[9 7 8]
 [3 1 2]
 [6 4 5]]
mat[np.ix_(p, p)]:
[[9 7 8]
 [3 1 2]
 [6 4 5]]
batch_matrices[:, np.ix_(p, p)[0], np.ix_(p, p)[1]]:
[[[9 7 8]
  [3 1 2]
  [6 4 5]]

 [[1 3 2]
  [7 9 8]
  [4 6 5]]]
batch_matrices[:, p, :][:, :, p]:
[[[9 7 8]
  [3 1 2]
  [6 4 5]]

 [[1 3 2]
  [7 9 8]
  [4 6 5]]]


In [3]:
import numpy as np
import time

# Example data
n_batch = 2
nao = 3

# Batched matrices (n_batch, nao, nao)
batch_matrices = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],   # First matrix
                           [[9, 8, 7], [6, 5, 4], [3, 2, 1]]])  # Second matrix

# Batched permutation maps (n_batch, nao)
perm_maps = np.array([[2, 0, 1],  # Perm for first matrix
                      [1, 2, 0]]) # Perm for second matrix

# ----------------------
# Loop method
# ----------------------
start_loop = time.time()

# Initialize the permuted result array
permuted_matrices_loop = np.empty_like(batch_matrices)

# Apply the permutation for each matrix in the batch
for i in range(n_batch):
    permuted_matrices_loop[i] = batch_matrices[i][np.ix_(perm_maps[i], perm_maps[i])]

end_loop = time.time()

# ----------------------
# Vectorized method
# ----------------------
start_vectorized = time.time()

# Permute the rows first
permuted_rows = np.take_along_axis(batch_matrices, perm_maps[:, :, None], axis=1)

# Permute the columns on the row-permuted result
permuted_matrices_vectorized = np.take_along_axis(permuted_rows, perm_maps[:, None, :], axis=2)

end_vectorized = time.time()

# ----------------------
# Results
# ----------------------
print("Original Matrices:")
print(batch_matrices)
print("\nPermuted Matrices (Loop Method):")
print(permuted_matrices_loop)
print("\nPermuted Matrices (Vectorized Method):")
print(permuted_matrices_vectorized)

# Time comparison
print("\nTime taken (Loop method): {:.6f} seconds".format(end_loop - start_loop))
print("Time taken (Vectorized method): {:.6f} seconds".format(end_vectorized - start_vectorized))


Original Matrices:
[[[1 2 3]
  [4 5 6]
  [7 8 9]]

 [[9 8 7]
  [6 5 4]
  [3 2 1]]]

Permuted Matrices (Loop Method):
[[[9 7 8]
  [3 1 2]
  [6 4 5]]

 [[5 4 6]
  [2 1 3]
  [8 7 9]]]

Permuted Matrices (Vectorized Method):
[[[9 7 8]
  [3 1 2]
  [6 4 5]]

 [[5 4 6]
  [2 1 3]
  [8 7 9]]]

Time taken (Loop method): 0.000274 seconds
Time taken (Vectorized method): 0.000202 seconds


In [4]:
import numpy as np
import numpy.linalg as LA
import torch

from tblite.interface import Calculator
import dxtb
from dxtb import OutputHandler
from dxtb.config import ConfigCache
from dxtb.typing import DD

possible_elements = [1, 6, 7, 8, 9]


def _get_permutation_map_closedshell_convention(element_numbers_batch, calculator="tblite"):
    """
    Create permutation maps to reorder orbital matrices from calculator conventions
    to OrbNet-Equi's convention.

    Parameters:
    - element_numbers_batch: List or array of atomic numbers for each molecule in the batch.
    - calculator: str, either "tblite" or "dxtb".

    Returns:
    - List of numpy arrays containing permutation indices for each molecule.
    """
    perm_maps = []
    for element_numbers in element_numbers_batch:
        perm_map = []
        idx = 0
        for el in element_numbers:
            if el == 1:
                perm_map.extend([idx, idx + 1])
                idx += 2
            else:
                if calculator == "tblite":
                    # [s, px, py, pz] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 3, idx + 2, idx + 1])
                elif calculator == "dxtb":
                    # [s, pz, px, py] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 1, idx + 3, idx + 2])
                idx += 4
        perm_maps.append(perm_map)
    return np.array(perm_maps, dtype=np.int64)


def apply_perm_map(array, perm_maps, pos):
    """
    Apply a permutation map to a matrix or a batch of matrices.

    Parameters:
    - array: np.ndarray, shape (n, n), (batch_size, n, n), (batch_size, n, n, nat, 3).
    - perm_maps: List of np.ndarray, permutation indices for each element in the batch.
    - pos: np.ndarray, atomic positions (currently unused but could be used for future extensions).

    Returns:
    - Permuted array or batch of permuted array.
    """
    if len(perm_maps) == 1:  # Non-batched case
        p = perm_maps[0]
        if array.ndim == 2:  # Features (nao, nao)
            return array[p, :][:, p]
        elif array.ndim == 3 and array.shape[0] == 1:  # Non-batched with batch dimension (1, nao, nao)
            return array[:, p, :][:, :, p]
        elif array.ndim == 4:  # Gradients (nao, nao, nat, 3)
            return array[p, :, :, :][:, p, :, :]
        elif array.ndim == 5 and array.shape[0] == 1:  # Gradients with batch dimension (1, nao, nao, nat, 3)
            return array[:, p, :, :, :][:, :, p, :, :]
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
    else:  # Batched case with multiple permutation maps (same length)
        if array.ndim == 3:  # Features (batch_size, nao, nao)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :], axis=2)
        if array.ndim == 5:  # Gradients (batch_size, nao, nao, nat, 3)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None, None, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :, None, None], axis=2)
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
  

def generate_xtb_matrices_fpsh_new(
        calculator="tblite",
        element_numbers=None,
        coordinates=None,
        spin_pol=False,
        cutoff=None,
        *args, **kwargs):
    """
    Generate Fock (F), density (P), overlap (S), and Hamiltonian (H) matrices
    using either the 'tblite' or 'dxtb' calculator.

    Supports batch processing with the 'dxtb' calculator.

    Parameters:
    - calculator: str, either "tblite" or "dxtb" (default: "tblite").
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - cutoff: float, threshold below which matrix elements are set to zero (default: None).
    - *args, **kwargs: Additional arguments passed to the calculators.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", ...}
        Else:
            {"F", "P", "S", "H", ...}
    """
    element_numbers = np.asarray(element_numbers)
    coordinates = np.asarray(coordinates)
    batched = element_numbers.ndim == 2 and coordinates.ndim == 3

    if batched and calculator != "dxtb":
        raise NotImplementedError("Batch processing is only supported with 'dxtb' calculator.")

    if calculator == "tblite":
        res_dict = generate_xtb_matrices_tblite(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            **kwargs
        )
    elif calculator == "dxtb":
        batch_mode = 1 if batched else 0
        res_dict = generate_xtb_matrices_dxtb(
            element_numbers=element_numbers,
            coordinates=coordinates,
            spin_pol=spin_pol,
            batch_mode=batch_mode,
            **kwargs
        )

    # Generate permutation maps
    perm_maps = _get_permutation_map_closedshell_convention(
        element_numbers_batch=element_numbers if batched else [element_numbers],
        calculator=calculator
    )

    # Apply permutation maps and apply cutoff
    for key, mat in res_dict.items():
        if cutoff is not None and key not in ["energy", "force"]:
            mat = np.where(np.abs(mat) <= cutoff, 0, mat)
        if key not in ["energy", "force"]:
            res_dict[key] = apply_perm_map(
                array=mat,
                perm_maps=perm_maps,
                pos=coordinates
            )
    return res_dict


def generate_xtb_matrices_tblite(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=None,
        uhf=None,
        spin_pol=False,
        get_energy=False,
        get_forces=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'tblite' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int, molecular charge (default: None).
    - uhf: int, unrestricted Hartree-Fock value (default: None).
    - spin_pol: bool, whether to perform spin-polarized calculations (default: False).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the tblite calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True:
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force"}
        Else:
            {"F", "P", "S", "H", "energy", "force"}
    """
    calc = Calculator(
        method=option,
        numbers=element_numbers,
        positions=coordinates,
        charge=charge,
        uhf=uhf,
        **kwargs
    )
    if spin_pol:
        calc.add("spin-polarization", 1.0)
    calc.set("verbosity", int(verbosity))
    calc.set("save-integrals", 1)
    res = calc.singlepoint()

    S = res.get("overlap-matrix")  # (nao, nao)
    H = res.get("hamiltonian-matrix")  # (nao, nao)
    P = res.get("density-matrix")  # (nao, nao) or (2, nao, nao)
    E = res.get("orbital-energies")  # (nao) or (2, nao)
    C = res.get("orbital-coefficients")  # (nao, nao) or (2, nao, nao)

    if spin_pol:
        F_a = S @ C[0] @ np.diag(E[0]) @ LA.inv(C[0])
        F_b = S @ C[1] @ np.diag(E[1]) @ LA.inv(C[1])
        res_dict = {"F_a": F_a, "F_b": F_b, "P_a": P[0], "P_b": P[1], "S": S, "H": H}
    else:
        F = S @ C @ np.diag(E) @ LA.inv(C)
        res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = res.get("energy")
    if get_forces:
        res_dict["force"] = res.get("gradient")

    return res_dict


def generate_xtb_matrices_dxtb(
        element_numbers,
        coordinates,
        option="GFN1-xTB",
        charge=0,
        spin=0,
        spin_pol=False,
        batch_mode=0,
        get_energy=False,
        get_forces=False,
        get_analytical_gradients=False,
        verbosity=False,
        **kwargs):
    """
    Generate xTB matrices using the 'dxtb' calculator.

    Parameters:
    - element_numbers: list or np.ndarray, atomic numbers.
        Shape: (n_atoms,) for single, (batch_size, n_atoms) for batch.
    - coordinates: list or np.ndarray, atomic coordinates in bohr.
        Shape: (n_atoms, 3) for single, (batch_size, n_atoms, 3) for batch.
    - option: str, xTB method to use (default: "GFN1-xTB").
    - charge: int or list, molecular charge(s) (default: 0).
    - spin: int or list, number of unpaired electrons (default: 0).
    - spin_pol: bool, whether to perform spin-polarized calculations (not supported).
    - batch_mode: int, 1 for batched, 0 for single (default: 0).
    - get_energy: bool, whether to compute the total energy (default: False).
    - get_forces: bool, whether to compute forces (default: False).
    - verbosity: bool, whether to enable verbosity (default: False).
    - **kwargs: Additional arguments passed to the dxtb calculator.

    Returns:
    - res_dict: dict containing matrices and optionally energy and forces.
        If spin_pol=True (not supported):
            {"F_a", "F_b", "P_a", "P_b", "S", "H", "energy", "force", ...}
        Else:
            {"F", "P", "S", "H", "energy", "force", ...}
    """
    assert not spin_pol, "Spin-polarized calculations are not supported for dxtb."
    par = getattr(dxtb, option.replace('-x', '_X'))
    dd: DD = {"dtype": torch.double, "device": torch.device("cpu")}

    numbers = torch.tensor(element_numbers, dtype=torch.int64, device=dd["device"]) # int64
    pos = torch.tensor(coordinates, dtype=dd["dtype"], device=dd["device"]).requires_grad_(get_forces)

    # Handle charge and spin
    if batch_mode == 1:
        if element_numbers.ndim != 2 or coordinates.ndim != 3:
            raise ValueError("For batch_mode=1, element_numbers must be (batch_size, natoms) and coordinates must be (batch_size, natoms, 3).")
        batch_size = element_numbers.shape[0]
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"]) if isinstance(charge, (list, np.ndarray)) else torch.full((batch_size,), charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"]) if isinstance(spin, (list, np.ndarray)) else torch.full((batch_size,), spin, dtype=torch.double, device=dd["device"])
    else:
        if isinstance(charge, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, charge must be a single value.")
        if isinstance(spin, (list, np.ndarray)):
            raise ValueError("For non-batched calculations, spin must be a single value.")
        charge = torch.tensor(charge, dtype=torch.double, device=dd["device"])
        spin = torch.tensor(spin, dtype=torch.double, device=dd["device"])

    # Set options
    opts = {"verbosity": verbosity, "batch_mode": batch_mode}

    calc = dxtb.Calculator(numbers, par, opts=opts, **dd)
    calc.opts.cache = ConfigCache(enabled=True, density=True, fock=True, overlap=True, hcore=True)
    OutputHandler.verbosity = int(verbosity)

    P = calc.get_density(pos, chrg=charge, spin=spin)
    S = calc.integrals.build_overlap(pos)
    H = calc.integrals.build_hcore(pos)
    F = calc.cache["fock"]

    res_dict = {"F": F, "P": P, "S": S, "H": H}

    if get_energy:
        res_dict["energy"] = calc.get_energy(pos, chrg=charge, spin=spin)
    if get_forces:
        assert get_energy, "Energy must be calculated to get forces."
        assert batch_mode == 0, "Forces are not supported in batch mode."
        res_dict["force"] = -torch.autograd.grad(res_dict["energy"], pos)[0]
        if get_analytical_gradients:
            # assert batch_mode == 0, "Forces are not supported in batch mode."
            keys_to_process = [key for key in res_dict if key not in ["energy", "force"]]
            for key in keys_to_process:
                res_dict[f"grad_{key}"] = get_jacobian(res_dict[key], pos)

    # Convert tensors to numpy arrays
    for k, v in res_dict.items():
        res_dict[k] = v.detach().cpu().numpy()

    return res_dict

def get_jacobian(matrix, pos):
    matrix_jac = torch.zeros(matrix.shape + pos.shape, dtype=matrix.dtype)
    for i, j in np.ndindex(matrix.shape):
        matrix_jac[i, j, :, :] = torch.autograd.grad(matrix[i, j], pos, create_graph=True, retain_graph=True)[0]
    return matrix_jac


# Testing

In [5]:

import numpy as np

# Define two methane molecules with different geometries
element_numbers_methane1 = [6, 1, 1, 1, 1]
coordinates_methane1 = [
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 1.089],
    [1.026719, 0.0, -0.363],
    [-0.51336, -0.889165, -0.363],
    [-0.51336, 0.889165, -0.363]
]

element_numbers_methane2 = [6, 1, 1, 1, 1]
coordinates_methane2 = [
    [0.0, 0.0, 0.0],
    [0.0, 1.089, 0.0],
    [0.0, -0.363, 1.026719],
    [-0.889165, -0.363, -0.51336],
    [0.889165, -0.363, -0.51336]
]

# Create batched inputs
element_numbers_batch = np.array([
    element_numbers_methane1, 
    element_numbers_methane2])

coordinates_batch = np.array([
    coordinates_methane1, 
    coordinates_methane2])


### Test batch vs single

In [6]:
# test_batch_processing.ipynb

# Compute matrices using batched processing
res_batched = generate_xtb_matrices_fpsh_new(
    calculator="dxtb",
    element_numbers=element_numbers_batch,
    coordinates=coordinates_batch,
    spin_pol=False,
    get_energy=True,
)

# Compute matrices separately (non-batched)
res_separate = []
for en, coords in zip(element_numbers_batch, coordinates_batch):
    res = generate_xtb_matrices_fpsh_new(
        calculator="dxtb",
        element_numbers=en,
        coordinates=coords,
        spin_pol=False,
        get_energy=True,
    )
    res_separate.append(res)

# Compare the results
for i, res_sep in enumerate(res_separate):
    print(f"\nComparing results for molecule {i+1}:")
    for key in res_sep.keys():
        if key == "energy":
            # Compare energies
            energy_batched = res_batched[key][i]
            energy_separate = res_sep[key]
            diff = abs(energy_batched - energy_separate)
            print(f"  Energy difference: {diff}")
            assert np.allclose(energy_batched, energy_separate), "Energies do not match!"
        else:
            mat_batched = res_batched[key][i]
            mat_separate = res_sep[key]        
            diff = np.max(np.abs(mat_batched - mat_separate))
            print(f"  Matrix {key} max difference: {diff}")
            assert np.allclose(mat_batched, mat_separate), f"Matrix {key} does not match!"




Comparing results for molecule 1:
  Matrix F max difference: 2.1779755776663023e-10
  Matrix P max difference: 8.79509798323852e-11
  Matrix S max difference: 0.0
  Matrix H max difference: 0.0
  Energy difference: 1.2434497875801753e-13

Comparing results for molecule 2:
  Matrix F max difference: 2.3025154005651416e-10
  Matrix P max difference: 9.300293868363951e-11
  Matrix S max difference: 0.0
  Matrix H max difference: 0.0
  Energy difference: 1.2789769243681803e-13


### Speed test

In [7]:
import numpy as np
import time
from qcm_ml.features.xtb import generate_xtb_matrices_fpsh

# Load data
molecule = "benzene"
file = f"../../../data/rmd17/npz_data/rmd17_{molecule}.npz"
data = np.load(file)

numbers = data["nuclear_charges"]
coords = data["coords"]

# Repeat nuclear charges to match batch size
numbers_batch = np.tile(numbers, (len(coords), 1))

# Limit the number of molecules for testing
max_nb_mols = 1000
numbers_batch = numbers_batch[:max_nb_mols]
coords = coords[:max_nb_mols]
print(f"numbers_batch.shape: {numbers_batch.shape}") 
print(f"coords.shape: {coords.shape}")

# Batched version
start_batched = time.time()
res_batched = generate_xtb_matrices_fpsh_new(
    calculator="dxtb",
    element_numbers=numbers_batch,
    coordinates=coords,
    spin_pol=False,
    get_energy=True,
)
end_batched = time.time()
print(f"Batched processing time: {end_batched - start_batched:.4f} seconds")

# Non-batched version
start_non_batched = time.time()
res_non_batched = []
for i in range(max_nb_mols):
    res = generate_xtb_matrices_fpsh_new(
        calculator="dxtb",
        element_numbers=numbers_batch[i],
        coordinates=coords[i],
        spin_pol=False,
        get_energy=True,
        get_forces=False
    )
    res_non_batched.append(res)
end_non_batched = time.time()
print(f"Non-batched processing time: {end_non_batched - start_non_batched:.4f} seconds")

# Old tblite version
start_old = time.time()
res_old = []
for i in range(max_nb_mols):
    res = generate_xtb_matrices_fpsh(
        calculator="tblite",
        element_numbers=numbers_batch[i],
        coordinates=coords[i],
        spin_pol=False,
        get_energy=True,
        get_forces=False
    )
    res_old.append(res)

end_old = time.time()
print(f"Old tblite processing time: {end_old - start_old:.4f} seconds")

# Old dxtb version
start_old_dxtb = time.time()
res_old_dxtb = []
for i in range(max_nb_mols):
    res = generate_xtb_matrices_fpsh(
        calculator="dxtb",
        element_numbers=numbers_batch[i],
        coordinates=coords[i],
        spin_pol=False,
        get_energy=True,
        get_forces=False
    )
    res_old_dxtb.append(res)
end_old_dxtb = time.time()
print(f"Old dxtb processing time: {end_old_dxtb - start_old_dxtb:.4f} seconds")

# Function to compute maximum difference
def max_diff(matrix1, matrix2):
    return np.max(np.abs(matrix1 - matrix2))

# Compare FPSH matrices (Fock, Density, Overlap, Hamiltonian) and energy
def compare_results(batched, non_batched, old, keys):
    for key in keys:
        if key == "energy":
            batched_val = batched[key]
            non_batched_val = np.array([res[key] for res in non_batched])
            old_val = np.array([res[key] for res in old])
        else:
            batched_val = batched[key]
            non_batched_val = np.array([res[key] for res in non_batched])
            old_val = np.array([res[key] for res in old])

        # Compare the three versions
        print(f"\nComparing {key} matrix:")
        print(f"  Max difference (batched vs non-batched): {max_diff(batched_val, non_batched_val):.2e}")
        print(f"  Max difference (batched vs old): {max_diff(batched_val, old_val):.2e}")
        print(f"  Max difference (non-batched vs old): {max_diff(non_batched_val, old_val):.2e}")

# List of keys to compare
keys_to_compare = ["F", "P", "S", "H", "energy"]

# Run comparison
compare_results(res_batched, res_non_batched, res_old, keys_to_compare)


numbers_batch.shape: (100, 12)
coords.shape: (100, 12, 3)


Batched processing time: 2.9103 seconds
Non-batched processing time: 8.7953 seconds
Old tblite processing time: 5.0252 seconds
Old dxtb processing time: 15.0312 seconds

Comparing F matrix:
  Max difference (batched vs non-batched): 5.87e-05
  Max difference (batched vs old): 2.09e-06
  Max difference (non-batched vs old): 5.91e-05

Comparing P matrix:
  Max difference (batched vs non-batched): 2.55e-03
  Max difference (batched vs old): 1.06e-04
  Max difference (non-batched vs old): 2.60e-03

Comparing S matrix:
  Max difference (batched vs non-batched): 0.00e+00
  Max difference (batched vs old): 1.61e-10
  Max difference (non-batched vs old): 1.61e-10

Comparing H matrix:
  Max difference (batched vs non-batched): 0.00e+00
  Max difference (batched vs old): 2.84e-08
  Max difference (non-batched vs old): 2.84e-08

Comparing energy matrix:
  Max difference (batched vs non-batched): 2.29e-06
  Max difference (batched vs old): 6.26e-07
  Max difference (non-batched vs old): 2.91e-06
