In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [1]:
!git config --global user.name "Alex McKenzie" 
!git config --global user.email "amckenzie9@gatech.edu"
!git clone https://github.com/Arrrlex/amptorch.git && cd amptorch && git checkout fix-conda-env
!git clone https://github.com/medford-group/bdqm-vip.git
!mamba env update -n base -f amptorch/env_cpu.yml
!pip install ./amptorch

fatal: destination path 'amptorch' already exists and is not an empty directory.
Cloning into 'bdqm-vip'...
remote: Enumerating objects: 694, done.[K
remote: Counting objects: 100% (280/280), done.[K
remote: Compressing objects: 100% (227/227), done.[K
remote: Total 694 (delta 157), reused 117 (delta 49), pack-reused 414[K
Receiving objects: 100% (694/694), 143.31 MiB | 4.86 MiB/s, done.
Resolving deltas: 100% (388/388), done.


In [7]:
import torch
import ase.io
from amptorch.trainer import AtomsTrainer

from pathlib import Path

bdqm_path = Path("bdqm-vip")
amptorch_path = Path("amptorch")

train_data = ase.io.read(bdqm_path / "data/amptorch_data/oc20_3k_train.traj")

elements = list(set(atom.symbol for atom in train_data))

def get_path_to_gaussian(element):
    gaussians_path = amptorch_path / "examples/GMP/valence_gaussians"
    return next(p for p in gaussians_path.iterdir() if p.name.startswith(element + "_"))

atom_gaussians = {element: get_path_to_gaussian(element) for element in elements}

sigmas = [0.2, 0.69, 1.1, 1.66, 2.66]

MCSHs = {
    "MCSHs": {
        "0": {"groups": [1], "sigmas": sigmas},
        "1": {"groups": [1], "sigmas": sigmas},
        "2": {"groups": [1, 2], "sigmas": sigmas},
    },
    "atom_gaussians": atom_gaussians,
    "cutoff": 8,
}

train_data_path = str(bdqm_path / "data/amptorch_data/oc20_3k_train.traj")


model_config = {
    "name":"singlenn",
    "get_forces": True,
    "num_layers": 3,
    "num_nodes": 10,
    "batchnorm": True,
}

optim_config = {
    "force_coefficient": 0.01,
    "lr": 1e-3,
    "batch_size": 16,
    "epochs": 10,
    "loss": "mse",
    "metric": "mae",
}

dataset_config = {
    "raw_data": train_data_path,
    "fp_scheme": "gmp",
    "fp_params": MCSHs,
    "elements": elements,
    "save_fps": True,
    "scaling": {"type": "normalize", "range": (0, 1)},
    "val_split": 0.1,
}

cmd_config = {
    "debug": False,
    "run_dir": "./",
    "seed": 1,
    "identifier": "test",
    "verbose": True,
    "logger": False,
}

config = {
    "model": model_config,
    "optim": optim_config,
    "dataset": dataset_config,
    "cmd": cmd_config,
}

torch.set_num_threads(1)
trainer = AtomsTrainer(config)
trainer.train()

Results saved to ./checkpoints/2022-02-08-15-14-05-test


converting ASE atoms collection to Data objects: 100%|██████████| 3000/3000 [00:59<00:00, 50.18 systems/s]
Scaling Feature data (normalize): 100%|██████████| 3000/3000 [01:09<00:00, 43.08 scalings/s]
Scaling Target data: 100%|██████████| 3000/3000 [00:00<00:00, 41523.93 scalings/s]


Loading dataset: 3000 images
Loading model: 501 parameters
Loading skorch trainer
  epoch    train_energy_mae    train_forces_mae    train_loss    val_energy_mae    val_forces_mae    valid_loss    cp      dur
-------  ------------------  ------------------  ------------  ----------------  ----------------  ------------  ----  -------
      1           [36m1217.9261[0m            [32m112.0899[0m       [35m37.5123[0m         [31m1512.9265[0m           [94m86.5026[0m       [36m29.8566[0m     +  15.1829




      2            [36m626.4187[0m             [32m85.2256[0m        [35m7.9544[0m          [31m479.3150[0m           [94m71.5122[0m        [36m4.6715[0m     +  15.2036
      3            [36m540.5634[0m             [32m66.5865[0m        [35m5.8387[0m          [31m436.3370[0m           [94m57.6336[0m        [36m3.4785[0m     +  15.1625
      4            [36m467.2123[0m             [32m58.9983[0m        [35m4.0027[0m          [31m316.2756[0m           58.0438        [36m2.5031[0m     +  15.4901
      5            [36m402.9479[0m             [32m53.5226[0m        [35m3.0752[0m          528.6107           58.6039        3.6555        15.4318
      6            [36m384.3636[0m             [32m48.1549[0m        [35m2.6971[0m          [31m270.6585[0m           [94m42.5644[0m        [36m1.6739[0m     +  15.4567
      7            [36m328.0601[0m             [32m35.9508[0m        [35m2.0188[0m          [31m265.5656[0m           47.52