In [1]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from scipy.spatial.distance import jensenshannon

from lematerial_forgebench.benchmarks.distribution_benchmark import (
    DistributionBenchmark,
)
from lematerial_forgebench.preprocess.base import PreprocessorResult
from lematerial_forgebench.metrics.distribution_metrics import (
    MMD,
    FrechetDistance,
    JSDistance,
)
from lematerial_forgebench.preprocess.distribution_preprocess import (
    DistributionPreprocessor,
)
from lematerial_forgebench.utils.distribution_utils import (
    map_space_group_to_crystal_system,
)

from lematerial_forgebench.preprocess.universal_stability_preprocess import (
    UniversalStabilityPreprocessor
)

%load_ext autoreload
%autoreload 2

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [2]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    # structure,
    # structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [8]:
with open("../data/full_reference_df.pkl", "rb") as f:
    reference_df = pickle.load(f)

In [3]:
stability_preprocessor = UniversalStabilityPreprocessor(model_name="orb", timeout = 100000, relax_structures = False)
stability_preprocessor_result = stability_preprocessor(structures)



100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.32s/it]


In [4]:
# stability_preprocessor_result.processed_structures[1].properties

In [5]:
distribution_preprocessor = DistributionPreprocessor()
dist_preprocessor_result = distribution_preprocessor(structures)

100%|███████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 304.51it/s]


In [13]:
metric = JSDistance(reference_df=reference_df) 
default_args = metric._get_compute_attributes()
metric_result = metric(dist_preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

{'SpaceGroup': np.float64(0.7564856529396667), 'CrystalSystem': np.float64(0.522976610018087), 'CompositionCounts': np.float64(0.7074950131631114), 'Composition': np.float64(0.7371862430633939), 'Average_Jensen_Shannon_Distance': np.float64(0.6810358797960647)}


In [14]:
metric = MMD(reference_df=reference_df) 
default_args = metric._get_compute_attributes()
metric_result = metric(dist_preprocessor_result.processed_structures, **default_args)
print(metric_result.metrics)

{'Volume': np.float64(0.49782917763830126), 'Density(g/cm^3)': np.float64(0.29838180697602024), 'Density(atoms/A^3)': np.float64(4.8503570894364856e-05), 'Average_MMD': np.float64(0.26541982939507197)}


In [15]:
metric = FrechetDistance(reference_df=reference_df) 

sample_embeddings = list(stability_preprocessor_result.processed_structures)

default_args = metric._get_compute_attributes()
metric_result = metric(sample_embeddings, **default_args)

print(metric_result.metrics)

{'FrechetDistance': 38.64963877642341}


In [16]:
final_processed_structures = []

for ind in range(0, len(dist_preprocessor_result.processed_structures)): 
    combined_structure = dist_preprocessor_result.processed_structures[ind]
    for entry in stability_preprocessor_result.processed_structures[ind].properties.keys():
        combined_structure.properties[entry] = stability_preprocessor_result.processed_structures[ind].properties[entry]
    final_processed_structures.append(combined_structure)

preprocessor_result = PreprocessorResult(processed_structures=final_processed_structures,
    config={
        "stability_preprocessor_config":stability_preprocessor_result.config,
        "distribution_preprocessor_config": dist_preprocessor_result.config,
    },
    computation_time={
        "stability_preprocessor_computation_time": stability_preprocessor_result.computation_time,
        "distribution_preprocessor_computation_time": dist_preprocessor_result.computation_time,
    },
    n_input_structures=stability_preprocessor_result.n_input_structures,
    failed_indices={
        "stability_preprocessor_failed_indices": stability_preprocessor_result.failed_indices,
        "distribution_preprocessor_failed_indices": dist_preprocessor_result.failed_indices,
    },
    warnings={
        "stability_preprocessor_warnings": stability_preprocessor_result.warnings,
        "distribution_preprocessor_warnings": dist_preprocessor_result.warnings,
    },
)

In [18]:
benchmark = DistributionBenchmark(reference_df=reference_df)
benchmark_result = benchmark.evaluate(preprocessor_result.processed_structures)

In [19]:
print("JSDistance")
print(benchmark_result.evaluator_results["JSDistance"]["metric_results"]["JSDistance"].metrics)
print("Average JSDistance: " + str(benchmark_result.evaluator_results["JSDistance"]["JSDistance_value"]))
print("MMD")
print(benchmark_result.evaluator_results["MMD"]["metric_results"]["MMD"].metrics)
print("Average MMD: " + str(benchmark_result.evaluator_results["MMD"]["MMD_value"]))
print("FrechetDistance")
print(benchmark_result.evaluator_results["FrechetDistance"]["metric_results"]["FrechetDistance"].metrics)
print("Average Frechet Distance: " + str(benchmark_result.evaluator_results["FrechetDistance"]["FrechetDistance_value"]))


JSDistance
{'SpaceGroup': np.float64(0.7564856529396667), 'CrystalSystem': np.float64(0.522976610018087), 'CompositionCounts': np.float64(0.7074950131631114), 'Composition': np.float64(0.7371862430633939), 'Average_Jensen_Shannon_Distance': np.float64(0.6810358797960647)}
Average JSDistance: 0.6810358797960647
MMD
{'Volume': np.float64(0.49782917763830126), 'Density(g/cm^3)': np.float64(0.29838180697602024), 'Density(atoms/A^3)': np.float64(4.8503570894364856e-05), 'Average_MMD': np.float64(0.26541982939507197)}
Average MMD: 0.26541982939507197
FrechetDistance
{'FrechetDistance': 38.64963877642341}
Average Frechet Distance: 38.64963877642341
