In [1]:
from pathlib import Path
import json
import numpy as np
from scipy import stats
from pymatgen.core.structure import Structure

import pandas as pd
from ase.db import connect

from mlip_arena.models import REGISTRY, MLIPEnum


def load_wbm_structures():
    """
    Load the WBM structures from a ASE DB file.
    """
    with connect("../wbm_structures.db") as db:
        for row in db.select():
            yield row.toatoms(add_additional_information=True)

  from .autonotebook import tqdm as notebook_tqdm


# Calculate relevant metrics for EoS

In [2]:
summary_table = pd.DataFrame(columns=[
        "model", 
        "energy-diff-flip-times",
        "tortuosity",
        "spearman-repulsion-energy",
        "spearman-repulsion-derivative",
        "spearman-attraction-energy",
        "missing_predictions",
    ])

for model in MLIPEnum:
    if "eos_bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
        print(f"Results for {model.name} have not been computed for the EoS-bulk task.")
        continue


    df_raw_results = pd.read_parquet(f"{model.name}.parquet")


    df_analyzed = pd.DataFrame(columns=[
        "model", 
        "structure", 
        "volume_per_atom", 
        "E",
        "energy-diff-flip-times",
        "tortuosity",
        "spearman-repulsion-energy",
        "spearman-repulsion-derivative",
        "spearman-attraction-energy",
    ])

    for wbm_struct in load_wbm_structures():
        structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"]
        struct = Structure.from_ase_atoms(wbm_struct)


        try:
            results = df_raw_results.loc[df_raw_results["id"] == structure_id]
            results = results["eos"].values[0]
            es = np.array(results["energies"])
            vols = np.array(results["volumes"])
            
            indices = np.argsort(vols)[::-1]
            vols = vols[indices]
            es = es[indices]
            eshift = es[0]
            es -= eshift

            imine = len(es) // 2

            de_dr = np.gradient(es, vols)
            d2e_dr2 = np.gradient(de_dr, vols)

            

            interpolated_volumes = [(vols[i] + vols[i+1])/2 for i in range(0, len(vols)-1)]
            ediff = np.diff(es)
            ediff_orig = ediff
            ediff[np.abs(ediff) < 1e-3] = 0 # 1meV
            ediff_sign = np.sign(ediff)
            mask = ediff_sign != 0
            ediff = ediff[mask]
            ediff_sign = ediff_sign[mask]
            ediff_flip = np.diff(ediff_sign) != 0
            ejump = np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()


        
            etv = np.sum(np.abs(np.diff(es)))

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": False, 
                "volume_per_atom": vols,
                "E": es + eshift,
                "energy-diff-flip-times": np.sum(ediff_flip),
                "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
                "spearman-repulsion-energy": stats.spearmanr(vols[imine:], es[imine:]).statistic,
                "spearman-repulsion-derivative": stats.spearmanr(interpolated_volumes[imine:], ediff_orig[imine:]).statistic,
                "spearman-attraction-energy": stats.spearmanr(vols[:imine], es[:imine]).statistic,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
        except Exception as e:

            data = {
                "model": model.name,
                "structure": structure_id,
                "composition": struct.symbol_set,
                "missing_prediction": True, 
                "volume_per_atom": None,
                "E": None,
                "energy-diff-flip-times": None,
                "tortuosity": None,
                "spearman-repulsion-energy": None,
                "spearman-repulsion-derivative": None,
                "spearman-attraction-energy": None,
            }

            df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
    
    
    json_fpath = Path(f"eos_analyzed_{model.name}.json")

    df_analyzed.to_json(json_fpath, orient="records")
    print(df_analyzed["volume_per_atom"])

    valid_results = df_analyzed[df_analyzed["missing_prediction"] == False]
    valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)

    analysis_summary = {
        "model": model.name, 
        "energy-diff-flip-times": valid_results["energy-diff-flip-times"].mean(),
        "tortuosity": valid_results["tortuosity"].mean(),
        "spearman-repulsion-energy": valid_results["spearman-repulsion-energy"].mean(),
        "spearman-repulsion-derivative": valid_results["spearman-repulsion-derivative"].mean(),
        "spearman-attraction-energy": valid_results["spearman-attraction-energy"].mean(),
        "missing_predictions": len(df_analyzed[df_analyzed["missing_prediction"] == True]),
    }
    summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)

  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


