In [1]:
from pathlib import Path
import json
import numpy as np
from scipy import stats
from pymatgen.core.structure import Structure

import pandas as pd
from ase.db import connect

from mlip_arena.models import REGISTRY, MLIPEnum


def load_wbm_structures():
    """
    Load the WBM structures from a ASE DB file.
    """
    with connect("../wbm_structures.db") as db:
        for row in db.select():
            yield row.toatoms(add_additional_information=True)

  from .autonotebook import tqdm as notebook_tqdm


# Calculate relevant metrics for EV scan

In [None]:
summary_table = pd.DataFrame(columns=[
    "model", 
    "energy-diff-flip-times",
    "tortuosity",
    "spearman-repulsion-energy",
    "spearman-repulsion-derivative",
    "spearman-attraction-energy",
    "missing_predictions",
])


for model in MLIPEnum:

    if "wbm_ev" not in REGISTRY[model.name].get("gpu-tasks", []):
        print(f"Results for {model.name} have not been computed for the EoS-bulk task.")
        continue 

    filename = f"{model.name}.parquet"
    df_raw_results = pd.read_parquet(filename)

    df_analyzed = pd.DataFrame(columns=[
        "model", 
        "structure", 
        "volume_per_atom", 
        "E",
        "energy-diff-flip-times",
        "tortuosity",
        "spearman-repulsion-energy",
        "spearman-repulsion-derivative",
        "spearman-attraction-energy",
    ])

    for wbm_struct in load_wbm_structures():
        structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"]
        struct = Structure.from_ase_atoms(wbm_struct)


        try:
            results = df_raw_results.loc[df_raw_results["id"] == structure_id]
            results = results["eos"].values[0]
            es = np.array(results["energies"])
            vols = np.array(results["volumes"])
            
            indices = np.argsort(vols)[::-1]
            vols = vols[indices]
            es = es[indices]

            eshift = es[0]
            es -= eshift

            imine = len(es) // 2

            de_dr = np.gradient(es, vols)
            d2e_dr2 = np.gradient(de_dr, vols)

            

            interpolated_volumes = [(vols[i] + vols[i+1])/2 for i in range(0, len(vols)-1)]
            ediff = np.diff(es)
            ediff_orig = ediff
            ediff[np.abs(ediff) < 1e-3] = 0 # 1meV
            ediff_sign = np.sign(ediff)
            mask = ediff_sign != 0
            ediff = ediff[mask]
            ediff_sign = ediff_sign[mask]
            ediff_flip = np.diff(ediff_sign) != 0

        
            etv = np.sum(np.abs(np.diff(es)))

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": False, 
                "volume_per_atom": vols,
                "E": es + eshift,
                "energy-diff-flip-times": np.sum(ediff_flip),
                "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
                "spearman-repulsion-energy": stats.spearmanr(vols[imine:], es[imine:]).statistic,
                "spearman-repulsion-derivative": stats.spearmanr(interpolated_volumes[imine:], ediff_orig[imine:]).statistic,
                "spearman-attraction-energy": stats.spearmanr(vols[:imine], es[:imine]).statistic,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
        except Exception as e:

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": True, 
                "volume_per_atom": None,
                "E": None,
                "energy-diff-flip-times": None,
                "tortuosity": None,
                "spearman-repulsion-energy": None,
                "spearman-repulsion-derivative": None,
                "spearman-attraction-energy": None,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)

    
    json_fpath = Path(f"EV_scan_analyzed_{model.name}.json")

    df_analyzed.to_json(json_fpath, orient="records")

    valid_results = df_analyzed[df_analyzed["missing_prediction"] == False]
    valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)

    analysis_summary = {
        "model": model.name, 
        "energy-diff-flip-times": valid_results["energy-diff-flip-times"].mean(),
        "tortuosity": valid_results["tortuosity"].mean(),
        "spearman-repulsion-energy": valid_results["spearman-repulsion-energy"].mean(),
        "spearman-repulsion-derivative": valid_results["spearman-repulsion-derivative"].mean(),
        "spearman-attraction-energy": valid_results["spearman-attraction-energy"].mean(),
        "missing_predictions": len(df_analyzed[df_analyzed["missing_prediction"] == True]),
    }
    summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)

  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


Results for EquiformerV2(OC22) have not been computed for the EoS-bulk task.
Results for EquiformerV2(OC20) have not been computed for the EoS-bulk task.
Results for eSCN(OC20) have not been computed for the EoS-bulk task.
Results for MACE-OFF(M) have not been computed for the EoS-bulk task.
Results for ANI2x have not been computed for the EoS-bulk task.


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  "spearman-repulsion-derivative": stats.spearmanr(interpolated_volumes[imine:], ediff_orig[imine:]).statistic,


Results for ORB have not been computed for the EoS-bulk task.


# Aggregate ranks along each category for overall rank, save summary table to file

In [3]:
flip_rank = summary_table["energy-diff-flip-times"].rank(ascending = True, method = "min")
tortuosity_rank = summary_table["tortuosity"].rank(method = "min")
spearman_repulsion_energy_rank = summary_table["spearman-repulsion-energy"].rank(method = "min")
spearman_repulsion_derivative_rank = summary_table["spearman-repulsion-derivative"].rank(method = "min")
spearman_attraction_energy_rank = summary_table["spearman-attraction-energy"].rank(ascending = False, method = "min")

rank_aggregate = flip_rank + tortuosity_rank + spearman_repulsion_energy_rank + spearman_repulsion_derivative_rank + spearman_attraction_energy_rank
rank = rank_aggregate.rank(method = "min")

summary_table.insert(1, "Rank", rank)
summary_table.insert(2, "Rank aggregate", rank_aggregate)
summary_table = summary_table.sort_values(by = "Rank", ascending = True)
summary_table = summary_table.reset_index(drop=True)
summary_table.to_csv("summarized_results.csv")
summary_table

Unnamed: 0,model,Rank,Rank aggregate,energy-diff-flip-times,tortuosity,spearman-repulsion-energy,spearman-repulsion-derivative,spearman-attraction-energy,missing_predictions
0,MACE-MPA,1.0,9.0,1.0,1.0,-0.998,-0.999309,0.999758,0
1,CHGNet,2.0,14.0,1.0,1.0,-0.997691,-0.943964,0.999939,0
2,MatterSim,3.0,19.0,1.008,1.000242,-0.997864,-0.999709,0.994468,0
3,eqV2(OMat),3.0,19.0,1.035,1.000284,-0.998109,-0.997224,0.999345,0
4,M3GNet,5.0,22.0,1.002,1.000559,-0.996473,-0.997442,0.998701,0
5,ORBv2,6.0,29.0,1.058,1.002964,-0.997527,-0.970752,0.998667,0
6,SevenNet,7.0,32.0,1.034,1.009069,-0.994964,-0.946558,0.995759,0
7,MACE-MP(M),8.0,34.0,1.119,1.076188,-0.947718,-0.901188,0.999867,0
8,ALIGNN,9.0,45.0,3.832,1.172117,-0.851882,-0.761837,0.882031,0
