# You can run the cell below multiple times and see that randomly 1 out of 2 errors appear

In [3]:
import torch
import dxtb

def generate_xtb_features_dxtb(
        element_numbers,
        coordinates,
        charge=0,
        spin=0,
        ):
    

    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}

    calc = dxtb.Calculator(element_numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    energy = calc.get_energy(coordinates, chrg=charge, spin=spin)
    forces = -torch.autograd.grad(energy.sum(), coordinates, retain_graph=True)[0]

    return energy, forces


# Load the problematic batch
dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"].to(dd["device"])
positions = problematic_batch["positions"].to(**dd)

opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)
spin = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

e = calc.get_energy(positions, chrg=charges, spin=spin)
forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

# Features calc
res = generate_xtb_features_dxtb(
   numbers,
    positions,
    charge=charges,
    spin=spin,
)

RuntimeError: _Map_base::at

# Or just run this looped version

In [None]:
import torch
from tqdm import tqdm

import dxtb

def generate_xtb_features_dxtb(
        element_numbers,
        coordinates,
        charge=0,
        spin=0,
        ):
    

    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}

    calc = dxtb.Calculator(element_numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    energy = calc.get_energy(coordinates, chrg=charge, spin=spin)
    forces = -torch.autograd.grad(energy.sum(), coordinates, retain_graph=True)[0]

    return energy, forces


# Load the problematic batch
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

for i in tqdm(range(10)):
    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}

    batch_size = numbers.shape[0]
    charges = torch.full((batch_size,), 0, **dd)
    spin = torch.full((batch_size,), 0, **dd)

    calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    e = calc.get_energy(positions, chrg=charges, spin=spin)
    forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

    # Features calc
    res = generate_xtb_features_dxtb(
    numbers,
        positions,
        charge=charges,
        spin=spin,
    )

 20%|██        | 2/10 [00:01<00:04,  1.66it/s]


RuntimeError: could not compute gradients for some functions

# Checks about the sample

It is a sample from the Transition1x dataset. Specifically C2H2N2O/rxn2091. So a single molecule at different steps in the reaction path. Units of the positions are Bohr.

In [5]:
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

print(f"numbers.shape: {numbers.shape}")
print(f"positions.shape: {positions.shape}")

print(f"\nSubset:")
idx = 0
print(f"numbers[0]: {numbers[idx]}")
print(f"positions[0]: {positions[idx]}")

print("\nProperties of the problematic batch:")
print(f"isnan positions: {torch.isnan(positions).any()}")
print(f"isinf positions: {torch.isinf(positions).any()}")
print(f"max positions: {positions.max()}")
print(f"min positions: {positions.min()}")


numbers.shape: torch.Size([64, 7])
positions.shape: torch.Size([64, 7, 3])

Subset:
numbers[0]: tensor([8, 6, 7, 7, 6, 1, 1], device='cuda:0', dtype=torch.int32)
positions[0]: tensor([[ 1.5977, -3.0446,  0.4738],
        [ 1.9159, -0.8995, -0.2115],
        [ 0.4210,  1.1152,  0.7972],
        [-1.3898,  2.4456, -1.2272],
        [-2.1598,  0.8804,  0.3519],
        [ 3.4245, -0.2894, -1.5262],
        [-3.7901, -0.1763,  0.9885]], device='cuda:0',
       grad_fn=<SelectBackward0>)

Properties of the problematic batch:
isnan positions: False
isinf positions: False
max positions: 3.943268299102783
min positions: -3.950294017791748


# Comments

- Env:
    - Python 3.11
    - Just pip install -e . of the main branch of dxtb
    - pip install jupyter notebook tqdm torch tad-libcint


- If restarting the kernel after each run it does not break. Maybe something is cached in the calculator? 

- It is only breaking when using a function. 