# Binary compound formation energy prediction example

This notebook demonstrates how to create a probabilistic model for predicting
formation energies of binary compounds with a quantified uncertainty.


In [1]:
import shutil
from pathlib import Path

import pandas as pd
from megnet.models import MEGNetModel
from pymatgen.ext.matproj import MPRester
from tensorflow.keras.callbacks import TensorBoard
from unlockgnn.download import load_data
from unlockgnn.model import MEGNetProbModel
from unlockgnn.metrics import evaluate_uq_metrics


In [2]:
THIS_DIR = Path(".").parent
CONFIG_FILE = THIS_DIR / ".config"

MODEL_SAVE_DIR: Path = THIS_DIR / "binary_e_form_model"
LOG_DIR = THIS_DIR / "logs"
BATCH_SIZE: int = 128
NUM_INDUCING_POINTS: int = 500
OVERWRITE: bool = True
TRAINING_RATIO: float = 0.8

if OVERWRITE:
    for directory in [MODEL_SAVE_DIR, LOG_DIR]:
        if directory.exists():
            shutil.rmtree(directory)


# Data gathering

Here we download binary compounds that lie on the convex hull from the Materials
Project, then split them into training and validation subsets.


In [3]:
full_df = load_data("binary_e_form")
full_df.head()

Unnamed: 0,structure,formation_energy_per_atom
0,"[[ 1.982598 -4.08421341 3.2051745 ] La, [1....",-0.737439
1,"[[0. 0. 0.] Fe, [1.880473 1.880473 1.880473] H]",-0.068482
2,"[[1.572998 0. 0. ] Ta, [0. ...",-0.773151
3,"[[0. 0. 7.42288687] Hf, [0. ...",-0.177707
4,"[[ 1.823716 -3.94193291 3.47897025] Tm, [1....",-0.905038


In [4]:
num_training = int(TRAINING_RATIO * len(full_df.index))
train_df = full_df[:num_training]
val_df = full_df[num_training:]

print(f"{num_training} training samples, {len(val_df.index)} validation samples.")

train_structs = train_df["structure"]
val_structs = val_df["structure"]

train_targets = train_df["formation_energy_per_atom"]
val_targets = val_df["formation_energy_per_atom"]


4217 training samples, 1055 validation samples.


# Model creation

Now we load the `MEGNet` 2019 formation energies model, then convert this to a
probabilistic model.


In [5]:
meg_model = MEGNetModel.from_mvl_models("Eform_MP_2019")


INFO:megnet.utils.models:Package-level mvl_models not included, trying temperary mvl_models downloads..
INFO:megnet.utils.models:Model found in local mvl_models path


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [6]:
kl_weight = BATCH_SIZE / num_training

prob_model = MEGNetProbModel(
    num_inducing_points=NUM_INDUCING_POINTS,
    save_path=MODEL_SAVE_DIR,
    meg_model=meg_model,
    kl_weight=kl_weight,
)


INFO:tensorflow:Assets written to: binary_e_form_model/megnet/assets


INFO:tensorflow:Assets written to: binary_e_form_model/megnet/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


Instructions for updating:
`jitter` is deprecated; please use `marginal_fn` directly.


Instructions for updating:
`jitter` is deprecated; please use `marginal_fn` directly.


# Train the uncertainty quantifier

Now we train the model. By default, the `MEGNet` (GNN) layers of the model are
frozen after initialization. Therefore, when we call `prob_model.train()`, the
only layers that are optimized are the `VariationalGaussianProcess` (VGP) and the
`BatchNormalization` layer (`Norm`) that feeds into it.

After this initial training, we unfreeze _all_ the layers and train the full model simulateously.


In [7]:
tb_callback_1 = TensorBoard(log_dir=LOG_DIR / "vgp_training", write_graph=False)
tb_callback_2 = TensorBoard(log_dir=LOG_DIR / "fine_tuning", write_graph=False)


In [8]:
%load_ext tensorboard
%tensorboard --logdir logs

In [9]:
prob_model.train(
    train_structs,
    train_targets,
    epochs=50,
    val_structs=val_structs,
    val_targets=val_targets,
    callbacks=[tb_callback_1],
)
prob_model.save()


Epoch 1/50




33/33 - 21s - loss: 2044489.2500 - val_loss: 1890266.8750
Epoch 2/50
33/33 - 9s - loss: 1892080.8750 - val_loss: 1733111.5000
Epoch 3/50
33/33 - 9s - loss: 1611927.8750 - val_loss: 1363083.5000
Epoch 4/50
33/33 - 9s - loss: 1240741.3750 - val_loss: 955943.0000
Epoch 5/50
33/33 - 9s - loss: 890893.8750 - val_loss: 602054.0000
Epoch 6/50
33/33 - 9s - loss: 571936.0000 - val_loss: 337989.4375
Epoch 7/50
33/33 - 9s - loss: 344380.3438 - val_loss: 171421.3906
Epoch 8/50
33/33 - 9s - loss: 211042.9062 - val_loss: 106291.5469
Epoch 9/50
33/33 - 9s - loss: 155044.4375 - val_loss: 83854.5938
Epoch 10/50
33/33 - 9s - loss: 123114.7344 - val_loss: 67730.7812
Epoch 11/50
33/33 - 9s - loss: 107424.1094 - val_loss: 58536.0352
Epoch 12/50
33/33 - 9s - loss: 88169.8984 - val_loss: 52115.5156
Epoch 13/50
33/33 - 9s - loss: 76644.9922 - val_loss: 45150.8438
Epoch 14/50
33/33 - 9s - loss: 70484.7188 - val_loss: 46568.0898
Epoch 15/50
33/33 - 9s - loss: 61135.7930 - val_loss: 37954.0234
Epoch 16/50
33/33 

