# Training a model with uncertainty quantification
This notebook demonstrates how to train a `MEGNetProbModel` to predict target values
and uncertainties for a benchmark dataset from `matminer`.

In [1]:
from pathlib import Path

import numpy as np
from matminer.datasets import get_all_dataset_info, load_dataset
from megnet.data.crystal import CrystalGraph
from sklearn.model_selection import train_test_split
from sse_gnn import MEGNetProbModel
from sse_gnn.datalib.metrics import MetricAnalyser

In [2]:
SAVE_DIR = Path.home() / "matbench_dielectric"
DATASET = "matbench_dielectric"
TARGET_VAR = "n"

In [3]:
print(get_all_dataset_info(DATASET))

Dataset: matbench_dielectric
Description: Matbench v0.1 test dataset for predicting refractive index from structure. Adapted from Materials Project database. Removed entries having a formation energy (or energy above the convex hull) more than 150meV and those having refractive indices less than 1 and those containing noble gases. Retrieved April 2, 2019.
Columns:
	n: Target variable. Refractive index (unitless).
	structure: Pymatgen Structure of the material.
Num Entries: 4764
Reference: Petousis, I., Mrdjenovich, D., Ballouz, E., Liu, M., Winston, D.,
Chen, W., Graf, T., Schladt, T. D., Persson, K. A. & Prinz, F. B.
High-throughput screening of inorganic compounds for the discovery
of novel dielectric and optical materials. Sci. Data 4, 160134 (2017).
Bibtex citations: ['@article{Jain2013,\nauthor = {Jain, Anubhav and Ong, Shyue Ping and Hautier, Geoffroy and Chen, Wei and Richards, William Davidson and Dacek, Stephen and Cholia, Shreyas and Gunter, Dan and Skinner, David and Ceder, 

In [4]:
data = load_dataset(DATASET)

In [5]:
train_df, test_df = train_test_split(data, random_state=2020)

print(train_df.describe())
print(test_df.describe())

n
count  3573.000000
mean      2.435454
std       2.298581
min       1.000000
25%       1.677955
50%       2.040549
75%       2.587695
max      62.062998
                 n
count  1191.000000
mean      2.406894
std       1.349136
min       1.000000
25%       1.712178
50%       2.105446
75%       2.630097
max      21.911912


In [6]:
nfeat_bond = 10
r_cutoff = 5
gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
gaussian_width = 0.5
graph_converter = CrystalGraph(cutoff=r_cutoff)
meg_args = {
    "graph_converter": graph_converter,
    "centers": gaussian_centers,
    "width": gaussian_width,
    "metrics": ["mae", "mape"],
}

In [7]:
prob_model = MEGNetProbModel(
    train_df["structure"],
    train_df[TARGET_VAR],
    "VGP",
    test_df["structure"],
    test_df[TARGET_VAR],
    SAVE_DIR,
    num_inducing_points=200,
    **meg_args
)

In [8]:
prob_model.train_meg_model(epochs=10, batch_size=32)

ut of the last 9 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f4744cf8af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.


KeyboardInterrupt: 

In [None]:
prob_model.train_uq(epochs=10)

In [None]:
prob_model.save(train_df.index, test_df.index)