# Function in notebook?

In [None]:
import os
import sys
import h5py
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm

import dxtb

def generate_xtb_features_dxtb(
        element_numbers,
        coordinates,
        charge=0,
        spin=0,
        ):
    

    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}

    calc = dxtb.Calculator(element_numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    energy = calc.get_energy(coordinates, chrg=charge, spin=spin)
    forces = -torch.autograd.grad(energy.sum(), coordinates, retain_graph=True)[0]

    return energy, forces

In [9]:


dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}

class TransitionBatchDataset(IterableDataset):
    def __init__(self, hdf5_path, split="val", batch_size=64, mol_names=None):
        super().__init__()
        self.hdf5_path = hdf5_path
        self.split = split
        self.batch_size = batch_size
        self.mol_names = mol_names

    def __iter__(self):
        with h5py.File(self.hdf5_path, "r") as f:
            for mol_name in list(self.mol_names or f[f"{self.split}"].keys()):
                mol_group = f[f"{self.split}/{mol_name}"]
                for rxn_name in mol_group.keys():
                    rxn_group = mol_group[rxn_name]
                    positions = rxn_group["positions"]
                    zs = rxn_group["atomic_numbers"][()]
                    n_samples = len(positions)

                    for i in range(0, n_samples, self.batch_size):
                        pos_batch = torch.tensor(positions[i:i+self.batch_size], **dd) / 0.529177 # [A] -> [Bohr]
                        pos_batch.requires_grad_(True)
                        z_batch = torch.tensor([zs] * len(pos_batch), device=dd["device"])  # [B, N]
                        yield {
                            "mol_name": mol_name,
                            "rxn_name": rxn_name,
                            "z": z_batch,
                            "pos": pos_batch,
                            "batch_size": len(pos_batch)
                        }

# Create dataset + dataloader
dataset = TransitionBatchDataset(
    hdf5_path="../../../../../data/Transition1x/data/transition1x.h5",
    batch_size=64,
    mol_names=None
)
dataloader = DataLoader(dataset, batch_size=None)

# Wrap in tqdm and track sample count
sample_count = 0
pbar = tqdm(dataloader, desc="Processing", unit=" datapoints")


i = 0
for batch in pbar:
    i += 1
    # if i <53:
    #     continue
    sample_count += batch["batch_size"]
    pbar.set_description(f"{batch['mol_name']}/{batch['rxn_name']}")
    pbar.set_postfix(total=sample_count)
    
    # print(f"z {batch['z']}")
    # print(f"pos {batch['pos']}")

    numbers = batch["z"]
    positions = batch["pos"]

    # DXTB CALC
    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}

    batch_size = batch['z'].shape[0]
    charges = torch.full((batch_size,), 0, **dd)
    spin = torch.full((batch_size,), 0, **dd)

    calc = dxtb.Calculator(batch['z'], dxtb.GFN1_XTB, **dd, opts=opts)

    e = calc.get_energy(batch['pos'], chrg=charges, spin=spin)
    forces = torch.autograd.grad(sum(e), batch['pos'], retain_graph=True)[0]
    
    # Features calc
    res = generate_xtb_features_dxtb(
        batch["z"],
        batch["pos"],
        charge=charges,
        spin=spin
    )


    

  z_batch = torch.tensor([zs] * len(pos_batch), device=dd["device"])  # [B, N]
C2H2N2O/rxn2091: : 2 datapoints [00:01,  1.97 datapoints/s, total=192]


RuntimeError: _Map_base::at

# Saving problematic batch!

In [None]:
# Save the numbers and positions
torch.save(
    {
        "numbers": numbers,
        "positions": positions,
    },
    "problematic_batch.pt",
)

In [None]:
# Load the problematic batch
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

In [2]:
import os
import sys
import h5py
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm

import dxtb

def generate_xtb_features_dxtb(
        element_numbers,
        coordinates,
        charge=0,
        spin=0,
        ):
    

    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}

    calc = dxtb.Calculator(element_numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    energy = calc.get_energy(coordinates, chrg=charge, spin=spin)
    forces = -torch.autograd.grad(energy.sum(), coordinates, retain_graph=True)[0]

    return energy, forces


# Load the problematic batch
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)
spin = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)

e = calc.get_energy(positions, chrg=charges, spin=spin)
forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

# Features calc
res = generate_xtb_features_dxtb(
   numbers,
    positions,
    charge=charges,
    spin=spin,
)

RuntimeError: _Map_base::at

# Disable caching!

In [8]:
import os
import sys
import h5py
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm

import dxtb
from dxtb.config import ConfigCache

def generate_xtb_features_dxtb(
        element_numbers,
        coordinates,
        charge=0,
        spin=0,
        ):
    

    dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
    opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}

    calc = dxtb.Calculator(element_numbers, dxtb.GFN1_XTB, **dd, opts=opts)

    energy = calc.get_energy(coordinates, chrg=charge, spin=spin)
    forces = -torch.autograd.grad(energy.sum(), coordinates, retain_graph=True)[0]

    return energy, forces

# Load the problematic batch
problematic_batch = torch.load("problematic_batch.pt", weights_only=False)
numbers = problematic_batch["numbers"]
positions = problematic_batch["positions"]

dd = {"dtype": torch.float32, "device": torch.device("cuda:0")}
opts = {"scf_mode": "full", "batch_mode": 2, "int_driver": "libcint"}

batch_size = numbers.shape[0]
charges = torch.full((batch_size,), 0, **dd)
spin = torch.full((batch_size,), 0, **dd)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)
calc.opts.cache = ConfigCache(enabled=False, density=False, fock=False)

e = calc.get_energy(positions, chrg=charges, spin=spin)
forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]

# Features calc
res = generate_xtb_features_dxtb(
   numbers,
    positions,
    charge=charges,
    spin=spin,
)

RuntimeError: _Map_base::at