In [1]:
from ase.io import read

import matplotlib.pyplot as plt
import numpy as np

import random

import metatensor as mts
from metatensor import Labels

from rascaline import SoapPowerSpectrum, LodeSphericalExpansion, SphericalExpansion
from rascaline.utils import MonomialBasis, LodeDensity, LodeSpliner, SoapSpliner, DensityCorrelations


### utility functions

In [2]:

def train_model(X_struc, E_struc, alpha):

    ## get Cov matrix
    XX = X_struc.T @ X_struc

    ## add diagonal regularization matrix of the smoothness prior multiplied by alpha,
    ## then also with body-order dependent beta*exp(nu)
    reg_mat = np.eye(XX.shape[0]) * alpha
    Xprime = XX + reg_mat
    Xinv = np.linalg.inv(Xprime)

    E_mean = E_struc.mean()
    Y = X_struc.T @ (E_struc - E_mean)
    
    weights = Xinv @ Y
        
    return weights, E_mean


def predict_compwise(X_struc, weights, E_mean, comp_dims):

    raw_E = X_struc @ weights
    tot_pred = raw_E + E_mean
    
    cw_preds = np.zeros((len(comp_dims), len(X_struc)))
    comp_idxs = np.cumsum(np.array([0] + comp_dims))

    for ci in range(len(comp_dims)):
        cur_mask = np.zeros(np.array(comp_dims).sum())
        cur_mask[comp_idxs[ci]:comp_idxs[ci+1]] += np.ones(comp_dims[ci])
        masked_weights = weights * cur_mask
        cw_preds[ci] = X_struc @ masked_weights + E_mean

    return tot_pred, cw_preds


In [3]:
def calculate_CPR(
    train_struc_feats,
    test_struc_feats,
    alpha,
    comp_dims,
):
    
    X_struc_train = train_struc_feats
    X_struc_test = test_struc_feats
        
    XX = X_struc_train.T @ X_struc_train
    reg_mat = (np.eye(XX.shape[0])) * alpha
    Xprime = XX + reg_mat

    Xinv = np.linalg.inv(Xprime)

    CPR = np.zeros((len(comp_dims), len(X_struc_test)))

    comp_idxs = np.cumsum(np.array([0] + comp_dims))
    
    for ci in range(len(comp_dims)):
        cur_mask = np.zeros(np.array(comp_dims).sum())
        cur_mask[comp_idxs[ci]:comp_idxs[ci+1]] += np.ones(comp_dims[ci])
        X_struc_test_cur = X_struc_test * cur_mask
        CPR[ci] = 1 / np.einsum("ij, jk, ik -> i", X_struc_test_cur, Xinv, X_struc_test_cur)
    
    return CPR


### load water datasets

In [4]:
monomers = read("datasets/sel_large_monomers.xyz", ":")
dimers = read("datasets/sel_large_dimers.xyz", ":")

OO_dist = np.array([f.get_distance(0, 3) for f in dimers])
dimers = [dimers[ii] for ii in OO_dist.argsort()]

In [5]:
monomer_E = np.array([f.info['energy']/len(f) for f in monomers])
dimer_E = np.array([f.info['energy']/len(f) for f in dimers])

### compute SOAP Powerspectrum

In [6]:

HYPER_PARAMETERS = {
    "cutoff": 2.8,
    "max_radial": 4,
    "max_angular": 3,
    "atomic_gaussian_width": 0.3,
    "center_atom_weight": 1.0,
    "radial_basis": {
        "Gto": {},
    },
    "cutoff_function": {
        "ShiftedCosine": {"width": 0.25},
    },
}

soap_calculator = SoapPowerSpectrum(**HYPER_PARAMETERS)

In [7]:
monomer_SOAP = soap_calculator.compute(monomers).keys_to_properties(['neighbor_1_type', 'neighbor_2_type', 'center_type'])
dimer_SOAP = soap_calculator.compute(dimers).keys_to_properties(['neighbor_1_type', 'neighbor_2_type', 'center_type'])

### compute LODE Powerspectrum

In [8]:
cutoff = 2.8
potential_exponent = 3
max_radial = 4
max_angular = 3
atomic_gaussian_width = 1.2
center_atom_weight = 1.0
spline_accuracy = 1e-10

