In [1]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from scipy.spatial.distance import jensenshannon

from lematerial_forgebench.benchmarks.distribution_benchmark import (
    DistributionBenchmark,
)
from lematerial_forgebench.metrics.distribution_metrics import (
    MMD,
    FrechetDistance,
    JSDistance,
)
from lematerial_forgebench.preprocess.distribution_preprocess import (
    DistributionPreprocessor,
)
from lematerial_forgebench.preprocess.universal_stability_preprocess import (
    UniversalStabilityPreprocessor,
)
from lematerial_forgebench.utils.distribution_utils import (
    map_space_group_to_crystal_system,
)

%load_ext autoreload
%autoreload 2

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
def lematbulk_item_to_structure(item: dict) -> Structure:
    """Convert a LeMat-Bulk item to a pymatgen Structure object.

    Parameters
    ----------
    item : dict
        The item to convert.

    Returns
    -------
    Structure
        The pymatgen Structure object.
    """
    sites = item["species_at_sites"]
    coords = item["cartesian_site_positions"]
    cell = item["lattice_vectors"]

    structure = Structure(
        species=sites, coords=coords, lattice=cell, coords_are_cartesian=True
    )

    return structure

In [4]:
dataset_name = "Lematerial/LeMat-Bulk"
name = "compatible_pbe"
split = "train"
dataset = load_dataset(dataset_name, name=name, split=split, streaming=False)

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [8]:
struts = []
for i in range(5, 10):
    strut = lematbulk_item_to_structure(dataset[i])
    struts.append(strut)

In [9]:
stability_preprocessor = UniversalStabilityPreprocessor(model_name="orb")
stability_preprocessor_result = stability_preprocessor(struts)

orb




In [20]:
lematbulk_embeddings = pd.DataFrame(np.asarray(stability_preprocessor_result.processed_structures, dtype = "object"), 
                                    columns = ["OrbProcessedStructures"])

In [32]:
stability_preprocessor_result.processed_structures[0].properties

