In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [3]:
import numpy as np
import torch
import torch.nn as nn

In [4]:
import sys
sys.path.append('../')

In [5]:
import cace
from cace import data

In [6]:
from scipy.spatial.transform import Rotation as R

In [7]:
torch.set_default_dtype(torch.float64)

In [8]:
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1, 1]),
    positions=np.array(
        [
            [0.0452, -2.02, 0.0452],
            [1.0145, 0.034, 0.0232],
            [0.0111, 1.041, -0.010],
             [-0.0111, -0.041, 0.510],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0, 0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
# Created the rotated environment
rot = R.from_euler("z", 70, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T

rot = R.from_euler("x", 10.6, degrees=True).as_matrix()
positions_rotated = np.array(rot @ positions_rotated.T).T

rot = R.from_euler("y", 190, degrees=True).as_matrix()
positions_rotated = np.array(rot @ positions_rotated.T).T

config_rotated = data.Configuration(
    atomic_numbers=np.array([8, 1, 1, 1]),
    positions=positions_rotated,
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0, 0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)


In [9]:
cutoff = 5.0

In [10]:
atomic_data = data.AtomicData.from_config(config, cutoff=cutoff)
atomic_data2 = data.AtomicData.from_config(config_rotated, cutoff=cutoff)

In [11]:
atomic_data.positions

tensor([[ 0.0452, -2.0200,  0.0452],
        [ 1.0145,  0.0340,  0.0232],
        [ 0.0111,  1.0410, -0.0100],
        [-0.0111, -0.0410,  0.5100]])

In [12]:
atomic_data2.positions

tensor([[-1.8716, -0.6457,  0.4060],
        [-0.3450,  0.9442, -0.1426],
        [ 0.9496,  0.3621, -0.2259],
        [-0.1205, -0.1179, -0.4832]])

In [13]:
from cace.representations.cace_representation import Cace

In [14]:
atomic_data

AtomicData(atomic_numbers=[4], cell=[3, 3], charges=[4], dipole=[1, 3], edge_index=[2, 12], energy=-1.5, energy_weight=1.0, forces=[4, 3], forces_weight=1.0, positions=[4, 3], shifts=[12, 3], stress_weight=1.0, unit_shifts=[12, 3], virials_weight=1.0, weight=1.0)

In [15]:
from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff
from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered, ExponentialDecayRBF

In [16]:
radial_basis = BesselRBF(cutoff=cutoff, n_rbf=5, trainable=False)
#radial_basis = ExponentialDecayRBF(n_rbf=4, cutoff=cutoff, prefactor=1, trainable=True)
cutoff_fn = CosineCutoff(cutoff=cutoff)

In [17]:
cace_representation = Cace(
    zs=[1,8],
    n_atom_basis=2,
    cutoff=cutoff,
    cutoff_fn=cutoff_fn,
    radial_basis=radial_basis,
    max_l=6,
    max_nu=4,
    num_message_passing=2,
    timeit=True
           )

In [18]:
cace_result = cace_representation(atomic_data)

node_one_hot time: 0.0028450489044189453
node_embedded time: 0.0012078285217285156
encoded_edges time: 0.0010211467742919922
edge_vectors time: 0.0008630752563476562
radial and angular component time: 0.0019729137420654297
elementwise_multiply_3tensors time: 0.003393888473510742
scatter_sum time: 0.016660213470458984
radial_transform time: 0.0028228759765625
symmetrizer time: 0.18505287170410156
message passing time: 0.1920318603515625
message passing time: 0.18765807151794434


In [19]:
cace_result2 = cace_representation(atomic_data2)

node_one_hot time: 9.703636169433594e-05
node_embedded time: 0.0001971721649169922
encoded_edges time: 6.67572021484375e-05
edge_vectors time: 4.601478576660156e-05
radial and angular component time: 0.0005249977111816406
elementwise_multiply_3tensors time: 0.002585887908935547
scatter_sum time: 5.602836608886719e-05
radial_transform time: 0.00040411949157714844
symmetrizer time: 0.18598198890686035
message passing time: 0.1862947940826416
message passing time: 0.1874370574951172


In [20]:
features = cace_result['node_feats']

In [21]:
features2 = cace_result2['node_feats']

In [22]:
torch.allclose(features, features2, rtol=1e-05, atol=1e-05)

True

In [23]:
features.shape

torch.Size([4, 5, 36, 4, 3])

In [24]:
features[0,2,:,0,0]

tensor([ 6.4856e-02,  2.4586e-02,  7.7137e-03,  1.8188e-02,  6.4129e-03,
         5.0591e-03,  2.4476e-03,  1.9887e-03, -1.5421e-03,  1.5007e-03,
        -7.1045e-04,  4.2688e-04, -4.2294e-04, -7.4507e-04, -1.5998e-04,
         3.8424e-04, -1.2107e-04,  1.1012e-04, -6.3107e-05,  3.9755e-05,
         1.1552e-04,  1.7388e-04,  2.5857e-05, -5.9758e-05,  2.3023e-05,
         3.8166e-05, -9.4524e-05,  4.4998e-05, -2.4850e-05,  4.4167e-05,
         5.6871e-06, -2.6322e-05, -4.2208e-05,  1.3931e-05,  2.4055e-05,
        -1.3016e-05], grad_fn=<SelectBackward0>)

In [25]:
features[0,1,:,0,1]

tensor([1.0186e-01, 6.2837e-04, 7.7271e-04, 9.2403e-04, 8.6461e-03, 2.4789e-04,
        1.9048e-02, 1.3659e-05, 1.5787e-05, 4.3820e-05, 9.1601e-06, 2.8647e-05,
        4.9899e-05, 7.3475e-06, 2.1580e-04, 6.7309e-05, 3.5177e-07, 9.6967e-07,
        2.6066e-07, 5.2568e-07, 9.4985e-07, 2.3408e-07, 3.7018e-06, 2.4481e-07,
        3.7749e-06, 6.5026e-07, 1.0861e-06, 4.2402e-07, 5.4014e-07, 2.9400e-07,
        1.3626e-05, 7.3267e-07, 6.0503e-07, 1.6987e-06, 1.9980e-06, 5.1601e-07],
       grad_fn=<SelectBackward0>)

In [26]:
features[0,1,:,0,2]

tensor([8.0173e-02, 3.6753e-05, 1.3638e-03, 1.4403e-04, 5.4463e-03, 5.1618e-05,
        4.9542e-03, 9.9414e-07, 1.8730e-06, 4.0433e-06, 2.4111e-06, 1.6095e-06,
        7.1048e-05, 1.9018e-06, 1.2551e-04, 5.4381e-06, 5.1122e-08, 1.0345e-07,
        5.6573e-08, 4.2364e-08, 1.0637e-07, 4.8301e-09, 1.5224e-07, 6.2398e-08,
        1.5803e-07, 4.6763e-08, 2.0548e-07, 1.1100e-07, 7.2424e-08, 1.2284e-07,
        6.4900e-06, 8.3540e-08, 2.6921e-07, 1.5589e-07, 1.7261e-07, 1.0176e-07],
       grad_fn=<SelectBackward0>)

In [27]:
features2[0,1,:,0,1]

tensor([1.0186e-01, 6.2837e-04, 7.7271e-04, 9.2403e-04, 8.6461e-03, 2.4789e-04,
        1.9048e-02, 1.3659e-05, 1.5787e-05, 4.3820e-05, 9.1601e-06, 2.8647e-05,
        4.9899e-05, 7.3475e-06, 2.1580e-04, 6.7309e-05, 3.5177e-07, 9.6967e-07,
        2.6066e-07, 5.2568e-07, 9.4985e-07, 2.3408e-07, 3.7018e-06, 2.4481e-07,
        3.7749e-06, 6.5026e-07, 1.0861e-06, 4.2402e-07, 5.4014e-07, 2.9400e-07,
        1.3626e-05, 7.3267e-07, 6.0503e-07, 1.6987e-06, 1.9980e-06, 5.1601e-07],
       grad_fn=<SelectBackward0>)