# deCIFer: Validity Assessment of datasets

In [2]:
# Standard library imports
import os
import sys
import argparse
import multiprocessing as mp
from queue import Empty
from queue import Empty
from glob import glob
import pickle
import gzip
from typing import Any, Dict, Optional, Tuple
from warnings import warn

# Third-party library imports
import torch
import numpy as np
from tqdm.notebook import tqdm
from pymatgen.io.cif import CifParser
from pymatgen.analysis.structure_matcher import StructureMatcher

# Conditional imports for backwards compatibility with older pymatgen versions
try:
    parser_from_string = CifParser.from_str
except AttributeError:
    parser_from_string = CifParser.from_string

from decifer.decifer_model import Decifer, DeciferConfig
from decifer.decifer_dataset import DeciferDataset
from decifer.tokenizer import Tokenizer
from bin.evaluate import load_model_from_checkpoint, get_cif_statistics, safe_extract, safe_extract_boolean
from decifer.utility import (
    get_rmsd,
    replace_symmetry_loop_with_P1,
    extract_space_group_symbol,
    reinstate_symmetry_loop,
    is_sensible,
    extract_numeric_property,
    get_unit_cell_volume,
    extract_volume,
    is_space_group_consistent,
    is_atom_site_multiplicity_consistent,
    is_formula_consistent,
    bond_length_reasonableness_score,
    extract_species,
    discrete_to_continuous_xrd,
    generate_continuous_xrd_from_cif,
)
from bin.train import TrainConfig

# Tokenizer, get start, padding and newline IDs
TOKENIZER = Tokenizer()
VOCAB_SIZE = TOKENIZER.vocab_size
START_ID = TOKENIZER.token_to_id["data_"]
PADDING_ID = TOKENIZER.padding_id
NEWLINE_ID = TOKENIZER.token_to_id["\n"]
SPACEGROUP_ID = TOKENIZER.token_to_id["_symmetry_space_group_name_H-M"]
DECODE = TOKENIZER.decode


In [1]:
from tqdm.auto import tqdm
from multiprocessing import Pool

# Define dataset path and initialize dataset
dataset_path = '../data/chili100k/full/serialized/train.h'
dataset = DeciferDataset(dataset_path, ["cif_name", "cif_tokens", "xrd.q", "xrd.iq", "cif_string", "spacegroup"])

num_generations = 5000
pbar = tqdm(total=num_generations)

# Helper function to process a single data point
def process_data(data):
    cif_string = data['cif_string']
    try:
        sg = is_space_group_consistent(cif_string)
        form = is_formula_consistent(cif_string)
        sm = is_atom_site_multiplicity_consistent(cif_string)
        bl = bond_length_reasonableness_score(cif_string) >= 1.0
        valid = form and sm and bl and sg
    except ZeroDivisionError:
        form, sm, bl, sg, valid = False, False, False, False, False
    except Exception as e:
        print(f"Error processing data: {e}")
        form, sm, bl, sg, valid = False, False, False, False, False

    return form, sm, bl, sg, valid

# Wrap dataset generation for multiprocessing compatibility
def dataset_generator():
    for _, data in zip(range(num_generations), dataset):
        yield data

num_cores = 7
with Pool(processes=num_cores) as pool:
    results = []
    for result in pool.imap_unordered(process_data, dataset_generator(), chunksize=10):
        results.append(result)
        pbar.update(1)
pbar.close()

# Unpack results
forms, sms, bls, sgs, valids = zip(*results)

# Calculate percentages
form_percent = (sum(forms) / num_generations) * 100
sm_percent = (sum(sms) / num_generations) * 100
bl_percent = (sum(bls) / num_generations) * 100
sg_percent = (sum(sgs) / num_generations) * 100
valid_percent = (sum(valids) / num_generations) * 100

# Print summarized results
print("Summary of Validity Checks:")
print(f"Formulas: {form_percent:.2f}% valid")
print(f"Site Multiplicities: {sm_percent:.2f}% valid")
print(f"Bond Lengths: {bl_percent:.2f}% valid")
print(f"Spacegroups: {sg_percent:.2f}% valid")
print(f"Overall Valid: {valid_percent:.2f}%")
