In [None]:
!tar -xf dichalcogenides_private.tar.gz
!tar -xf dichalcogenides_public.tar.gz

In [None]:
import numpy as np
import tensorflow as tf
import random
import json
from pymatgen.core import Structure

def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

In [None]:
import yaml
import json
import random

import pandas as pd
import numpy as np
import tensorflow as tf
import argparse

from pathlib import Path
from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import MeanAbsoluteError

train_datapath = "dichalcogenides_public"

config = {
    "seed" : 17,
    "epochs" :800,
    "batch_size":128,
    "lr":0.001,
    "cutoff":4.0,
    "nblocks":3,
    "npass":2,
    "width":0.5,
    "nfeat_bond":100,
    "embedding_dim":16,
    "additional_data":False,
    "test_size":0.20,
}



checkpoint_filepath = "./checkpoint"
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="ewt",
    mode="max",
    save_best_only=True)

reduce_lr_callback = ReduceLROnPlateau(monitor="mean_absolute_error",mode="min", factor=0.2,patience=30, min_lr=0.00001,verbose=1)


def ewt(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

def prepare_dataset(config):
    dataset_path = Path(train_datapath)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return data

 
def prepare_model(config):
    nfeat_bond = config["nfeat_bond"]
    r_cutoff = config["cutoff"]
    gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
    gaussian_width = config["width"]

    pretrained_model = MEGNetModel.from_file("band_gap_regression.hdf5")
    model = MEGNetModel(
        graph_converter=CrystalGraph(cutoff=r_cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        nblocks=config["nblocks"],
        loss=["MAE"],
        npass=config["npass"],
        lr=config["lr"],
        embedding_dim=config["embedding_dim"],
        metrics=[ewt,tf.keras.metrics.MeanAbsoluteError()])
    weights = pretrained_model.get_weights()
    model.set_weights(weights) 
    #embedding_layer_index = [i for i, j in enumerate(model.layers) if j.name.startswith('atom_embedding')][0]
    #model.layers[embedding_layer_index].trainable = False

    return model

def main(config):
    seed_everything(config["seed"])
    train = prepare_dataset(config)
    model = prepare_model(config)
    model.train(
        train.structures,
        train.targets,
        epochs=config["epochs"],
        batch_size=config["batch_size"],
        callbacks=[model_checkpoint_callback,reduce_lr_callback],
        save_checkpoint=False,
    )


with open("config.json", 'w') as f:
    json.dump(config, f, indent=4)
main(config)

In [None]:
import yaml
import json
import pandas as pd

from pathlib import Path



checkpoint_path = "./checkpoint"
test_datapath = "dichalcogenides_private"

def submit(config):
    model = prepare_model(config)
    model.load_weights(checkpoint_path)

    dataset_path = Path(test_datapath)
    struct = {item.name.strip('.json'): read_pymatgen_dict(item) for item in (dataset_path/'structures').iterdir()}
    private_test = pd.DataFrame(columns=['id', 'structures'], index=struct.keys())
    private_test = private_test.assign(structures=struct.values())
    private_test = private_test.assign(predictions=model.predict_structures(private_test.structures))
    private_test[['predictions']].to_csv('./submission.csv', index_label='id')

with open("config.json") as f:
    config = json.loads(f.read())
submit(config)