In [2]:
import json
import os
import pickle
import random
import sys
import pickle

from typing import Callable, Dict, List, Optional
import haiku as hk
import ase
import ase.io
import jax
import jax.numpy as jnp
import numpy as np
import optax
import yaml

from phonax.datasets import (
    datasets,
    ph_datasets,
)
from phonax.optimizer import optimizer
from phonax.energy_force_train import energy_force_train
from phonax.loss import (
    WeightedEnergyFrocesStressLoss,
    crystalHessianLoss,
)
from phonax.nequip_model import NequIP_JAXMD_model

from phonax.utils import (
    create_directory_with_random_name,
)

from phonax.phonons_train import (
    ph_evaluate,
    hessian_train,
    two_stage_hessian_train,
)

from phonax.predictors import (
    predict_energy_forces_stress,
    predict_crystal_hessian,
)
from phonax.phonons import (
    atoms_to_ext_graph, 
    predict_gamma_spectra_filter
)

from phonax.data_utils import to_f32

from phonax.trained_models import NequIP_JAXMD_molecule_model   

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
np.set_printoptions(precision=3, suppress=True)

# Introduction

Previously, we have been making the second derivative Hessian matrix predictions given an energy model, which is trained with zeroth and first order derivative data (i.e. energy and forces).
Here in this tutorial, we will demonstrate an alternative view of the Hessian data, and use them as part of the training data (the Hessian matrix, or the eigenvalue spectrum) for the periodic crystal solids.

As discussed in our work, this augments the training dataset in converging the energy models. Locally, the second derivative data improves the local curvatures of the energy model landscape.
To demonstrate this here, we use periodic crystal cases as the examples, and focus on using the Hessian matrix as an extended data type to train the energy model/functional. 


## Energy model training with crystalline Hessian data

This section focuses on using the Hessian matrix as an extended data type to train the energy model/functional for the interatomic potentials.
The training procedure can be viewed as an extension to the conventional interatomic potential training using only the zeroth and first order training data for the energy landscape (i.e. energy and force).

For the specific example, we use Si crystal (mp-149) example as discussed in our work.
When trained with energy and force data for the energy model, we found the training can be improved by the augmented supercell geometries, which effectively create finer mesh grid in the momentum space sampling.
However, in our training example here, we will only use the $1 \times 1 \times 1$ energy and force data, combined with the crystal Hessian data to achieve good phonon predictions, without supercell geometries for the energy and forces.

Compared to the molecular cases, the derivation is more complicated for the periodic crystalline solids.
This is due to the periodic boundary conditions which require the construction for the extended graphs in the computations.
To describe the phonon states, one has to introduce crystal momentum $\vec{k}$ as well from the Bloch theorem.
In summary, the Hessian training involves the random selection for momentum $\vec{k}$, and two (random) displacement patterns at this momentum. 
The one can project out the dynamical matrix element given these two displacement patterns. (In analogy to the molecular cases, one can also choose random atomic sites and two displacement vectors on these sites to perform the projection.)
With enough projections (or slices) of the Hessians from the trained model, the energy model would converge to the groundtruth energy functional at all momentum $\vec{k}$ and bewteen any states.

One note about the random direction (or patterns) here. If the displacements are constrained to the x, y or z direcitons, the specific coordinate frame breaks the (rotational) equivariance in the training procedure, despite our energy model respects the equivarnace. The randomized projections restore the rotational equivariance.





### Initialize the training procedures

Here we load the training configurations and datasets, initialize the model, loss functions, and the predictors used for the next training steps.