{'mlip_model': 'ORBCalculator',
 'model_config': 'orb_v3_conservative_inf_omat',
 'energy': -36.79518127441406,
 'forces': array([[ 0.00265144, -0.0002274 , -0.00225908],
        [-0.00130672,  0.00199342, -0.00678664],
        [ 0.00164026, -0.00098787,  0.00512138],
        [-0.00298501, -0.00077813,  0.00392441]], dtype=float32),
 'formation_energy': 0.3305917601953136,
 'e_above_hull': np.float64(3.207070782729817),
 'node_embeddings': array([[ 1.11983396e-01,  4.06865329e-02,  2.49886602e-01, ...,
          3.33024971e-02, -8.62053912e-06, -7.56915689e-01],
        [ 3.80771548e-01,  9.24586952e-02,  3.82968992e-01, ...,
         -2.75345027e-01, -3.93321079e-06, -3.24808538e-01],
        [ 5.49154580e-01, -2.45792150e-01,  4.73581962e-02, ...,
         -2.90092260e-01,  2.34534491e-06,  6.19082525e-02],
        [ 3.11380718e-02,  8.25848207e-02, -2.67587423e-01, ...,
         -5.74522018e-01,  1.87184787e-06, -3.83741893e-02]],
       shape=(4, 256), dtype=float32),
 'graph_embed

In [30]:
lematbulk_embeddings.to_pickle("LeMatBulk_embeddings.pkl")

In [7]:
dataset_temp = dataset.select(range(0,100))

In [36]:
with open("small_lematbulk.pkl", "rb") as f:
    test_lemat = pickle.load(f)

[autoreload of lematerial_forgebench.metrics.distribution_metrics failed: Traceback (most recent call last):
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\autoreload.py", line 280, in check
    elif self.deduper_reloader.maybe_reload_module(m):
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\deduperreload\deduperreload.py", line 533, in maybe_reload_module
    new_source_code = f.read()
                      ^^^^^^^^
  File "C:\Users\samue\AppData\Roaming\uv\python\cpython-3.11.11-windows-x86_64-none\Lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnicodeDecodeError: 'charmap' codec can't decode byte 0x8f in position 290: character maps to <undefined>
]


In [39]:
test_lemat = test_lemat.rename(columns={'GraphEmbeddings': 'OrbGraphEmbeddings'})

In [12]:
test_lemat["Structure"] = struts

In [40]:
test_lemat.to_pickle("small_lematbulk.pkl")

In [5]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    # structure,
    # structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [6]:
stability_preprocessor = UniversalStabilityPreprocessor(model_name="orb")
stability_preprocessor_result = stability_preprocessor(structures)



In [9]:
stability_preprocessor_result.config

UniversalStabilityPreprocessorConfig(name='UniversalStabilityPreprocessor_orb', description='Stability preprocessing using orb', n_jobs=1, model_name='orb', model_config={}, relax_structures=True, relaxation_config={'fmax': 0.02, 'steps': 500}, calculate_formation_energy=True, calculate_energy_above_hull=True, extract_embeddings=True)

In [10]:
stability_preprocessor_result.processed_structures[0].properties

{'mlip_model': 'ORBCalculator',
 'model_config': 'orb_v3_conservative_inf_omat',
 'energy': -10.82815170288086,
 'forces': array([[ 1.2829287e-03,  6.0487073e-06,  7.1092043e-04],
        [-1.2829287e-03, -6.0496968e-06, -7.1092264e-04]], dtype=float32),
 'formation_energy': 0.02191613711914009,
 'e_above_hull': np.float64(0.010958068559570044),
 'node_embeddings': array([[ 8.81311595e-02,  1.88308999e-01,  1.77247897e-01,
         -3.92862037e-02,  5.58909215e-02,  7.27849374e-06,
         -1.02837123e-01,  3.98036718e-01, -1.05413198e-01,
          3.26188594e-01, -1.70153499e-01,  3.42952132e-01,
          5.24205454e-02, -2.35818848e-01,  1.71355739e-01,
         -2.59484351e-01, -1.29346829e-03, -8.82879347e-02,
          3.65222991e-01, -2.12877154e-01, -1.16113447e-01,
         -1.75549239e-02,  2.63406128e-01, -2.71737855e-02,
         -8.90087426e-01,  6.56546056e-02, -2.52401512e-02,
          9.42403525e-02,  5.95285773e-01, -9.04264450e-02,
          2.19623432e-01, -1.8126

In [4]:
distribution_preprocessor = DistributionPreprocessor()
preprocessor_result = distribution_preprocessor(structures)

In [5]:
preprocessor_result.processed_structures[0].properties.get("distribution_properties")

{'Volume': 392.55602810984936,
 'Density(g/cm^3)': 3.600794626101272,
 'Density(atoms/A^3)': 0.020379256531914348,
 'SpaceGroup': 225,
 'CrystalSystem': 7,
 'CompositionCounts': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'Composition': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.

In [8]:
metric = FrechetDistance(reference_df=test_lemat)

[autoreload of lematerial_forgebench.metrics.distribution_metrics failed: Traceback (most recent call last):
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\autoreload.py", line 280, in check
    elif self.deduper_reloader.maybe_reload_module(m):
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\deduperreload\deduperreload.py", line 533, in maybe_reload_module
    new_source_code = f.read()
                      ^^^^^^^^
  File "C:\Users\samue\AppData\Roaming\uv\python\cpython-3.11.11-windows-x86_64-none\Lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnicodeDecodeError: 'charmap' codec can't decode byte 0x8f in position 290: character maps to <undefined>
]


In [7]:
metric.compute(preprocessor_result.processed_structures)

       Volume  Density(g/cm^3)  Density(atoms/A^3)  SpaceGroup  CrystalSystem  \
0  392.556028         3.600795            0.020379         225              7   
1  769.207176         5.006763            0.026001          62              3   
2   40.044795         2.329245            0.049944         227              7   
3  299.607968         3.497400            0.093455          14              2   

                                   CompositionCounts  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
3  [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0,...   

                                         Composition  
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
3  [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...  


NameError: name 'reference_df' is not defined

In [8]:
benchmark = DistributionBenchmark(test_lemat)

In [11]:
benchmark_result = benchmark.evaluate([test_df])

MMD_results


In [215]:
benchmark_result.evaluator_results["JSDistance"]["JSDistance_value"]

{'SpaceGroup': np.float64(0.8325546111576977),
 'CrystalSystem': np.float64(0.2788095948658411),
 'CompositionCounts': np.float64(0.6901659684588751),
 'Composition': np.float64(0.713320308611451)}

In [216]:
benchmark_result.evaluator_results["MMD"]["MMD_value"]

{'Volume': np.float64(0.3578373686257075),
 'Density(g/cm^3)': np.float64(0.16174021104526248),
 'Density(atoms/A^3)': np.float64(0.0003819190830129937)}

In [5]:
from lematerial_forgebench.preprocess.universal_stability_preprocess import (
    UniversalStabilityPreprocessor,
)

metric = FrechetDistance(reference_df=test_lemat) 
stability_preprocessor = UniversalStabilityPreprocessor()
stability_preprocessor_result = stability_preprocessor(structures)
# stability_preprocessor_result.processed_structures[0].get("graph_embedding")

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


orb




In [9]:
stability_preprocessor_result.processed_structures[0].properties.get("graph_embedding")

array([-6.18248940e-01,  3.95354360e-01,  1.27876326e-01, -1.16611309e-01,
        2.31794804e-01,  7.32130047e-06, -2.79786468e-01,  8.81949887e-02,
       -7.43535999e-03, -1.39016643e-01,  2.98540831e-01, -4.52517152e-01,
        3.66314828e-01,  2.86402740e-02,  6.79986626e-02, -3.37073579e-02,
       -8.75898898e-02,  1.79945394e-01,  1.30843539e-02, -1.96450241e-02,
       -1.07172832e-01, -1.17588826e-01, -4.13982309e-02,  1.43906191e-01,
        9.81063366e-01,  2.12826937e-01, -1.22115538e-01, -2.34585345e-01,
        1.43253520e-01,  2.65044570e-02,  1.74727708e-01, -5.13868630e-01,
       -6.30615652e-02,  1.81849316e-01, -1.18187435e-01, -1.45289525e-01,
       -1.59209639e-01, -1.15232386e-01, -2.24942550e-01, -2.44276747e-02,
       -2.42034093e-01,  2.23043859e-01,  1.71593294e-01, -6.91397786e-02,
        7.92256668e-02, -1.93331271e-01, -1.50620952e-01, -2.10464019e-02,
       -1.16789214e-01,  1.05905093e-01, -6.28117546e-02,  1.36864364e-01,
        2.18050227e-01,  