In [None]:
!pip install pymatgen==2022.0.17
!pip install megnet==1.2.9

In [None]:
import yaml
import json

import pandas as pd
import numpy as np
import tensorflow as tf

from pathlib import Path
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph

In [None]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return train_test_split(data, test_size=0.25, random_state=666)

 
def prepare_model(cutoff, lr):
    nfeat_bond = 10
    r_cutoff = cutoff
    gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
    gaussian_width = 0.8
    
    return MEGNetModel(
        graph_converter=CrystalGraph(cutoff=r_cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        loss=["MAE"],
        npass=2,
        lr=lr,
        metrics=energy_within_threshold
    )


In [None]:
!tar -xvf ./dichalcogenides_private.tar.gz

In [None]:
!tar -xvf ./dichalcogenides_public.tar.gz

In [None]:
with open("config.yaml") as file:
        config = yaml.safe_load(file)
train, test = prepare_dataset(config["datapath"])
model = prepare_model(
        float(config["model"]["cutoff"]),
        float(config["model"]["lr"]), 
    )

In [None]:
model.train(
        train.structures,
        train.targets,
        validation_structures=test.structures,
        validation_targets=test.targets,
        epochs=int(config["model"]["epochs"]),
        batch_size=int(config["model"]["batch_size"]),)

In [None]:
dataset_path = Path(config['test_datapath'])
struct = {item.name.strip('.json'): read_pymatgen_dict(item) for item in (dataset_path/'structures').iterdir()}

In [None]:
private_test = pd.DataFrame(columns=['id', 'structures'], index=struct.keys())
private_test = private_test.assign(structures=struct.values())
private_test = private_test.assign(predictions=model.predict_structures(private_test.structures))
private_test[['predictions']].to_csv('./test.csv', index_label='id')

In [None]:
pd.read_csv('test.csv', index_col='id')