In [2]:
import numpy as np
import torch

def _get_permutation_map_closedshell_convention(element_numbers_batch, calculator="tblite"):
    """
    Create permutation maps to reorder orbital matrices from calculator conventions
    to OrbNet-Equi's convention.

    Parameters:
    - element_numbers_batch: List or array of atomic numbers for each molecule in the batch.
    - calculator: str, either "tblite" or "dxtb".

    Returns:
    - List of numpy arrays containing permutation indices for each molecule.
    """
    perm_maps = []
    for element_numbers in element_numbers_batch:
        perm_map = []
        idx = 0
        for el in element_numbers:
            if el == 1:
                perm_map.extend([idx, idx + 1])
                idx += 2
            else:
                if calculator == "tblite":
                    # [s, px, py, pz] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 3, idx + 2, idx + 1])
                elif calculator == "dxtb":
                    # [s, pz, px, py] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 1, idx + 3, idx + 2])
                idx += 4
        perm_maps.append(perm_map)
    return np.array(perm_maps, dtype=np.int64)


def apply_perm_map(array, perm_maps, pos):
    """
    Apply a permutation map to a matrix or a batch of matrices.

    Parameters:
    - array: np.ndarray, shape (n, n), (batch_size, n, n), (batch_size, n, n, nat, 3).
    - perm_maps: List of np.ndarray, permutation indices for each element in the batch.
    - pos: np.ndarray, atomic positions (currently unused but could be used for future extensions).

    Returns:
    - Permuted array or batch of permuted arrays.
    """
    if len(perm_maps) == 1:  # Non-batched case
        p = perm_maps[0]
        if array.ndim == 2:  # Features (nao, nao)
            return array[p, :][:, p]
        elif array.ndim == 3 and array.shape[0] == 1:  # Non-batched with batch dimension (1, nao, nao)
            return array[:, p, :][:, :, p]
        elif array.ndim == 4:  # Gradients (nao, nao, nat, 3)
            return array[p, :, :, :][:, p, :, :]
        elif array.ndim == 5 and array.shape[0] == 1:  # Gradients with batch dimension
            return array[:, p, :, :, :][:, :, p, :, :]
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
    else:  # Batched case with multiple permutation maps
        perm_maps = np.array(perm_maps)
        if array.ndim == 3:  # Features (batch_size, nao, nao)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :], axis=2)
        elif array.ndim == 5:  # Gradients (batch_size, nao, nao, nat, 3)
            permuted_rows = np.take_along_axis(array, perm_maps[:, :, None, None, None], axis=1)
            return np.take_along_axis(permuted_rows, perm_maps[:, None, :, None, None], axis=2)
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")


def _get_permutation_map_closedshell_convention_torch(element_numbers_batch, calculator="tblite", device="cpu"):
    perm_maps = []
    for element_numbers in element_numbers_batch:
        perm_map = []
        idx = 0
        for el in element_numbers:
            if el == 1:
                perm_map.extend([idx, idx + 1])
                idx += 2
            else:
                if calculator == "tblite":
                    # [s, px, py, pz] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 3, idx + 2, idx + 1])
                elif calculator == "dxtb":
                    # [s, pz, px, py] -> [s, pz, py, px]
                    perm_map.extend([idx, idx + 1, idx + 3, idx + 2])
                idx += 4
        perm_maps.append(torch.tensor(perm_map, dtype=torch.long, device=device))
    return perm_maps


def apply_perm_map_torch(array, perm_maps, pos):
    """
    Apply a permutation map to a matrix or a batch of matrices.

    Parameters:
    - array: torch.Tensor, shape (n, n), (batch_size, n, n).
    - perm_maps: List of torch.Tensor, permutation indices for each element in the batch.
    - pos: torch.Tensor, atomic positions (currently unused but could be used for future extensions).

    Returns:
    - Permuted array or batch of permuted arrays.
    """
    device = array.device  # Ensure all tensors are moved to the device of `array`

    if len(perm_maps) == 1:  # Non-batched case
        p = perm_maps[0].to(device)  # Move perm_map to the correct device
        if array.ndim == 2:  # Features (nao, nao)
            return array[p][:, p]
        elif array.ndim == 3 and array.shape[0] == 1:  # Features with batch dimension (1, nao, nao)
            return array[:, p][:, :, p]
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")
    else:  # Batched case
        batch_size = array.shape[0]
        n = array.shape[1]
        perm_maps_tensor = torch.stack(perm_maps).to(device)  # Move perm_maps to the correct device
        if array.ndim == 3:  # Features (batch_size, nao, nao)
            # Permute rows
            perm_maps_row = perm_maps_tensor.unsqueeze(2).expand(-1, -1, n)
            array = torch.gather(array, dim=1, index=perm_maps_row)
            # Permute columns
            perm_maps_col = perm_maps_tensor.unsqueeze(1).expand(-1, n, -1)
            array = torch.gather(array, dim=2, index=perm_maps_col)
            return array
        else:
            raise ValueError(f"Unsupported array shape {array.shape}.")


