In [1]:
from pymatgen.core import Structure
from datasets import load_dataset
import pandas as pd
from scipy.spatial.distance import jensenshannon
import matplotlib.pyplot as plt
import numpy as np
from lematerial_forgebench.preprocess.distribution_preprocess import DistributionPreprocessor
from lematerial_forgebench.metrics.distribution_metrics import JSDistance, MMD, FrechetDistance
from pymatgen.util.testing import PymatgenTest

%load_ext autoreload
%autoreload 2

In [2]:
reference_df_1 = pd.read_csv("lematbulk_scaled.csv")

In [60]:
# reference_df_2 = pd.read_csv("lematbulk_composition.csv")

In [39]:
# reference_df_final = pd.concat([reference_df_1, reference_df_2])

In [3]:
reference_df_density = pd.read_csv("lematbulk_density.csv")

In [4]:
reference_df_density.rename(columns = {"density":"Density(atoms/A^3)"}, inplace = True)

In [4]:
def generate_probabilities(df, show_hist = True):
    # create an empty list of space groups/crystal systems/compositions and fill in proportions/counts
    # depending on the application (as some samples will have zero of space group 1 etc) 
    
    probs = np.asarray(df.value_counts("SpaceGroup")/len(df))
    indicies = np.asarray(df.value_counts("SpaceGroup").index)
    strut_list = np.concatenate(([indicies], [probs]), axis = 0).T
    strut_list = strut_list[strut_list[:, 0].argsort()]
    # strut_list = np.flip(strut_list)
    if show_hist:     
        plt.bar(strut_list.T[0], strut_list.T[1])
        plt.show()

    return strut_list # 2d array with col1 = crystal descriptor and col2 = probability

In [3]:
dataset_name = "Lematerial/LeMat-Bulk"
name = "compatible_pbe"
split = "train"
dataset = load_dataset(dataset_name, name=name, split=split, streaming=False)

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [12]:
pd.DataFrame(dataset.select(range(0, 10))).columns

Index(['elements', 'nsites', 'chemical_formula_anonymous',
       'chemical_formula_reduced', 'chemical_formula_descriptive', 'nelements',
       'dimension_types', 'nperiodic_dimensions', 'lattice_vectors',
       'immutable_id', 'cartesian_site_positions', 'species',
       'species_at_sites', 'last_modified', 'elements_ratios', 'stress_tensor',
       'energy', 'magnetic_moments', 'forces', 'total_magnetization', 'dos_ef',
       'functional', 'cross_compatibility', 'entalpic_fingerprint'],
      dtype='object')

In [5]:
test = PymatgenTest()

filename = "CsBr.cif"
structure = Structure.from_file(filename)
structure = structure.remove_oxidation_states()

filename2 = "CsPbBr3.cif"
structure2 = Structure.from_file(filename2)
structure2 = structure2.remove_oxidation_states()

structures = [
    structure,
    structure2,
    test.get_structure("Si"),
    test.get_structure("LiFePO4"),
]

Use MatSciTest in pymatgen.util.testing instead.
  test = PymatgenTest()


In [6]:
distribution_preprocessor = DistributionPreprocessor()
preprocessor_result = distribution_preprocessor(structures)

In [7]:
test_df = pd.DataFrame(preprocessor_result.processed_structures, columns = ["SpaceGroup", "Volume", "Density(atoms/A^3)", "Composition", "CompositionCounts"])

In [35]:
type(test_df.Composition.iloc[0])

numpy.ndarray

In [8]:
MMD.compute_structure(test_df, reference_df_density)

{'Density(atoms/A^3)': np.float64(0.05325134883413285)}

In [13]:
FrechetDistance.compute_structure(test_df, reference_df_1)

ValueError: Input curves do not have the same dimensions.

In [36]:
test_df

Unnamed: 0,SpaceGroup,Volume,Density(atoms/A^3),Composition,CompositionCounts
0,225,392.556028,0.020379,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,62,769.207176,0.026001,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,227,40.044795,0.049944,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,14,299.607968,0.093455,"[0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0,...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."
