In [232]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [233]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [234]:
import numpy as np
import torch
import torch.nn as nn

In [235]:
import cace
from cace import data

In [236]:
from scipy.spatial.transform import Rotation as R

In [237]:
torch.set_default_dtype(torch.float32)

In [238]:
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=np.array(
        [
            [0.0452, -2.02, 0.0452],
            [1.0145, 0.034, 0.0232],
            [0.0111, 1.041, -0.010],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)
# Created the rotated environment
rot = R.from_euler("z", 70, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T

rot = R.from_euler("x", 10.6, degrees=True).as_matrix()
positions_rotated = np.array(rot @ positions_rotated.T).T

rot = R.from_euler("y", 190, degrees=True).as_matrix()
positions_rotated = np.array(rot @ positions_rotated.T).T

config_rotated = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=positions_rotated,
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    charges=np.array([-2.0, 1.0, 1.0]),
    dipole=np.array([-1.5, 1.5, 2.0]),
)


In [239]:
cutoff = 5.0

In [240]:
atomic_data = data.AtomicData.from_config(config, cutoff=cutoff)
atomic_data2 = data.AtomicData.from_config(config_rotated, cutoff=cutoff)

In [241]:
atomic_data.positions

tensor([[ 0.0452, -2.0200,  0.0452],
        [ 1.0145,  0.0340,  0.0232],
        [ 0.0111,  1.0410, -0.0100]])

In [242]:
atomic_data2.positions

tensor([[-1.8716, -0.6457,  0.4060],
        [-0.3450,  0.9442, -0.1426],
        [ 0.9496,  0.3621, -0.2259]])

In [243]:
from cace.representations.cace_representation import Cace

In [244]:
atomic_data

AtomicData(edge_index=[2, 6], num_nodes=3, positions=[3, 3], shifts=[6, 3], unit_shifts=[6, 3], cell=[3, 3], atomic_numbers=[3], weight=1.0, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, virials_weight=1.0, forces=[3, 3], energy=-1.5, dipole=[1, 3], charges=[3])

In [245]:
from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff
from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered, ExponentialDecayRBF
from cace.modules import EdgeEncoder

In [246]:
edge_coding = EdgeEncoder(directed=True)
#radial_basis = BesselRBF(cutoff=cutoff, n_rbf=4, trainable=False)
radial_basis = ExponentialDecayRBF(n_rbf=4, cutoff=2, prefactor=1, trainable=True)
cutoff_fn = CosineCutoff(cutoff=cutoff)

In [247]:
cace_representation = Cace(
    zs=[1,8],
    n_atom_basis=2,
    edge_coding=edge_coding,
    cutoff=cutoff,
    cutoff_fn=cutoff_fn,
    radial_basis=radial_basis,
    max_l=4,
    max_nu=4,
    num_message_passing=2,
    timeit=True
           )

In [248]:
cace_representation.l_list

In [249]:
cace_representation(atomic_data)

node_one_hot time: 0.00031113624572753906
node_embedded time: 0.0005676746368408203
encoded_edges time: 0.0014290809631347656
edge_vectors time: 0.0005381107330322266
radial and angular component time: 0.0010600090026855469
elementwise_multiply_3tensors time: 0.0005650520324707031
radial_transform time: 0.003568887710571289
scatter_sum time: 0.0007491111755371094
symmetrizer time: 0.059903860092163086
message passing time: 0.06686806678771973
message passing time: 0.07671713829040527


AtomicData(edge_index=[2, 6], num_nodes=3, positions=[3, 3], shifts=[6, 3], unit_shifts=[6, 3], cell=[3, 3], atomic_numbers=[3], weight=1.0, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, virials_weight=1.0, forces=[3, 3], energy=-1.5, dipole=[1, 3], charges=[3], node_feat_B=[3, 4, 13, 4, 3])

In [250]:
atomic_data.positions

tensor([[ 0.0452, -2.0200,  0.0452],
        [ 1.0145,  0.0340,  0.0232],
        [ 0.0111,  1.0410, -0.0100]])

In [251]:
cace_representation(atomic_data2)

node_one_hot time: 0.0002789497375488281
node_embedded time: 0.0002601146697998047
encoded_edges time: 0.0012738704681396484
edge_vectors time: 0.00047206878662109375
radial and angular component time: 0.0010080337524414062
elementwise_multiply_3tensors time: 0.0006868839263916016
radial_transform time: 0.002415180206298828
scatter_sum time: 0.0003368854522705078
symmetrizer time: 0.0652761459350586
message passing time: 0.062361955642700195
message passing time: 0.052996158599853516


AtomicData(edge_index=[2, 6], num_nodes=3, positions=[3, 3], shifts=[6, 3], unit_shifts=[6, 3], cell=[3, 3], atomic_numbers=[3], weight=1.0, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, virials_weight=1.0, forces=[3, 3], energy=-1.5, dipole=[1, 3], charges=[3], node_feat_B=[3, 4, 13, 4, 3])

In [252]:
atomic_data2.positions

tensor([[-1.8716, -0.6457,  0.4060],
        [-0.3450,  0.9442, -0.1426],
        [ 0.9496,  0.3621, -0.2259]])

In [253]:
features = atomic_data['node_feat_B']

In [254]:
features2 = atomic_data2['node_feat_B']

In [255]:
torch.allclose(features, features2, rtol=1e-05, atol=1e-05)

True

In [256]:
atomic_data['node_feat_B'].shape

torch.Size([3, 4, 13, 4, 3])

In [257]:
#[n_nodes, radial_dim, angular_dim, embedding_dim, message_passing_layer]

In [258]:
features[0,1,:,0,0]

tensor([-1.0894e-02,  1.1522e-04,  1.1211e-04,  1.0932e-04,  1.0680e-04,
        -1.2194e-06, -1.1873e-06, -1.1584e-06, -1.1567e-06,  1.2566e-08,
         1.2250e-08,  1.2260e-08,  1.1938e-08], grad_fn=<SelectBackward0>)

In [259]:
features[0,1,:,0,1]

tensor([6.8958e-02, 1.4369e-04, 6.8195e-03, 1.9604e-04, 7.1297e-03, 8.9827e-06,
        1.0936e-05, 1.0730e-05, 5.5299e-04, 9.2028e-07, 9.2413e-07, 9.4158e-07,
        1.1274e-06], grad_fn=<SelectBackward0>)

In [260]:
features[0,1,:,0,2]

tensor([1.9175e-01, 3.1308e-04, 4.5179e-02, 3.5143e-04, 4.6458e-02, 6.5051e-05,
        6.8865e-05, 6.9550e-05, 9.5978e-03, 1.4619e-05, 1.4723e-05, 1.4811e-05,
        1.5706e-05], grad_fn=<SelectBackward0>)

In [261]:
features2[0,1,:,0,1]

tensor([6.8958e-02, 1.4369e-04, 6.8195e-03, 1.9604e-04, 7.1297e-03, 8.9826e-06,
        1.0936e-05, 1.0730e-05, 5.5299e-04, 9.2028e-07, 9.2412e-07, 9.4158e-07,
        1.1274e-06], grad_fn=<SelectBackward0>)

In [262]:
atomic_data.edge_index

tensor([[0, 0, 1, 1, 2, 2],
        [1, 2, 0, 2, 0, 1]])