## Prediction rigidity calculation script for SchNet/PaiNN

IMPORTANT TO INSTALL AND USE THE SCHNETPACK VERSION AVAILABLE HERE:
<https://github.com/SanggyuChong/schnetpack/tree/LLPR>

In [None]:
import os
import schnetpack as spk
from schnetpack.datasets import AtomsDataModule
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl

from schnetpack.utils.llpr import calibrate_llpr_params

import numpy as np
import tqdm

import matplotlib.pyplot as plt

### load datasets

In [None]:
qm9data = AtomsDataModule(
    datapath = "data.db", ## training set data should already be processed according to schnetpack formalism
    batch_size = 100,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.CastTo32()
    ],
    split_file = "PaiNN/split.npz",
    pin_memory=False, # set to false, when not using a GPU    
    load_properties = ["energy"],
)
qm9data.prepare_data()
qm9data.setup()

In [None]:
qm9test = AtomsDataModule(
    datapath = "test_data.db", ## test set data should already be processed according to schnetpack formalism
    batch_size = 100,
    num_train = 1,
    num_val = 1,
    num_test = 5000,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.CastTo32()
    ],
    pin_memory=False,
    load_properties = ["energy"],
)
qm9test.prepare_data()
qm9test.setup()

### load PaiNN model

In [None]:
best_model = torch.load("best_model", map_location='cpu') ## modify model filename as needed 
llpr_model = spk.model.LLPredRigidityNNP(best_model, save_ll_feat_per_atom=True)

### compute covariance and inv covariance on original training set

In [None]:
weight_dict = {'E': 1, 'F': 0, 'S': 0}
llpr_model.compute_covariance(qm9data.train_dataloader(), weights=weight_dict)

In [None]:
llpr_model.compute_inv_covariance(1, 5e-6)

### obtain last-layer features for test set

In [None]:

pred_ll_feats = []
pred_ll_feats_per_atom = []

for batch in tqdm.tqdm(qm9test.test_dataloader()):
    outputs = llpr_model(batch)
    pred_ll_feats.append(outputs['ll_feats'].detach())
    pred_ll_feats_per_atom.append(outputs['ll_feats_per_atom'].detach())    

pred_ll_feats = torch.vstack(pred_ll_feats)
pred_ll_feats_per_atom = torch.vstack(pred_ll_feats_per_atom)

### compute (L)PR

In [None]:
test_pr = 1 / torch.einsum(
    "ij, jk, ik -> i",
    pred_ll_feats,
    llpr_model.inv_covariance,
    pred_ll_feats,
)

In [None]:
test_lpr = 1 / torch.einsum(
    "ij, jk, ik -> i",
    pred_ll_feats_per_atom,
    llpr_model.inv_covariance,
    pred_ll_feats_per_atom,
)