In [1]:
from typing import Tuple
import pandas as pd
from scipy.stats import pearsonr
from tqdm.notebook import tqdm
import sys
sys.path.append('..')
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure_reduced_comp_index, record_to_anonymous_fingerprint

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [2]:
dft_datasets = {
    "WyFormerDirect": ("WyckoffTransformer", "DFT"),
    "WyFormerCrySPR": ("WyckoffTransformer", "CrySPR", "CHGNet_fix", "DFT"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "DFT"),
    "WyFormerDiffCSP++10k": ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free", "DFT"),
    "WyFormerDiffCSP++10k-GGA-relax-1": ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free", "DFT-GGA-relax-1"),
    "WyFormerHarmonicDiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++", "DFT"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "DFT"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "DFT"),
    "SymmCD": ("SymmCD", "DFT"),
    "DiffCSP": ("DiffCSP", "DFT"),
    "CrystalFormer": ("CrystalFormer", "DFT"),
    "DiffCSP++": ("DiffCSP++", "DFT"),
    "FlowMM": ("FlowMM", "DFT"),
    "MatterGen": ("MatterGen", "MatterGen_10k", "DFT")
    # This dataset is corrupted
    # "MatterGen-CHGNet": ("MatterGen", "MatterGen_10k", "CHGNet_fix", "DFT")
}

source_datasets = {name: t[:-1] for name, t in dft_datasets.items()}

In [3]:
chgnet_datasets = {
    "WyFormerDirect": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerCrySPR": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    "WyFormerDiffCSP++10k": ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free"),
    "WyFormerDiffCSP++10k-GGA-relax-1": ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "CHGNet_fix"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "CHGNet_fix"),
    "WyFormerHarmonicDiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++", "CHGNet_fix"),
    "SymmCD": ("SymmCD", "CHGNet_fix"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "FlowMM": ("FlowMM", "CHGNet_fix"),
    "MatterGen": ("MatterGen", "MatterGen_10k", "CHGNet_fix")
}

In [4]:
chgnet_data = load_all_from_config(datasets=list(chgnet_datasets.values()) + [('WyckoffTransformer', 'CrySPR', 'CHGNet_fix')])

In [5]:
all_datasets = load_all_from_config(
    datasets=list(dft_datasets.values()) + list(source_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test")],
    dataset_name="mp_20")

In [6]:
wycryst_transformations = ('WyCryst', 'CrySPR', 'CHGNet_fix')
source_datasets["WyCryst"] = wycryst_transformations
chgnet_datasets["WyCryst"] = wycryst_transformations
chgnet_data[wycryst_transformations] = GeneratedDataset.from_cache(wycryst_transformations, "mp_20_biternary")
dft_datasets["WyCryst"] = tuple(list(wycryst_transformations) + ["DFT"])
all_datasets[dft_datasets["WyCryst"]] = GeneratedDataset.from_cache(dft_datasets["WyCryst"], "mp_20_biternary")

In [7]:
excluded_categories = frozenset(["radioactive", "rare_earth_metal", "noble_gas"])
from pymatgen.core import Structure
def check_composition(structure: Structure) -> bool:
    for category in excluded_categories:
        if structure.composition.contains_element_type(category):
            return False
    return True

In [8]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference, reference_index_type="reduced_composition")



In [9]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [10]:
train_w_template_set = frozenset(novelty_reference.apply(record_to_anonymous_fingerprint, axis=1))

In [11]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[transformations[:-1]].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [16]:
train_novelty_filter = NoveltyFilter(all_datasets[('split', 'train')].data, reference_index_type="reduced_composition")



In [17]:
def get_train_based_sun(dataset: pd.DataFrame, intial_count: int):
    unique = filter_by_unique_structure_reduced_comp_index(dataset)
    print("Unique structures %", len(unique)/len(dataset) * 100)
    unique_novel = train_novelty_filter.get_novel(unique)
    print("Novel & Unique structures %", len(unique_novel)/len(dataset) * 100)
    for threshold in (0, 0.08):
        sun = ((unique_novel.e_above_hull_corrected < threshold) & unique_novel.structure.map(lambda s: len(set(s.composition)) >=2)).sum()
        print(f"S.U.N. (E<{threshold}) %: {sun/intial_count * 100:.2f})")

