In [1]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from scipy.spatial.distance import jensenshannon

from lemat_genbench.benchmarks.distribution_benchmark import (
    DistributionBenchmark,
)
from lemat_genbench.metrics.distribution_metrics import (
    MMD,
    FrechetDistance,
    JSDistance,
)
from lemat_genbench.preprocess.base import PreprocessorResult
from lemat_genbench.preprocess.distribution_preprocess import (
    DistributionPreprocessor,
)
from lemat_genbench.preprocess.universal_stability_preprocess import (
    UniversalStabilityPreprocessor,
)
from lemat_genbench.utils.distribution_utils import (
    map_space_group_to_crystal_system,
)

%load_ext autoreload
%autoreload 2

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    # structure,
    # structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [29]:
with open("../data/sample_lematbulk.pkl", "rb") as f:
    test_df = pickle.load(f)

In [27]:
with open("../data/full_reference_df.pkl", "rb") as f:
    reference_df = pickle.load(f)

In [28]:
reference_df

Unnamed: 0,LeMatID,Volume,Density(g/cm^3),Density(atoms/A^3),SpaceGroup,CrystalSystem,OrbGraphEmbeddings,OrbNodeEmbeddings,OrbEnergy,OrbForces,OrbFormationEnergy,OrbEAboveHull,MaceGraphEmbeddings,MaceNodeEmbeddings,MaceEnergy,MaceForces,MaceFormationEnergy,MaceEAboveHull
0,mp-1196446,1432.672306,1.71,0.09,13,2,"[0.109472044, 0.2572507, 0.02018192, 0.3772189...","[[-1.0641912, 0.3299405, -0.04178173, 1.067353...",-610.357239,"[[0.0012629332, -0.00081717595, 0.0012267774],...",-25.347359,0.003112,"[-0.1861449, -0.020684373, 0.066283025, 0.0940...","[[-0.11461396, -0.015990674, 0.107097805, 0.09...",-609.734131,"[[-1.6835966e-06, -1.3360595e-06, -1.5439e-06]...",-25.222737,0.007980
1,mp-774651,489.504032,3.11,0.08,33,3,"[0.11850287, 0.14595717, 0.02197475, 0.2681615...","[[-0.32058975, 0.19373147, -0.012800036, 0.899...",-213.581329,"[[0.035016894, -0.0034213867, 0.0062979627], [...",-11.177186,0.036492,"[-0.09791283, -0.04369984, 0.010444844, 0.0802...","[[-0.16484381, -0.028772023, 0.07445615, 0.157...",-214.043533,"[[0.0357335, -0.017460236, -0.012615044], [-0....",-11.269627,0.024937
2,agm005415715,181.286910,5.26,0.03,229,7,"[0.26939476, 0.20468056, -0.12542768, 0.032836...","[[-0.07381898, 0.79753023, -0.38630268, 0.8288...",-17.411106,"[[0.00012565777, -0.0016299748, -0.00230372], ...",0.390865,0.557381,"[0.021588078, 0.07617653, -0.032691456, 0.1895...","[[0.04680038, 0.044881374, -0.06807892, -0.018...",-17.212358,"[[-1.2940145e-06, -1.1700031e-06, 4.344911e-06...",0.490239,0.597131
3,mp-735578,693.095559,3.73,0.08,15,2,"[0.11508359, 0.10241604, 0.010340529, 0.121209...","[[-0.27764162, 0.34364513, -0.08888271, 1.4730...",-341.050354,"[[-0.11718704, -0.05151949, -0.02423022], [0.1...",-14.577774,0.254481,"[-0.07588258, -0.04773107, 0.047378432, 0.0385...","[[-0.1785602, -0.11307899, 0.1966575, 0.052531...",-339.172089,"[[-0.13032529, -0.083571285, 0.02394116], [0.1...",-14.202121,0.288021
4,mp-1221862,1088.022952,1.77,0.11,7,2,"[0.19177356, 0.18612428, 0.0010497392, 0.30191...","[[-1.0027432, 0.04330963, -0.025442515, 0.7367...",-595.334900,"[[-0.021937246, -0.0041813655, 0.027856894], [...",-28.518212,0.020043,"[-0.22079198, -0.037785087, 0.064638026, 0.093...","[[-0.05608654, -0.020514945, 0.08029426, 0.044...",-595.084106,"[[-0.00979438, -0.021082044, 0.029391693], [-0...",-28.468053,0.022205
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,mp-698301,382.507041,2.22,0.09,9,2,"[0.15783544, 0.17722034, 0.009200763, 0.274925...","[[-0.8967323, -0.09575288, -0.049783986, 1.143...",-180.585617,"[[0.0065214206, 0.031426117, -0.02118668], [0....",-4.869578,4.025927,"[-0.1996858, -0.02465518, 0.012579967, 0.05657...","[[0.015688254, -0.018094055, -0.049125608, 0.0...",-180.922058,"[[0.018018203, 0.0074682916, 0.034540795], [0....",-4.936867,4.016032
496,mp-1217479,455.938638,3.26,0.08,21,3,"[0.3432244, 0.00506988, -0.018574214, 0.286489...","[[0.24601227, 0.3068273, -0.7677959, 1.4149146...",-273.717407,"[[-0.0013649535, -0.010621709, 0.0059434045], ...",-10.962075,3.390340,"[-0.22776303, -0.005970049, 0.025932766, 0.068...","[[-0.6181542, -0.17707199, 0.18061917, -0.1246...",-273.717896,"[[-4.990885e-06, 4.904792e-06, -2.4374021e-06]...",-10.962173,3.390327
497,agm003915635,68.826719,8.25,0.06,216,7,"[0.22693583, -0.36586532, 0.03352779, -0.32932...","[[-0.17580587, -0.37149918, 0.34359527, 0.7186...",-13.005263,"[[-0.0011570402, 0.0017539749, -0.0019723596],...",1.015155,1.194809,"[0.04072046, -0.07314953, 0.02702967, -0.13443...","[[-0.24228477, 0.10586107, 0.07650542, -0.0950...",-12.991813,"[[1.4227494e-06, -8.860161e-07, 6.2079635e-06]...",1.019639,1.198171
498,mp-622196,945.741987,2.91,0.05,14,2,"[-0.07858891, 0.3029094, 0.2139958, 0.18392271...","[[-0.32075185, 0.40140757, 0.23005971, 1.13496...",-264.373596,"[[-0.014269993, -0.024956819, 0.013857508], [0...",0.345249,0.556899,"[-0.03414698, -0.049499776, -0.024802292, 0.10...","[[-0.07554076, -0.024172362, 0.09211306, 0.193...",-265.068634,"[[-0.008831942, -0.05815789, 0.04731905], [0.0...",0.206241,0.541103


