In [None]:
import numpy as np
import torch
import torch.nn.functional
from mace import data, modules, tools
from mace.tools import torch_geometric
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float64)
device = tools.init_device('cuda')

In [None]:
mace = torch.load('50_0/CG_water.model').to('cuda')
mace_llpr = modules.LLPRModel(mace)
mace_llpr.to(device)

In [None]:
stats = {"atomic_numbers": [0],
         "r_max": 6.0}
config_type_weights = {"Default": 1.0}
z_table = tools.get_atomic_number_table_from_zs(stats['atomic_numbers'])

In [None]:
from mace.tools.scripts_utils import get_dataset_from_xyz
collections, atomic_energies_dict = get_dataset_from_xyz(
    train_path="50_0/CG_water_train_50_0.xyz",
    valid_path="50_0/CG_water_val_1k.xyz",
    test_path="50_0/CG_water_test_1k.xyz",
    valid_fraction=0,
    config_type_weights=config_type_weights,
)

In [None]:
train_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections.train
    ],
    batch_size=10,
    shuffle=False,
    drop_last=False,
)

valid_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections.valid
    ],
    batch_size=10,
    shuffle=False,
    drop_last=False,
)

test_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections.tests[0][1]
    ],
    batch_size=1, # keep as 1 for later!
    shuffle=False,
    drop_last=False,
)

In [None]:

mace_llpr.compute_covariance(
    train_loader,
    include_energy=False,
    include_forces=True,
)


In [None]:
mace_llpr.compute_inv_covariance(C = 1, sigma = 5e-6)

In [None]:
import tqdm
PRs = []
LPRs = []
for batch in tqdm.tqdm(test_loader):
    batch.to(mace_llpr.covariance.device)
    cur_llfeats = mace_llpr(batch, save_atomic_llfeats = True)['atomic_llfeats']

    cur_LPRs = 1 / torch.einsum("ij, jk, ik -> i",
        cur_llfeats,
        mace_llpr.inv_covariance,
        cur_llfeats,
    )

    LPRs.append(cur_LPRs)

    struc_llfeats = cur_llfeats.mean(axis=0).unsqueeze(0)

    cur_PR = 1 / torch.einsum("ij, jk, ik -> i",
        struc_llfeats,
        mace_llpr.inv_covariance,
        struc_llfeats,
    )

    PRs.append(cur_PR)   
    

In [None]:
np.savez("50_results_testPR.npz",
         test_PRs = torch.hstack(PRs).detach().cpu().numpy(),
        )

In [None]:
torch.save(mace_llpr, '50_0/CG_water_cov.model')

In [None]:
mace_llpr = torch.load('50_0/CG_water_cov.model').to('cuda')

In [None]:
from mace.tools.scripts_utils import get_dataset_from_xyz
collections2, atomic_energies_dict2 = get_dataset_from_xyz(
    train_path="50_0/test.xyz",
    valid_path="50_0/CG_water_test_1k.xyz",
    test_path="50_0/CG_water_train_50_0.xyz",
    valid_fraction=0,
    config_type_weights=config_type_weights,
)

In [None]:
traj_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections2.train
    ],
    batch_size=1,
    shuffle=False,
    drop_last=False,
)

In [None]:
import tqdm
traj_PRs = []
traj_LPRs = []
for batch in tqdm.tqdm(traj_loader):
    batch.to(mace_llpr.covariance.device)
    cur_llfeats = mace_llpr(batch, save_atomic_llfeats = True)['atomic_llfeats']

    cur_LPRs = 1 / torch.einsum("ij, jk, ik -> i",
        cur_llfeats,
        mace_llpr.inv_covariance,
        cur_llfeats,
    )

    traj_LPRs.append(cur_LPRs)

    struc_llfeats = cur_llfeats.mean(dim=0).unsqueeze(0)

    cur_PR = 1 / torch.einsum("ij, jk, ik -> i",
        struc_llfeats,
        mace_llpr.inv_covariance,
        struc_llfeats,
    )

    traj_PRs.append(cur_PR)
    

In [None]:
np.savez("50_results_trajPR.npz",
         traj_PRs = torch.hstack(traj_PRs).detach().cpu().numpy(),
        )