In [9]:
basis = MonomialBasis(cutoff=cutoff)

density = LodeDensity(
    atomic_gaussian_width=atomic_gaussian_width,
    potential_exponent=potential_exponent,
)
rs_splines = SoapSpliner(
    cutoff=cutoff,
    max_radial=max_radial,
    max_angular=max_angular,
    basis=basis,
    density=density,
    accuracy=spline_accuracy,
).compute()

# Usually this value for `k_cutoff` gives good convergences for the k-space version
k_cutoff = 1.2 * np.pi / atomic_gaussian_width

# Fourier space splines
fs_splines = LodeSpliner(
    k_cutoff=k_cutoff,
    max_radial=max_radial,
    max_angular=max_angular,
    basis=basis,
    density=density,
    accuracy=spline_accuracy,
).compute()

rs_lode_calc = SphericalExpansion(
    cutoff=cutoff,
    max_radial=max_radial,
    max_angular=max_angular,
    atomic_gaussian_width=atomic_gaussian_width,
    radial_basis=rs_splines,
    center_atom_weight=center_atom_weight,
    cutoff_function={"Step": {}},
)

fs_lode_calc = LodeSphericalExpansion(
    cutoff=cutoff,
    max_radial=max_radial,
    max_angular=max_angular,
    atomic_gaussian_width=atomic_gaussian_width,
    center_atom_weight=center_atom_weight,
    potential_exponent=potential_exponent,
    radial_basis=fs_splines,
    k_cutoff=k_cutoff,
)

In [10]:
monomer_fs_lode_sphex = fs_lode_calc.compute(monomers)
monomer_rs_lode_sphex = rs_lode_calc.compute(monomers)

dimer_fs_lode_sphex = fs_lode_calc.compute(dimers)
dimer_rs_lode_sphex = rs_lode_calc.compute(dimers)

In [11]:
monomer_subtracted_lode_sphex = monomer_fs_lode_sphex - monomer_rs_lode_sphex
dimer_subtracted_lode_sphex = dimer_fs_lode_sphex - dimer_rs_lode_sphex

In [12]:
# Get power spectrum of farLODE
cg_hypers = {
    "correlation_order": 2,
    "max_angular": max_angular * 2, 
    # important for memory consumption if target correlation order or starting l values are high
    "angular_cutoff": None,
    # We only want invariants with even inversion parity
    "selected_keys": Labels(
        names=["o3_lambda", "o3_sigma"],
        values=np.array([[0, 1]]), #
    ),
    "skip_redundant": True,
    "output_selection": None,
    "arrays_backend": None,
    "cg_backend": None,
}

density_correlator = DensityCorrelations(**cg_hypers)

In [13]:
monomer_fs_lode_ps = density_correlator.compute(monomer_fs_lode_sphex).keys_to_properties(['o3_lambda', 'o3_sigma', 'center_type']).components_to_properties(['o3_mu'])
monomer_pure_lode_ps = density_correlator.compute(monomer_subtracted_lode_sphex).keys_to_properties(['o3_lambda', 'o3_sigma', 'center_type']).components_to_properties(['o3_mu'])

dimer_fs_lode_ps = density_correlator.compute(dimer_fs_lode_sphex).keys_to_properties(['o3_lambda', 'o3_sigma', 'center_type']).components_to_properties(['o3_mu'])
dimer_pure_lode_ps = density_correlator.compute(dimer_subtracted_lode_sphex).keys_to_properties(['o3_lambda', 'o3_sigma', 'center_type']).components_to_properties(['o3_mu'])

### join tensormaps for the final descriptor sets

In [14]:
monomer_orig_LODE = mts.join([monomer_SOAP, monomer_fs_lode_ps], axis='properties')
monomer_pure_LODE = mts.join([monomer_SOAP, monomer_pure_lode_ps], axis='properties')

dimer_orig_LODE = mts.join([dimer_SOAP, dimer_fs_lode_ps], axis='properties')
dimer_pure_LODE = mts.join([dimer_SOAP, dimer_pure_lode_ps], axis='properties')

