# Binary compound formation energy prediction example

This notebook demonstrates how to create a probabilistic model for predicting
formation energies of binary compounds with a quantified uncertainty.


In [1]:
import shutil
from pathlib import Path
from pprint import pprint

import pandas as pd
from megnet.models import MEGNetModel
from pymatgen.ext.matproj import MPRester
from tensorflow.keras.callbacks import TensorBoard
from unlockgnn.download import load_data
from unlockgnn.model import MEGNetProbModel
from unlockgnn.metrics import evaluate_uq_metrics


In [2]:
THIS_DIR = Path(".").parent
CONFIG_FILE = THIS_DIR / ".config"

MODEL_SAVE_DIR: Path = THIS_DIR / "binary_e_form_model"
LOG_DIR = THIS_DIR / "logs"
BATCH_SIZE: int = 128
NUM_INDUCING_POINTS: int = 500
OVERWRITE: bool = True
TRAINING_RATIO: float = 0.8

if OVERWRITE:
    for directory in [MODEL_SAVE_DIR, LOG_DIR]:
        if directory.exists():
            shutil.rmtree(directory)


# Data gathering

Here we download binary compounds that lie on the convex hull from the Materials
Project, then split them into training and validation subsets.


In [3]:
full_df = load_data("binary_e_form")
full_df.head()

Unnamed: 0,structure,formation_energy_per_atom
0,"[[ 1.982598 -4.08421341 3.2051745 ] La, [1....",-0.737439
1,"[[0. 0. 0.] Fe, [1.880473 1.880473 1.880473] H]",-0.068482
2,"[[1.572998 0. 0. ] Ta, [0. ...",-0.773151
3,"[[0. 0. 7.42288687] Hf, [0. ...",-0.177707
4,"[[ 1.823716 -3.94193291 3.47897025] Tm, [1....",-0.905038


In [4]:
num_training = int(TRAINING_RATIO * len(full_df.index))
train_df = full_df[:num_training]
val_df = full_df[num_training:]

print(f"{num_training} training samples, {len(val_df.index)} validation samples.")

train_structs = train_df["structure"]
val_structs = val_df["structure"]

train_targets = train_df["formation_energy_per_atom"]
val_targets = val_df["formation_energy_per_atom"]


4217 training samples, 1055 validation samples.


# Model creation

Now we load the `MEGNet` 2019 formation energies model, then convert this to a
probabilistic model.


In [5]:
meg_model = MEGNetModel.from_mvl_models("Eform_MP_2019")


INFO:megnet.utils.models:Package-level mvl_models not included, trying temperary mvl_models downloads..
INFO:megnet.utils.models:Model found in local mvl_models path


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [6]:
kl_weight = BATCH_SIZE / num_training

prob_model = MEGNetProbModel(
    meg_model=meg_model,
    num_inducing_points=NUM_INDUCING_POINTS,
    kl_weight=kl_weight,
)


Instructions for updating:
`jitter` is deprecated; please use `marginal_fn` directly.


Instructions for updating:
`jitter` is deprecated; please use `marginal_fn` directly.


# Train the uncertainty quantifier

Now we train the model. By default, the `MEGNet` (GNN) layers of the model are
frozen after initialization. Therefore, when we call `prob_model.train()`, the
only layers that are optimized are the `VariationalGaussianProcess` (VGP) and the
`BatchNormalization` layer (`Norm`) that feeds into it.

After this initial training, we unfreeze _all_ the layers and train the full model simulateously.


In [7]:
tb_callback_1 = TensorBoard(log_dir=LOG_DIR / "vgp_training", write_graph=False)
tb_callback_2 = TensorBoard(log_dir=LOG_DIR / "fine_tuning", write_graph=False)


In [8]:
%load_ext tensorboard
%tensorboard --logdir logs

In [9]:
prob_model.train(
    train_structs,
    train_targets,
    epochs=50,
    val_structs=val_structs,
    val_targets=val_targets,
    callbacks=[tb_callback_1],
)
prob_model.save(MODEL_SAVE_DIR)


Epoch 1/50




33/33 - 21s - loss: 2043370.6250 - val_loss: 1889545.3750
Epoch 2/50
33/33 - 9s - loss: 1891937.8750 - val_loss: 1724109.3750
Epoch 3/50
33/33 - 9s - loss: 1623211.0000 - val_loss: 1369590.3750
Epoch 4/50
33/33 - 9s - loss: 1263507.3750 - val_loss: 980873.8750
Epoch 5/50
33/33 - 9s - loss: 910178.8125 - val_loss: 641627.2500
Epoch 6/50
33/33 - 9s - loss: 597134.5625 - val_loss: 359959.3750
Epoch 7/50
33/33 - 9s - loss: 342737.5000 - val_loss: 171222.0312
Epoch 8/50
33/33 - 9s - loss: 207822.9062 - val_loss: 104365.7812
Epoch 9/50
33/33 - 9s - loss: 150359.3438 - val_loss: 80415.6016
Epoch 10/50
33/33 - 9s - loss: 125844.5938 - val_loss: 67596.2812
Epoch 11/50
33/33 - 9s - loss: 104631.8047 - val_loss: 61675.3828
Epoch 12/50
33/33 - 9s - loss: 92663.6562 - val_loss: 51817.0273
Epoch 13/50
33/33 - 9s - loss: 76480.6797 - val_loss: 45383.0703
Epoch 14/50
33/33 - 8s - loss: 65874.5469 - val_loss: 41886.4219
Epoch 15/50
33/33 - 9s - loss: 55451.9219 - val_loss: 36796.4844
Epoch 16/50
33/33 