0      [95.65299401941463, 94.05877745242435, 92.4645...
1      [216.94623216343595, 213.3304616273788, 209.71...
2      [91.75929736765559, 90.2299757448612, 88.70065...
3      [278.7937003329284, 274.14713866071304, 269.50...
4      [240.47476524482596, 236.46685249074523, 232.4...
                             ...                        
995    [274.17766203428715, 269.60803433371586, 265.0...
996    [137.26488359429916, 134.977135534394, 132.689...
997    [239.23611854017216, 235.24884989783567, 231.2...
998    [152.6027033895552, 150.0593249997292, 147.515...
999    [329.2542033954668, 323.76663333887586, 318.27...
Name: volume_per_atom, Length: 1000, dtype: object


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  summary_table = pd.concat([summary_table, pd.DataFrame([analysis_summary])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gui

0      [95.53250657419684, 93.94029813129345, 92.3480...
1      [215.25117648041763, 211.6636568724104, 208.07...
2      [91.6378703297095, 90.11057249088097, 88.58327...
3      [274.83055080492653, 270.2500416248441, 265.66...
4      [235.91540380975908, 231.9834804129299, 228.05...
                             ...                        
995    [274.95282114486514, 270.37027412578396, 265.7...
996    [138.982629026709, 136.66625187626366, 134.349...
997    [239.0829301886657, 235.098214685521, 231.1134...
998    [152.60492275008374, 150.06150737091562, 147.5...
999    [331.37332495599713, 325.85043620673014, 320.3...
Name: volume_per_atom, Length: 1000, dtype: object


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


0      [94.67679451200793, 93.09884793680769, 91.5209...
1      [217.58975705930106, 213.96326110831285, 210.3...
2      [92.92852823460017, 91.37971943069013, 89.8309...
3      [270.6886716589618, 266.17719379797927, 261.66...
4      [246.03697423812048, 241.9363580008184, 237.83...
                             ...                        
995    [275.021043304762, 270.4373592496823, 265.8536...
996    [138.39326659784484, 136.0867121545474, 133.78...
997    [243.4307289010188, 239.37355008600193, 235.31...
998    [154.11150247656974, 151.54297743529352, 148.9...
999    [331.3217767223884, 325.79974711034816, 320.27...
Name: volume_per_atom, Length: 1000, dtype: object


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


0      [95.44414624001409, 93.85341046934724, 92.2626...
1      [216.20605635158543, 212.60262207905893, 208.9...
2      [92.00683566539125, 90.47338840430145, 88.9399...
3      [310.64101079673134, 305.463660616786, 300.286...
4      [236.81933618992608, 232.87234725342742, 228.9...
                             ...                        
995    [274.86411914872554, 270.2830504962467, 265.70...
996    [137.26836700073062, 134.98056088405158, 132.6...
997    [240.63125532720693, 236.62073440508652, 232.6...
998    [152.6027033895552, 150.0593249997292, 147.515...
999    [331.3414421685603, 325.81908479908395, 320.29...
Name: volume_per_atom, Length: 1000, dtype: object


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], igno

0      [94.53237945003127, 92.9568397925308, 91.38130...
1      [215.9606116253537, 212.3612680982646, 208.761...
2      [91.6378703297095, 90.11057249088097, 88.58327...
3      [274.83055080492653, 270.2500416248441, 265.66...
4      [231.9471163763046, 228.0813311033659, 224.215...
                             ...                        
995    [274.95282114486514, 270.37027412578396, 265.7...
996    [137.30583450454796, 135.017403929472, 132.728...
997    [237.67334073817153, 233.71211839253527, 229.7...
998    [152.6027033895552, 150.0593249997292, 147.515...
999    [329.83400316429623, 324.3367697782245, 318.83...
Name: volume_per_atom, Length: 1000, dtype: object


  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


0      [95.96098071697223, 94.36163103835591, 92.7622...
1      [216.89127569601476, 213.2764211010811, 209.66...
2      [91.3682807697216, 89.8454760902262, 88.322671...
3      [295.9610927853412, 291.02840790558514, 286.09...
4      [234.39940968624893, 230.4927528581447, 226.58...
                             ...                        
995    [274.95282114486514, 270.37027412578396, 265.7...
996    [141.41448269222568, 139.0575746473551, 136.70...
997    [239.60158470424543, 235.60822495917458, 231.6...
998    [152.6027033895552, 150.0593249997292, 147.515...
999    [330.57180643730925, 325.06227633002067, 319.5...
Name: volume_per_atom, Length: 1000, dtype: object
Results for eqV2(OMat) have not been computed for the EoS-bulk task.


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
  df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)


0      [95.77188208842381, 94.1756840536167, 92.57948...
1      [213.7492188381466, 210.18673185751075, 206.62...
2      [92.57953779315888, 91.03654549660628, 89.4935...
3      [279.68075212533154, 275.0194062565757, 270.35...
4      [233.71164327835902, 229.81644922371962, 225.9...
                             ...                        
995    [274.89874649181627, 270.3171007169523, 265.73...
996    [133.58193219330863, 131.35556665675333, 129.1...
997    [232.73823925916358, 228.8592686048439, 224.98...
998    [153.29641541028482, 150.7414751534467, 148.18...
999    [330.6480207536061, 325.1372204077123, 319.626...
Name: volume_per_atom, Length: 1000, dtype: object
Results for EquiformerV2(OC22) have not been computed for the EoS-bulk task.
Results for EquiformerV2(OC20) have not been computed for the EoS-bulk task.
Results for eSCN(OC20) have not been computed for the EoS-bulk task.
Results for MACE-OFF(M) have not been computed for the EoS-bulk task.
Results for ANI2x have not be

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid_results["energy-diff-flip-times"] = valid_results["energy-diff-flip-times"].astype(int)


In [3]:
flip_rank = summary_table["energy-diff-flip-times"].rank(ascending = True, method = "min")
tortuosity_rank = summary_table["tortuosity"].rank(method = "min")
spearman_repulsion_energy_rank = summary_table["spearman-repulsion-energy"].rank(method = "min")
spearman_repulsion_derivative_rank = summary_table["spearman-repulsion-derivative"].rank(method = "min")
spearman_attraction_energy_rank = summary_table["spearman-attraction-energy"].rank(ascending = False, method = "min")

rank_aggregate = flip_rank + tortuosity_rank + spearman_repulsion_energy_rank + spearman_repulsion_derivative_rank + spearman_attraction_energy_rank
rank = rank_aggregate.rank(method = "min")

summary_table.insert(1, "Rank", rank)
summary_table.insert(2, "Rank aggregate", rank_aggregate)
summary_table = summary_table.sort_values(by = "Rank", ascending = True)
summary_table = summary_table.reset_index(drop=True)
summary_table.to_csv("summarized_results.csv")
summary_table

Unnamed: 0,model,Rank,Rank aggregate,energy-diff-flip-times,tortuosity,spearman-repulsion-energy,spearman-repulsion-derivative,spearman-attraction-energy,missing_predictions
0,MACE-MPA,1.0,6.0,1.03507,1.003633,-0.999371,-0.996332,0.994535,2
1,MACE-MP(M),2.0,12.0,1.042211,1.007933,-0.999095,-0.994125,0.993082,5
2,MatterSim,3.0,14.0,1.041123,1.002993,-0.996708,-0.992786,0.989457,3
3,SevenNet,4.0,22.0,1.098295,1.008328,-0.997766,-0.988936,0.986951,3
4,CHGNet,5.0,24.0,1.101304,1.013812,-0.995824,-0.992994,0.987806,3
5,M3GNet,6.0,27.0,1.157789,1.010885,-0.996199,-0.989743,0.981197,5
6,ORBv2,7.0,35.0,1.305136,1.017229,-0.988236,-0.970143,0.969935,7