In [3]:
stability_preprocessor = UniversalStabilityPreprocessor(model_name="orb", timeout = 100000, relax_structures = False)
stability_preprocessor_result = stability_preprocessor(structures)



100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.32s/it]


In [4]:
# stability_preprocessor_result.processed_structures[1].properties

In [5]:
distribution_preprocessor = DistributionPreprocessor()
dist_preprocessor_result = distribution_preprocessor(structures)

100%|███████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 304.51it/s]


In [13]:
metric = JSDistance(reference_df=reference_df) 
default_args = metric._get_compute_attributes()
metric_result = metric(dist_preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

{'SpaceGroup': np.float64(0.7564856529396667), 'CrystalSystem': np.float64(0.522976610018087), 'CompositionCounts': np.float64(0.7074950131631114), 'Composition': np.float64(0.7371862430633939), 'Average_Jensen_Shannon_Distance': np.float64(0.6810358797960647)}


In [14]:
metric = MMD(reference_df=reference_df) 
default_args = metric._get_compute_attributes()
metric_result = metric(dist_preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

{'Volume': np.float64(0.49782917763830126), 'Density(g/cm^3)': np.float64(0.29838180697602024), 'Density(atoms/A^3)': np.float64(4.8503570894364856e-05), 'Average_MMD': np.float64(0.26541982939507197)}


In [15]:
metric = FrechetDistance(reference_df=reference_df) 

sample_embeddings = list(stability_preprocessor_result.processed_structures)

default_args = metric._get_compute_attributes()
metric_result = metric(sample_embeddings, **default_args)

print(metric_result.metrics)

{'FrechetDistance': 38.64963877642341}


In [16]:
final_processed_structures = []

for ind in range(0, len(dist_preprocessor_result.processed_structures)): 
    combined_structure = dist_preprocessor_result.processed_structures[ind]
    for entry in stability_preprocessor_result.processed_structures[ind].properties.keys():
        combined_structure.properties[entry] = stability_preprocessor_result.processed_structures[ind].properties[entry]
    final_processed_structures.append(combined_structure)

preprocessor_result = PreprocessorResult(processed_structures=final_processed_structures,
    config={
        "stability_preprocessor_config":stability_preprocessor_result.config,
        "distribution_preprocessor_config": dist_preprocessor_result.config,
    },
    computation_time={
        "stability_preprocessor_computation_time": stability_preprocessor_result.computation_time,
        "distribution_preprocessor_computation_time": dist_preprocessor_result.computation_time,
    },
    n_input_structures=stability_preprocessor_result.n_input_structures,
    failed_indices={
        "stability_preprocessor_failed_indices": stability_preprocessor_result.failed_indices,
        "distribution_preprocessor_failed_indices": dist_preprocessor_result.failed_indices,
    },
    warnings={
        "stability_preprocessor_warnings": stability_preprocessor_result.warnings,
        "distribution_preprocessor_warnings": dist_preprocessor_result.warnings,
    },
)

In [18]:
benchmark = DistributionBenchmark(reference_df=reference_df)
benchmark_result = benchmark.evaluate(preprocessor_result.processed_structures)

In [19]:
print("JSDistance")
print(benchmark_result.evaluator_results["JSDistance"]["metric_results"]["JSDistance"].metrics)
print("Average JSDistance: " + str(benchmark_result.evaluator_results["JSDistance"]["JSDistance_value"]))
print("MMD")
print(benchmark_result.evaluator_results["MMD"]["metric_results"]["MMD"].metrics)
print("Average MMD: " + str(benchmark_result.evaluator_results["MMD"]["MMD_value"]))
print("FrechetDistance")
print(benchmark_result.evaluator_results["FrechetDistance"]["metric_results"]["FrechetDistance"].metrics)
print("Average Frechet Distance: " + str(benchmark_result.evaluator_results["FrechetDistance"]["FrechetDistance_value"]))


JSDistance
{'SpaceGroup': np.float64(0.7564856529396667), 'CrystalSystem': np.float64(0.522976610018087), 'CompositionCounts': np.float64(0.7074950131631114), 'Composition': np.float64(0.7371862430633939), 'Average_Jensen_Shannon_Distance': np.float64(0.6810358797960647)}
Average JSDistance: 0.6810358797960647
MMD
{'Volume': np.float64(0.49782917763830126), 'Density(g/cm^3)': np.float64(0.29838180697602024), 'Density(atoms/A^3)': np.float64(4.8503570894364856e-05), 'Average_MMD': np.float64(0.26541982939507197)}
Average MMD: 0.26541982939507197
FrechetDistance
{'FrechetDistance': 38.64963877642341}
Average Frechet Distance: 38.64963877642341
