In [1]:
%reload_ext autoreload
%autoreload 2

import h5py
N_Cs = 9

with h5py.File('../../dxtb/dxtb-gpu/gpu-cpu_analysis/rdkit/alkanes_data_500.hdf5', 'r') as f:
    for mol_name, data in f.items():
        if mol_name == f"alkane_{N_Cs}_carbons":
            atomic_numbers = data['atomic_numbers'][:]
            coordinates = data['coordinates'][:]

print(f"Number of carbon atoms in {mol_name}: {N_Cs}")
print(f"Nb of atoms: {len(atomic_numbers)}")

Number of carbon atoms in alkane_9_carbons: 9
Nb of atoms: 29


In [2]:
import torch
import dxtb
from dxtb.config import ConfigCache

opts = {"scf_mode": "implicit", "batch_mode": 0, "int_driver": "libcint", "scp_mode": dxtb.labels.SCP_MODE_FOCK}
device = "cuda:0"
# device = "cpu"

dd = {"dtype": torch.float64, "device": torch.device(device)}
numbers = torch.tensor(atomic_numbers, device= dd["device"], dtype = torch.int32)
positions = torch.tensor(coordinates, device = dd['device'], dtype = dd['dtype']).requires_grad_()

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts, timer=True)
calc.opts.cache = ConfigCache(enabled=True, density=True, fock=True, overlap=True, hcore=True)

dxtb.timer.reset()
res = {
    "P": calc.get_density(positions),
    "S": calc.integrals.build_overlap(positions),
    "H": calc.integrals.build_hcore(positions),
    "F": calc.cache["fock"],
}

# Gradients
for name, M in res.items():
    dxtb.timer.start(name + "_grad")
    grad = torch.autograd.grad(M.sum(), positions, retain_graph=True)[0]
    dxtb.timer.stop(name + "_grad")
    print(f"Gradients for {name}:")
    print(grad)
dxtb.timer.print(v=0)

Converged to 0.
Total Energy: -3.55489866077250 Hartree.
Gradients for P:
tensor([[ 1.7028, -0.8479,  1.4786],
        [ 2.0329,  1.7361, -1.0413],
        [-0.7372,  1.6017,  1.3697],
        [-1.7581, -0.6297, -0.5075],
        [-2.9175,  3.9542,  1.4011],
        [ 9.7487,  7.1829,  7.2927],
        [ 6.1963,  1.2558,  2.8465],
        [-2.0164, -1.3017, -0.2194],
        [ 3.0827,  3.8443,  1.0395],
        [-2.0056, -0.7665, -1.1169],
        [-1.7348, -1.4431,  0.7918],
        [ 0.5827, -0.0438, -0.1937],
        [-2.3086, -0.5354, -0.1768],
        [ 0.4501,  0.3860, -0.7820],
        [ 0.3220, -0.1155, -0.3557],
        [-1.1651, -1.7233,  0.9540],
        [-0.1809, -0.7914, -1.0213],
        [ 1.8765, -2.8387, -2.2383],
        [-0.5562, -1.1958, -4.0240],
        [ 0.3136, -0.2359, -1.2170],
        [-0.5545,  0.0117, -0.6568],
        [-0.4723, -0.3200,  0.4657],
        [-1.9963, -1.8599, -0.5591],
        [ 0.0622, -0.0686,  0.0331],
        [-3.5130, -1.2217,  0.2026],
 