In [45]:
%reload_ext autoreload
%autoreload 2

import h5py
N_Cs = 9

with h5py.File('../dxtb/dxtb-gpu/gpu-cpu_analysis/rdkit/alkanes_data_500.hdf5', 'r') as f:
    for mol_name, data in f.items():
        if mol_name == f"alkane_{N_Cs}_carbons":
            atomic_numbers = data['atomic_numbers'][:]
            coordinates = data['coordinates'][:]

print(f"Number of carbon atoms in {mol_name}: {N_Cs}")
print(f"Nb of atoms: {len(atomic_numbers)}")

Number of carbon atoms in alkane_9_carbons: 9
Nb of atoms: 29


In [46]:
import dxtb
from dxtb._src.typing import DD
import torch
from dxtb.config import ConfigCache

batch_size = 64

print(f"Number of carbon atoms in {mol_name}: {N_Cs}")
print(f"Nb of atoms: {len(atomic_numbers)}")
print(f"batch_size: {batch_size}")

dd = {"device": torch.device("cuda:0"), "dtype": torch.float64}
numbers = torch.tensor(atomic_numbers, device=dd["device"], dtype=torch.int32)
positions = torch.tensor(coordinates, **dd).requires_grad_()
charges = torch.tensor(0.0, **dd)
# numbers = torch.stack([numbers] * batch_size)
# positions = torch.stack([positions] * batch_size).requires_grad_()
# charges = torch.zeros((batch_size,), device=dd["device"], dtype=dd["dtype"])

results = {}

Number of carbon atoms in alkane_9_carbons: 9
Nb of atoms: 29
batch_size: 64


In [47]:
opts = {"scf_mode": "full", "batch_mode": 0, "int_driver": "libcint", "maxiter":10000}

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts, timer=True)
calc.opts.cache = ConfigCache(enabled=False, density=True, fock=True, overlap=False)
dxtb.timer.reset()
e = calc.get_energy(positions, chrg=charges)
dxtb.timer.start("Forces autograd")
forces = torch.autograd.grad(e, positions, retain_graph=True)[0]
dxtb.timer.stop("Forces autograd")
dxtb.timer.print(v=0)

results[f"e_{opts['scf_mode']}"] = e
results[f"forces_{opts['scf_mode']}"] = forces
results[f"Fgrad_{opts['scf_mode']}"] = torch.autograd.grad(calc.cache["fock"].sum(), positions, retain_graph=True)[0]
results[f"Pgrad_{opts['scf_mode']}"] = torch.autograd.grad(calc.get_density(positions, chrg=charges).sum(), positions, retain_graph=True)[0]

# For reconnect modes
scf_charges = calc.get_charges(positions, chrg=charges)
scf_charge_mode = opts["scf_mode"]

Total Energy: -3.55489866559794 Hartree.




Timings
-------