In [4]:
# Load the config files
with open('data/mp-149/mp149-fh.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
        
if 'save_dir_name' in config:
    save_dir_name = config['save_dir_name']
else:
    save_dir_name = create_directory_with_random_name()
print(save_dir_name)
        
# Save config
with open(f"{save_dir_name}/config.yaml", "w") as f:
    yaml.dump(config, f)
    
# Load the energy/force training/valid dataset
EF_train_loader, EF_valid_loader, EF_test_loader, EF_r_max = datasets(
    r_max = config["cutoff"],
    config_dataset = config["dataset"],
)

# Load the Hessian training/valid dataset
ph_train_loader, ph_valid_loader, ph_test_loader, ph_r_max, num_message_passing = ph_datasets(
    r_max = config["cutoff"],
    config_dataset = config["ph_dataset"],
    num_message_passing = config['model']['num_layers'],
)

assert ph_r_max==EF_r_max
r_max = EF_r_max

# Initialize the NequIP model
model_fn, params, num_message_passing = NequIP_JAXMD_model(
    r_max=r_max,
    atomic_energies_dict={},
    train_graphs=EF_train_loader.graphs,
    initialize_seed=config["model"]["seed"],
    num_species = config["model"]["num_species"],
    use_sc = True,
    graph_net_steps = config["model"]["num_layers"],
    hidden_irreps = config["model"]["internal_irreps"],
    nonlinearities =  {'e': 'swish', 'o': 'tanh'},
    save_dir_name = save_dir_name,
    reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
)
    
print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))

# Define Loss functions for energy/force and Hessians
EF_loss_fn = WeightedEnergyFrocesStressLoss(
    energy_weight = config["training"]["energy_weight"],
    forces_weight = config["training"]["forces_weight"],
    stress_weight = config["training"]["stress_weight"],
)


H_loss_fn = crystalHessianLoss()

# Predictors for the energy/force/stress and Hessians
EF_predictor = jax.jit(
    lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
)

H_predictor = jax.jit(
    lambda w, g: predict_crystal_hessian(lambda *x: model_fn(w, *x), g)
)


2024-01-11-21:31-tasteless-valentia
2024-01-11-21:31-tasteless-valentia
nums check 1500 9000 128
Loaded 1000 training configurations from 'data/mp-149/mp-149-SC111_train.xyz'
Loaded 200 validation configurations from 'data/mp-149/mp-149-SC333_valid.xyz'
Total number of configurations: train=1000, valid=200, test=0


100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 25685.76it/s]
100%|████████████████████████████████████████| 200/200 [00:00<00:00, 741.75it/s]
0it [00:00, ?it/s]
Pad train: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 267.35it/s]
Pad valid: 100%|█████████████████████████████| 100/100 [00:00<00:00, 265.56it/s]


Total number of configurations: train=2000, valid=100, test=0
Compute the average number of neighbors: 27.992
Do not normalize the radial basis (avg_r_min=None)
Computed average Atomic Energies using least squares: {14: -5.531418045215}
Create NequIP (JAX-MD version) with parameters {'use_sc': True, 'graph_net_steps': 2, 'hidden_irreps': '16x0e + 16x0o +12x1e + 12x1o + 8x2e +8x2o', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 27.992, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7fd5feea6340>, 'radial_envelope': <function soft_envelope at 0x7fd5feea62a0>}
num_params: 58920


### Training with energy / force training data

In this section, we demonstrate the energy model training by using only the energy and force training data, as is done in the conventional interatomic potential model training procedure.
Note in this simple propane molecule example, we only use four molecular configurations in the training dataset.



In [5]:
# force only training

Fonly_params = params.copy()

gradient_transform, steps_per_interval, max_num_intervals = optimizer(
    lr = config["training"]["learning_rate"],
    max_num_intervals = config["training"]["max_num_intervals"],
    steps_per_interval = config["training"]["steps_per_interval"],
    # weight_decay = config["training"]["weight_decay"],
)
optimizer_state = gradient_transform.init(params)
print("optimizer num_params:", sum(p.size for p in jax.tree_util.tree_leaves(optimizer_state)))
    
train_out = energy_force_train(
    EF_predictor,
    Fonly_params,
    optimizer_state,
    EF_train_loader,
    EF_valid_loader,
    EF_test_loader,
    gradient_transform,
    EF_loss_fn,
    max_num_intervals,
    steps_per_interval,
    save_dir_name,
    ema_decay = config["training"]["ema_decay"],
    patience = config["training"]["patience"],
)



optimizer num_params: 176762
Started training


eval_train:   0%|                                         | 0/7 [00:00<?, ?it/s]

