In [7]:
import torch
import dxtb
from dxtb.config import ConfigCache

dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}

# Load the problematic batch
problematic_batch = torch.load("../problematic_batch_for_full.pt", weights_only=False)
numbers = problematic_batch["numbers"].to(dd["device"])
positions = problematic_batch["positions"].to(**dd)

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)


# Impplicit
opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)
calc.opts.cache = ConfigCache(enabled=False, density=True, fock=True, overlap=False)

e_impl = calc.get_energy(positions, chrg=charges)
forces_impl = torch.autograd.grad(sum(e_impl), positions, retain_graph=True)[0]
f_impl = calc.get_density(positions=positions, chrg=charges)
f_grad_impl = torch.autograd.grad(f_impl.sum(), positions, retain_graph=True)[0]

# Full
opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)
calc.opts.cache = ConfigCache(enabled=False, density=True, fock=True, overlap=False)
e_full = calc.get_energy(positions, chrg=charges)
forces_full = torch.autograd.grad(sum(e_full), positions, retain_graph=True)[0]
f_full = calc.get_density(positions=positions, chrg=charges)
f_grad_full = torch.autograd.grad(f_full.sum(), positions, retain_graph=True)[0]


# Comparison
print(f"max diff energy: {torch.max(torch.abs(e_impl - e_full))}")
print(f"max diff forces: {torch.max(torch.abs(forces_impl - forces_full))}")
print(f"max diff density: {torch.max(torch.abs(f_impl - f_full))}")
print(f"max diff density grad: {torch.max(torch.abs(f_grad_impl - f_grad_full))}")

max diff energy: 1.33514404296875e-05
max diff forces: 1.996755599975586e-05
max diff density: 3.7550926208496094e-05
max diff density grad: 0.388153076171875