[1mObjective                Time (s)        % Total[0m
------------------------------------------------
[1mClassicals                  0.005           3.20[0m
 - DispersionD3        [37m     0.005          83.95[0m
 - Halogen             [37m     0.000           3.49[0m
 - Repulsion           [37m     0.001          11.72[0m
[1mIntegrals                   0.004           2.43[0m
 - Overlap             [37m     0.002          56.72[0m
 - Core Hamiltonian    [37m     0.002          43.14[0m
[1mSCF                         0.102          60.97[0m
 - Interaction Cache   [37m     0.001           0.72[0m
 - Potential           [37m     0.009           9.04[0m
 - Fock build          [37m     0.001           0.66[0m
 - Diagonalize         [37m     0.065          63.21[0m
 - Density             [37m     0.003           2.65[0m
 - Charges             [37m     0.003           2.88[0m
[1mForces autograd             0.054          32.18[0m
---------

In [48]:
opts = {"scf_mode": "implicit", "batch_mode":0, "int_driver": "libcint", "maxiter":10000}

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts, timer=True)
calc.opts.cache = ConfigCache(enabled=False, density=True, fock=True, overlap=False)
dxtb.timer.reset()
e = calc.get_energy(positions, chrg=charges)
dxtb.timer.start("Forces autograd")
forces = torch.autograd.grad(e, positions, retain_graph=True)[0]
dxtb.timer.stop("Forces autograd")
dxtb.timer.print(v=0)

results[f"e_{opts['scf_mode']}"] = e
results[f"forces_{opts['scf_mode']}"] = forces
results[f"Fgrad_{opts['scf_mode']}"] = torch.autograd.grad(calc.cache["fock"].sum(), positions, retain_graph=True)[0]
results[f"Pgrad_{opts['scf_mode']}"] = torch.autograd.grad(calc.get_density(positions, chrg=charges).sum(), positions, retain_graph=True)[0]

# For reconnect modes
scf_charges = calc.get_charges(positions, chrg=charges)
scf_charge_mode = opts["scf_mode"]


Total Energy: -3.55489865388107 Hartree.


Timings
-------

[1mObjective                Time (s)        % Total[0m
------------------------------------------------
[1mClassicals                  0.004           4.17[0m
 - Halogen             [37m     0.000           3.76[0m
 - Repulsion           [37m     0.001          14.76[0m
 - DispersionD3        [37m     0.003          80.50[0m
[1mIntegrals                   0.004           3.93[0m
 - Overlap             [37m     0.002          61.51[0m
 - Core Hamiltonian    [37m     0.002          38.33[0m
[1mSCF                         0.086          83.81[0m
 - Interaction Cache   [37m     0.001           0.86[0m
 - Potential           [37m     0.069          79.56[0m
 - Fock build          [37m     0.001           0.58[0m
 - Diagonalize         [37m     0.054          62.39[0m
 - Density             [37m     0.002           2.68[0m
 - Charges             [37m     0.002           2.76[0m
[1mForces autograd      

Total Energy: -3.55489865388107 Hartree.
Total Energy: -3.55489865388107 Hartree.


# Gradchecker

In [49]:
import torch
from tad_mctc.autograd import dgradcheck

from dxtb import GFN1_XTB as par
from dxtb import Calculator, OutputHandler, labels
from dxtb._src.typing import DD, Callable, Tensor
from dxtb.config import ConfigCache

# from .samples import samples

tol = 1e-6

def gradchecker(
    dtype: torch.dtype, scp_mode: str = "potential", scf_mode: str = "implicit"
) -> tuple[Callable[[Tensor], Tensor], Tensor]:
    """Prepare gradient check from `torch.autograd`."""
    dd: DD = {"dtype": dtype, "device": torch.device("cpu")}

    numbers = torch.tensor(atomic_numbers, device=dd["device"], dtype=torch.int32)
    positions = torch.tensor(coordinates, **dd)

    opts = {
        "scf_mode": scf_mode,
        "scp_mode": scp_mode,
    }

    calc = Calculator(numbers, par, **dd, opts=opts)
    calc.opts.cache = ConfigCache(enabled=False, fock=True)
    OutputHandler.verbosity = 0

    # variables to be differentiated
    pos = positions.clone().requires_grad_(True)

    def func(p: Tensor) -> Tensor:
        _ = calc.get_energy(p)  # triggers Fock matrix computation
        return calc.cache["fock"]

    return func, pos

def test_grad_fock(dtype: torch.dtype, scp_mode: str, scf_mode: str) -> None:
    """
    Check analytical gradient of Fock matrix against numerical
    gradient from `torch.autograd.gradcheck`.
    """
    func, diffvars = gradchecker(dtype, scp_mode, scf_mode)
    assert dgradcheck(func, diffvars, atol=tol, fast_mode=True)

test_grad_fock(torch.float64, "potential", "full" )

# Manual jacobian calculations

In [63]:
import torch
from torch.autograd.gradcheck import (
    get_analytical_jacobian,
    get_numerical_jacobian,
)

# — your existing gradchecker() definition must be in scope here —

def analytical_jacobian(fn, inputs):
    # wrap input in a 1‑tuple, flatten output to 1‑D
    inputs_tup = (inputs,)
    y = fn(inputs)            # e.g. your Fock matrix, shape [M,N]
    y_flat = y.reshape(-1)    # shape [M*N]
    # returns a tuple of Jacobians—one per input
    (J_flat,), reentrant, sizes_ok, types_ok = get_analytical_jacobian(
        inputs_tup,
        y_flat,
        nondet_tol=0.0,
        grad_out=1.0,
    )
    # J_flat shape: (inputs.numel(), y_flat.numel())
    return J_flat

def numerical_jacobian(fn, inputs, eps=1e-6):
    # we need a function that takes a tuple of inputs
    # and returns a flat (1‑D) output
    def flat_fn(inp_tuple):
        x = inp_tuple[0]
        y = fn(x)
        return y.reshape(-1)
    # get_numerical_jacobian returns one Jacobian per input
    (J_flat,) = get_numerical_jacobian(
        flat_fn,
        inputs,   # single Tensor; internals call _as_tuple on it
        eps=eps,
    )
    # J_flat shape: (inputs.numel(), y_flat.numel())
    return J_flat

# —— usage example —— 

# build your function + variables
SCF_MODE = "implicit"
SCP_MODE = "potential"
func32, pos32 = gradchecker(torch.float32, SCP_MODE, SCF_MODE)
func64, pos64 = gradchecker(torch.float64, SCP_MODE, SCF_MODE)

# compute slow/full analytical and numerical
J_an32 = analytical_jacobian(func32, pos32); print(J_an32.shape)
J_num32 = numerical_jacobian(func32, pos32); print(J_num32.shape)
J_an64 = analytical_jacobian(func64, pos64); print(J_an64.shape)
J_num64 = numerical_jacobian(func64, pos64); print(J_num64.shape)

  (J_flat,), reentrant, sizes_ok, types_ok = get_analytical_jacobian(


torch.Size([87, 5776])


  (J_flat,) = get_numerical_jacobian(


torch.Size([87, 5776])


  (J_flat,), reentrant, sizes_ok, types_ok = get_analytical_jacobian(


torch.Size([87, 5776])


  (J_flat,) = get_numerical_jacobian(


torch.Size([87, 5776])


In [64]:
# Print datatypes
print(f"SCF_MODE: {SCF_MODE}, SCP_MODE: {SCP_MODE}")
print(f"Number of carbon atoms in {mol_name}: {N_Cs}")
print(f"Nb of atoms: {len(atomic_numbers)}")
print(f"J_an32: {J_an32.dtype}")
print(f"J_num32: {J_num32.dtype}")
print(f"J_an64: {J_an64.dtype}")
print(f"J_num64: {J_num64.dtype}")
print()

assert J_an32.shape == J_num32.shape
max_diff_32an_64an = torch.max(torch.abs(J_an32 - J_an64))
max_diff_32nu_64nu = torch.max(torch.abs(J_num32 - J_num64))

max_diff_32an_32nu = torch.max(torch.abs(J_an32 - J_num32))
max_diff_64an_64nu = torch.max(torch.abs(J_an64 - J_num64))
max_diff_32an_64nu = torch.max(torch.abs(J_an32 - J_num64))

print(f"max_diff_32an_64an: {max_diff_32an_64an:.2e}")
print(f"max_diff_32nu_64nu: {max_diff_32nu_64nu:.2e}\n")
print(f"max_diff_32an_32nu: {max_diff_32an_32nu:.2e}")
print(f"max_diff_64an_64nu: {max_diff_64an_64nu:.2e}")
print(f"max_diff_32an_64nu: {max_diff_32an_64nu:.2e}")

print()

# Check the relative max difference 
max_rdiff_32an_64an = torch.max(torch.abs(J_an32 - J_an64) / (torch.abs(J_an32) + 1e-15))
max_rdiff_32nu_64nu = torch.max(torch.abs(J_num32 - J_num64) / (torch.abs(J_num32) + 1e-15))

max_rdiff_32an_32nu = torch.max(torch.abs(J_an32 - J_num32) / (torch.abs(J_an32) + 1e-15))
max_rdiff_64an_64nu = torch.max(torch.abs(J_an64 - J_num64) / (torch.abs(J_an64) + 1e-15))
max_rdiff_32an_64nu = torch.max(torch.abs(J_an32 - J_num64) / (torch.abs(J_an32) + 1e-15))
print(f"max_rdiff_32an_64an: {max_rdiff_32an_32nu:.2e}")
print(f"max_rdiff_32nu_64nu: {max_rdiff_32nu_64nu:.2e}\n")
print(f"max_rdiff_32an_32nu: {max_rdiff_32an_32nu:.2e}")
print(f"max_rdiff_64an_64nu: {max_rdiff_64an_64nu:.2e}")
print(f"max_rdiff_32an_64nu: {max_rdiff_32an_64nu:.2e}")


SCF_MODE: implicit, SCP_MODE: potential
Number of carbon atoms in alkane_9_carbons: 9
Nb of atoms: 29
J_an32: torch.float32
J_num32: torch.float32
J_an64: torch.float64
J_num64: torch.float64

max_diff_32an_64an: 5.18e-07
max_diff_32nu_64nu: 2.47e+01

max_diff_32an_32nu: 2.47e+01
max_diff_64an_64nu: 1.15e-01
max_diff_32an_64nu: 1.15e-01

max_rdiff_32an_64an: 1.02e+14
max_rdiff_32nu_64nu: 3.70e+14

max_rdiff_32an_32nu: 1.02e+14
max_rdiff_64an_64nu: 7.93e+12
max_rdiff_32an_64nu: 7.93e+12


# Other attempts

In [None]:
from torch.autograd import gradcheck
from dxtb import OutputHandler
from dxtb.config import ConfigCache

OutputHandler.verbosity = 0

# Inputs (must be float64)
positions_d = positions.detach().double().requires_grad_()
charges_d = charges.double()

calc.opts.cache = ConfigCache(enabled=False, fock=True, density=True)

def run_gradcheck(fn, inputs):
    # gradcheck assumes float64 and requires_grad=True
    inputs = tuple(i.detach().double().requires_grad_() for i in inputs)
    passed = gradcheck(fn, inputs, eps=1e-6, atol=1e-3, rtol=1e-3, nondet_tol=1e-5)
    print(f"Gradcheck passed: {passed}")

# Energy (scalar output)
def energy_fn(pos):
    return calc.get_energy(pos, chrg=charges_d)

print("Energy")
run_gradcheck(energy_fn, (positions_d,))

# Fock (matrix output)
def fock_fn(pos):
    _ = calc.get_energy(pos, chrg=charges_d)  # populate cache
    return calc.cache["fock"]

print("Fock")
run_gradcheck(fock_fn, (positions_d,))

# Density (matrix output)
def density_fn(pos):
    return calc.get_density(pos, chrg=charges_d)

print("Density")
run_gradcheck(density_fn, (positions_d,))


In [None]:
from torch.autograd import gradcheck
from dxtb import OutputHandler

OutputHandler.verbosity = 0

def run_gradcheck(fn, inputs):
    inputs = tuple(i.detach().requires_grad_() for i in inputs)
    passed = gradcheck(fn, inputs)
    print(f"Gradcheck passed: {passed}")

# Functions must return tuple of tensor outputs in float64
def energy_fn(pos):
    return (calc.get_energy(pos, chrg=charges),)

def fock_fn(pos):
    calc.get_energy(pos, chrg=charges)  # populate cache
    return (calc.cache["fock"].sum(),)

def density_fn(pos):
    return (calc.get_density(pos, chrg=charges).sum(),)

# Inputs
positions_d = positions.detach().requires_grad_()
charges_d = charges

print("Energy")
run_gradcheck(energy_fn, (positions_d,))

print("Fock")
run_gradcheck(fock_fn, (positions_d,))

print("Density")
run_gradcheck(density_fn, (positions_d,))


In [None]:
from torch.autograd import gradcheck
from torch.autograd.gradcheck import _get_numerical_jacobian, _as_tuple
from dxtb import OutputHandler

OutputHandler.verbosity = 0

def compare_grads(fn, inputs):
    # Prepare input
    inputs = tuple(i.detach().requires_grad_() for i in _as_tuple(inputs))
    output = fn(*inputs)
    output = _as_tuple(output)

    # Compute autograd
    autograd_grads = torch.autograd.grad(output, inputs, grad_outputs=[torch.ones_like(o) for o in output], retain_graph=True)

    # Compute numerical
    numerical_grads = _get_numerical_jacobian(fn, inputs, eps=1e-6)

    # Print comparison
    for i, (a, n) in enumerate(zip(autograd_grads, numerical_grads)):
        n_tensor = n[0][0]  # FIXED: Unwrap twice
        print(f"[Input {i}] max(abs diff): {(a - n_tensor).abs().max().item():.2e}")
        # print(f"Autograd:\n{a}\nNumerical:\n{n_tensor}")

def energy_fn(pos):
    return calc.get_energy(pos, chrg=charges)
print("Energy")
compare_grads(energy_fn, (positions,))

def fock_fn(pos):
    calc.get_energy(pos, chrg=charges)
    return calc.cache["fock"].sum()
print("Fock")
compare_grads(fock_fn, (positions,))

def density_fn(pos):
    return calc.get_density(pos, chrg=charges).sum()
print("Density")
compare_grads(density_fn, (positions,))