In [1]:
import torch
import dxtb
import os 

# dxtb params
opts = {
    "verbosity": 0,
    "batch_mode": 0,
    "int_driver": "libcint",
}
dd = {"dtype": torch.float64, "device": torch.device("cpu")}

data_dir = "data"
pt_files = [f for f in os.listdir(data_dir) if f.endswith('.pt')]

for pt_file in pt_files:
    file_path = os.path.join(data_dir, pt_file)
    data = torch.load(file_path, weights_only=False)

    z = data["z"]
    pos = data["pos"].requires_grad_(True)

    opts["scf_mode"] = "full"
    calc_full = dxtb.Calculator(z, dxtb.GFN1_XTB, opts=opts, **dd)
    e_full = calc_full.get_energy(pos)
    forces_full = -torch.autograd.grad(e_full.sum(), pos, create_graph=True)[0]

    opts["scf_mode"] = "implicit"
    calc_implicit = dxtb.Calculator(z, dxtb.GFN1_XTB, opts=opts, **dd)
    e_implicit = calc_implicit.get_energy(pos)
    forces_implicit = -torch.autograd.grad(e_implicit.sum(), pos, create_graph=True)[0]

    print(f"\n{data['name']}:")
    print(f"e_full: {e_full:.2e}, e_implicit: {e_implicit:.2e}, e_DFT: {data['label']:.2e}")
    print(f"max force full: {torch.max(torch.norm(forces_full, dim=-1)):.2e}")
    print(f"max force implicit: {torch.max(torch.norm(forces_implicit, dim=-1)):.2e}")
    print(f"max force DFT: {torch.max(torch.norm(data['force_label'], dim=-1)):.2e}")



C3H6N2O_rxn2575_434:
e_full: -1.98e+01, e_implicit: -2.00e+01, e_DFT: -3.03e+02
max force full: 6.70e+21
max force implicit: 1.93e-01
max force DFT: 1.35e-03

C4H8O2_rxn3652_213:
e_full: -2.17e+01, e_implicit: -2.17e+01, e_DFT: -3.07e+02
max force full: 3.13e+04
max force implicit: 1.14e-01
max force DFT: 3.69e-02

C6H7N_rxn8411_252:
e_full: -1.94e+01, e_implicit: -1.94e+01, e_DFT: -2.87e+02
max force full: 9.56e+04
max force implicit: 1.23e-01
max force DFT: 6.65e-03

C4H9NO2_rxn6584_382:
e_full: -2.49e+01, e_implicit: -2.49e+01, e_DFT: -3.63e+02
max force full: 1.71e+22
max force implicit: 2.86e-01
max force DFT: 9.97e-04

C4H9NO2_rxn7545_241:
e_full: -2.52e+01, e_implicit: -2.53e+01, e_DFT: -3.63e+02
max force full: 9.99e+04
max force implicit: 1.06e-01
max force DFT: 5.00e-04


# vs numerical gradients

In [3]:
import torch
import dxtb
import os

# Manual central finite difference gradient
def finite_difference_grad(f, x, eps=1e-3):
    x = x.detach()
    grad = torch.zeros_like(x)

    for idx in range(x.numel()):
        x_pos = x.clone().reshape(-1)
        x_neg = x.clone().reshape(-1)

        x_pos[idx] += eps
        x_neg[idx] -= eps

        f_pos = f(x_pos.view_as(x))
        f_neg = f(x_neg.view_as(x))

        grad.view(-1)[idx] = (f_pos - f_neg) / (2 * eps)

    return grad

# dxtb params
opts = {
    "verbosity": 0,
    "batch_mode": 0,
    "int_driver": "libcint",
}
dd = {"dtype": torch.float64, "device": torch.device("cpu")}

data_dir = "data"
pt_files = [f for f in os.listdir(data_dir) if f.endswith('.pt')]

