In [1]:
import numpy as np
import torch
import dxtb
from dxtb.typing import DD
from dxtb.config import ConfigCache
from dxtb import OutputHandler
from tblite.interface import Calculator

dd: DD = {"dtype": torch.double, "device": torch.device("cuda:0")}

# LiH
numbers = torch.tensor([3, 1], device=dd["device"])
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.5]], **dd) # ** to use dd as kwargs 

numbers = torch.tensor([6, 6, 7, 7, 1, 1, 1, 1, 1, 1, 8, 8,], device=dd["device"])
positions = torch.tensor([
                [-3.81469488143921, +0.09993441402912, 0.00000000000000],
                [+3.81469488143921, -0.09993441402912, 0.00000000000000],
                [-2.66030049324036, -2.15898251533508, 0.00000000000000],
                [+2.66030049324036, +2.15898251533508, 0.00000000000000],
                [-0.73178529739380, -2.28237795829773, 0.00000000000000],
                [-5.89039325714111, -0.02589114569128, 0.00000000000000],
                [-3.71254944801331, -3.73605775833130, 0.00000000000000],
                [+3.71254944801331, +3.73605775833130, 0.00000000000000],
                [+0.73178529739380, +2.28237795829773, 0.00000000000000],
                [+5.89039325714111, +0.02589114569128, 0.00000000000000],
                [-2.74426102638245, +2.16115570068359, 0.00000000000000],
                [+2.74426102638245, -2.16115570068359, 0.00000000000000],
                ], **dd) # ** to use dd as kwargs

pos = positions.clone().requires_grad_(True)

# instantiate a dxtb calculator
# cache_config = ConfigCache(enabled=False, density=True)
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd)
# calc.opts.cache = cache_config
OutputHandler.verbosity = 0

In [2]:
from tqdm import tqdm

for i in tqdm(range(1000)):
    e = calc.get_energy(pos)
    # grad = torch.autograd.grad(e, pos, retain_graph=True)[0]  

100%|██████████| 1000/1000 [00:34<00:00, 28.94it/s]


In [3]:
from dxtb import labels
from dxtb._src import scf
from dxtb.config import Config
from dxtb.integrals import DriverManager
from dxtb._src.integral.container import IntegralMatrices
from dxtb import IndexHelper
from dxtb._src.components.interactions.coulomb.secondorder import new_es2
from dxtb._src.components.interactions.coulomb.thirdorder import new_es3
from dxtb._src.components.interactions.list import InteractionList
from dxtb._src.xtb.gfn1 import GFN1Hamiltonian
from dxtb._src.constants import defaults

from dxtb.integrals import DriverManager
from dxtb.integrals.factories import new_overlap
from tad_mctc.convert import any_to_tensor

def trimmed_singlepoint(numbers, positions, result="density", chrg=defaults.CHRG, spin=defaults.SPIN):
    # Device and dtype setup
    device = positions.device
    dtype = positions.dtype
    dd = {'device': device, 'dtype': dtype}

    # Move charge and spin to the target device
    chrg = any_to_tensor(chrg, **dd).to(device)
    if spin is not None:
        spin = any_to_tensor(spin, **dd).to(device)

    # Config setup on the correct device
    opts = Config(**dd)
    opts.ints.level = labels.INTLEVEL_HCORE
    opts.exclude = set()  # Ensure 'scf' is not excluded

    par = dxtb.GFN1_XTB
    ihelp = IndexHelper.from_numbers(numbers, par).to(device)  # Move to device

    # Initialize integral matrices on device
    intmats = IntegralMatrices(**dd)

    # Overlap integral
    driver_name = 0  # libcint
    drv_mgr = DriverManager(driver_name, **dd)
    drv_mgr.create_driver(numbers.to(device), par, ihelp)  # Ensure numbers are on device
    drv_mgr.driver.setup(positions.to(device))

    ovlp_integral = new_overlap(drv_mgr.driver_type, **dd)
    ovlp_integral.build(drv_mgr.driver)

    # Ensure the overlap matrix is on the specified device
    intmats.overlap = ovlp_integral.matrix.to(device)  # Move overlap matrix to device

    # Hcore integral
    h0 = GFN1Hamiltonian(numbers.to(device), par, ihelp, **dd)  # Ensure numbers and ihelp on device
    hcore = h0.build(positions.to(device), intmats.overlap)  # Use the overlap matrix on the device
    intmats.hcore = hcore


    # Create interactions (es2 and es3) on the correct device
    es2 = (
        new_es2(numbers.to(device), par, **dd)
        if not {"all", "es2"} & opts.exclude
        else None
    )
    es3 = (
        new_es3(numbers.to(device), par, **dd)
        if not {"all", "es3"} & opts.exclude
        else None
    )
    interactions = InteractionList(es2, es3, **dd)

    # Build interaction cache on the correct device
    icaches = interactions.get_cache(numbers=numbers.to(device), positions=positions.to(device), ihelp=ihelp)

    scf_opts = opts.scf

    # Run SCF on the specified device
    scf_results = scf.solve(
        numbers.to(device),
        positions.to(device),
        chrg,
        spin,
        interactions,
        icaches,
        ihelp,
        scf_opts,
        intmats,
        h0.refocc.to(device) if h0.refocc is not None else None,
    )

    return scf_results["density"].to(device)  # Return density on the specified device

# Run the calculation
scf_results = trimmed_singlepoint(numbers, pos, chrg=0, spin=None)



In [4]:

for i in tqdm(range(10000)):
    scf_results = trimmed_singlepoint(numbers, pos, chrg=0, spin=None)
    # grad = torch.autograd.grad(scf_results.sum(), pos)[0]

 46%|████▌     | 4571/10000 [02:38<03:08, 28.83it/s]


KeyboardInterrupt: 