In [18]:
get_train_based_sun(all_datasets[dft_datasets["WyFormerDiffCSP++10k"]].data, 1e4)

Unique structures % 98.80494648238594




Novel & Unique structures % 91.14621219993765
S.U.N. (E<0) %: 3.78)
S.U.N. (E<0.08) %: 18.54)


In [19]:
get_train_based_sun(all_datasets[dft_datasets["WyFormerDiffCSP++10k-GGA-relax-1"]].data, 1e4)

Unique structures % 98.80855397148676




Novel & Unique structures % 91.17107942973523
S.U.N. (E<0) %: 4.09)
S.U.N. (E<0.08) %: 18.83)


Validity
1. Vanilla; Valid records: 2866 / 9648 = 29.71%
2. Naive; Valid records: 9492 / 9804 = 96.82%
3. Site Symmetry; Valid records: 8955 / 9709 = 92.23%

In [22]:
tables = {}
for E_hull_threshold in (0, 0.08):
    table = pd.DataFrame(
        index=dft_datasets.keys(), columns=[
            "DFT dataset size",
            "Source Novelty (%)",
            "In-DFT Novelty (%)",
            "S.U.N. (%)",
            "P1 in source (%)",
            "S.S.U.N. (%)"])
    table.index.name = "Method"

    for name, transformations in tqdm(dft_datasets.items()):
        dataset = all_datasets[transformations]
        if "corrected_e_hull" not in dataset.data.columns:
            dataset.data["corrected_e_hull"] = dataset.data.e_above_hull_corrected
        table.loc[name, "DFT dataset size"] = len(dataset.data)
        try:
            source_dataset = all_datasets[transformations[:-1]]
        except KeyError:
            source_dataset = chgnet_data[transformations[:-1]]
        chgnet_dataset = chgnet_data[chgnet_datasets[name]]

        unique = filter_by_unique_structure_reduced_comp_index(dataset.data)
        novel = novelty_filter.get_novel(unique)
        table.loc[name, "In-DFT Novelty (%)"] = 100 * len(novel) / len(unique)
        if "structure" not in source_dataset.data.columns:
            print(f"Skipping {name} as the source does not have structure column")
            continue
        source_novel = novelty_filter.get_novel(source_dataset.data)
        source_novelty = 100 * len(source_novel) / len(source_dataset.data)
        table.loc[name, "Source Novelty (%)"] = len(novel) / len(unique) * source_novelty
        table.loc[name, "P1 in source (%)"] = 100 * (source_novel.group == 1).mean()
        try:
            table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
        except KeyError:
            pass
        # source_novel_symmetric = (source_novel.group != 1).sum() / len(source_dataset.data)
        # table["Source Novel !P1 (%)"] = 100 * source_novel_symmetric
        # DFT failure == unreal structure
        if name == "WyFormerDiffCSP++10k":
            # This dataset is a bit special, as it contains 10k structures from the WyckoffTransformer
            # not filtered by the novelty filter
            dft_structures = 10000
            source_novelty = 100
        else:
            dft_structures = 105
        has_ehull = dataset.data.corrected_e_hull.notna()
        is_sun = (novel.corrected_e_hull <= E_hull_threshold) # & (novel.elements.apply(lambda x: len(frozenset(x))) >= 2)
        table.loc[name, "S.U.N. (%)"] = source_novelty * is_sun.sum() / dft_structures
        table.loc[name, "total_sun"] = is_sun.sum().astype(int)
        table.loc[name, "S.S.U.N. (%)"] = source_novelty * (is_sun & (novel.group != 1)).sum() / dft_structures
        table.loc[name, "total_ssun"] = (is_sun & (novel.group != 1)).sum().astype(int)
        table.loc[name, "P1 in stable (%)"] = 100 * (novel[is_sun].group == 1).mean()

        chgnet_unique = filter_by_unique_structure_reduced_comp_index(chgnet_dataset.data)
        chgnet_novel = novelty_filter.get_novel(chgnet_unique)
        if "corrected_chgnet_ehull" in chgnet_dataset.data.columns:
            chgnet_is_sun = (chgnet_novel.corrected_chgnet_ehull < E_hull_threshold)
            #table.loc[name, "CHGNet dataset size"] = chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
            table.loc[name, "S.U.N. (CHGNet) (%)"] =  100 * chgnet_is_sun.sum() / chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
            table.loc[name, "S.S.U.N. (CHGNet) (%)"] = 100 * (chgnet_is_sun & (chgnet_novel.group != 1)).sum() / chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
            
            chgnet_dft_available = chgnet_dataset.data.reindex(dataset.data.index[has_ehull])
            table.loc[name, "r DFT CHGNet"] = \
                pearsonr((chgnet_dft_available.corrected_chgnet_ehull < E_hull_threshold).astype(float),
                        (dataset.data.corrected_e_hull < E_hull_threshold).astype(float)).correlation
    tables[E_hull_threshold] = table

  0%|          | 0/15 [00:00<?, ?it/s]