Compiled function `model` for args:
cache size: 1


eval_train: 100%|██████████████████████████| 7/7 [00:05<00:00,  1.25it/s, n=889]


Interval 0: eval_train: loss=354.6784, mae_e_per_atom=2777.5 meV, mae_f=1523.9 meV/Å, mae_s=27.9 meV/Å³


eval_valid:   0%|                                        | 0/40 [00:00<?, ?it/s]

Compiled function `model` for args:
cache size: 2


eval_valid: 100%|████████████████████████| 40/40 [00:05<00:00,  6.72it/s, n=200]


Interval 0: eval_valid: loss=311.1701, mae_e_per_atom=2772.3 meV, mae_f=1373.5 meV/Å, mae_s=9.1 meV/Å³


Train interval 0:   1%|▏           | 7/500 [00:07<06:21,  1.29it/s, loss=10.871]

Compiled function `update_fn` for args:
Outout: loss= 330.859
Compilation time: 7.259s, cache size: 1
Compiled function `update_fn` for args:
Outout: loss= 245.776
Compilation time: 0.018s, cache size: 2


Train interval 0: 100%|███████████| 500/500 [00:15<00:00, 31.67it/s, loss=0.091]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.18it/s, n=889]


Interval 1: eval_train: loss=0.1210, mae_e_per_atom=51.2 meV, mae_f=24.3 meV/Å, mae_s=68.4 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.54it/s, n=200]


Interval 1: eval_valid: loss=3.1964, mae_e_per_atom=14.6 meV, mae_f=165.6 meV/Å, mae_s=28.2 meV/Å³


Train interval 1: 100%|███████████| 500/500 [00:08<00:00, 58.29it/s, loss=0.096]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 55.97it/s, n=889]


Interval 2: eval_train: loss=0.0810, mae_e_per_atom=38.1 meV, mae_f=20.8 meV/Å, mae_s=62.6 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.60it/s, n=200]


Interval 2: eval_valid: loss=2.3822, mae_e_per_atom=11.1 meV, mae_f=143.2 meV/Å, mae_s=19.9 meV/Å³


Train interval 2: 100%|███████████| 500/500 [00:08<00:00, 58.09it/s, loss=0.052]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.22it/s, n=889]


Interval 3: eval_train: loss=0.0577, mae_e_per_atom=28.3 meV, mae_f=18.4 meV/Å, mae_s=60.0 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.69it/s, n=200]


Interval 3: eval_valid: loss=2.0516, mae_e_per_atom=11.5 meV, mae_f=132.9 meV/Å, mae_s=14.9 meV/Å³


Train interval 3: 100%|███████████| 500/500 [00:08<00:00, 57.90it/s, loss=0.050]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.19it/s, n=889]


Interval 4: eval_train: loss=0.0439, mae_e_per_atom=20.4 meV, mae_f=17.1 meV/Å, mae_s=58.6 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.42it/s, n=200]


Interval 4: eval_valid: loss=1.9294, mae_e_per_atom=14.0 meV, mae_f=128.3 meV/Å, mae_s=13.3 meV/Å³


Train interval 4: 100%|███████████| 500/500 [00:08<00:00, 58.09it/s, loss=0.032]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.00it/s, n=889]


Interval 5: eval_train: loss=0.0367, mae_e_per_atom=14.9 meV, mae_f=16.2 meV/Å, mae_s=57.4 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.55it/s, n=200]


Interval 5: eval_valid: loss=1.9241, mae_e_per_atom=15.4 meV, mae_f=128.0 meV/Å, mae_s=13.9 meV/Å³


Train interval 5: 100%|███████████| 500/500 [00:08<00:00, 57.68it/s, loss=0.039]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.75it/s, n=889]


Interval 6: eval_train: loss=0.0330, mae_e_per_atom=11.2 meV, mae_f=15.6 meV/Å, mae_s=57.2 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.96it/s, n=200]


Interval 6: eval_valid: loss=1.9180, mae_e_per_atom=16.0 meV, mae_f=127.9 meV/Å, mae_s=15.1 meV/Å³


