# Prediction test

## Imports

In [None]:
from itertools import chain
from os import makedirs
from pathlib import Path
import numpy as np
import pandas as pd
from ocpmodels.datasets import SinglePointLmdbDataset

## Variables

In [None]:
ROOT_DIR = Path("./predictions")            # Root of the predictions

TARBALL = ROOT_DIR/"ocp_predictions.tar.xz" # Location of the dataset tarball
# TARBALL = None                            # Set to False or None to avoid extraction.
INITIAL_GEOMETRY = "poscar"                 # Either use poscar or contcar predictions
GEOM_MODEL = "full"                         # Geometric model, either full or ensemble
GNN_MODEL = "gemnet"                        # GNN model, either dpp, painn or gemnet
DS_NAME = f"""\
lmdb_fg_{GEOM_MODEL}_{INITIAL_GEOMETRY}\
"""                
DS_DIR = ROOT_DIR/DS_NAME                   # Dataset DIR
PREDICT_DIR = DS_DIR/GNN_MODEL              # Prediction Dir

In [None]:
# Extract tarball to DS_DIR location
if TARBALL:
    import tarfile
    tar_ds = tarfile.open(TARBALL, mode="r:xz")
    tar_ds.extractall(ROOT_DIR)
    tar_ds.close()

### Group Names

Translate folder group names to the chemical families names shown in the manuscript.

In [None]:
group_to_family_dict = {
    "carbamate_esters": "Carbamates"
     , "aromatics": "Aromatics"
     , "aromatics2": "Aromatics"
     , "oximes": "Oximes"
     , "group2": "$C_{x}H_{y}O_{(0,1)}$"
     , "group2b": "$C_{x}H_{y}O_{(0,1)}$"
     , "amides": "Amides"
     , "amidines": "Amidines"
     , "group3S": "$C_{x}H_{y}S$"
     , "group3N": "$C_{x}H_{y}N$"
     , "group4": "$C_{x}H_{y}O_{(2,3)}$"
     , "metal_surfaces": "metal"
}

## Read Predictions

In [None]:
def arr_load_n_dict(f):
    arr = np.load(f)
    return map(
        lambda s: {"sid": s[0], "e_pred": s[1]}
        , zip(np.asarray(arr["ids"], dtype=int)
                 , np.asarray(arr["energy"], dtype=float)))

cross_data_test_preds = chain.from_iterable(map(
    lambda ci: map(
        lambda x: x | {"index": str(ci.stem).split("_")[-1]}
        , arr_load_n_dict(ci))
    , PREDICT_DIR.glob(f"./predictions*")))

## Collect the data in a dataframe

In [None]:
crossval_df = pd.read_csv(DS_DIR/"ds_data.csv"
                         , names=("sid", "name", "family", "e_true"))
crossval_df = crossval_df.merge(pd.DataFrame(cross_data_test_preds))
crossval_df.drop(crossval_df[crossval_df['family'] == "metal"].index, inplace=True)
# Remove sid column to prettify the output
crossval_df.drop(["sid"], inplace=True, axis=1) 
# Pretify family names
crossval_df["family"] = crossval_df["family"].apply(lambda x: group_to_family_dict[x])
# Compute the error
crossval_df["error"] = np.abs(crossval_df["e_true"] - crossval_df["e_pred"]) 

## Results

### Mean Average Error (MAE)

In [None]:
crossval_df.groupby("family").mean()

### Standard Error of the Mean (SEM)

In [None]:
crossval_df.groupby(["family", "index"]).mean().groupby("family").std()/np.sqrt(20)