In [1]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from scipy.spatial.distance import jensenshannon

from lematerial_forgebench.benchmarks.distribution_benchmark import (
    DistributionBenchmark,
)
from lematerial_forgebench.metrics.distribution_metrics import (
    MMD,
    FrechetDistance,
    JSDistance,
)
from lematerial_forgebench.preprocess.base import PreprocessorResult
from lematerial_forgebench.preprocess.distribution_preprocess import (
    DistributionPreprocessor,
)
from lematerial_forgebench.preprocess.universal_stability_preprocess import (
    UniversalStabilityPreprocessor,
)
from lematerial_forgebench.utils.distribution_utils import (
    map_space_group_to_crystal_system,
)

%load_ext autoreload
%autoreload 2

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
with open("sample_lematbulk.pkl", "rb") as f:
    test_lemat = pickle.load(f)

In [3]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    # structure,
    # structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [4]:
stability_preprocessor = UniversalStabilityPreprocessor(model_name="orb", timeout = 100000, relax_structures = False)
stability_preprocessor_result = stability_preprocessor(structures)



In [5]:
stability_preprocessor_result.processed_structures[1].properties

{'mlip_model': 'ORBCalculator',
 'model_config': 'orb_v3_conservative_inf_omat',
 'energy': -190.75131225585938,
 'forces': array([[-2.1145190e-03, -1.6023011e-03,  3.0304189e-05],
        [-5.6832342e-04, -5.0707202e-04,  7.0619571e-04],
        [-4.5502523e-04, -2.0675794e-03,  4.9744768e-04],
        [ 7.1775360e-04, -7.8805070e-04, -4.2188680e-05],
        [ 7.2393157e-02,  9.3152896e-02,  7.6207250e-02],
        [-7.4446805e-02, -9.4374299e-02,  7.6251946e-02],
        [ 7.0277348e-02,  9.4230048e-02, -7.2140619e-02],
        [-7.2767898e-02, -9.1961712e-02, -7.3393255e-02],
        [ 6.2961161e-02,  9.8316088e-02, -5.5881400e-02],
        [-6.5632656e-02, -9.9183455e-02, -5.1081479e-02],
        [ 6.5183312e-02,  9.7099535e-02,  5.0565090e-02],
        [-6.7757100e-02, -9.7848818e-02,  4.3986291e-02],
        [ 6.3024655e-02,  2.1918677e-02,  4.4074915e-03],
        [-1.4486998e-02,  3.1478763e-02,  6.7651063e-02],
        [-6.2403355e-02,  7.9694912e-03, -7.2695017e-02],
       

In [32]:
distribution_preprocessor = DistributionPreprocessor()
dist_preprocessor_result = distribution_preprocessor(structures)

In [10]:
metric = JSDistance(reference_df=test_lemat) 
default_args = metric._get_compute_attributes()
metric_result = metric(preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

FileNotFoundError: [Errno 2] No such file or directory: 'data/lematbulk_composition_counts_distribution.json'

In [11]:
metric = MMD(reference_df=test_lemat) 
default_args = metric._get_compute_attributes()
metric_result = metric(preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

{'Volume': np.float64(0.5156492331277469), 'Density(g/cm^3)': np.float64(0.378594576755264), 'Density(atoms/A^3)': np.float64(5.477059509595428e-05), 'Average_MMD': np.float64(0.29809952682603563)}


In [14]:
metric = FrechetDistance(reference_df=test_lemat) 

sample_embeddings = list(stability_preprocessor_result.processed_structures)

default_args = metric._get_compute_attributes()
metric_result = metric(sample_embeddings, **default_args)
print(metric_result.metrics)

Index(['LeMatID', 'Volume', 'Density(g/cm^3)', 'Density(atoms/A^3)',
       'SpaceGroup', 'CrystalSystem', 'Structure', 'OrbGraphEmbeddings'],
      dtype='object')
{'FrechetDistance': 40.38935887003667}


In [34]:
final_processed_structures = []

for ind in range(0, len(dist_preprocessor_result.processed_structures)): 
    combined_structure = dist_preprocessor_result.processed_structures[ind]
    for entry in stability_preprocessor_result.processed_structures[ind].properties.keys():
        combined_structure.properties[entry] = stability_preprocessor_result.processed_structures[ind].properties[entry]
    final_processed_structures.append(combined_structure)

preprocessor_result = PreprocessorResult(processed_structures=final_processed_structures,
    config={
        "stability_preprocessor_config":stability_preprocessor_result.config,
        "distribution_preprocessor_config": dist_preprocessor_result.config,
    },
    computation_time={
        "stability_preprocessor_computation_time": stability_preprocessor_result.computation_time,
        "distribution_preprocessor_computation_time": dist_preprocessor_result.computation_time,
    },
    n_input_structures=stability_preprocessor_result.n_input_structures,
    failed_indices={
        "stability_preprocessor_failed_indices": stability_preprocessor_result.failed_indices,
        "distribution_preprocessor_failed_indices": dist_preprocessor_result.failed_indices,
    },
    warnings={
        "stability_preprocessor_warnings": stability_preprocessor_result.warnings,
        "distribution_preprocessor_warnings": dist_preprocessor_result.warnings,
    },
)

In [40]:
benchmark = DistributionBenchmark(reference_df=test_lemat)
benchmark_result = benchmark.evaluate(preprocessor_result.processed_structures)

Index(['LeMatID', 'Volume', 'Density(g/cm^3)', 'Density(atoms/A^3)',
       'SpaceGroup', 'CrystalSystem', 'Structure', 'OrbGraphEmbeddings'],
      dtype='object')


In [41]:
print("JSDistance")
print(benchmark_result.evaluator_results["JSDistance"]["metric_results"]["JSDistance"].metrics)
print("Average JSDistance: " + str(benchmark_result.evaluator_results["JSDistance"]["JSDistance_value"]))
print("MMD")
print(benchmark_result.evaluator_results["MMD"]["metric_results"]["MMD"].metrics)
print("Average MMD: " + str(benchmark_result.evaluator_results["MMD"]["MMD_value"]))
print("FrechetDistance")
print(benchmark_result.evaluator_results["FrechetDistance"]["metric_results"]["FrechetDistance"].metrics)
print("Average Frechet Distance: " + str(benchmark_result.evaluator_results["FrechetDistance"]["FrechetDistance_value"]))


JSDistance
{'JSDistance': nan}
Average JSDistance: nan
MMD
{'Volume': np.float64(0.5156492331277469), 'Density(g/cm^3)': np.float64(0.378594576755264), 'Density(atoms/A^3)': np.float64(5.477059509595428e-05), 'Average_MMD': np.float64(0.29809952682603563)}
Average MMD: 0.29809952682603563
FrechetDistance
{'FrechetDistance': 40.38935887003667}
Average Frechet Distance: 40.38935887003667