Skipping WyFormerDirect as the source does not have structure column


  pearsonr((chgnet_dft_available.corrected_chgnet_ehull < E_hull_threshold).astype(float),


  0%|          | 0/15 [00:00<?, ?it/s]

Skipping WyFormerDirect as the source does not have structure column


  pearsonr((chgnet_dft_available.corrected_chgnet_ehull < E_hull_threshold).astype(float),


In [23]:
tables[0]

Unnamed: 0_level_0,DFT dataset size,Source Novelty (%),In-DFT Novelty (%),S.U.N. (%),P1 in source (%),S.S.U.N. (%),SG preserved (%),total_sun,total_ssun,P1 in stable (%),S.U.N. (CHGNet) (%),S.S.U.N. (CHGNet) (%),r DFT CHGNet
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormerDirect,94,,100.0,,,,,,,,,,
WyFormerCrySPR,104,89.78979,100.0,4.275704,1.560758,4.275704,96.153846,5.0,5.0,0.0,13.713714,13.613614,0.323818
WyFormerDiffCSP++,104,88.540385,99.038462,4.257143,1.565996,4.257143,95.145631,5.0,5.0,0.0,12.4,12.4,0.334818
WyFormerDiffCSP++10k,9623,80.786204,90.145141,3.45,7.120536,3.4,92.27628,345.0,340.0,1.449275,,,
WyFormerDiffCSP++10k-GGA-relax-1,9820,80.825152,90.188601,315.796493,7.120536,309.821964,93.863558,370.0,363.0,1.891892,,,
WyFormerHarmonicDiffCSP++,101,88.707921,98.019802,4.309524,2.430939,4.309524,91.919192,5.0,5.0,0.0,7.731434,7.629705,0.35976
WyLLM-DiffCSP++,102,93.171603,99.019608,4.480671,1.382979,4.480671,99.009901,5.0,5.0,0.0,7.441386,7.33945,0.778248
WyFormer-letters-DiffCSP++,104,85.609747,97.115385,3.358195,1.138952,3.358195,96.039604,4.0,4.0,0.0,8.342023,8.342023,0.52915
SymmCD,96,85.021108,95.833333,4.224651,2.224601,4.224651,92.391304,5.0,5.0,0.0,9.728601,9.707724,-0.09759
DiffCSP,104,86.425385,98.076923,6.713905,32.081253,5.874667,78.431373,8.0,7.0,12.5,17.4,14.7,0.441401


In [24]:
tables[0.08]

Unnamed: 0_level_0,DFT dataset size,Source Novelty (%),In-DFT Novelty (%),S.U.N. (%),P1 in source (%),S.S.U.N. (%),SG preserved (%),total_sun,total_ssun,P1 in stable (%),S.U.N. (CHGNet) (%),S.S.U.N. (CHGNet) (%),r DFT CHGNet
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormerDirect,94,,100.0,,,,,,,,,,
WyFormerCrySPR,104,89.78979,100.0,23.088803,1.560758,22.233662,96.153846,27.0,26.0,3.703704,39.139139,38.138138,0.664923
WyFormerDiffCSP++,104,88.540385,99.038462,22.137143,1.565996,21.285714,95.145631,26.0,25.0,3.846154,36.6,35.9,0.615387
WyFormerDiffCSP++10k,9623,80.786204,90.145141,16.7,7.120536,16.29,92.27628,1670.0,1629.0,2.45509,,,
WyFormerDiffCSP++10k-GGA-relax-1,9820,80.825152,90.188601,1447.542842,7.120536,1410.842168,93.863558,1696.0,1653.0,2.535377,,,
WyFormerHarmonicDiffCSP++,101,88.707921,98.019802,18.1,2.430939,18.1,91.919192,21.0,21.0,0.0,33.570702,32.553408,0.593771
WyLLM-DiffCSP++,102,93.171603,99.019608,10.753611,1.382979,10.753611,99.009901,12.0,12.0,0.0,30.88685,30.173293,0.592083
WyFormer-letters-DiffCSP++,104,85.609747,97.115385,15.951425,1.138952,15.951425,96.039604,19.0,19.0,0.0,29.19708,28.675704,0.548767
SymmCD,96,85.021108,95.833333,16.898605,2.224601,16.898605,92.391304,20.0,20.0,0.0,32.640919,32.160752,0.108845
DiffCSP,104,86.425385,98.076923,21.82019,32.081253,20.141714,78.431373,26.0,24.0,7.692308,56.7,39.9,0.405781


In [None]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set).mean()

0.9713685606898077

In [None]:
from scipy.stats import ttest_ind
import numpy as np
table = tables[0.08]
def get_observation(name, column="total_ssun"):
    all_observations = np.zeros(dft_structures)
    all_observations[:int(table.at[name, column])] = table.loc[name, "Source Novelty (%)"]/100
    return all_observations

In [None]:
for second in table.index:
    print(second, ttest_ind(get_observation("WyFormerCrySPR"), get_observation(second)))

WyFormerDirect TtestResult(statistic=4.235610970337431, pvalue=3.422476353970967e-05, df=208.0)
WyFormerCrySPR TtestResult(statistic=0.0, pvalue=1.0, df=208.0)
WyFormerDiffCSP++ TtestResult(statistic=0.22189907647859575, pvalue=0.8246101615178144, df=208.0)
WyFormerDiffCSP++10k TtestResult(statistic=-15.557210817266073, pvalue=1.0401809468383434e-36, df=208.0)
WyFormerDiffCSP++10k-GGA-relax-1 TtestResult(statistic=-15.560980065912586, pvalue=1.0122567974613785e-36, df=208.0)
WyFormerHarmonicDiffCSP++ TtestResult(statistic=0.645849177863795, pvalue=0.519088910017748, df=208.0)
WyLLM-DiffCSP++ TtestResult(statistic=2.1651036448177687, pvalue=0.03151803796626064, df=208.0)
WyFormer-letters-DiffCSP++ TtestResult(statistic=1.3471852097630213, pvalue=0.1793863861847237, df=208.0)
SymmCD TtestResult(statistic=1.1936704932668818, pvalue=0.2339665193408975, df=208.0)
DiffCSP TtestResult(statistic=0.40879296725299696, pvalue=0.6831121767700197, df=208.0)
CrystalFormer TtestResult(statistic=0.487

  res = hypotest_fun_out(*samples, **kwds)


In [None]:
for second in table.index:
    print(second, ttest_ind(get_observation("WyFormerCrySPR", column="total_sun"), get_observation(second, column="total_sun")))

WyFormerDirect TtestResult(statistic=4.392884673305486, pvalue=1.781146148169839e-05, df=208.0)
WyFormerCrySPR TtestResult(statistic=0.0, pvalue=1.0, df=208.0)
WyFormerDiffCSP++ TtestResult(statistic=0.22144572797630796, pvalue=0.824962637800953, df=208.0)
WyFormerDiffCSP++10k TtestResult(statistic=-15.141723815002894, pvalue=2.094904242650081e-35, df=208.0)
WyFormerDiffCSP++10k-GGA-relax-1 TtestResult(statistic=-15.145446238867047, pvalue=2.0392501659104891e-35, df=208.0)
WyFormerHarmonicDiffCSP++ TtestResult(statistic=0.8040325161930311, pvalue=0.4222961543125353, df=208.0)
WyLLM-DiffCSP++ TtestResult(statistic=2.3228084157052336, pvalue=0.021156109729705327, df=208.0)
WyFormer-letters-DiffCSP++ TtestResult(statistic=1.507493556944736, pvalue=0.13320149847034363, df=208.0)
SymmCD TtestResult(statistic=1.3541726172991884, pvalue=0.17715066923558573, df=208.0)
DiffCSP TtestResult(statistic=0.24781409040535707, pvalue=0.8045227172507443, df=208.0)
CrystalFormer TtestResult(statistic=0.6

In [None]:
table

Unnamed: 0_level_0,DFT dataset size,Source Novelty (%),In-DFT Novelty (%),S.U.N. (%),P1 in source (%),S.S.U.N. (%),SG preserved (%),total_sun,total_ssun,P1 in stable (%),S.U.N. (CHGNet) (%),S.S.U.N. (CHGNet) (%),r DFT CHGNet
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormerDirect,94,90.09,100.0,4.29,1.964702,4.29,86.170213,5.0,5.0,0.0,39.239239,38.238238,0.269486
WyFormerCrySPR,104,89.98999,100.0,23.140283,1.557286,22.283236,96.153846,27.0,26.0,3.703704,39.239239,38.238238,0.664923
WyFormerDiffCSP++,104,88.639423,99.038462,22.161905,1.564246,21.309524,95.145631,26.0,25.0,3.846154,36.7,36.0,0.615387
WyFormerDiffCSP++10k,9623,81.537579,90.488702,17.02,7.081807,16.61,92.218351,1702.0,1661.0,2.408931,,,
WyFormerDiffCSP++10k-GGA-relax-1,9820,81.551936,90.504634,1482.062412,7.081807,1445.161032,93.832499,1727.0,1684.0,2.489867,,,
WyFormerHarmonicDiffCSP++,101,90.19802,99.009901,19.087619,2.414929,19.087619,92.0,22.0,22.0,0.0,33.97762,32.960326,0.593771
WyLLM-DiffCSP++,102,94.594595,100.0,11.711712,1.375661,11.711712,99.019608,13.0,13.0,0.0,31.090724,30.377166,0.592083
WyFormer-letters-DiffCSP++,104,85.902263,97.115385,16.005928,1.135074,16.005928,96.039604,19.0,19.0,0.0,29.405631,28.884254,0.548767
SymmCD,96,85.44591,95.833333,16.983038,2.213542,16.983038,92.391304,20.0,20.0,0.0,33.058455,32.578288,0.108845
DiffCSP,104,88.082885,98.076923,22.238667,31.566641,20.528,78.431373,26.0,24.0,7.692308,57.4,40.6,0.405781


In [None]:
ttest_ind(
    (all_datasets[('WyckoffTransformer', 'CrySPR', 'CHGNet_fix')].data.group == 1).astype(float),
    (chgnet_data[('WyckoffTransformer', "DiffCSP++", "CHGNet_fix")].data.group == 1).astype(float))

TtestResult(statistic=0.19628385802590687, pvalue=0.8444079665608555, df=1997.0)

In [None]:

for second in table.index:
    print(second, ttest_ind(get_observation("WyFormerCrySPR", column="total_sun"), get_observation(second, column="total_sun")))

WyFormerDirect TtestResult(statistic=4.392884673305486, pvalue=1.781146148169839e-05, df=208.0)
WyFormerCrySPR TtestResult(statistic=0.0, pvalue=1.0, df=208.0)
WyFormerDiffCSP++ TtestResult(statistic=0.22144572797630796, pvalue=0.824962637800953, df=208.0)
WyFormerDiffCSP++10k TtestResult(statistic=-15.141723815002894, pvalue=2.094904242650081e-35, df=208.0)
WyFormerDiffCSP++10k-GGA-relax-1 TtestResult(statistic=-15.145446238867047, pvalue=2.0392501659104891e-35, df=208.0)
WyFormerHarmonicDiffCSP++ TtestResult(statistic=0.8040325161930311, pvalue=0.4222961543125353, df=208.0)
WyLLM-DiffCSP++ TtestResult(statistic=2.3228084157052336, pvalue=0.021156109729705327, df=208.0)
WyFormer-letters-DiffCSP++ TtestResult(statistic=1.507493556944736, pvalue=0.13320149847034363, df=208.0)
SymmCD TtestResult(statistic=1.3541726172991884, pvalue=0.17715066923558573, df=208.0)
DiffCSP TtestResult(statistic=0.24781409040535707, pvalue=0.8045227172507443, df=208.0)
CrystalFormer TtestResult(statistic=0.6

  res = hypotest_fun_out(*samples, **kwds)