comp_dims = [384, 512]
comp_idxs = np.cumsum(np.array([0] + comp_dims))

### CASE 1: dimer + monomer with original LODE

In [15]:
X_bef = np.array(dimer_orig_LODE.block().values)
X_struc_bef = np.array(mts.mean_over_samples(dimer_orig_LODE, sample_names="atom").block().values)
E_bef = dimer_E

X_aft = np.vstack([X_bef[:50], np.array(monomer_orig_LODE.block().values)[:50]])
X_struc_aft = np.vstack([X_struc_bef[:50], np.array(mts.mean_over_samples(monomer_orig_LODE, sample_names="atom").block().values[:50])])
E_aft = np.hstack([dimer_E[:50], monomer_E[:50]])

X_struc_dimer = np.array(mts.mean_over_samples(dimer_orig_LODE, sample_names="atom").block().values)
X_struc_monomer = np.array(mts.mean_over_samples(monomer_orig_LODE, sample_names="atom").block().values)

In [16]:
alpha = 1e-11

E_dims = []
E_mons = []

sizes = [5, 10, 20, 50, 100]
for size in [5, 10, 20, 50, 100]:
    weights_bef, E_mean_bef = train_model(X_struc_bef[:size], E_bef[:size], alpha)
    E_dim_bef, E_dim_cw_bef = predict_compwise(X_struc_dimer, weights_bef, E_mean_bef, comp_dims)
    E_mon_bef, E_mon_cw_bef = predict_compwise(X_struc_monomer, weights_bef, E_mean_bef, comp_dims)
    
    E_mons.append(E_mon_bef)
    E_dims.append(E_dim_bef)


In [19]:
CPRs_dirty = []
for size in [5, 10, 20, 50, 100]:
    CPRs_dirty.append(calculate_CPR(X_struc_bef[:size], X_struc_dimer, alpha, comp_dims))
    
CPRs_dirty = np.array(CPRs_dirty)


### CASE 2: dimer + monomer with farLODE

In [21]:
X_bef = np.array(dimer_pure_LODE.block().values)
X_struc_bef = np.array(mts.mean_over_samples(dimer_pure_LODE, sample_names="atom").block().values)
E_bef = dimer_E

X_aft = np.vstack([X_bef[:90], np.array(monomer_pure_LODE.block().values)[:10]])
X_struc_aft = np.vstack([X_struc_bef[:90], np.array(mts.mean_over_samples(monomer_pure_LODE, sample_names="atom").block().values)[:10]])
E_aft = np.hstack([dimer_E[:90], monomer_E[:10]])

X_struc_dimer = np.array(mts.mean_over_samples(dimer_pure_LODE, sample_names="atom").block().values)
X_struc_monomer = np.array(mts.mean_over_samples(monomer_pure_LODE, sample_names="atom").block().values)

In [22]:
alpha = 1e-11

E_dims_pure = []
E_mons_pure = []

for size in [5, 10, 20, 50, 100]:
    weights_bef, E_mean_bef = train_model(X_struc_bef[:size], E_bef[:size], alpha)
    E_dim_bef, E_dim_cw_bef = predict_compwise(X_struc_dimer, weights_bef, E_mean_bef, comp_dims)
    E_mon_bef, E_mon_cw_bef = predict_compwise(X_struc_monomer, weights_bef, E_mean_bef, comp_dims)
    
    E_mons_pure.append(E_mon_bef)
    E_dims_pure.append(E_dim_bef)


In [25]:
CPRs_pure = []
for size in [5, 10, 20, 50, 100]:
    CPRs_pure.append(calculate_CPR(X_struc_bef[:size], X_struc_dimer, alpha, comp_dims))
    
CPRs_pure = np.array(CPRs_pure)


In [29]:
np.savez("water_results.npz",
         monomer_E = monomer_E,
         dimer_E = dimer_E,
         dirty_mon_E = E_mons,
         dirty_dim_E = E_dims,
         pure_mon_E = E_mons_pure,
         pure_dim_E = E_dims_pure,
         num_configs = sizes,
         CPRs_dirty = CPRs_dirty,
         CPRs_pure = CPRs_pure)
         