In [1]:
import torch
import dxtb
from dxtb.typing import DD

from qcm_ml.orbnet_equi.model import generate_xtb_matrices_dxtb

def add_padding(numbers_list, positions_list, padding_value=0):
    """
    Pads a list of atomic numbers and positions tensors to match the longest tensor in the batch.
    
    Parameters:
    - numbers_list (List[torch.Tensor]): List of 1D tensors with atomic numbers for each molecule.
    - positions_list (List[torch.Tensor]): List of 2D tensors with atomic positions for each molecule.
    - padding_value (int): Value used for padding (default: 0).
    
    Returns:
    - padded_numbers (torch.Tensor): Padded atomic numbers, shape (nbatch, max_natoms).
    - padded_positions (torch.Tensor): Padded atomic positions, shape (nbatch, max_natoms, 3).
    """
    max_natoms = max(numbers.size(0) for numbers in numbers_list)
    nbatch = len(numbers_list)
    
    # Initialize padded tensors with padding_value
    padded_numbers = torch.full((nbatch, max_natoms), padding_value, dtype=numbers_list[0].dtype, device=numbers_list[0].device)
    padded_positions = torch.full((nbatch, max_natoms, 3), padding_value, dtype=positions_list[0].dtype, device=positions_list[0].device)

    for i in range(nbatch):
        natoms = numbers_list[i].size(0)
        padded_numbers[i, :natoms] = numbers_list[i]
        padded_positions[i, :natoms, :] = positions_list[i]

    return padded_numbers, padded_positions


def remove_padding(numbers, positions, padding_value=0):
    """
    Removes padding atoms and corresponding coordinates for batched data.
    
    Parameters:
    - numbers (torch.Tensor): Tensor of atomic numbers with padding, shape (nbatch, natoms).
    - positions (torch.Tensor): Tensor of atomic positions with padding, shape (nbatch, natoms, 3).
    - padding_value (int): The padding value in `numbers` to remove (default: 0).
    
    Returns:
    - cleaned_numbers (List[torch.Tensor]): List of tensors with non-padded atomic numbers for each batch.
    - cleaned_positions (List[torch.Tensor]): List of tensors with non-padded atomic positions for each batch.
    """
    cleaned_numbers = []
    cleaned_positions = []
    
    for i in range(numbers.size(0)):  # Iterate over each batch
        # Mask to identify non-padding atoms
        mask = numbers[i] != padding_value
        # Apply the mask to remove padding
        cleaned_numbers.append(numbers[i][mask])
        cleaned_positions.append(positions[i][mask])
        
    return cleaned_numbers, cleaned_positions


dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

numbers = torch.tensor(
    [
        [3, 1, 0],
        [8, 1, 1],
    ],
    device=dd["device"],
)
positions = torch.tensor(
    [
        [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 1.0],
            [0.0, 0.0, 0.0],
        ],
        [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 1.0],
            [0.0, 0.0, 2.0],
        ],
    ],
    **dd
).requires_grad_(True)
print(f"positions.shape (nbatch, nat, 3): {positions.shape}")

charge = torch.tensor([0, 0], **dd)

n_ls, pos_ls = remove_padding(numbers, positions)
print(f"numbers_list: {n_ls}")
print(f"positions_list: {pos_ls}")

padded_numbers, padded_positions = add_padding(n_ls, pos_ls)
print(f"padded_numbers: {padded_numbers}")
print(f"padded_positions: {padded_positions}")

positions.shape (nbatch, nat, 3): torch.Size([2, 3, 3])
numbers_list: [tensor([3, 1]), tensor([8, 1, 1])]
positions_list: [tensor([[0., 0., 0.],
        [0., 0., 1.]], dtype=torch.float64, grad_fn=<IndexBackward0>), tensor([[0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 2.]], dtype=torch.float64, grad_fn=<IndexBackward0>)]
padded_numbers: tensor([[3, 1, 0],
        [8, 1, 1]])
padded_positions: tensor([[[0., 0., 0.],
         [0., 0., 1.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 1.],
         [0., 0., 2.]]], dtype=torch.float64, grad_fn=<CopySlices>)


# Batched vs normal calculations

In [2]:

# Batched
opts = {"verbosity": 0, "batch_mode": 1}
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
energy = calc.get_energy(positions, chrg=charge, **dd)
# forces = calc.get_forces(positions, chrg=charge) # Does not work for batched calculations
print(f"energy: {energy}")

# Individual
opts_ind = {"verbosity": 0, "batch_mode": 0}
calc0 = dxtb.Calculator(n_ls[0], dxtb.GFN1_XTB, opts=opts_ind, **dd)
calc1 = dxtb.Calculator(n_ls[1], dxtb.GFN1_XTB, opts=opts_ind, **dd)

energy0 = calc0.get_energy(pos_ls[0])
energy1 = calc1.get_energy(pos_ls[1])
print(f"energy0: {energy0}, energy1: {energy1}")

energy: tensor([ 0.0111, -4.8440], dtype=torch.float64, grad_fn=<SumBackward1>)
energy0: 0.01105307556317131, energy1: -4.843954905460443


# Forces in batches with energy.sum()

In [3]:
grad = torch.autograd.grad(energy.sum(), positions, retain_graph=True)[0]

grad0 = calc0.get_forces(pos_ls[0])
grad1 = calc1.get_forces(pos_ls[1])

torch.set_printoptions(precision=4, sci_mode=False)
print(f"grad: {grad}")
print()
print(f"grad0: {grad0}")
print(f"grad1: {grad1}")

grad: tensor([[[     0.0000,      0.0000,      2.5666],
         [    -0.0000,     -0.0000,     -2.5666],
         [     0.0000,      0.0000,      0.0000]],

        [[    -0.0000,     -0.0000,      2.8084],
         [     0.0000,      0.0000,     -2.2184],
         [     0.0000,      0.0000,     -0.5900]]], dtype=torch.float64)

grad0: tensor([[     0.0000,      0.0000,     -2.5666],
        [    -0.0000,     -0.0000,      2.5666]], dtype=torch.float64)
grad1: tensor([[    -0.0000,      0.0000,     -2.8084],
        [    -0.0000,     -0.0000,      2.2184],
        [     0.0000,     -0.0000,      0.5900]], dtype=torch.float64)


# Differentiate wrt to 2 tensors

In [4]:
torch.autograd.grad(energy0, [pos_ls[0], pos_ls[0]], retain_graph=True, allow_unused=True)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.