In [1]:
import sys
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
from pymatgen.core import Structure
sys.path.append("..")
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [2]:
datasets = {
    "WyckoffTransformer-raw": ("WyckoffTransformer",),
    "WyFormer-harmonic-raw": ("WyckoffTransformer-harmonic",),
    "WyFormer-letters": ("WyckoffTransformer-letters",),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "CHGNet_fix"),
    "SymmCD": ("SymmCD", "CHGNet_fix"),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++", "CHGNet_fix"),
    "WyForDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-naive-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-vanilla-DiffCSP++": ("WyckoffLLM-vanilla", "DiffCSP++"),
    "WyLLM-site-symmetry-DiffCSP++": ("WyckoffLLM-site-symmetry", "DiffCSP++"),
    #"WyckoffTransformer-free": ("WyckoffTransformer", "CrySPR", "CHGNet_free"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    #"DiffCSP++ raw": ("DiffCSP++",),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "FlowMM": ("FlowMM", "CHGNet_fix"),
    "MiAD": ("MiAD", "CHGNet_free"),
    #"MP-20 train": ("split", "train"),
    #"MP-20 test": ("split", "test"),
}
raw_datasets = {
    "SymmCD": ("SymmCD",),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++"),
    "CrystalFormer": ("CrystalFormer",),
    "DiffCSP++": ("DiffCSP++",),
    "DiffCSP": ("DiffCSP",),
    "FlowMM": ("FlowMM",),
}

