# You can run the cell below multiple times and see that randomly 1 out of 2 errors appear

In [5]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTORCH_JIT_DISABLE_NVFUSER"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"

###########################################################

import torch
import numpy as np
import dxtb
from tqdm import tqdm


torch.autograd.set_detect_anomaly(False)
torch.use_deterministic_algorithms(False, warn_only=False)
torch.set_deterministic_debug_mode(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)

torch.cuda.empty_cache()

###########################################################

print("Torch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("CUDNN version:", torch.backends.cudnn.version())

###########################################################

problematic_batch_path = "../problematic_batches/problematic_batch_for_full.pt"

dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
opts = {
    "scf_mode": "full",
    "batch_mode": 2, 
    "int_driver": "libcint",
    "exclude": ["disp"], # excluding D3 fixes the issue
}

# Load the problematic batch
problematic_batch = torch.load(
    problematic_batch_path, weights_only=False
)
numbers = problematic_batch["numbers"].to(dd["device"])
positions = problematic_batch["positions"].to(**dd)

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

for i in tqdm(range(100)):
    e = calc.get_energy(positions, chrg=charges)
    forces = torch.autograd.grad(e.sum(), positions, retain_graph=True)[0]

    # MF: Also fails, when commented out
    # Features calc
    func = lambda e, p: -torch.autograd.grad(e.sum(), p, retain_graph=True)[0]
    res = func(e, positions)

Torch version: 2.5.1
CUDA version: 12.4
CUDNN version: 90100


 69%|██████▉   | 69/100 [00:21<00:09,  3.17it/s]


KeyboardInterrupt: 

# Checks about the sample

It is a sample from the Transition1x dataset. Specifically C2H2N2O/rxn2091. So a single molecule at different steps in the reaction path. Units of the positions are Bohr.

In [None]:
problematic_batch = torch.load(problematic_batch_path, weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

print(f"numbers.shape: {numbers.shape}")
print(f"positions.shape: {positions.shape}")

print(f"\nSubset:")
idx = 0
print(f"numbers[0]: {numbers[idx]}")
print(f"positions[0]: {positions[idx]}")

print("\nProperties of the problematic batch:")
print(f"isnan positions: {torch.isnan(positions).any()}")
print(f"isinf positions: {torch.isinf(positions).any()}")
print(f"max positions: {positions.max()}")
print(f"min positions: {positions.min()}")


# Comments

- Env:
    - Python 3.11
    - Just pip install -e . of the main branch of dxtb
    - pip install jupyter notebook tqdm torch tad-libcint

    - Torch 2.5.1, CUDA 12.4, driver 550.120


- If restarting the kernel after each run it does not break. Maybe something is cached in the calculator? 

- It is only breaking when using a function. 