In [2]:
import torch
import ase.io
from amptorch.trainer import AtomsTrainer
import amptorch

from pathlib import Path

In [14]:
amptorch_path = Path("../../amptorch")
train_data_path = Path("../oc20_3k_train.traj")

In [15]:
train_data = ase.io.read(train_data_path)

elements = list(set(atom.symbol for atom in train_data))

def get_path_to_gaussian(element):
    gaussians_path = amptorch_path / "examples/GMP/valence_gaussians"
    return next(p for p in gaussians_path.iterdir() if p.name.startswith(element + "_"))

atom_gaussians = {element: get_path_to_gaussian(element) for element in elements}

sigmas = [0.2, 0.69, 1.1, 1.66, 2.66]

MCSHs = {
    "MCSHs": {
        "0": {"groups": [1], "sigmas": sigmas},
        "1": {"groups": [1], "sigmas": sigmas},
        "2": {"groups": [1, 2], "sigmas": sigmas},
    },
    "atom_gaussians": atom_gaussians,
    "cutoff": 8,
}


In [16]:
dataset_config = {
    "raw_data": train_data_path,
    "fp_scheme": "gmp",
    "fp_params": MCSHs,
    "elements": elements,
    "save_fps": True,
    "scaling": {"type": "normalize", "range": (0, 1)},
    "val_split": 0.1,
}

cmd_config = 

In [19]:
def calc_loss(**params):
    model_params = ["get_forces", "num_layers", "num_nodes", "batchnorm"]
    optim_params = ["force_coefficient", "lr", "batch_size", "epochs", "loss", "metric"]
    
    config = {
        "model": {"name": "singlenn"},
        "optim": {},
        "dataset": {
            "raw_data": str(train_data_path),
            "fp_scheme": "gmp",
            "fp_params": MCSHs,
            "elements": elements,
            "save_fps": True,
            "scaling": {"type": "normalize", "range": (0, 1)},
            "val_split": 0.1,
        },
        "cmd": {
            "debug": False,
            "run_dir": "./",
            "seed": 1,
            "identifier": "test",
            "verbose": True,
            "logger": False,
        },
    }
    
    for k in model_params:
        config["model"][k] = params[k]
    for k in optim_params:
        config["optim"][k] = params[k]
        
    trainer = AtomsTrainer(config)
    trainer.train()

In [20]:
calc_loss(
    get_forces=True,
    num_layers=3,
    num_nodes=10,
    batchnorm=True,
    force_coefficient=0,
    lr=1e-3,
    batch_size=16,
    epochs=10,
    loss="mse",
    metric="mae",
)

Results saved to ./checkpoints/2022-02-14-12-58-31-test


HBox(children=(FloatProgress(value=0.0, description='converting ASE atoms collection to Data objects', max=300…




HBox(children=(FloatProgress(value=0.0, description='Scaling Feature data (normalize)', max=3000.0, style=Prog…




HBox(children=(FloatProgress(value=0.0, description='Scaling Target data', max=3000.0, style=ProgressStyle(des…


Loading dataset: 3000 images
Use Xavier initialization
Loading model: 501 parameters
Loading skorch trainer
  epoch    train_energy_mae    train_forces_mae    train_loss    val_energy_mae    val_forces_mae    valid_loss    cp     dur
-------  ------------------  ------------------  ------------  ----------------  ----------------  ------------  ----  ------
      1           [36m1066.3795[0m             [32m55.6408[0m       [35m33.2362[0m          [31m501.5571[0m           [94m34.6613[0m        [36m7.4407[0m     +  6.4808




      2            [36m441.6300[0m             [32m32.3773[0m        [35m4.8012[0m          [31m327.7580[0m           [94m26.6124[0m        [36m2.3287[0m     +  6.4146
      3            [36m338.0735[0m             [32m22.7018[0m        [35m2.5142[0m          438.6731           [94m20.8129[0m        3.1687        6.4898
      4            [36m295.9209[0m             [32m16.5670[0m        [35m1.9177[0m          [31m279.4230[0m           [94m17.1901[0m        [36m1.4291[0m     +  6.4499
      5            [36m253.7717[0m             [32m13.4289[0m        [35m1.3217[0m          398.5257           [94m15.6561[0m        2.3555        7.1460
      6            [36m248.9843[0m             [32m11.5695[0m        [35m1.1249[0m          [31m239.2520[0m           [94m12.3963[0m        [36m1.0510[0m     +  6.5255
      7            [36m233.0127[0m              [32m9.5620[0m        [35m0.9323[0m          [31m238.6214[0m           [94m10.75