In [None]:

import sys, os
sys.path.insert(0, '../src')
sys.path.insert(0, '../src/gnn_eads')
from pymatgen.io.vasp import Outcar
from pyRDTP.geomio import file_to_mol
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import torch
import matplotlib.pyplot as plt
import numpy as np

from gnn_eads.graph_tools import plotter
from gnn_eads.functions import contcar_to_graph
from gnn_eads.nets import PreTrainedModel

MODEL_NAME = "best_model"
MODEL_PATH = "../models/{}".format(MODEL_NAME)  # Name of the model (must be present in the "Models" directory)
model = PreTrainedModel(MODEL_PATH)

ni_energy = Outcar("../data/BM_dataset/Biomass/ni-0000/OUTCAR").final_energy
ru_energy = Outcar("../data/BM_dataset/Biomass/ru-0000/OUTCAR").final_energy
sur = ["ni", "ru"]
pu = ["mol1", "mol2", "mol3", "mol4", "mol5"]
sur_energy = {"ni": ni_energy, "ru": ru_energy}
error = np.zeros((len(pu)*(1 + len(sur)),1))
error_per_atom = np.zeros((len(pu)*(1 + len(sur)),1))
abs_error = np.zeros((len(pu)*(1 + len(sur)),1))
counter = 0
for metal in sur:
    for molecule in pu:
        calc = "{}-{}".format(metal, molecule)
        system = "../data/BM_dataset/Biomass/{}/CONTCAR".format(calc)
        s = Outcar("../data/BM_dataset/Biomass/{}/OUTCAR".format(calc)).final_energy
        graph = contcar_to_graph(system, model.g_tol, model.g_sf, model.g_metal_2nn)
        gnn_energy = model.evaluate(graph)
        dft_energy = s - sur_energy[metal]
        absolute_error = abs(gnn_energy - dft_energy)
        print("-----------------------------------")
        plotter(graph)
        plt.show()
        print("System: {}-{}".format(metal, molecule))
        print("GNN energy = {:.2f} eV ".format(gnn_energy))
        print("VASP energy = {} eV".format(dft_energy))
        print("Abs. Error = {:.2f} eV".format(absolute_error))
        error[counter] = dft_energy - gnn_energy
        abs_error[counter] = absolute_error
        error_per_atom[counter] = abs_error[counter] / graph.num_nodes
        counter += 1
for molecule in pu:  # gas molecules
    system = "../data/BM_dataset/Biomass/{}/CONTCAR".format(molecule)
    s = Outcar("../data/BM_dataset/Biomass/{}/OUTCAR".format(molecule)).final_energy
    graph = contcar_to_graph(system, model.g_tol, model.g_sf, model.g_metal_2nn)
    gnn_energy = model.evaluate(graph)
    dft_energy = s
    absolute_error = abs(gnn_energy - dft_energy)
    print("-----------------------------------")
    plotter(graph)
    plt.show()
    print("System: {}".format(molecule))
    print("GNN energy = {:.2f} eV ".format(gnn_energy))
    print("VASP energy = {} eV".format(s))
    print("Abs. Error = {:.2f} eV".format(absolute_error))
    error[counter] = dft_energy - gnn_energy
    abs_error[counter] = absolute_error
    error_per_atom[counter] = abs_error[counter] / graph.num_nodes
    counter += 1
print("----------BIOMASS-----------------")
print("MAE = {:.2f} eV".format(np.mean(abs_error)))
print("MAE/atom = {:.2f} eV/atom".format(np.mean(error_per_atom)))
print("----------------------------------")
        