In [1]:
from pymatgen.core import Composition
from pymatgen.core.periodic_table import Element
allowed_elements = [Element.from_Z(z) for z in frozenset(range(1, 84)).difference([2, 4, 6, 10, 18, 35, 43, 44, 45, 46, 54, 73, 76, 79, 80, 81, 82] +
    list(range(58, 72)) +
        list(map(lambda x: x.Z, Composition("Os, Hg, Pb, Cd, As, Sb, Te, Se, Pd, Tc, Ru, Sr, Sc, Ir, Cs, Rb, Pt, Ge, Hf, W".replace(", ", "")).elements)))]
allowed_elements = frozenset((e for e in allowed_elements if not e.is_transition_metal))
required_elements = frozenset((Element("Li"),))

In [2]:
import sys
sys.path.append("..")
from evaluation.generated_dataset import GeneratedDataset, DATA_KEYS

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [3]:
from omegaconf import OmegaConf
import pandas as pd
from scripts.cache_generated_datasets import compute_fields_and_cache
dft_datasets = {}
dataset_config = OmegaConf.load("../generated/datasets.yaml")
def load_dataset(transformations, config, dataset):
    if transformations and transformations[0] == "FlowMM":
        # Can't use FlowMM
        return
    if transformations and transformations[-1] == "DFT":
        key = tuple(transformations)
        try:
            dft_datasets[key] = GeneratedDataset.from_cache(key, dataset=dataset).data
        except FileNotFoundError:
            print(f"Dataset {dataset} with transformations {transformations} not found in cache.")
            dataset_raw = GeneratedDataset.from_transformations(
                    transformations=key,
                    dataset=dataset)
            dft_datasets[key] = compute_fields_and_cache(dataset_raw).data
        dft_datasets[key]["origin"]=[key] * len(dft_datasets[key])
    else:
        for next_transformation, next_config in config.items():
            if next_transformation not in DATA_KEYS:
                load_dataset(transformations + [next_transformation], next_config, dataset)
for dataset, config in dataset_config.items():
    load_dataset([], config, dataset)

In [4]:
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure
all_dft_data = pd.concat(dft_datasets.values(), ignore_index=True, copy=False)
print(f"Loaded {len(all_dft_data)} DFT structures from {len(dft_datasets)} datasets.")
acceptable_composition = all_dft_data.structure.map(lambda s:
    required_elements.issubset(s.composition) and frozenset(s.composition.elements).issubset(allowed_elements))
all_dft_data = all_dft_data[acceptable_composition]
print(f"Have acceptable composition: {len(all_dft_data)} entries.")

Loaded 24418 DFT structures from 16 datasets.
Have acceptable composition: 292 entries.


In [5]:
reference_datasets = ("mp_2022", "mp_20", "mpts_52")
from itertools import chain
novelty_reference = pd.concat(chain(*
    ((GeneratedDataset.from_cache(("split", part), dataset=reference_dataset).data for part in ("train", "val", "test"))
    for reference_dataset in reference_datasets)),
    axis=0, ignore_index=True, verify_integrity=False, copy=False)

In [6]:
novelty_filter = NoveltyFilter(novelty_reference)
all_dft_data = novelty_filter.get_novel(all_dft_data)
print(f"Novel wrt {reference_datasets}: {len(all_dft_data)} entries.")

Novel wrt ('mp_2022', 'mp_20', 'mpts_52'): 274 entries.


In [7]:
all_dft_data = filter_by_unique_structure(all_dft_data)
print(f"Unique structures: {len(all_dft_data)} entries.")

Unique structures: 254 entries.


In [8]:
E_threshold = 0.08
all_dft_data = all_dft_data.loc[all_dft_data.e_above_hull_corrected < E_threshold]
print(f"Energy above hull < {E_threshold} eV: {len(all_dft_data)} entries.")

Energy above hull < 0.08 eV: 100 entries.


In [9]:
from monty.json import MontyEncoder
encoder = MontyEncoder()
def to_json(obj):
    if isinstance(obj, str):
        return obj
    if isinstance(obj, frozenset):
        obj = tuple(obj)
    return encoder.encode(obj)
export_filter = all_dft_data.filter(
                ["cdvae_crystal", "fingerprint", "composition", "naive_validity",
                 "spacegroup_number", "density"], axis=1)
all_dft_data.drop(export_filter, axis=1).map(to_json).to_csv(
    "battery_candidates.csv.gz", index=False)
print("Exported battery candidates to battery_candidates.csv.gz")

Exported battery candidates to battery_candidates.csv.gz


In [10]:
all_dft_data.origin.value_counts()

origin
(MiAD, CHGNet_free, DFT)                                61
(WyckoffTransformer, DiffCSP++10k, CHGNet_free, DFT)    20
(DiffCSP, 1k-sample, eq-V2_free, DFT)                    8
(DiffCSP, 1k-sample, CHGNet_free, DFT)                   4
(SymmCD, CHGNet_free, DFT)                               3
(DiffCSP, 1k-sample, DFT)                                2
(WyckoffTransformer-letters, DiffCSP++, DFT)             1
(WyCryst, CrySPR, CHGNet_fix, DFT)                       1
Name: count, dtype: int64