In [None]:
import torch
import ase.io
from amptorch.trainer import AtomsTrainer
import amptorch

from pathlib import Path

In [None]:
amptorch_path = Path("../../amptorch")
train_data_path = Path("../data/oc20_3k_train.traj")

In [None]:
train_data = ase.io.read(train_data_path)

elements = list(set(atom.symbol for atom in train_data))

def get_path_to_gaussian(element):
    gaussians_path = amptorch_path / "examples/GMP/valence_gaussians"
    return next(p for p in gaussians_path.iterdir() if p.name.startswith(element + "_"))

atom_gaussians = {element: get_path_to_gaussian(element) for element in elements}

sigmas = [0.2, 0.69, 1.1, 1.66, 2.66]

MCSHs = {
    "MCSHs": {
        "0": {"groups": [1], "sigmas": sigmas},
        "1": {"groups": [1], "sigmas": sigmas},
        "2": {"groups": [1, 2], "sigmas": sigmas},
    },
    "atom_gaussians": atom_gaussians,
    "cutoff": 8,
}


In [None]:
def calc_loss(**params):
    model_params = ["get_forces", "num_layers", "num_nodes", "batchnorm"]
    optim_params = ["force_coefficient", "lr", "batch_size", "epochs", "loss", "metric"]
    
    config = {
        "model": {"name": "singlenn"},
        "optim": {},
        "dataset": {
            "raw_data": str(train_data_path),
            "fp_scheme": "gmp",
            "fp_params": MCSHs,
            "elements": elements,
            "save_fps": True,
            "scaling": {"type": "normalize", "range": (0, 1)},
            "val_split": 0.1,
        },
        "cmd": {
            "debug": False,
            "run_dir": "./",
            "seed": 1,
            "identifier": "test",
            "verbose": True,
            "logger": False,
        },
    }
    
    for k in model_params:
        config["model"][k] = params[k]
    for k in optim_params:
        config["optim"][k] = params[k]
        
    trainer = AtomsTrainer(config)
    trainer.train()

In [None]:
calc_loss(
    get_forces=True,
    num_layers=3,
    num_nodes=10,
    batchnorm=True,
    force_coefficient=0,
    lr=1e-3,
    batch_size=16,
    epochs=10,
    loss="mse",
    metric="mae",
)