In [16]:
numbers = torch.tensor([3, 1])
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.5]]) 

F = torch.tensor([[1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
                  [2.1, 2.2, 2.3, 2.4, 2.5, 2.6],
                  [3.1, 3.2, 3.3, 3.4, 3.5, 3.6],
                  [4.1, 4.2, 4.3, 4.4, 4.5, 4.6],
                  [5.1, 5.2, 5.3, 5.4, 5.5, 5.6],
                  [6.1, 6.2, 6.3, 6.4, 6.5, 6.6]])
F_batch = torch.stack([F, F])

perm_maps_np = _get_permutation_map_closedshell_convention([numbers], calculator="dxtb")
perm_maps_torch = _get_permutation_map_closedshell_convention_torch([numbers], calculator="dxtb", device="cpu")

print(f"perm_maps_np: {perm_maps_np}")
print(f"perm_maps_torch: {perm_maps_torch}")

F_perm_np = apply_perm_map(F.numpy(), perm_maps_np, positions.numpy())
F_perm_torch = apply_perm_map_torch(F, perm_maps_torch, positions)

print(f"F_perm_np: {F_perm_np}")
print(f"F_perm_torch: {F_perm_torch}")

perm_maps_np_batch = _get_permutation_map_closedshell_convention([numbers, numbers], calculator="dxtb")
perm_maps_torch_batch = _get_permutation_map_closedshell_convention_torch([numbers, numbers], calculator="dxtb", device="cpu")

print(f"perm_maps_np_batch: {perm_maps_np_batch}")
print(f"perm_maps_torch_batch: {perm_maps_torch_batch}")

F_perm_np_batch = apply_perm_map(F_batch.numpy(), perm_maps_np_batch, positions.numpy())
F_perm_torch_batch = apply_perm_map_torch(F_batch, perm_maps_torch_batch, positions)

print(f"F_perm_np_batch: {F_perm_np_batch}")
print(f"F_perm_torch_batch: {F_perm_torch_batch}")

perm_maps_np: [[0 1 3 2 4 5]]
perm_maps_torch: [tensor([0, 1, 3, 2, 4, 5])]
F_perm_np: [[1.1 1.2 1.4 1.3 1.5 1.6]
 [2.1 2.2 2.4 2.3 2.5 2.6]
 [4.1 4.2 4.4 4.3 4.5 4.6]
 [3.1 3.2 3.4 3.3 3.5 3.6]
 [5.1 5.2 5.4 5.3 5.5 5.6]
 [6.1 6.2 6.4 6.3 6.5 6.6]]
F_perm_torch: tensor([[1.1000, 1.2000, 1.4000, 1.3000, 1.5000, 1.6000],
        [2.1000, 2.2000, 2.4000, 2.3000, 2.5000, 2.6000],
        [4.1000, 4.2000, 4.4000, 4.3000, 4.5000, 4.6000],
        [3.1000, 3.2000, 3.4000, 3.3000, 3.5000, 3.6000],
        [5.1000, 5.2000, 5.4000, 5.3000, 5.5000, 5.6000],
        [6.1000, 6.2000, 6.4000, 6.3000, 6.5000, 6.6000]])
perm_maps_np_batch: [[0 1 3 2 4 5]
 [0 1 3 2 4 5]]
perm_maps_torch_batch: [tensor([0, 1, 3, 2, 4, 5]), tensor([0, 1, 3, 2, 4, 5])]
F_perm_np_batch: [[[1.1 1.2 1.4 1.3 1.5 1.6]
  [2.1 2.2 2.4 2.3 2.5 2.6]
  [4.1 4.2 4.4 4.3 4.5 4.6]
  [3.1 3.2 3.4 3.3 3.5 3.6]
  [5.1 5.2 5.4 5.3 5.5 5.6]
  [6.1 6.2 6.4 6.3 6.5 6.6]]

 [[1.1 1.2 1.4 1.3 1.5 1.6]
  [2.1 2.2 2.4 2.3 2.5 2.6]
  [4.1 4.2 4.