Train interval 6: 100%|███████████| 500/500 [00:08<00:00, 58.12it/s, loss=0.032]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.37it/s, n=889]


Interval 7: eval_train: loss=0.0299, mae_e_per_atom=8.8 meV, mae_f=15.1 meV/Å, mae_s=56.6 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.84it/s, n=200]


Interval 7: eval_valid: loss=1.9330, mae_e_per_atom=17.3 meV, mae_f=128.4 meV/Å, mae_s=16.2 meV/Å³


Train interval 7: 100%|███████████| 500/500 [00:08<00:00, 57.59it/s, loss=0.029]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.34it/s, n=889]


Interval 8: eval_train: loss=0.0284, mae_e_per_atom=6.7 meV, mae_f=14.9 meV/Å, mae_s=57.1 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.52it/s, n=200]


Interval 8: eval_valid: loss=1.9487, mae_e_per_atom=17.3 meV, mae_f=128.9 meV/Å, mae_s=17.2 meV/Å³


Train interval 8: 100%|███████████| 500/500 [00:08<00:00, 57.59it/s, loss=0.026]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 52.51it/s, n=889]


Interval 9: eval_train: loss=0.0273, mae_e_per_atom=5.4 meV, mae_f=14.8 meV/Å, mae_s=57.0 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.05it/s, n=200]


Interval 9: eval_valid: loss=1.9568, mae_e_per_atom=17.7 meV, mae_f=129.2 meV/Å, mae_s=18.0 meV/Å³


Train interval 9: 100%|███████████| 500/500 [00:08<00:00, 57.64it/s, loss=0.039]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.16it/s, n=889]


Interval 10: eval_train: loss=0.0267, mae_e_per_atom=4.4 meV, mae_f=14.7 meV/Å, mae_s=57.1 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.20it/s, n=200]


Interval 10: eval_valid: loss=1.9708, mae_e_per_atom=17.5 meV, mae_f=129.5 meV/Å, mae_s=18.6 meV/Å³


Train interval 10: 100%|██████████| 500/500 [00:08<00:00, 59.51it/s, loss=0.025]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 55.11it/s, n=889]


Interval 11: eval_train: loss=0.0262, mae_e_per_atom=3.8 meV, mae_f=14.6 meV/Å, mae_s=57.1 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 61.70it/s, n=200]


Interval 11: eval_valid: loss=1.9817, mae_e_per_atom=17.5 meV, mae_f=129.8 meV/Å, mae_s=19.0 meV/Å³


Train interval 11: 100%|██████████| 500/500 [00:08<00:00, 59.19it/s, loss=0.030]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.14it/s, n=889]


Interval 12: eval_train: loss=0.0248, mae_e_per_atom=3.3 meV, mae_f=14.3 meV/Å, mae_s=56.6 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.17it/s, n=200]


Interval 12: eval_valid: loss=2.0041, mae_e_per_atom=17.7 meV, mae_f=130.6 meV/Å, mae_s=19.4 meV/Å³


Train interval 12: 100%|██████████| 500/500 [00:08<00:00, 59.06it/s, loss=0.027]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.75it/s, n=889]


Interval 13: eval_train: loss=0.0252, mae_e_per_atom=2.8 meV, mae_f=14.3 meV/Å, mae_s=57.2 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.49it/s, n=200]


Interval 13: eval_valid: loss=2.0313, mae_e_per_atom=18.2 meV, mae_f=131.6 meV/Å, mae_s=19.6 meV/Å³


Train interval 13: 100%|██████████| 500/500 [00:08<00:00, 59.14it/s, loss=0.025]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.78it/s, n=889]


Interval 14: eval_train: loss=0.0247, mae_e_per_atom=2.6 meV, mae_f=14.2 meV/Å, mae_s=57.0 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 61.98it/s, n=200]


Interval 14: eval_valid: loss=2.0697, mae_e_per_atom=18.3 meV, mae_f=132.8 meV/Å, mae_s=19.8 meV/Å³


Train interval 14: 100%|██████████| 500/500 [00:08<00:00, 59.11it/s, loss=0.028]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 56.07it/s, n=889]


