In [1]:
import sys
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
from pymatgen.core import Structure
sys.path.append("..")
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [2]:
datasets = {
    "WyckoffTransformer-raw": ("WyckoffTransformer",),
    "WyFormer-harmonic-raw": ("WyckoffTransformer-harmonic",),
    "WyFormer-letters": ("WyckoffTransformer-letters",),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "CHGNet_fix"),
    "SymmCD": ("SymmCD", "CHGNet_fix"),
    "SymmCD-raw": ("SymmCD",),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++", "CHGNet_fix"),
    "WyForDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-naive-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-vanilla-DiffCSP++": ("WyckoffLLM-vanilla", "DiffCSP++"),
    "WyLLM-site-symmetry-DiffCSP++": ("WyckoffLLM-site-symmetry", "DiffCSP++"),
    #"WyckoffTransformer-free": ("WyckoffTransformer", "CrySPR", "CHGNet_free"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    #"DiffCSP++ raw": ("DiffCSP++",),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "FlowMM": ("FlowMM", "CHGNet_fix"),
    "MiAD": ("MiAD", "CHGNet_free"),
    "MatterGen": ("MatterGen", "MatterGen_10k", "CHGNet_fix")
    #"MP-20 train": ("split", "train"),
    #"MP-20 test": ("split", "test"),
}
raw_datasets = {
    "SymmCD": ("SymmCD",),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++"),
    "CrystalFormer": ("CrystalFormer",),
    "DiffCSP++": ("DiffCSP++",),
    "DiffCSP": ("DiffCSP",),
    "FlowMM": ("FlowMM",),
    "MatterGen": ("MatterGen", "MatterGen_10k")
}