INFO:tensorflow:Assets written to: binary_e_form_model/megnet/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


In [10]:
prob_model.set_frozen(["GNN", "VGP"], freeze=False)


In [11]:
prob_model.train(
    train_structs,
    train_targets,
    epochs=50,
    val_structs=val_structs,
    val_targets=val_targets,
    callbacks=[tb_callback_2],
)
prob_model.save(MODEL_SAVE_DIR)


Epoch 1/50




33/33 - 21s - loss: 53061.4062 - val_loss: 481662.5625
Epoch 2/50
33/33 - 9s - loss: 33067.5000 - val_loss: 65340.0273
Epoch 3/50
33/33 - 9s - loss: 25956.3125 - val_loss: 29207.7422
Epoch 4/50
33/33 - 9s - loss: 21629.7422 - val_loss: 15283.6445
Epoch 5/50
33/33 - 9s - loss: 14493.3877 - val_loss: 11428.2402
Epoch 6/50
33/33 - 9s - loss: 15374.6836 - val_loss: 12495.3408
Epoch 7/50
33/33 - 9s - loss: 17335.4980 - val_loss: 11754.5166
Epoch 8/50
33/33 - 9s - loss: 12458.4238 - val_loss: 10183.0635
Epoch 9/50
33/33 - 9s - loss: 12680.4619 - val_loss: 10713.4814
Epoch 10/50
33/33 - 9s - loss: 16434.4258 - val_loss: 12243.9531
Epoch 11/50
33/33 - 9s - loss: 12405.5117 - val_loss: 8716.2539
Epoch 12/50
33/33 - 9s - loss: 15193.8047 - val_loss: 8424.6230
Epoch 13/50
33/33 - 9s - loss: 12915.4229 - val_loss: 9271.1270
Epoch 14/50
33/33 - 9s - loss: 10881.9014 - val_loss: 10849.7568
Epoch 15/50
33/33 - 9s - loss: 6977.2739 - val_loss: 10196.3164
Epoch 16/50
33/33 - 9s - loss: 9455.2100 - val_

INFO:tensorflow:Assets written to: binary_e_form_model/megnet/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


INFO:tensorflow:Assets written to: binary_e_form_model/gnn/assets


# Model evaluation

Finally, we'll evaluate model metrics and make some sample predictions! Note that the predictions give predicted values and standard deviations. The standard deviations can then be converted to an uncertainty;
in this example, we'll take the uncertainty as twice the standard deviation, which will give us the 95% confidence interval (see <https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule>).


In [12]:
example_structs = val_structs[:10].tolist()
example_targets = val_targets[:10].tolist()

predicted, stddevs = prob_model.predict(example_structs)
uncerts = 2 * stddevs




In [13]:
pd.DataFrame(
    {
        "Composition": [struct.composition.reduced_formula for struct in example_structs],
        "Formation energy per atom / eV": example_targets,
        "Predicted / eV": [
            f"{pred:.2f} ± {uncert:.2f}" for pred, uncert in zip(predicted, uncerts)
        ],
    }
)


Unnamed: 0,Composition,Formation energy per atom / eV,Predicted / eV
0,Zr2Cu,-0.132384,-0.14 ± 0.03
1,NbRh,-0.401313,-0.52 ± 0.04
2,Cu3Ge,-0.005707,-0.07 ± 0.04
3,Pr3In,-0.273232,-0.28 ± 0.04
4,InS,-0.742895,-0.72 ± 0.04
5,TmPb3,-0.215892,-0.25 ± 0.04
6,InNi,-0.174754,-0.20 ± 0.04
7,GdGe,-0.857117,-0.82 ± 0.08
8,GdTl,-0.380423,-0.35 ± 0.03
9,HoTl3,-0.215986,-0.21 ± 0.03


In [14]:
val_metrics = evaluate_uq_metrics(prob_model, val_structs, val_targets)
train_metrics = evaluate_uq_metrics(prob_model, train_structs, train_targets)

print("Validation metrics:")
pprint(val_metrics)
print("Training metrics:")
pprint(train_metrics)




Validation metrics:
{'mae': 0.049803377508653594,
 'mse': 0.00571253439208755,
 'nll': 1270.362237633195,
 'rmse': 0.07558130980664168,
 'sharpness': 0.030653343161385096,
 'variation': 0.5497132304190294}
Training metrics:
{'mae': 0.02887053388680764,
 'mse': 0.0017579980037584235,
 'nll': -7803.821887522724,
 'rmse': 0.04192848678116614,
 'sharpness': 0.030605892140298463,
 'variation': 0.533647833791028}
