In [None]:
!pip install megnet -qq

In [None]:
!pip install pymatgen -qq

In [None]:
import yaml
import json

import pandas as pd
import numpy as np
import tensorflow as tf

from pathlib import Path
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph, CrystalGraphDisordered

In [None]:
class config:
    datapath = 'idao_2022_data/dichalcogenides_public/dichalcogenides_public'
    test_datapath = 'idao_2022_data/dichalcogenides_private/dichalcogenides_private'
    checkpoint_path = 'model.hdf5'
    epochs = 1500
    batch_size = 64
    lr = 2e-4 
    cutoff = 4
    from_file = False
    from_MVL = True
    model = 'Efermi_MP_2019' # "../input/mvl-models/mvl_models/mp-2018.6.1/formation_energy.hdf5"
    fold = 0

In [None]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return data

def prepare_model(cutoff, lr):
    nfeat_bond = 10
    r_cutoff = cutoff
    gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
    gaussian_width = 0.8
    model = MEGNetModel(
        graph_converter=CrystalGraph(cutoff=r_cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        loss=["categorical_crossentropy"],
        ntarget = 3,
        npass=2,
        learning_rate=lr
    )
    model_form = MEGNetModel.from_mvl_models('Efermi_MP_2019')
    
    embedding_layer = [i for i in model_form.layers if i.name.startswith('embedding')][0]
    embedding = embedding_layer.get_weights()[0]
    
    embedding_layer_index = [i for i, j in enumerate(model.layers) if j.name.startswith('atom_embedding')][0]

    # Set the weights to our previous embedding
    model.layers[embedding_layer_index].set_weights([embedding])
    
    return model

In [None]:
data = prepare_dataset(config.datapath)
folds = pd.read_csv('IDAO_Data_Folds.csv')
data['Fold'] = folds['Fold'].values
data['bins'] = pd.qcut(data['targets'], q = 3, labels = [0, 1, 2]).astype('int')
train = data[data['Fold'] != config.fold]
test = data[data['Fold'] == config.fold]

In [None]:
model = prepare_model(
    float(config.cutoff),
    float(config.lr), 
)


In [None]:
model.train(
    train.structures,
    train.bins,
    validation_structures=test.structures,
    validation_targets=test.bins,
    epochs=int(config.epochs),
    batch_size=int(config.batch_size),
)