In [1]:
import numpy as np
import torch
import torch.nn.functional
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R
from mace import data, modules, tools
from mace.tools import torch_geometric
torch.set_default_dtype(torch.float64)
config = data.Configuration(
    atomic_numbers=np.array([8, 4, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
# Created the rotated environment
rot = R.from_euler("z", 60, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T
config_rotated = data.Configuration(
    atomic_numbers=np.array([8, 4, 1]),
    positions=positions_rotated,
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
table = tools.AtomicNumberTable([1, 4, 8])
atomic_energies = np.array([1.0, 3.0, 4.0], dtype=float)




In [2]:

model_config = dict(
    r_max=5,
    num_bessel=8,
    num_polynomial_cutoff=6,
    max_ell=2,
    interaction_cls=modules.interaction_classes[
        "RealAgnosticResidualInteractionBlock"
    ],
    interaction_cls_first=modules.interaction_classes[
        "RealAgnosticResidualInteractionBlock"
    ],
    num_interactions=5,
    num_elements=3,
    hidden_irreps=o3.Irreps("32x0e + 32x1o"),
    MLP_irreps=o3.Irreps("16x0e"),
    gate=torch.nn.functional.silu,
    atomic_energies=atomic_energies,
    avg_num_neighbors=8,
    atomic_numbers=table.zs,
    correlation=3,
    radial_type="bessel",
)
model = modules.MACE(**model_config)
model_compiled = jit.compile(model)
atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
atomic_data2 = data.AtomicData.from_config(
    config_rotated, z_table=table, cutoff=3.0
)
data_loader = torch_geometric.dataloader.DataLoader(
    dataset=[atomic_data, atomic_data2],
    batch_size=2,
    shuffle=True,
    drop_last=False,
)
batch = next(iter(data_loader))
output1 = model(batch.to_dict(), training=True)
output2 = model_compiled(batch.to_dict(), training=True)
assert torch.allclose(output1["energy"][0], output2["energy"][0])
assert torch.allclose(output2["energy"][0], output2["energy"][1])



In [3]:
llpr_model = modules.models.LLPredRigidityMACE(model_compiled, ll_feat_format="avg_over_atom")
llpr_model(batch)

RecursiveScriptModule(
  original_name=LinearReadoutBlock
  (linear): RecursiveScriptModule(
    original_name=Linear
    (_compiled_main): RecursiveScriptModule(original_name=GraphModule)
  )
)
RecursiveScriptModule(
  original_name=LinearReadoutBlock
  (linear): RecursiveScriptModule(
    original_name=Linear
    (_compiled_main): RecursiveScriptModule(original_name=GraphModule)
  )
)
RecursiveScriptModule(
  original_name=LinearReadoutBlock
  (linear): RecursiveScriptModule(
    original_name=Linear
    (_compiled_main): RecursiveScriptModule(original_name=GraphModule)
  )
)
RecursiveScriptModule(
  original_name=LinearReadoutBlock
  (linear): RecursiveScriptModule(
    original_name=Linear
    (_compiled_main): RecursiveScriptModule(original_name=GraphModule)
  )
)
RecursiveScriptModule(
  original_name=NonLinearReadoutBlock
  (linear_1): RecursiveScriptModule(
    original_name=Linear
    (_compiled_main): RecursiveScriptModule(original_name=GraphModule)
  )
  (non_linearity): Rec

{'energy': tensor([8.1135, 8.1135], grad_fn=<SumBackward1>),
 'node_energy': tensor([3.7032, 2.8552, 1.5551, 3.7032, 2.8552, 1.5551],
        grad_fn=<SumBackward1>),
 'contributions': tensor([[ 8.0000,  0.1263,  0.1325, -0.0208, -0.0144, -0.1101],
         [ 8.0000,  0.1263,  0.1325, -0.0208, -0.0144, -0.1101]],
        grad_fn=<StackBackward0>),
 'forces': tensor([[-0.0159,  0.0239, -0.0000],
         [-0.2382, -0.0917, -0.0000],
         [ 0.2541,  0.0679, -0.0000],
         [ 0.0127,  0.0257, -0.0000],
         [-0.1985,  0.1604, -0.0000],
         [ 0.1858, -0.1861, -0.0000]]),
 'virials': None,
 'stress': None,
 'displacement': tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]),
 'll_feats': tensor([[ 0.1767,  0.1193, -0.4057,  ...,  0.0266,  0.0101,  0.0175],
         [ 0.1767,  0.1193, -0.4057,  ...,  0.0266,  0.0101,  0.0175]],
        grad_fn=<DivBackward0>),
 'uncertainty': None}

In [4]:
llpr_compiled = jit.compile(llpr_model)
llpr_compiled(batch.to_dict())

<__torch__.mace.modules.blocks.LinearReadoutBlock object at 0x29817c8d0>
<__torch__.mace.modules.blocks.___torch_mangle_203.LinearReadoutBlock object at 0x29817e740>


Error: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/Users/sanggyu/Research/mace/mace/modules/models.py", line 418, in forward
                node_energies = node_energies.squeeze(-1)  # [n_nodes, ]
            else:
                raise TypeError("Unknown readout for LLPR!")
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
            energy = scatter_sum(
builtins.TypeError: Unknown readout for LLPR!


In [10]:
for name, f in llpr_compiled.orig_model.readouts.named_children():
    print(name)

0
1
2
3
4
