# You can run the cell below multiple times and see that randomly 1 out of 2 errors appear

In [2]:
import torch
import dxtb

problematic_batch_path = "problematic_batch_for_full.pt"

dd = {"dtype": torch.float64, "device": torch.device("cuda:0")}
opts = {"scf_mode": "full", "batch_mode": 1, "int_driver": "libcint"}

# Load the problematic batch
problematic_batch = torch.load(problematic_batch_path, weights_only=False)
numbers = problematic_batch["numbers"].to(dd["device"])
positions = problematic_batch["positions"].to(**dd)

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

e = calc.get_energy(positions, chrg=charges)
forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

# Features calc
func = lambda e, p: -torch.autograd.grad(e.sum(), p, retain_graph=True)[0]
res = func(e, positions)

# Or just run this looped version

In [4]:
import torch
import dxtb
from tqdm import tqdm

dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
opts = {"scf_mode": "full", "batch_mode": 1, "int_driver": "libcint"}

# Load the problematic batch
problematic_batch = torch.load(problematic_batch_path, weights_only=False)
numbers = problematic_batch["numbers"].to(dd["device"])
positions = problematic_batch["positions"].to(**dd)

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

for i in tqdm(range(100)):
    e = calc.get_energy(positions, chrg=charges)
    forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

    # Features calc
    func = lambda e, p: -torch.autograd.grad(e.sum(), p, retain_graph=True)[0]
    res = func(e, positions)

100%|██████████| 100/100 [00:23<00:00,  4.30it/s]


# Checks about the sample

It is a sample from the Transition1x dataset. Specifically C2H2N2O/rxn2091. So a single molecule at different steps in the reaction path. Units of the positions are Bohr.

In [3]:
problematic_batch = torch.load(problematic_batch_path, weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

print(f"numbers.shape: {numbers.shape}")
print(f"positions.shape: {positions.shape}")

print(f"\nSubset:")
idx = 0
print(f"numbers[0]: {numbers[idx]}")
print(f"positions[0]: {positions[idx]}")

print("\nProperties of the problematic batch:")
print(f"isnan positions: {torch.isnan(positions).any()}")
print(f"isinf positions: {torch.isinf(positions).any()}")
print(f"max positions: {positions.max()}")
print(f"min positions: {positions.min()}")


numbers.shape: torch.Size([64, 7])
positions.shape: torch.Size([64, 7, 3])

Subset:
numbers[0]: tensor([8, 6, 7, 6, 7, 1, 1], device='cuda:0', dtype=torch.int32)
positions[0]: tensor([[ 0.8494,  3.3460,  0.0348],
        [ 1.2790, -1.3759, -0.2903],
        [-0.2177, -1.3857, -2.1151],
        [-1.1522, -0.2871,  0.3423],
        [-1.1985,  2.4734,  0.5247],
        [ 3.2372, -1.7409,  0.1731],
        [-2.7249, -1.2143,  1.2866]], device='cuda:0',
       grad_fn=<SelectBackward0>)

Properties of the problematic batch:
isnan positions: False
isinf positions: False
max positions: 3.877077102661133
min positions: -3.9229071140289307


# Comments

- Env:
    - Python 3.11
    - Just pip install -e . of the main branch of dxtb
    - pip install jupyter notebook tqdm torch tad-libcint

    - Torch 2.5.1, CUDA 12.4, driver 550.120


- If restarting the kernel after each run it does not break. Maybe something is cached in the calculator? 

- It is only breaking when using a function. 