Interval 15: eval_train: loss=0.0240, mae_e_per_atom=2.6 meV, mae_f=14.1 meV/Å, mae_s=57.0 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 61.61it/s, n=200]


Interval 15: eval_valid: loss=2.1142, mae_e_per_atom=18.3 meV, mae_f=134.3 meV/Å, mae_s=20.0 meV/Å³
Training complete


### Training with crystal Hessians

In the next section, we show the energy model training with additional second order Hessian data beynd the force data at the first order derivative.
Specifically, there are two tages in the training procedure.
The first stage, or the warm-up stage, uses only the energy and force data to initialize the energy model training.
After trained with several warm-up steps, the second training stage adds the Hessian data in computing the loss function.
At the second stage, the update gradients come from both the energy / force training losses, and the Hessian losses, and we combine the two (weighted sum) to get the overall training gradient.

From the training results below, one can see the valid force MAE further decreases once we have included the Hessian training data in the second stage.

We note that, it remains an open question to find the optimal training procedure in adding the second order Hessian data. For example, one can perform the training without warm-up steps, or gradually ramp up the training weights associated with the Hessian losses.
It would be relevant to examine the energy model converged for their predictions and stability under molecular dynamics simulations.


In [6]:
FH_params = params.copy()

# restart the optimizer state
gradient_transform, steps_per_interval, max_num_intervals = optimizer(
    lr = config["training"]["learning_rate"],
    max_num_intervals = 20, #config["training"]["max_num_intervals"],
    steps_per_interval = 50, # config["training"]["steps_per_interval"],
    # weight_decay = config["training"]["weight_decay"],
)
optimizer_state = gradient_transform.init(params)
print("optimizer num_params:", sum(p.size for p in jax.tree_util.tree_leaves(optimizer_state)))

two_stage_hessian_train(
    energy_forces_stress_predictor=EF_predictor,
    phonon_predictor=H_predictor,
    params=FH_params,
    gradient_transform=gradient_transform,
    optimizer_state=optimizer_state,
    steps_per_interval=steps_per_interval,
    EF_loss_fn=EF_loss_fn,
    H_loss_fn=H_loss_fn,
    EF_train_loader=EF_train_loader,
    EF_valid_loader=EF_valid_loader,
    H_train_loader=ph_train_loader,
    H_valid_loader=ph_valid_loader,
    warmup_num_intervals=2 ,#config['training']['warmup_num_intervals'],
    phtrain_num_intervals=5  ,#config['training']['phtrain_num_intervals'],
    periodic_crystal=True,
)









optimizer num_params: 176762


eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.81it/s, n=889]


Interval 0: eval_train: loss=357.8423, mae_e_per_atom=2778.442 meV, mae_f=1534.204 meV/Å, mae_s=27.945 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.96it/s, n=200]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Interval 0: eval_valid: loss=311.1701, mae_e_per_atom=2772.260 meV, mae_f=1373.496 meV/Å, mae_s=9.124 meV/Å³


Interval 0: 100%|████████████████████| 50/50 [00:06<00:00,  7.68it/s, loss=1.14]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 53.56it/s, n=889]


Interval 1: eval_train: loss=0.8267, mae_e_per_atom=66.257 meV, mae_f=58.644 meV/Å, mae_s=76.991 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.73it/s, n=200]


Interval 1: eval_valid: loss=8.2698, mae_e_per_atom=19.607 meV, mae_f=249.091 meV/Å, mae_s=39.159 meV/Å³


Interval 1: 100%|███████████████████| 50/50 [00:00<00:00, 58.55it/s, loss=0.265]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.54it/s, n=889]


Interval 0: eval_train: loss=0.3260, mae_e_per_atom=62.470 meV, mae_f=44.231 meV/Å, mae_s=74.979 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.78it/s, n=200]


Interval 0: eval_valid: loss=6.1873, mae_e_per_atom=24.421 meV, mae_f=224.627 meV/Å, mae_s=37.214 meV/Å³


