In [1]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from scipy.spatial.distance import jensenshannon

from lematerial_forgebench.benchmarks.distribution_benchmark import (
    DistributionBenchmark,
)
from lematerial_forgebench.metrics.distribution_metrics import (
    MMD,
    FrechetDistance,
    JSDistance,
)
from lematerial_forgebench.preprocess.distribution_preprocess import (
    DistributionPreprocessor,
)
from lematerial_forgebench.utils.distribution_utils import (
    map_space_group_to_crystal_system,
)

%load_ext autoreload
%autoreload 2

In [2]:
with open("small_lematbulk.pkl", "rb") as f:
    test_lemat = pickle.load(f)

In [3]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    structure,
    structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [4]:
distribution_preprocessor = DistributionPreprocessor()
preprocessor_result = distribution_preprocessor(structures)

In [5]:
preprocessor_result.processed_structures[0].properties.get("distribution_properties")

{'Volume': 392.55602810984936,
 'Density(g/cm^3)': 3.600794626101272,
 'Density(atoms/A^3)': 0.020379256531914348,
 'SpaceGroup': 225,
 'CrystalSystem': 7,
 'CompositionCounts': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'Composition': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.

In [8]:
metric = FrechetDistance(reference_df=test_lemat)

[autoreload of lematerial_forgebench.metrics.distribution_metrics failed: Traceback (most recent call last):
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\autoreload.py", line 280, in check
    elif self.deduper_reloader.maybe_reload_module(m):
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\samue\AppData\Local\uv\cache\archive-v0\Lj7vvSFvlp-U62Tn9J8TQ\Lib\site-packages\IPython\extensions\deduperreload\deduperreload.py", line 533, in maybe_reload_module
    new_source_code = f.read()
                      ^^^^^^^^
  File "C:\Users\samue\AppData\Roaming\uv\python\cpython-3.11.11-windows-x86_64-none\Lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnicodeDecodeError: 'charmap' codec can't decode byte 0x8f in position 290: character maps to <undefined>
]


In [7]:
metric.compute(preprocessor_result.processed_structures)

       Volume  Density(g/cm^3)  Density(atoms/A^3)  SpaceGroup  CrystalSystem  \
0  392.556028         3.600795            0.020379         225              7   
1  769.207176         5.006763            0.026001          62              3   
2   40.044795         2.329245            0.049944         227              7   
3  299.607968         3.497400            0.093455          14              2   

                                   CompositionCounts  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
3  [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0,...   

                                         Composition  
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
3  [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...  


NameError: name 'reference_df' is not defined

In [8]:
benchmark = DistributionBenchmark(test_lemat)

In [11]:
benchmark_result = benchmark.evaluate([test_df])

MMD_results


In [215]:
benchmark_result.evaluator_results["JSDistance"]["JSDistance_value"]

{'SpaceGroup': np.float64(0.8325546111576977),
 'CrystalSystem': np.float64(0.2788095948658411),
 'CompositionCounts': np.float64(0.6901659684588751),
 'Composition': np.float64(0.713320308611451)}

In [216]:
benchmark_result.evaluator_results["MMD"]["MMD_value"]

{'Volume': np.float64(0.3578373686257075),
 'Density(g/cm^3)': np.float64(0.16174021104526248),
 'Density(atoms/A^3)': np.float64(0.0003819190830129937)}