In [3]:
import torch 
import numpy as np
from typing import Dict, List, Optional
from scipy.stats import moment
from ase.io import read
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import (
    AtomisticModel,
    ModelCapabilities,
    ModelEvaluationOptions,
    ModelMetadata,
    ModelOutput,
    System,
    systems_to_torch,
)
from featomic.torch import SoapPowerSpectrum

structures = read('/Users/markusfasching/EPFL/Work/project-SOAP/scripts/SOAP-time-code/data/interfaces/250_275_fast/positions.lammpstrj', index='::50')
systems = systems_to_torch(structures[:1])

In [4]:

HYPER_PARAMETERS = {
    "cutoff": {
        "radius": 4, #4 #5 #6
        "smoothing": {"type": "ShiftedCosine", "width": 0.5},
    },
    "density": {
        "type": "Gaussian",
        "width": 0.25, #changed from 0.3
    },
    "basis": {
        "type": "TensorProduct",
        "max_angular": 2, #8
        "radial": {"type": "Gto", "max_radial": 2}, #6
    },
}
calculator = SoapPowerSpectrum(**HYPER_PARAMETERS)
centers = [8]
neighbors = [1]
selected_keys = Labels(
    names=["center_type", "neighbor_1_type", "neighbor_2_type"],
    values=torch.tensor([[i,j,k] for i in centers for j in neighbors for k in neighbors if j <=
        k], dtype=torch.int32),
)
selected_atoms = [i for i in range(10)]
selected_samples = Labels(
            names=["atom"],
            values=torch.tensor(selected_atoms, dtype=torch.int64).unsqueeze(-1),
        )


soap = calculator(
    systems,
    selected_samples=selected_samples,
    selected_keys=selected_keys,
)

soap = soap.keys_to_samples("center_type")
soap = soap.keys_to_properties(["neighbor_1_type", "neighbor_2_type"])
soap_block = soap.block()


In [None]:
def compute_cumulants_fwd(X: torch.Tensor, n_cumulants: int):
        """
        TorchScript-friendly computation of cumulants.

        X: (N, P) tensor
        n_cumulants: number of cumulants per feature
        returns: (N, P * n_cumulants) tensor
        """
        # ensure float
        X = X.float()
        N, P = X.shape  # Python ints

        # Preallocate output, per structure: N=1
        out = torch.empty((1, P * n_cumulants), dtype=X.dtype, device=X.device)

        # Temporary tensors reused per feature
        moments = torch.empty((n_cumulants,), dtype=X.dtype, device=X.device)
        c = torch.empty((n_cumulants,), dtype=X.dtype, device=X.device)

        jbase = 0
        for j in range(P):
            x = X[:, j]

            # mean
            m = torch.mean(x)
            centered = x - m

            # compute central moments μ_k = mean((x - m)^k) for k=1..n_cumulants
            k = 1
            while k <= n_cumulants:
                moments[k - 1] = torch.mean(centered ** k)
                k += 1

            # fill cumulant vector c
            # 1st cumulant = mean
            c[0] = m

            # 2nd cumulant = variance (μ2)
            if n_cumulants > 1:
                c[1] = moments[1 - 1]  # μ2

            # 3rd cumulant = μ3
            if n_cumulants > 2:
                c[2] = moments[2 - 1]  # μ3

            # 4th cumulant = μ4 − 3 μ2²
            if n_cumulants > 3:
                mu2 = moments[1]
                mu4 = moments[3 - 1]
                c[3] = mu4 - 3.0 * (mu2 * mu2)

            # 5th cumulant = μ5 − 10 μ2 μ3
            if n_cumulants > 4:
                mu2 = moments[1]
                mu3 = moments[2]
                mu5 = moments[5 - 1]
                c[4] = mu5 - 10.0 * mu2 * mu3

            # broadcast c to N rows without extra Python list
            # c_row: (1, n_cumulants) then expanded to (N, n_cumulants)
            c_row = c.unsqueeze(0)  # no new allocation for repeated view
            # write into output slice
            out[:, jbase:jbase + n_cumulants] = c_row

            jbase += n_cumulants

        return out

In [21]:
new = compute_cumulants_fwd(soap_block.values, 3)

torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([1, 3])


In [19]:
new.shape

torch.Size([4, 81])

In [16]:
soap_block.shape


[4, 27]