In [1]:
import numpy as np
import torch
import torch.nn.functional
from mace import data, modules, tools
from mace.tools import torch_geometric
torch.set_default_dtype(torch.float64)

In [2]:
mace_mp = torch.load('calculators/foundations_models/2023-12-03-mace-mp.model')
mace_mp_llpr = modules.LLPRScaleShiftMACE(mace_mp)

In [3]:
stats = {"atomic_energies": 
    {1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245},
    "avg_num_neighbors": 61.964672446250916,
    "mean": 0.16409696359187365,
    "std": 0.8041538754478097, 
    "atomic_numbers": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 89, 90, 91, 92, 93, 94],
    "r_max": 6.0}

config_type_weights = {"Default": 1.0}

z_table = tools.get_atomic_number_table_from_zs(stats['atomic_numbers'])

In [4]:
from mace.tools.scripts_utils import get_dataset_from_xyz
collections, atomic_energies_dict = get_dataset_from_xyz(
    train_path="dummy_train.xyz",
    valid_path=None,
    valid_fraction=0.1,
    config_type_weights=config_type_weights,
)

In [5]:
train_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections.train
    ],
    batch_size=10,
    shuffle=True,
    drop_last=False,
)

In [6]:
valid_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections.valid
    ],
    batch_size=10,
    shuffle=False,
    drop_last=False,
)

In [7]:
from mace.tools.scripts_utils import get_dataset_from_xyz
collections_test, atomic_energies_dict = get_dataset_from_xyz(
    train_path="dummy_test.xyz",
    valid_path=None,
    valid_fraction=0.00000001,
    config_type_weights=config_type_weights,
)

In [8]:
test_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        data.AtomicData.from_config(config, z_table=z_table, cutoff=stats['r_max'])
        for config in collections_test.train
    ],
    batch_size=10,
    shuffle=False,
    drop_last=False,
)

In [9]:
mace_mp_llpr.compute_covariance(train_loader)

RuntimeError: expected scalar type Float but found Double