In [10]:
prob_model.set_frozen(["GNN", "VGP"], freeze=False)


In [11]:
prob_model.train(
    train_structs,
    train_targets,
    epochs=50,
    val_structs=val_structs,
    val_targets=val_targets,
    callbacks=[tb_callback_2],
)
prob_model.save()


Epoch 1/50
33/33 - 20s - loss: 63314.0469 - val_loss: 221106.9062
Epoch 2/50
33/33 - 9s - loss: 30254.3105 - val_loss: 51659.4648
Epoch 3/50
33/33 - 9s - loss: 23182.4062 - val_loss: 23414.0781
Epoch 4/50
33/33 - 9s - loss: 19415.3867 - val_loss: 12821.4141
Epoch 5/50
33/33 - 9s - loss: 16007.1553 - val_loss: 9837.3643
Epoch 6/50
33/33 - 9s - loss: 15056.1191 - val_loss: 11087.1143
Epoch 7/50
33/33 - 9s - loss: 12715.9199 - val_loss: 23158.5977
Epoch 8/50
33/33 - 9s - loss: 9836.6113 - val_loss: 13169.1006
Epoch 9/50
33/33 - 9s - loss: 9448.9180 - val_loss: 15868.6494
Epoch 10/50
33/33 - 9s - loss: 9751.2949 - val_loss: 11778.0898
Epoch 11/50
33/33 - 9s - loss: 10222.7979 - val_loss: 10327.6270
Epoch 12/50
33/33 - 9s - loss: 11687.4141 - val_loss: 7964.2944
Epoch 13/50
33/33 - 9s - loss: 7418.4922 - val_loss: 9480.5498
Epoch 14/50
33/33 - 9s - loss: 9931.1279 - val_loss: 10010.1680
Epoch 15/50
33/33 - 9s - loss: 8833.6992 - val_loss: 9882.1807
Epoch 16/50
33/33 - 9s - loss: 7542.1870 -

# Model evaluation

Finally, we'll evaluate model metrics and make some sample predictions! Note that the predictions give predicted values and standard deviations. The standard deviations can then be converted to an uncertainty;
in this example, we'll take the uncertainty as twice the standard deviation, which will give us the 95% confidence interval (see <https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule>).


In [12]:
example_structs = val_structs[:10].tolist()
example_targets = val_targets[:10].tolist()

predicted, stddevs = prob_model.predict(example_structs)
uncerts = 2 * stddevs


In [13]:
pd.DataFrame(
    {
        "Composition": [struct.composition.reduced_formula for struct in example_structs],
        "Formation energy per atom / eV": example_targets,
        "Predicted / eV": [
            f"{pred:.2f} ± {uncert:.2f}" for pred, uncert in zip(predicted, uncerts)
        ],
    }
)


Unnamed: 0,Composition,Formation energy per atom / eV,Predicted / eV
0,Zr2Cu,-0.132384,-0.16 ± 0.04
1,NbRh,-0.401313,-0.47 ± 0.05
2,Cu3Ge,-0.005707,-0.08 ± 0.05
3,Pr3In,-0.273232,-0.36 ± 0.04
4,InS,-0.742895,-0.91 ± 0.04
5,TmPb3,-0.215892,-0.25 ± 0.03
6,InNi,-0.174754,-0.26 ± 0.04
7,GdGe,-0.857117,-0.88 ± 0.08
8,GdTl,-0.380423,-0.47 ± 0.04
9,HoTl3,-0.215986,-0.26 ± 0.04


In [14]:
val_metrics = evaluate_uq_metrics(prob_model, val_structs, val_targets)
train_metrics = evaluate_uq_metrics(prob_model, train_structs, train_targets)

print(f"{val_metrics=}")
print(f"{train_metrics=}")




val_metrics={'nll': 3837.821282918251, 'sharpness': 0.029964437195619257, 'variation': 0.4809019003733541, 'mae': 0.06154699488300184, 'mse': 0.00690016962570721, 'rmse': 0.08306725964967913}
train_metrics={'nll': -2010.5427565007687, 'sharpness': 0.030200000085766555, 'variation': 0.46698754909699147, 'mae': 0.04274989286400665, 'mse': 0.0024746102763464678, 'rmse': 0.04974545483103424}
