In [None]:
from pathlib import Path
import json
import numpy as np
from scipy import stats
from pymatgen.core.structure import Structure

import pandas as pd
from ase.db import connect

from mlip_arena.models import REGISTRY, MLIPEnum


def load_wbm_structures():
    """
    Load the WBM structures from a ASE DB file.
    """
    with connect("../wbm_structures.db") as db:
        for row in db.select():
            yield row.toatoms(add_additional_information=True)


# for model in MLIPEnum:

#     if "eos-bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
#         continue

#     all_data = []

#     for atoms in load_wbm_structures():

#         fpath = Path(model.name) / f"{atoms.info['key_value_pairs']['wbm_id']}.pkl"
#         if not fpath.exists():
#             continue

#         all_data.append(pd.read_pickle(fpath))

#     df = pd.concat(all_data, ignore_index=True)
#     df.to_parquet(f"{model.name}.parquet")


summary_table = pd.DataFrame(columns=[
        "model", 
        "energy-diff-flip-times",
        "tortuosity",
        "spearman-repulsion-energy",
        "spearman-repulsion-derivative",
        "spearman-attraction-energy",
        "missing_predictions",
        # "energy-grad-norm-max",
        # "energy-jump",
        # "energy-total-variation",
        # no need for attraction-derivative as this does not monotonically increase/decrease
    ])

for model in MLIPEnum:
    if "eos-bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
        continue

    try:
        df_raw_results = pd.read_parquet(f"{model.name}.parquet")
    except Exception as e:
        print(f"Results for {model.name} have not been computed for the EoS-bulk task.")
        continue


    df_analyzed = pd.DataFrame(columns=[
        "model", 
        "structure", 
        "volume_per_atom", 
        "E",
        "energy-diff-flip-times",
        "tortuosity",
        "spearman-repulsion-energy",
        "spearman-repulsion-derivative",
        "spearman-attraction-energy",
        # "energy-grad-norm-max",
        # "energy-jump",
        # "energy-total-variation",
        # no need for attraction-derivative as this does not monotonically increase/decrease
    ])

    for wbm_struct in load_wbm_structures():
        structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"]
        struct = Structure.from_ase_atoms(wbm_struct)


        try:
            results = df_raw_results.loc[df_raw_results["id"] == structure_id]
            results = results["eos"].values[0]
            es = np.array(results["energies"])
            vols = np.array(results["volumes"])
            
            indices = np.argsort(vols)[::-1]
            vols = vols[indices]
            es = es[indices]
            eshift = es[0]
            es -= eshift

            imine = len(es) // 2

            de_dr = np.gradient(es, vols)
            d2e_dr2 = np.gradient(de_dr, vols)

            

            interpolated_volumes = [(vols[i] + vols[i+1])/2 for i in range(0, len(vols)-1)]
            ediff = np.diff(es)
            ediff_orig = ediff
            ediff[np.abs(ediff) < 1e-3] = 0 # 1meV
            ediff_sign = np.sign(ediff)
            mask = ediff_sign != 0
            ediff = ediff[mask]
            ediff_sign = ediff_sign[mask]
            ediff_flip = np.diff(ediff_sign) != 0
            ejump = np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()


        
            etv = np.sum(np.abs(np.diff(es)))

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": False, 
                "volume_per_atom": vols,
                "E": es + eshift,
                "energy-diff-flip-times": np.sum(ediff_flip),
                "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
                "spearman-repulsion-energy": stats.spearmanr(vols[imine:], es[imine:]).statistic,
                "spearman-repulsion-derivative": stats.spearmanr(interpolated_volumes[imine:], ediff_orig[imine:]).statistic,
                "spearman-attraction-energy": stats.spearmanr(vols[:imine], es[:imine]).statistic,
                # "energy-grad-norm-max": np.max(np.abs(de_dr)),
                # "energy-jump": ejump,
                # "energy-total-variation": etv,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
        except Exception as e:

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": True, 
                "volume_per_atom": None,
                "E": None,
                "energy-diff-flip-times": None,
                "tortuosity": None,
                "spearman-repulsion-energy": None,
                "spearman-repulsion-derivative": None,
                "spearman-attraction-energy": None,
                # "energy-grad-norm-max": np.max(np.abs(de_dr)),
                # "energy-jump": ejump,
                # "energy-total-variation": etv,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
    
    
    json_fpath = Path(f"eos_analyzed_{model.name}.json")

    df_analyzed.to_json(json_fpath, orient="records")
    print(df_analyzed["volume_per_atom"])

    valid_results = df_analyzed[df_analyzed["missing_prediction"] == False]
    valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
    # print(valid_results.mean(numeric_only=True))

    analysis_summary = {
        "model": model.name, 
        "energy-diff-flip-times": valid_results["energy-diff-flip-times"].mean(),
        "tortuosity": valid_results["tortuosity"].mean(),
        "spearman-repulsion-energy": valid_results["spearman-repulsion-energy"].mean(),
        "spearman-repulsion-derivative": valid_results["spearman-repulsion-derivative"].mean(),
        "spearman-attraction-energy": valid_results["spearman-attraction-energy"].mean(),
        "missing_predictions": len(df_analyzed[df_analyzed["missing_prediction"] == True]),
    }
    summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)

summary_table


need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFram

need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], igno

need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], igno

need to add EoS to registry


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


need to add EoS to registry
Results for eqV2(OMat) have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for EquiformerV2(OC22) have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for EquiformerV2(OC20) have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for eSCN(OC20) have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for MACE-OFF(M) have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for ANI2x have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for ALIGNN have not been computed for the EoS-bulk task.
need to add EoS to registry
Results for ORB have not been computed for the EoS-bulk task.


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)


Unnamed: 0,model,energy-diff-flip-times,tortuosity,spearman-repulsion-energy,spearman-repulsion-derivative,spearman-attraction-energy,missing_predictions
0,MACE-MP(M),1.042211,1.007933,-0.999095,-0.994125,0.993082,5
1,CHGNet,1.101304,1.013812,-0.995824,-0.992994,0.987806,3
2,M3GNet,1.157789,1.010885,-0.996199,-0.989743,0.981197,5
3,MatterSim,1.041123,1.002993,-0.996708,-0.992786,0.989457,3
4,ORBv2,1.305136,1.017229,-0.988236,-0.970143,0.969935,7
5,SevenNet,1.098295,1.008328,-0.997766,-0.988936,0.986951,3


In [None]:
print(df_analyzed[df_analyzed["missing_prediction"] == False].describe())