In [1]:
import numpy as np
import torch
import torch.nn.functional
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R

from mace import data, modules, tools
from mace.tools import torch_geometric

torch.set_default_dtype(torch.float64)

config = data.Configuration(
    atomic_numbers=np.array([8, 4, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
# Created the rotated environment
rot = R.from_euler("z", 60, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T
config_rotated = data.Configuration(
    atomic_numbers=np.array([8, 4, 1]),
    positions=positions_rotated,
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
table = tools.AtomicNumberTable([1, 4, 8])
atomic_energies = np.array([1.0, 3.0, 4.0], dtype=float)



In [2]:

model_config = dict(
    r_max=5,
    num_bessel=8,
    num_polynomial_cutoff=6,
    max_ell=2,
    interaction_cls=modules.interaction_classes[
        "RealAgnosticResidualInteractionBlock"
    ],
    interaction_cls_first=modules.interaction_classes[
        "RealAgnosticResidualInteractionBlock"
    ],
    num_interactions=5,
    num_elements=3,
    hidden_irreps=o3.Irreps("32x0e + 32x1o"),
    MLP_irreps=o3.Irreps("16x0e"),
    gate=torch.nn.functional.silu,
    atomic_energies=atomic_energies,
    avg_num_neighbors=8,
    atomic_numbers=table.zs,
    correlation=3,
    radial_type="bessel",
)

model = modules.MACE(**model_config)
model_compiled = jit.compile(model)


atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
atomic_data2 = data.AtomicData.from_config(
    config_rotated, z_table=table, cutoff=3.0
)

data_loader = torch_geometric.dataloader.DataLoader(
    dataset=[atomic_data, atomic_data2],
    batch_size=2,
    shuffle=True,
    drop_last=False,
)
batch = next(iter(data_loader))
output1 = model(batch.to_dict(), training=True)
output2 = model_compiled(batch.to_dict(), training=True)
assert torch.allclose(output1["energy"][0], output2["energy"][0])
assert torch.allclose(output2["energy"][0], output2["energy"][1])




In [3]:
llpr_model = modules.models.MACE_LLPR(model, per_atom=True, **model_config)



x
x
x
x
x
x
x
x
y
288




In [4]:
output1["node_feats"].shape

torch.Size([6, 544])

In [None]:
llpr_model["nod"]