In [3]:
all_datasets = load_all_from_config(
    datasets=list(datasets.values()) + list(raw_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test"), ("WyckoffTransformer", "CrySPR", "CHGNet_fix")],
    dataset_name="mp_20")

In [4]:
wycryst_transformations = ('WyCryst', 'CrySPR', 'CHGNet_fix')
datasets["WyCryst"] = wycryst_transformations
raw_datasets["WyCryst"] = wycryst_transformations
all_datasets[wycryst_transformations] = GeneratedDataset.from_cache(wycryst_transformations, "mp_20_biternary")

In [5]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [6]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [7]:
import evaluation.novelty
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [8]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [9]:
mp_20 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
(mp_20.spacegroup_number == 1).mean()
mp_20.smact_validity.mean()

0.9057020937893829

In [10]:
from collections import Counter
from operator import itemgetter
from itertools import chain
element_counts = Counter(chain(*mp_20.elements))

In [11]:
represented_elements=frozenset(map(itemgetter(0), element_counts.most_common(30)))

In [12]:
def check_represented_composition(structure: Structure) -> bool:
    for element in structure.composition:
        if element not in represented_elements:
            return False
    return True

In [13]:
def check_double_represented_composition(structure: Structure) -> bool:
    found_first_exotic = False
    for element in structure.composition:
        if element not in represented_elements:
            if found_first_exotic:
                return False
            found_first_exotic = True
    return True

In [14]:
top_10_groups = frozenset(mp_20.spacegroup_number.value_counts().iloc[:10].index)
n_elements_dist = {}

In [15]:
from scipy.stats import ttest_ind

Validity
1. Vanilla; Valid records: 2866 / 9648 = 29.71%
2. Naive; Valid records: 9492 / 9804 = 96.82%
3. Site Symmetry; Valid records: 8955 / 9709 = 92.23%

Formal validity:
1. WyFomer: 0.9772; [WanDB](https://wandb.ai/symmetry-advantage/WyckoffTransformer/runs/yj1cme83/logs?nw=nwuserkazeev)
2. SymmCD: 0.9580

In [16]:
table = pd.DataFrame(
    index=datasets.keys(), columns=[
        "Novelty (%)",
        "Represented composition",
        "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "S.U.N. (%)",
        "R.S.U.N. (%)",
        "~R.S.U.N. (%)",
        "Top-10 S.U.N. (%)",
        "Novel Uniques Templates (#)",
        "Novel Template (%)", 
        "P1 (%)",
        "Space Group", "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
precision_etc_sample_size = 1000
best_p1 = (novelty_filter.get_novel(filter_by_unique_structure(
    all_datasets[("WyckoffTransformer", "DiffCSP++", "CHGNet_fix")].data)).group == 1).astype(float)
for name, transformations in tqdm([("train", ("split", "train"))]+list(datasets.items())):
    data = all_datasets[transformations].data.copy()
    if len(data) > precision_etc_sample_size:
        data = data.sample(precision_etc_sample_size, random_state=42)
    unique = filter_by_unique_structure(data)
    print(f"{name} unique: {len(unique)} / {len(data)} = {len(unique) / len(data) :.3%}")
    if transformations == ("split", "train"):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        table.loc[name, "Represented composition"] = novel.structure.apply(check_represented_composition).mean()
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        all_templates = novel.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1)
        novel_template = ~all_templates.isin(train_w_template_set)
        table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
        table.loc[name, "Novel Uniques Templates (#)"] = all_templates[novel_template].nunique() 
        table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(novel)
        n_elements_dist[name] = novel.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(novel)
    p1 = (novel.group == 1).astype(float)
    table.loc[name, "P1 (%)"] = 100 * p1.mean()
    table.loc[name, "P1 p_value"] = ttest_ind(p1, best_p1).pvalue
    # table.loc[name, "# DoF"] = test_evaluator.get_dof_emd(novel)
    table.loc[name, "Space Group"] = test_evaluator.get_sg_chi2(novel)
    #try:
    #    table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    #except KeyError:
    #    pass
    #table.loc[name, "Elements"] = test_evaluator.get_elements_chi2(novel)
    if "corrected_chgnet_ehull" in novel.columns:
        # S.U.N. is measured with respect to the initial structures
        has_ehull = data.corrected_chgnet_ehull.notna()
        data_is_represented = data.structure.apply(check_represented_composition)
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold) # & (novel.elements.apply(lambda x: len(frozenset(x))) >= 2)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull.sum()
        is_represented_novel = novel.structure.apply(check_represented_composition)
        table.loc[name, "R.S.U.N. (%)"] = 100 * (is_sun & is_represented_novel).sum() / (has_ehull & data_is_represented).sum()
        table.loc[name, "~R.S.U.N. (%)"] = 100 * (is_sun & ~is_represented_novel).sum() / (has_ehull & ~data_is_represented).sum()
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull.sum()
        has_ehull_top_10 = data[data.spacegroup_number.isin(top_10_groups)].corrected_chgnet_ehull.notna().sum()
        table.loc[name, "Top-10 S.U.N. (%)"] = 100 * is_sun[novel.spacegroup_number.isin(top_10_groups)].sum() / has_ehull_top_10
    #if transformations == ("split", "train"):
        # Train forms the baseline of the hull
      #  test_dataset = all_datasets[("split", "test")].data
      #  test_with_ehull = test_dataset[test_dataset.corrected_chgnet_ehull.notna()]
      #  test_unique = filter_by_unique_structure(test_with_ehull)
      #  test_novel = novelty_filter.get_novel(test_unique)
      #  table.loc[name, "S.U.N. (%)"] = 100 * (test_novel.corrected_chgnet_ehull <= E_hull_threshold).sum() / len(test_with_ehull)
n_elements_dist["MP-20"] = mp_20.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(mp_20)
table

  0%|          | 0/20 [00:00<?, ?it/s]

train unique: 1000 / 1000 = 100.000%
WyckoffTransformer-raw unique: 999 / 1000 = 99.900%
WyFormer-harmonic-raw unique: 999 / 1000 = 99.900%
WyFormer-letters unique: 995 / 1000 = 99.500%
WyFormer-letters-DiffCSP++ unique: 958 / 959 = 99.896%
SymmCD unique: 998 / 1000 = 99.800%
SymmCD-raw unique: 998 / 1000 = 99.800%
WyFormer unique: 1000 / 1000 = 100.000%
WyFormer-harmonic-DiffCSP++ unique: 983 / 983 = 100.000%
WyForDiffCSP++ unique: 1000 / 1000 = 100.000%
WyLLM-naive-DiffCSP++ unique: 979 / 981 = 99.796%
WyLLM-vanilla-DiffCSP++ unique: 568 / 1000 = 56.800%
WyLLM-site-symmetry-DiffCSP++ unique: 998 / 999 = 99.900%
CrystalFormer unique: 988 / 992 = 99.597%
DiffCSP++ unique: 999 / 1000 = 99.900%
DiffCSP unique: 996 / 1000 = 99.600%
FlowMM unique: 994 / 997 = 99.699%
MiAD unique: 997 / 1000 = 99.700%
MatterGen unique: 998 / 1000 = 99.800%
WyCryst unique: 994 / 994 = 100.000%


Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%),P1 p_value
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
WyckoffTransformer-raw,91.591592,,,,,,,,,,,,,,,2.185792,0.206719,,0.2311815
WyFormer-harmonic-raw,91.891892,,,,,,,,,,,,,,,2.396514,0.170586,,0.135446
WyFormer-letters,89.547739,,,,,,,,,,,,,,,1.571268,0.211415,,0.8145416
WyFormer-letters-DiffCSP++,89.665971,0.208382,99.534342,82.770664,98.137369,96.937873,0.442721,0.036067,0.112597,30.76121,29.166667,31.160365,33.26572,217.0,31.578947,1.169591,0.220777,30.239833,0.6224418
SymmCD,88.176353,0.190909,95.795455,84.772727,99.545455,94.450586,0.619413,0.104534,0.530028,33.5,32.786885,33.659731,37.629938,124.0,20.640569,2.372479,0.2401,32.6,0.1509405
SymmCD-raw,89.378758,0.211883,100.0,84.304933,99.663677,94.141057,0.113228,0.208906,0.447636,,,,,114.0,17.825112,1.793722,0.257048,,0.5481464
WyFormer,91.0,0.198901,99.56044,80.43956,98.681319,96.73889,0.705265,0.055357,0.087259,39.83984,43.979058,38.861386,45.020747,200.0,28.918322,3.200883,0.223609,38.938939,0.01260309
WyFormer-harmonic-DiffCSP++,91.251272,0.175028,99.777035,82.608696,99.219621,95.821358,0.605746,0.095361,0.036371,34.486267,31.325301,35.128519,40.6639,220.0,31.061453,2.346369,0.176766,33.468973,0.1564039
WyForDiffCSP++,90.5,0.19779,99.668508,80.331492,99.226519,96.838382,0.633953,0.05107,0.086851,37.5,38.743455,37.206428,40.120968,206.0,28.935698,1.441242,0.212665,36.8,0.9932012
WyLLM-naive-DiffCSP++,94.892748,0.186222,99.677072,82.77718,98.493003,95.556047,0.47299,0.05991,0.016785,31.396534,28.49162,32.044888,36.734694,259.0,33.585313,1.295896,0.171783,30.682977,0.7956636


In [17]:
import numpy as np
def get_observation(name, column, n_structures=1000):
    all_observations = np.zeros(n_structures)
    all_observations[:int(table.at[name, column])] = 1.
    return all_observations

def compare_binary_columns(name1, name2, column):
    obs1 = get_observation(name1, column)
    obs2 = get_observation(name2, column)
    return ttest_ind(obs1, obs2)

for name in datasets:
    try:
        print(name, compare_binary_columns("WyForDiffCSP++", name, "Novel Uniques Templates (#)").pvalue)
    except ValueError:
        print(f"Failed to compare {name} with WyForDiffCSP++ for Novel Uniques Templates (#)")

Failed to compare WyckoffTransformer-raw with WyForDiffCSP++ for Novel Uniques Templates (#)
Failed to compare WyFormer-harmonic-raw with WyForDiffCSP++ for Novel Uniques Templates (#)
Failed to compare WyFormer-letters with WyForDiffCSP++ for Novel Uniques Templates (#)
WyFormer-letters-DiffCSP++ 0.5471988089627766
SymmCD 7.348490119234499e-07
SymmCD-raw 1.800203042508396e-08
WyFormer 0.7388741235380769
WyFormer-harmonic-DiffCSP++ 0.4447596186015651
WyForDiffCSP++ 1.0
WyLLM-naive-DiffCSP++ 0.005008134562181904
WyLLM-vanilla-DiffCSP++ 5.980109391637373e-11
WyLLM-site-symmetry-DiffCSP++ 0.7415748830204665
CrystalFormer 1.3724399232549557e-11
DiffCSP++ 6.257135287555153e-32
DiffCSP 3.4671589005995573e-13
FlowMM 6.0806918326275166e-21
MiAD 6.0806918326275166e-21
MatterGen 1.2040434850992265e-08
WyCryst 0.31179417717750796


In [18]:
symmetry_table = table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [19]:
symmetry_table.style.format("{:.1f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
WyFormer-letters-DiffCSP++,217.0,1.2,0.2,30.8,30.2
WyForDiffCSP++,206.0,1.4,0.2,37.5,36.8


In [20]:
print(symmetry_table.to_markdown())

| Method                     |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| WyFormer-letters-DiffCSP++ |                           217 |  1.16959 |      0.220777 |      30.7612 |        30.2398 |
| WyForDiffCSP++             |                           206 |  1.44124 |      0.212665 |      37.5    |        36.8    |


In [21]:
symmetry_table_symmcd = table.loc[["SymmCD", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [22]:
symmetry_table_symmcd.style.format("{:.2f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
SymmCD,124.0,2.37,0.24,33.5,32.6
WyForDiffCSP++,206.0,1.44,0.21,37.5,36.8


In [23]:
print(symmetry_table_symmcd.to_markdown())

| Method         |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| SymmCD         |                           124 |  2.37248 |      0.2401   |         33.5 |           32.6 |
| WyForDiffCSP++ |                           206 |  1.44124 |      0.212665 |         37.5 |           36.8 |


In [24]:
table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
          ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]]

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyFormer-letters-DiffCSP++,89.665971,99.534342,82.770664,98.137369,96.937873,0.442721,0.036067,0.112597
WyForDiffCSP++,90.5,99.668508,80.331492,99.226519,96.838382,0.633953,0.05107,0.086851


In [25]:
table.loc[:, ["S.U.N. (%)", "S.S.U.N. (%)"]]

Unnamed: 0_level_0,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1
WyckoffTransformer-raw,,
WyFormer-harmonic-raw,,
WyFormer-letters,,
WyFormer-letters-DiffCSP++,30.76121,30.239833
SymmCD,33.5,32.6
SymmCD-raw,,
WyFormer,39.83984,38.938939
WyFormer-harmonic-DiffCSP++,34.486267,33.468973
WyForDiffCSP++,37.5,36.8
WyLLM-naive-DiffCSP++,31.396534,30.682977


In [26]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set).mean()

0.9713685606898077

In [27]:
max_subset=["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", "S.S.U.N. (%)", "S.U.N. (%)", "Novel Template (%)"]
# -1 to exclude the MP-20 training set
def highlight_max_value(s):
    if s.name not in max_subset:
        return ['' for _ in s]
    is_max = s == s.max()
    #is_max.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_max]

min_subset=[r"$\rho$", "$E$", "# Elements", "# DoF", "Space Group", "Elements", "P1 (%)"]
def highlight_min_value(s):
    if s.name not in min_subset:
        return ['' for _ in s]
    is_min = s == s.min()
    #is_min.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_min]

In [28]:
def prettify(table):
    return table.style.format({
    "Novelty (%)": "{:.2f}",
    "Structural": "{:.2f}",
    "Compositional": "{:.2f}",
    "Recall": "{:.2f}",
    "Precision": "{:.2f}",
    r"$\rho$": "{:.2f}",
    "$E$": "{:.3f}",
    "# Elements": "{:.3f}",
    "# DoF": "{:.3f}",
    "Space Group": "{:.3f}",
    "Elements": "{:.3f}",
    "Novel Template (%)": "{:.2f}",
    "P1 (%)": "{:.2f}",
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
}).apply(highlight_max_value).apply(highlight_min_value)
prettify(table)

Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%),P1 p_value
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
WyckoffTransformer-raw,91.59,,,,,,,,,,,,,,,2.19,0.207,,0.231182
WyFormer-harmonic-raw,91.89,,,,,,,,,,,,,,,2.4,0.171,,0.135446
WyFormer-letters,89.55,,,,,,,,,,,,,,,1.57,0.211,,0.814542
WyFormer-letters-DiffCSP++,89.67,0.208382,99.53,82.77,98.14,96.94,0.44,0.036,0.113,30.8,29.166667,31.160365,33.26572,217.0,31.58,1.17,0.221,30.2,0.622442
SymmCD,88.18,0.190909,95.8,84.77,99.55,94.45,0.62,0.105,0.53,33.5,32.786885,33.659731,37.629938,124.0,20.64,2.37,0.24,32.6,0.150941
SymmCD-raw,89.38,0.211883,100.0,84.3,99.66,94.14,0.11,0.209,0.448,,,,,114.0,17.83,1.79,0.257,,0.548146
WyFormer,91.0,0.198901,99.56,80.44,98.68,96.74,0.71,0.055,0.087,39.8,43.979058,38.861386,45.020747,200.0,28.92,3.2,0.224,38.9,0.012603
WyFormer-harmonic-DiffCSP++,91.25,0.175028,99.78,82.61,99.22,95.82,0.61,0.095,0.036,34.5,31.325301,35.128519,40.6639,220.0,31.06,2.35,0.177,33.5,0.156404
WyForDiffCSP++,90.5,0.19779,99.67,80.33,99.23,96.84,0.63,0.051,0.087,37.5,38.743455,37.206428,40.120968,206.0,28.94,1.44,0.213,36.8,0.993201
WyLLM-naive-DiffCSP++,94.89,0.186222,99.68,82.78,98.49,95.56,0.47,0.06,0.017,31.4,28.49162,32.044888,36.734694,259.0,33.59,1.3,0.172,30.7,0.795664


In [29]:
LLM_columns_1 = ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]
LLM_columns_2 = ["Novel Uniques Templates (#)", "P1 (%)", "Space Group"]

In [30]:
LLM_rows = ["WyForDiffCSP++", "WyLLM-naive-DiffCSP++", "WyLLM-vanilla-DiffCSP++", "WyLLM-site-symmetry-DiffCSP++"]

In [31]:
prettify(table.loc[["WyForDiffCSP++", "DiffCSP++", "DiffCSP", "FlowMM", "MiAD"], LLM_columns_1])

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyForDiffCSP++,90.5,99.67,80.33,99.23,96.84,0.63,0.051,0.087
DiffCSP++,90.59,100.0,85.08,99.34,95.91,0.14,0.042,0.496
DiffCSP,91.37,100.0,81.21,99.56,96.26,0.79,0.055,0.291
FlowMM,91.75,96.27,82.57,99.67,96.65,0.29,0.048,0.093
MiAD,85.56,99.06,84.17,99.88,94.24,0.2,0.056,0.066


In [32]:
len(GeneratedDataset.from_cache(("WyckoffLLM-naive", )).data)

9492

In [33]:
len(GeneratedDataset.from_cache(("WyckoffLLM-site-symmetry", )).data)

7476

In [34]:
len(GeneratedDataset.from_cache(("WyckoffLLM-vanilla", )).data)

2866

In [35]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size_for_precision = 1000
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        if len(valid) > sample_size_for_precision:
            valid_precision_sample = valid.sample(sample_size_for_precision, random_state=42, replace=False)        
            cov_metrics_subsampled = raw_test_evaluator.get_coverage(valid_precision_sample.cdvae_crystal)
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics_subsampled["cov_precision"]
        else:
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
prettify(cdvae_table)

  0%|          | 0/19 [00:00<?, ?it/s]

Ignoring 2 generated samples without composition fingerprints.


Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyckoffTransformer-raw,,,,,,,
WyFormer-harmonic-raw,,,,,,,
WyFormer-letters,,,,,,,
WyFormer-letters-DiffCSP++,99.58,84.15,98.51,96.21,0.17,0.049,0.116
SymmCD,95.49,85.86,99.19,96.05,0.32,0.095,0.392
SymmCD-raw,100.0,86.27,99.5,94.82,0.06,0.16,0.402
WyFormer,99.6,81.4,98.77,95.94,0.39,0.078,0.081
WyFormer-harmonic-DiffCSP++,99.8,83.72,99.15,95.28,0.41,0.079,0.055
WyForDiffCSP++,99.7,81.4,99.26,95.85,0.33,0.07,0.078
WyLLM-naive-DiffCSP++,99.69,83.18,98.77,94.25,0.23,0.078,0.025


In [36]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(raw_datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 1000
for name, transformations in tqdm(list(raw_datasets.items())):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        if len(valid) > sample_size_for_precision:
            valid_precision_sample = valid.sample(sample_size_for_precision, random_state=42, replace=False)        
            cov_metrics_subsampled = raw_test_evaluator.get_coverage(valid_precision_sample.cdvae_crystal)
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics_subsampled["cov_precision"]
        else:
            cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
prettify(cdvae_table)

  0%|          | 0/12 [00:00<?, ?it/s]

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
SymmCD,100.0,86.27,99.5,94.82,0.06,0.16,0.402
WyFormer,99.6,81.4,98.77,99.57,0.39,0.078,0.081
WyFormerDiffCSP++,99.8,81.4,99.51,95.94,0.36,0.083,0.079
WyFormer-harmonic-DiffCSP++,99.8,83.7,99.52,95.81,0.2,0.084,0.049
WyFormer-letters-DiffCSP++,99.5,84.14,98.92,95.7,0.17,0.065,0.104
WyLLM-DiffCSP++,99.5,83.28,98.91,96.2,0.19,0.09,0.029
CrystalFormer,93.39,84.98,99.62,94.09,0.19,0.208,0.128
DiffCSP++,99.94,85.13,99.67,95.71,0.31,0.069,0.399
DiffCSP,100.0,83.22,99.82,96.84,0.35,0.095,0.346
FlowMM,96.87,83.11,99.73,95.59,0.12,0.073,0.094