In [3]:
all_datasets = load_all_from_config(
    datasets=list(datasets.values()) + list(raw_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test"), ("WyckoffTransformer", "CrySPR", "CHGNet_fix")],
    dataset_name="mp_20")

In [4]:
wycryst_transformations = ('WyCryst', 'CrySPR', 'CHGNet_fix')
datasets["WyCryst"] = wycryst_transformations
raw_datasets["WyCryst"] = wycryst_transformations
all_datasets[wycryst_transformations] = GeneratedDataset.from_cache(wycryst_transformations, "mp_20_biternary")

In [5]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [6]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [7]:
import evaluation.novelty
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [8]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [9]:
mp_20 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
(mp_20.spacegroup_number == 1).mean()
mp_20.smact_validity.mean()

0.9057020937893829

In [10]:
from collections import Counter
from operator import itemgetter
from itertools import chain
element_counts = Counter(chain(*mp_20.elements))

In [11]:
represented_elements=frozenset(map(itemgetter(0), element_counts.most_common(30)))

In [12]:
def check_represented_composition(structure: Structure) -> bool:
    for element in structure.composition:
        if element not in represented_elements:
            return False
    return True

In [13]:
def check_double_represented_composition(structure: Structure) -> bool:
    found_first_exotic = False
    for element in structure.composition:
        if element not in represented_elements:
            if found_first_exotic:
                return False
            found_first_exotic = True
    return True

In [None]:
top_10_groups = frozenset(mp_20.spacegroup_number.value_counts().iloc[:10].index)
n_elements_dist = {}

In [35]:
from scipy.stats import ttest_ind

Validity
1. Vanilla; Valid records: 2866 / 9648 = 29.71%
2. Naive; Valid records: 9492 / 9804 = 96.82%
3. Site Symmetry; Valid records: 8955 / 9709 = 92.23%

Formal validity:
1. WyFomer: 0.9772; [WanDB](https://wandb.ai/symmetry-advantage/WyckoffTransformer/runs/yj1cme83/logs?nw=nwuserkazeev)
2. SymmCD: 0.9580

In [None]:
table = pd.DataFrame(
    index=datasets.keys(), columns=[
        "Novelty (%)",
        "Represented composition",
        "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "S.U.N. (%)",
        "R.S.U.N. (%)",
        "~R.S.U.N. (%)",
        "Top-10 S.U.N. (%)",
        "Novel Uniques Templates (#)",
        "Novel Template (%)", 
        "P1 (%)",
        "Space Group", "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
precision_etc_sample_size = 1000
best_p1 = (novelty_filter.get_novel(filter_by_unique_structure(
    all_datasets[("WyckoffTransformer", "DiffCSP++", "CHGNet_fix")].data)).group == 1).astype(float)
for name, transformations in tqdm([("train", ("split", "train"))]+list(datasets.items())):
    data = all_datasets[transformations].data.copy()
    if len(data) > precision_etc_sample_size:
        data = data.sample(precision_etc_sample_size, random_state=42)
    unique = filter_by_unique_structure(data)
    print(f"{name} unique: {len(unique)} / {len(data)} = {len(unique) / len(data) :.3%}")
    if transformations == ("split", "train"):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        table.loc[name, "Represented composition"] = novel.structure.apply(check_represented_composition).mean()
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        all_templates = novel.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1)
        novel_template = ~all_templates.isin(train_w_template_set)
        table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
        table.loc[name, "Novel Uniques Templates (#)"] = all_templates[novel_template].nunique() 
        table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(novel)
        n_elements_dist[name] = novel.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(novel)
    p1 = (novel.group == 1).astype(float)
    table.loc[name, "P1 (%)"] = 100 * p1.mean()
    table.loc[name, "P1 p_value"] = ttest_ind(p1, best_p1).pvalue
    # table.loc[name, "# DoF"] = test_evaluator.get_dof_emd(novel)
    table.loc[name, "Space Group"] = test_evaluator.get_sg_chi2(novel)
    #try:
    #    table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    #except KeyError:
    #    pass
    #table.loc[name, "Elements"] = test_evaluator.get_elements_chi2(novel)
    if "corrected_chgnet_ehull" in novel.columns:
        # S.U.N. is measured with respect to the initial structures
        has_ehull = data.corrected_chgnet_ehull.notna()
        data_is_represented = data.structure.apply(check_represented_composition)
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold) # & (novel.elements.apply(lambda x: len(frozenset(x))) >= 2)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull.sum()
        is_represented_novel = novel.structure.apply(check_represented_composition)
        table.loc[name, "R.S.U.N. (%)"] = 100 * (is_sun & is_represented_novel).sum() / (has_ehull & data_is_represented).sum()
        table.loc[name, "~R.S.U.N. (%)"] = 100 * (is_sun & ~is_represented_novel).sum() / (has_ehull & ~data_is_represented).sum()
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull.sum()
        has_ehull_top_10 = data[data.spacegroup_number.isin(top_10_groups)].corrected_chgnet_ehull.notna().sum()
        table.loc[name, "Top-10 S.U.N. (%)"] = 100 * is_sun[novel.spacegroup_number.isin(top_10_groups)].sum() / has_ehull_top_10
    #if transformations == ("split", "train"):
        # Train forms the baseline of the hull
      #  test_dataset = all_datasets[("split", "test")].data
      #  test_with_ehull = test_dataset[test_dataset.corrected_chgnet_ehull.notna()]
      #  test_unique = filter_by_unique_structure(test_with_ehull)
      #  test_novel = novelty_filter.get_novel(test_unique)
      #  table.loc[name, "S.U.N. (%)"] = 100 * (test_novel.corrected_chgnet_ehull <= E_hull_threshold).sum() / len(test_with_ehull)
n_elements_dist["MP-20"] = mp_20.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(mp_20)
table

  0%|          | 0/18 [00:00<?, ?it/s]

train unique: 1000 / 1000 = 100.000%
WyckoffTransformer-raw unique: 999 / 1000 = 99.900%
WyFormer-harmonic-raw unique: 999 / 1000 = 99.900%
WyFormer-letters unique: 995 / 1000 = 99.500%
WyFormer-letters-DiffCSP++ unique: 958 / 959 = 99.896%
SymmCD unique: 998 / 1000 = 99.800%
WyFormer unique: 1000 / 1000 = 100.000%
WyFormer-harmonic-DiffCSP++ unique: 983 / 983 = 100.000%
WyForDiffCSP++ unique: 1000 / 1000 = 100.000%
WyLLM-naive-DiffCSP++ unique: 979 / 981 = 99.796%
WyLLM-vanilla-DiffCSP++ unique: 568 / 1000 = 56.800%
WyLLM-site-symmetry-DiffCSP++ unique: 998 / 999 = 99.900%
CrystalFormer unique: 988 / 992 = 99.597%
DiffCSP++ unique: 999 / 1000 = 99.900%
DiffCSP unique: 996 / 1000 = 99.600%
FlowMM unique: 994 / 997 = 99.699%
MiAD unique: 997 / 1000 = 99.700%
WyCryst unique: 994 / 994 = 100.000%


Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%),P1 p_value
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
WyckoffTransformer-raw,90.990991,,,,,,,,,,,,,,,2.20022,0.207734,,0.2362964
WyFormer-harmonic-raw,91.391391,,,,,,,,,,,,,,,2.409639,0.171727,,0.1398796
WyFormer-letters,88.542714,,,,,,,,,,,,,,,1.589103,0.213655,,0.8141707
WyFormer-letters-DiffCSP++,88.204593,0.211834,99.526627,82.721893,98.106509,96.672562,0.476117,0.031184,0.111384,29.301356,29.166667,29.335072,31.643002,198.0,24.494649,1.189061,0.231082,28.779979,0.6316157
SymmCD,86.973948,0.190092,95.737327,84.677419,99.539171,94.306876,0.628383,0.095043,0.539334,32.3,31.147541,32.55814,36.798337,101.0,12.876053,2.406739,0.248896,31.4,0.1482273
WyFormer,90.0,0.2,99.555556,80.444444,98.666667,96.705726,0.735854,0.053076,0.096462,38.938939,43.455497,37.871287,44.190871,180.0,22.098214,3.236607,0.222591,38.038038,0.0125926
WyFormer-harmonic-DiffCSP++,90.742625,0.173767,99.775785,82.735426,99.215247,95.710811,0.595413,0.091496,0.03457,33.97762,30.120482,34.761322,39.834025,192.0,23.370787,2.359551,0.176036,32.960326,0.1611706
WyForDiffCSP++,89.5,0.198883,99.664804,80.335196,99.217877,96.783109,0.664005,0.050209,0.098084,36.6,38.219895,36.217553,39.314516,186.0,22.309417,1.457399,0.212014,35.9,0.9931246
WyLLM-naive-DiffCSP++,94.586313,0.185745,99.676026,82.721382,98.488121,95.556047,0.469391,0.057899,0.014727,31.090724,27.932961,31.795511,36.326531,235.0,26.760563,1.300108,0.171937,30.377166,0.7804164
WyLLM-vanilla-DiffCSP++,95.598592,0.766114,99.631676,88.766114,94.475138,59.672784,2.225215,0.233504,0.252875,,,,,87.0,28.096118,2.033272,0.621622,,0.4058305


In [51]:
import numpy as np
def get_observation(name, column, n_structures=1000):
    all_observations = np.zeros(n_structures)
    all_observations[:int(table.at[name, column])] = 1.
    return all_observations

def compare_binary_columns(name1, name2, column):
    obs1 = get_observation(name1, column)
    obs2 = get_observation(name2, column)
    return ttest_ind(obs1, obs2)

for name in datasets:
    try:
        print(name, compare_binary_columns("WyForDiffCSP++", name, "Novel Uniques Templates (#)").pvalue)
    except ValueError:
        print(f"Failed to compare {name} with WyForDiffCSP++ for Novel Uniques Templates (#)")

Failed to compare WyckoffTransformer-raw with WyForDiffCSP++ for Novel Uniques Templates (#)
Failed to compare WyFormer-harmonic-raw with WyForDiffCSP++ for Novel Uniques Templates (#)
Failed to compare WyFormer-letters with WyForDiffCSP++ for Novel Uniques Templates (#)
WyFormer-letters-DiffCSP++ 0.49595439125302077
SymmCD 5.3847860977427054e-08
WyFormer 0.7287682560720293
WyFormer-harmonic-DiffCSP++ 0.7319929638985407
WyForDiffCSP++ 1.0
WyLLM-naive-DiffCSP++ 0.007178703174329506
WyLLM-vanilla-DiffCSP++ 9.306497717605116e-11
WyLLM-site-symmetry-DiffCSP++ 0.954249351326199
CrystalFormer 6.65213364237302e-14
DiffCSP++ 1.0010008329654983e-41
DiffCSP 2.2233449165597084e-13
FlowMM 3.782926557412568e-21
MiAD 3.782926557412568e-21
WyCryst 0.21723847577045158


In [16]:
symmetry_table = table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [17]:
symmetry_table.style.format("{:.1f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
WyFormer-letters-DiffCSP++,198.0,1.2,0.2,29.3,28.8
WyForDiffCSP++,186.0,1.5,0.2,36.6,35.9


In [18]:
print(symmetry_table.to_markdown())

| Method                     |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| WyFormer-letters-DiffCSP++ |                           198 |  1.18906 |      0.231082 |      29.3014 |          28.78 |
| WyForDiffCSP++             |                           186 |  1.4574  |      0.212014 |      36.6    |          35.9  |


In [19]:
symmetry_table_symmcd = table.loc[["SymmCD", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [20]:
symmetry_table_symmcd.style.format("{:.2f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
SymmCD,101.0,2.41,0.25,32.3,31.4
WyForDiffCSP++,186.0,1.46,0.21,36.6,35.9


In [21]:
print(symmetry_table_symmcd.to_markdown())

| Method         |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| SymmCD         |                           101 |  2.40674 |      0.248896 |         32.3 |           31.4 |
| WyForDiffCSP++ |                           186 |  1.4574  |      0.212014 |         36.6 |           35.9 |


In [22]:
table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
          ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]]

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyFormer-letters-DiffCSP++,88.204593,99.526627,82.721893,98.106509,96.672562,0.476117,0.031184,0.111384
WyForDiffCSP++,89.5,99.664804,80.335196,99.217877,96.783109,0.664005,0.050209,0.098084


In [23]:
table.loc[:, ["S.U.N. (%)", "S.S.U.N. (%)"]]

Unnamed: 0_level_0,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1
WyckoffTransformer-raw,,
WyFormer-harmonic-raw,,
WyFormer-letters,,
WyFormer-letters-DiffCSP++,29.301356,28.779979
SymmCD,32.3,31.4
WyFormer,38.938939,38.038038
WyFormer-harmonic-DiffCSP++,33.97762,32.960326
WyForDiffCSP++,36.6,35.9
WyLLM-naive-DiffCSP++,31.090724,30.377166
WyLLM-vanilla-DiffCSP++,,


In [24]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set).mean()

0.9713685606898077

In [25]:
max_subset=["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", "S.S.U.N. (%)", "S.U.N. (%)", "Novel Template (%)"]
# -1 to exclude the MP-20 training set
def highlight_max_value(s):
    if s.name not in max_subset:
        return ['' for _ in s]
    is_max = s == s.max()
    #is_max.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_max]

min_subset=[r"$\rho$", "$E$", "# Elements", "# DoF", "Space Group", "Elements", "P1 (%)"]
def highlight_min_value(s):
    if s.name not in min_subset:
        return ['' for _ in s]
    is_min = s == s.min()
    #is_min.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_min]

In [26]:
def prettify(table):
    return table.style.format({
    "Novelty (%)": "{:.2f}",
    "Structural": "{:.2f}",
    "Compositional": "{:.2f}",
    "Recall": "{:.2f}",
    "Precision": "{:.2f}",
    r"$\rho$": "{:.2f}",
    "$E$": "{:.3f}",
    "# Elements": "{:.3f}",
    "# DoF": "{:.3f}",
    "Space Group": "{:.3f}",
    "Elements": "{:.3f}",
    "Novel Template (%)": "{:.2f}",
    "P1 (%)": "{:.2f}",
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
}).apply(highlight_max_value).apply(highlight_min_value)
prettify(table)

Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
WyckoffTransformer-raw,90.99,,,,,,,,,,,,,,,2.2,0.208,
WyFormer-harmonic-raw,91.39,,,,,,,,,,,,,,,2.41,0.172,
WyFormer-letters,88.54,,,,,,,,,,,,,,,1.59,0.214,
WyFormer-letters-DiffCSP++,88.2,0.211834,99.53,82.72,98.11,96.67,0.48,0.031,0.111,29.3,29.166667,29.335072,31.643002,198.0,24.49,1.19,0.231,28.8
SymmCD,86.97,0.190092,95.74,84.68,99.54,94.31,0.63,0.095,0.539,32.3,31.147541,32.55814,36.798337,101.0,12.88,2.41,0.249,31.4
WyFormer,90.0,0.2,99.56,80.44,98.67,96.71,0.74,0.053,0.096,38.9,43.455497,37.871287,44.190871,180.0,22.1,3.24,0.223,38.0
WyFormer-harmonic-DiffCSP++,90.74,0.173767,99.78,82.74,99.22,95.71,0.6,0.091,0.035,34.0,30.120482,34.761322,39.834025,192.0,23.37,2.36,0.176,33.0
WyForDiffCSP++,89.5,0.198883,99.66,80.34,99.22,96.78,0.66,0.05,0.098,36.6,38.219895,36.217553,39.314516,186.0,22.31,1.46,0.212,35.9
WyLLM-naive-DiffCSP++,94.59,0.185745,99.68,82.72,98.49,95.56,0.47,0.058,0.015,31.1,27.932961,31.795511,36.326531,235.0,26.76,1.3,0.172,30.4
WyLLM-vanilla-DiffCSP++,95.6,0.766114,99.63,88.77,94.48,59.67,2.23,0.234,0.253,,,,,87.0,28.1,2.03,0.622,


In [27]:
LLM_columns_1 = ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]
LLM_columns_2 = ["Novel Uniques Templates (#)", "P1 (%)", "Space Group"]

In [28]:
LLM_rows = ["WyForDiffCSP++", "WyLLM-naive-DiffCSP++", "WyLLM-vanilla-DiffCSP++", "WyLLM-site-symmetry-DiffCSP++"]

In [29]:
prettify(table.loc[["WyForDiffCSP++", "DiffCSP++", "DiffCSP", "FlowMM", "MiAD"], LLM_columns_1])

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyForDiffCSP++,89.5,99.66,80.34,99.22,96.78,0.66,0.05,0.098
DiffCSP++,89.69,100.0,85.04,99.33,95.79,0.15,0.036,0.504
DiffCSP,90.06,100.0,80.94,99.55,96.2,0.82,0.052,0.294
FlowMM,90.14,96.21,82.48,99.67,96.35,0.31,0.044,0.115
MiAD,84.15,99.05,84.15,99.88,94.24,0.17,0.047,0.064


In [30]:
len(GeneratedDataset.from_cache(("WyckoffLLM-naive", )).data)

9492

In [31]:
len(GeneratedDataset.from_cache(("WyckoffLLM-site-symmetry", )).data)

7476

In [32]:
len(GeneratedDataset.from_cache(("WyckoffLLM-vanilla", )).data)

2866

In [65]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size_for_precision = 1000
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        if len(valid) > sample_size_for_precision:
            valid_precision_sample = valid.sample(sample_size_for_precision, random_state=42, replace=False)        
            cov_metrics_subsampled = raw_test_evaluator.get_coverage(valid_precision_sample.cdvae_crystal)
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics_subsampled["cov_precision"]
        else:
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
prettify(cdvae_table)

  0%|          | 0/17 [00:00<?, ?it/s]

Ignoring 2 generated samples without composition fingerprints.


Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyckoffTransformer-raw,,,,,,,
WyFormer-harmonic-raw,,,,,,,
WyFormer-letters,,,,,,,
WyFormer-letters-DiffCSP++,99.58,84.15,98.51,96.21,0.17,0.049,0.116
SymmCD,95.49,85.86,99.19,96.05,0.32,0.095,0.392
WyFormer,99.6,81.4,98.77,95.94,0.39,0.078,0.081
WyFormer-harmonic-DiffCSP++,99.8,83.72,99.15,95.28,0.41,0.079,0.055
WyForDiffCSP++,99.7,81.4,99.26,95.85,0.33,0.07,0.078
WyLLM-naive-DiffCSP++,99.69,83.18,98.77,94.25,0.23,0.078,0.025
WyLLM-vanilla-DiffCSP++,99.8,93.0,96.55,59.81,2.01,0.339,0.284


In [66]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(raw_datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 1000
for name, transformations in tqdm(list(raw_datasets.items())):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        if len(valid) > sample_size_for_precision:
            valid_precision_sample = valid.sample(sample_size_for_precision, random_state=42, replace=False)        
            cov_metrics_subsampled = raw_test_evaluator.get_coverage(valid_precision_sample.cdvae_crystal)
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics_subsampled["cov_precision"]
        else:
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
prettify(cdvae_table)

  0%|          | 0/11 [00:00<?, ?it/s]

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
SymmCD,100.0,86.27,99.5,94.82,0.06,0.16,0.402
WyFormer,99.6,81.4,98.77,99.57,0.39,0.078,0.081
WyFormerDiffCSP++,99.8,81.4,99.51,95.94,0.36,0.083,0.079
WyFormer-harmonic-DiffCSP++,99.8,83.7,99.52,95.81,0.2,0.084,0.049
WyFormer-letters-DiffCSP++,99.5,84.14,98.92,95.7,0.17,0.065,0.104
WyLLM-DiffCSP++,99.5,83.28,98.91,96.2,0.19,0.09,0.029
CrystalFormer,93.39,84.98,99.62,94.09,0.19,0.208,0.128
DiffCSP++,99.94,85.13,99.67,95.71,0.31,0.069,0.399
DiffCSP,100.0,83.22,99.82,96.84,0.35,0.095,0.346
FlowMM,96.87,83.11,99.73,95.59,0.12,0.073,0.094
