In [None]:

import torch
import random
from ase.io import read, write
from datetime import datetime
import matplotlib.pyplot as plt

from metatensor.models.experimental.soap_bpnn import Model, LLPRModel
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, systems_to_torch
from metatensor.models.utils.neighbors_lists import get_system_with_neighbors_lists
from metatensor.models.utils.data.readers import read_systems, read_targets
from metatensor.learn.data import Dataset, DataLoader
from metatensor.models.utils.output_gradient import compute_gradient
from metatensor.torch import mean_over_samples


from omegaconf import OmegaConf
from ase.io import read
import numpy as np

from tqdm import tqdm

cur_dtype = torch.float64
torch.set_default_dtype(cur_dtype)

### load metatrain SOAP-BPNN model

In [None]:

model = torch.jit.load(f'model.pt', map_location='cpu') ## modify model filename as needed
llpr_model = LLPRModel(model, exported=True)


### compute covariance and inv covariance on original training set

In [None]:

train_frames = read("train.xyz", ":") ## modify training set filename as needed
train_systems = [systems_to_torch(f) for f in train_frames]

batch_size = 100

for i in tqdm(range(len(train_systems)//batch_size + 1)):
    cur_systems = train_systems[i*batch_size:(i+1)*batch_size]
    if len(cur_systems) == 0:
        continue
    llpr_model.compute_covariance(cur_systems)
    del cur_systems


In [None]:
llpr_model.compute_inv_covariance(1, 5e-6)

### obtain last-layer features for test set

In [None]:

llfeats = {
    "last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=False, ## set this to False for PR, True for LPR
    )
}

per_atom_llfeats = {
    "last_layer_features": ModelOutput(
        quantity="",
        unit="",
        per_atom=True, ## set this to False for PR, True for LPR
    )
}

In [None]:
random.seed(1215)
test_frames = read("test.xyz", ":100") ## modify test set filename as needed
random.shuffle(test_frames)
test_systems = [systems_to_torch(f) for f in test_frames]

test_llfeats = []
for system in tqdm(test_systems):
    output = llpr_model([system], llfeats)
    test_llfeats.append(output["last_layer_features"].block().values.detach())
    
test_llfeats = torch.vstack(test_llfeats)

test_per_atom_llfeats = []
for system in tqdm(test_systems):
    output = llpr_model([system], per_atom_llfeats)
    test_per_atom_llfeats.append(output["last_layer_features"].block().values.detach())
    
test_per_atom_llfeats = torch.vstack(test_per_atom_llfeats)

### compute (L)PR

In [None]:
test_pr = 1 / torch.einsum(
    "ij, jk, ik -> i",
    test_llfeats, 
    llpr_model.inv_covariance,
    test_llfeats,
)

In [None]:
test_lpr = 1 / torch.einsum(
    "ij, jk, ik -> i",
    test_per_atom_llfeats, 
    llpr_model.inv_covariance,
    test_per_atom_llfeats,
)