In [1]:
import os
import glob
import re
import pymatgen as mg
from pymatgen.core import Structure
from tqdm import tqdm
#from evaluate_text2material import parse_material_string, smact_validity


In [2]:
import numpy as np
import itertools
import smact
from smact.screening import pauling_test


def smact_validity(
    comp: tuple[int, ...] | tuple[str, ...],
    count: tuple[int, ...],
    use_pauling_test: bool = True,
    include_alloys: bool = True,
    include_cutoff: bool = False,
    use_element_symbol: bool = False,
) -> bool:
    """Computes SMACT validity.

    Args:
        comp: Tuple of atomic number or element names of elements in a crystal.
        count: Tuple of counts of elements in a crystal.
        use_pauling_test: Whether to use electronegativity test. That is, at least in one
            combination of oxidation states, the more positive the oxidation state of a site,
            the lower the electronegativity of the element for all pairs of sites.
        include_alloys: if True, returns True without checking charge balance or electronegativity
            if the crystal is an alloy (consisting only of metals) (default: True).
        include_cutoff: assumes valid crystal if the combination of oxidation states is more
            than 10^6 (default: False).

    Returns:
        True if the crystal is valid, False otherwise.
    """
    assert len(comp) == len(count)
    if use_element_symbol:
        elem_symbols = comp
    else:
        elem_symbols = tuple([get_element_symbol(Z=elem) for elem in comp])  # type:ignore
    space = smact.element_dictionary(elem_symbols)
    smact_elems = [e[1] for e in space.items()]
    electronegs = [e.pauling_eneg for e in smact_elems]
    ox_combos = [e.oxidation_states for e in smact_elems]
    if len(set(elem_symbols)) == 1:
        return True
    if include_alloys:
        is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols]
        if all(is_metal_list):
            return True

    threshold = np.max(count)
    compositions = []
    n_comb = np.prod([len(ls) for ls in ox_combos])
    # If the number of possible combinations is big, it'd take too much time to run the smact checker
    # In this case, we assum that at least one of the combinations is valid
    if n_comb > 1e6 and include_cutoff:
        return True
    for ox_states in itertools.product(*ox_combos):
        stoichs = [(c,) for c in count]
        # Test for charge balance
        cn_e, cn_r = smact.neutral_ratios(ox_states, stoichs=stoichs, threshold=threshold)
        # Electronegativity test
        if cn_e:
            if use_pauling_test:
                try:
                    electroneg_OK = pauling_test(ox_states, electronegs)
                except TypeError:
                    # if no electronegativity data, assume it is okay
                    electroneg_OK = True
            else:
                electroneg_OK = True
            if electroneg_OK:
                for ratio in cn_r:
                    compositions.append(tuple([elem_symbols, ox_states, ratio]))
    compositions = [(i[0], i[2]) for i in compositions]
    compositions = list(set(compositions))
    if len(compositions) > 0:
        return True
    else:
        return False

In [None]:
path = "../../instruct_mat_7b_beam4_06282024"
path = "../../instruct_mat_8b_beam4_07022024"
path = "/msralaphilly2/ml-la/renqian/SFM/threedimargen/outputs/3dargenlan_v0.1_base_mp_nomad_qmdb_ddp_noniggli_layer24_head16_epoch50_warmup8000_lr1e-4_wd0.1_bs256/instructv1_mat_sample/"
# load data
files = glob.glob(os.path.join(path, "*.cif"))
print(len(files))

In [10]:
import pickle as pkl
import re
#path = "/msralaphilly2/ml-la/yinxia/wu2/backup/SFM_for_material.20240430/instruct_mat_7b_beam4_06282024.pkl"
#path = "/msralaphilly2/ml-la/yinxia/wu2/backup/SFM_for_material.20240430/instruct_mat_8b_beam4_07022024.pkl"
#path = "/msralaphilly2/ml-la/yinxia/wu2/backup/SFM_for_material.20240430/instruct_mat_7b_beam4_07082024.pkl"
path = "/msralaphilly2/ml-la/yinxia/wu2/backup/SFM_for_material.20240430/instruct_mat_8b_beam4_07092024.pkl"
#path = "/msralaphilly2/ml-la/yinxia/wu2/backup/SFM_for_material.20240430/instructv1_mat_sample.pkl"
with open(path, "rb") as f:
    data = pkl.load(f)
res = []
for i in range(len(data)):
    if len(data[i][1]) == 0:
        continue
    seq = data[i][1][0]
    # extract sequence in seq before <sg*> tag
    elements = re.findall(r"([A-Z][a-z]*)", seq)
    res.append(elements)
 

In [11]:
from collections import Counter
from tqdm import tqdm
import numpy as np

ret_mat = {}
fail, success, total = 0, 0, 0
total = len(res)
for structure in tqdm(res):
    comp, count = [], []
    temp = Counter(structure)
    for k, v in temp.items():
        comp.append(k)
        count.append(v)
    count = np.array(count)
    count = count / np.gcd.reduce(count)
    count = count.astype(int)
    try:
        if smact_validity(tuple(comp), tuple(count), use_element_symbol=True):
            k = tuple([tuple(comp), tuple(count)])
            if k not in ret_mat:
                ret_mat[k] = 0
            ret_mat[k] += 1
            success += 1
        else:
            fail += 1
        #print(f"Failed: {structure}")
    except Exception as e:
        fail += 1
print(f"Success: {success}, Fail: {fail}, Total: {total}")
print(f"Success rate: {success/total:.4f}")

100%|██████████| 20000/20000 [00:15<00:00, 1331.99it/s]

Success: 16716, Fail: 3284, Total: 20000
Success rate: 0.8358





In [None]:
ret_mat = {}
fail, success, total = 0, 0, 0

for fname in tqdm(files):
    # read cif
    total += 1
    try:
        structure = Structure.from_file(fname)
        #sg = structure.get_space_group_info()[1]
        sg = fname.split("_")[-1].split(".")[0]
        comp, count = [], []
        for k, v in structure.composition.as_dict().items():
            comp.append(k)
            count.append(int(v))
        if smact_validity(tuple(comp), tuple(count), use_element_symbol=True):
            k = tuple([tuple(comp), tuple(count)])
            v = sg
            if k not in ret_mat:
                ret_mat[k] = {}
            if v not in ret_mat[k]:
                ret_mat[k][v] = 0
            ret_mat[k][v] += 1
            success += 1
        else:
            fail += 1
    except Exception as e:
        print(f"{fname}\n{e}")
        fail += 1
print(f"success: {success}, fail: {fail}, total: {total}")
print(f"success rate: {success/total:.2f}")

In [24]:
# unique materials
unique_success = 0
for k, v in ret_mat.items():
    unique_success += v #len(v)

print(f"unique success: {unique_success}, total: {total}")
print(f"unique success rate: {unique_success/total:.2f}")

unique success: 9109, total: 20000
unique success rate: 0.46


In [25]:
print(f"unique number: {len(ret_mat.keys())}")
print(f"unique rate: {len(ret_mat.keys())/total:.2f}")

unique number: 7116
unique rate: 0.36