Evaluating hessian train: 100%|█| 1000/1000 [00:12<00:00, 79.36it/s, n_graphs=20


interval 0: hessian train hessian MAE 0.892 eV/A2 


Evaluating hessian valid: 100%|██| 50/50 [00:00<00:00, 135.82it/s, n_graphs=100]


interval 0: hessian valid hessian MAE 0.956 eV/A2 


Interval 0: 100%|███████████████████| 50/50 [00:20<00:00,  2.48it/s, loss=0.143]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.65it/s, n=889]


Interval 1: eval_train: loss=0.2646, mae_e_per_atom=61.101 meV, mae_f=39.969 meV/Å, mae_s=74.410 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.75it/s, n=200]


Interval 1: eval_valid: loss=5.0671, mae_e_per_atom=21.306 meV, mae_f=204.976 meV/Å, mae_s=35.867 meV/Å³


Evaluating hessian train: 100%|█| 1000/1000 [00:07<00:00, 134.06it/s, n_graphs=2


interval 1: hessian train hessian MAE 0.816 eV/A2 


Evaluating hessian valid: 100%|██| 50/50 [00:00<00:00, 132.46it/s, n_graphs=100]


interval 1: hessian valid hessian MAE 0.866 eV/A2 


Interval 1: 100%|█████████████████████| 50/50 [00:01<00:00, 37.89it/s, loss=0.1]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.31it/s, n=889]


Interval 2: eval_train: loss=0.2264, mae_e_per_atom=60.225 meV, mae_f=36.641 meV/Å, mae_s=73.289 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 60.72it/s, n=200]


Interval 2: eval_valid: loss=4.1941, mae_e_per_atom=18.565 meV, mae_f=187.254 meV/Å, mae_s=34.677 meV/Å³


Evaluating hessian train: 100%|█| 1000/1000 [00:07<00:00, 133.15it/s, n_graphs=2


interval 2: hessian train hessian MAE 0.713 eV/A2 


Evaluating hessian valid: 100%|██| 50/50 [00:00<00:00, 133.18it/s, n_graphs=100]


interval 2: hessian valid hessian MAE 0.738 eV/A2 


Interval 2: 100%|██████████████████| 50/50 [00:01<00:00, 38.01it/s, loss=0.0946]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.82it/s, n=889]


Interval 3: eval_train: loss=0.1886, mae_e_per_atom=57.333 meV, mae_f=32.674 meV/Å, mae_s=72.406 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.90it/s, n=200]


Interval 3: eval_valid: loss=3.6191, mae_e_per_atom=17.186 meV, mae_f=174.708 meV/Å, mae_s=33.505 meV/Å³


Evaluating hessian train: 100%|█| 1000/1000 [00:07<00:00, 132.65it/s, n_graphs=2


interval 3: hessian train hessian MAE 0.635 eV/A2 


Evaluating hessian valid: 100%|██| 50/50 [00:00<00:00, 132.49it/s, n_graphs=100]


interval 3: hessian valid hessian MAE 0.639 eV/A2 


Interval 3: 100%|███████████████████| 50/50 [00:01<00:00, 37.76it/s, loss=0.086]
eval_train: 100%|██████████████████████████| 7/7 [00:00<00:00, 54.45it/s, n=889]


Interval 4: eval_train: loss=0.1557, mae_e_per_atom=56.043 meV, mae_f=28.257 meV/Å, mae_s=71.105 meV/Å³


eval_valid: 100%|████████████████████████| 40/40 [00:00<00:00, 59.89it/s, n=200]


Interval 4: eval_valid: loss=3.1565, mae_e_per_atom=15.912 meV, mae_f=163.530 meV/Å, mae_s=31.922 meV/Å³


Evaluating hessian train: 100%|█| 1000/1000 [00:07<00:00, 132.43it/s, n_graphs=2


interval 4: hessian train hessian MAE 0.559 eV/A2 


Evaluating hessian valid: 100%|██| 50/50 [00:00<00:00, 132.19it/s, n_graphs=100]


interval 4: hessian valid hessian MAE 0.545 eV/A2 


Interval 4: 100%|███████████████████| 50/50 [00:01<00:00, 37.67it/s, loss=0.072]