for pt_file in pt_files:
    file_path = os.path.join(data_dir, pt_file)
    data = torch.load(file_path, weights_only=False)

    z = data["z"]
    pos = data["pos"].detach().clone().requires_grad_(True)

    # Full SCF
    opts["scf_mode"] = "full"
    calc_full = dxtb.Calculator(z, dxtb.GFN1_XTB, opts=opts, **dd)
    e_full = calc_full.get_energy(pos)
    forces_full = -torch.autograd.grad(e_full.sum(), pos, create_graph=True)[0]
    forces_numerical_full = -finite_difference_grad(lambda x: calc_full.get_energy(x).sum(), pos)

    # Implicit SCF
    opts["scf_mode"] = "implicit"
    calc_implicit = dxtb.Calculator(z, dxtb.GFN1_XTB, opts=opts, **dd)
    e_implicit = calc_implicit.get_energy(pos)
    forces_implicit = -torch.autograd.grad(e_implicit.sum(), pos, create_graph=True)[0]
    forces_numerical_implicit = -finite_difference_grad(lambda x: calc_implicit.get_energy(x).sum(), pos)

    # Differences
    diff_num_full = torch.norm(forces_numerical_full - forces_full, dim=-1).max()
    diff_num_implicit = torch.norm(forces_numerical_implicit - forces_implicit, dim=-1).max()

    print(f"\n{data['name']}:")
    print(f"e_full: {e_full:.2e}, e_implicit: {e_implicit:.2e}, e_DFT: {data['label']:.2e}")
    print(f"max force full: {torch.norm(forces_full, dim=-1).max():.2e}")
    print(f"max force numerical (full): {torch.norm(forces_numerical_full, dim=-1).max():.2e}")
    print(f"‖num(full) - full‖_max: {diff_num_full:.2e}")
    print(f"max force implicit: {torch.norm(forces_implicit, dim=-1).max():.2e}")
    print(f"max force numerical (implicit): {torch.norm(forces_numerical_implicit, dim=-1).max():.2e}")
    print(f"‖num(impl) - impl‖_max: {diff_num_implicit:.2e}")
    print(f"max force DFT: {torch.norm(data['force_label'], dim=-1).max():.2e}")



C3H6N2O_rxn2575_434:
e_full: -1.98e+01, e_implicit: -2.00e+01, e_DFT: -3.03e+02
max force full: 6.70e+21
max force numerical (full): 1.83e+04
‖num(full) - full‖_max: 6.70e+21
max force implicit: 1.93e-01
max force numerical (implicit): 1.94e-01
‖num(impl) - impl‖_max: 1.56e-03
max force DFT: 1.35e-03

C4H8O2_rxn3652_213:
e_full: -2.17e+01, e_implicit: -2.17e+01, e_DFT: -3.07e+02
max force full: 3.13e+04
max force numerical (full): 2.27e+02
‖num(full) - full‖_max: 3.13e+04
max force implicit: 1.14e-01
max force numerical (implicit): 3.25e+02
‖num(impl) - impl‖_max: 3.25e+02
max force DFT: 3.69e-02

C6H7N_rxn8411_252:
e_full: -1.94e+01, e_implicit: -1.94e+01, e_DFT: -2.87e+02
max force full: 9.56e+04
max force numerical (full): 1.22e-01
‖num(full) - full‖_max: 9.56e+04
max force implicit: 1.23e-01
max force numerical (implicit): 1.22e-01
‖num(impl) - impl‖_max: 1.42e-03
max force DFT: 6.65e-03

C4H9NO2_rxn6584_382:
e_full: -2.49e+01, e_implicit: -2.49e+01, e_DFT: -3.63e+02
max force ful

# Plot the samples

In [2]:
import torch
import os
from ase import Atoms
from ase.visualize import view
from ase.io import write

data_dir = "data"
pt_files = [f for f in os.listdir(data_dir) if f.endswith('.pt')]

def show_mol(z, pos, name):
    atoms = Atoms(numbers=z.tolist(), positions=pos.detach().cpu().numpy())
    atoms.info['name'] = name
    view(atoms, viewer='x3d')

# Show each molecule in the data directory
for pt_file in pt_files:
    file_path = os.path.join(data_dir, pt_file)
    data = torch.load(file_path, weights_only=False)
    atoms = Atoms(numbers=data["z"].tolist(), positions=data["pos"].detach().cpu().numpy())
    atoms.info['name'] = data["name"]
    write(f"{data['name']}.xyz", atoms)