In [1]:
import h5py
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pymatgen.core import Composition
from os import cpu_count
from matplotlib.patches import Patch
from multiprocessing import Pool
from decifer.utility import (
    extract_formula_nonreduced,
    space_group_symbol_to_number,
    element_to_atomic_number,
    replace_symmetry_loop_with_P1,
    generate_continuous_xrd_from_cif,
)

# Function to classify dataset
def classify_dataset(cif_name):
    if cif_name.startswith('NOMAD'):
        return 'NOMAD'
    elif cif_name.startswith('OQMD'):
        return 'OQMD'
    elif cif_name.startswith('MP'):
        return 'MP'
    else:
        return 'CHILI-100K'

# Function to process a single index (updated with reduced_formula)
def process_single_index(args):
    i, file_path = args
    with h5py.File(file_path, 'r') as f:
        cif_name = f['cif_name'][i].decode("utf-8")
        cif_string = f['cif_string'][i].decode("utf-8")
        comp_obj = Composition(extract_formula_nonreduced(cif_string))
        comp = comp_obj.reduced_composition.as_dict()
        red_formula = comp_obj.reduced_formula  # Store the canonical reduced formula
        spacegroup_num = space_group_symbol_to_number(f['spacegroup'][i].decode("utf-8"))
        atomic_numbers = [element_to_atomic_number(e) * n for e, n in comp.items()]
        species = list(comp.keys())
        try:
            cif_tokens_len = len(f['cif_tokens'][i])
        except:
            cif_tokens_len = len(f['cif_tokenized'][i])
        return {
            'cif_name': cif_name,
            'cif_string': cif_string,
            'reduced_formula': red_formula,  # New key for composition lookup
            'spacegroup_num': spacegroup_num,
            'species': species,
            'atomic_numbers': atomic_numbers,
            'num_elements': len(species),
            'dataset': classify_dataset(cif_name),
            'cif_token_len': cif_tokens_len,
        }

def search_by_composition(df, query_formula):
    """
    Search the dataframe for entries matching the given composition.

    Parameters:
        df (pd.DataFrame): DataFrame containing the processed dataset.
        query_formula (str): A composition formula to search for (e.g., "Fe2O3").

    Returns:
        List[str]: A list of cif_string entries that match the query composition.
    """
    # Convert the query composition to its reduced (canonical) formula
    query_reduced = Composition(query_formula).reduced_formula
    # Filter the dataframe based on the 'reduced_formula' column
    matched = df[df['reduced_formula'] == query_reduced]
    return matched['cif_string'].tolist()


# Extract data using multiprocessing
def extract_data_from_file(file_path, debug_max=None):
    with h5py.File(file_path, 'r') as f:
        num_files = min(debug_max, len(f['cif_name'])) if debug_max else len(f['cif_name'])

    args = [(i, file_path) for i in range(num_files)]
    with Pool(cpu_count()) as pool:
        results = list(tqdm(pool.imap(process_single_index, args), total=num_files, desc=f'Processing {file_path}'))
    return results

# Process all files using multiprocessing
def process_all_files(data_paths, debug_max=None):
    all_data = []
    for path in data_paths:
        all_data.extend(extract_data_from_file(path, debug_max))
    return pd.DataFrame(all_data)


In [2]:
# Configuration
DATA_PATHS = [
    '../../deCIFer/data/crystallm/full/serialized/train.h5',
    # '../../deCIFer/data/crystallm/full/serialized/val.h5',
    # '../../deCIFer/data/crystallm/full/serialized/test.h5',
]
DEBUG_MAX = None

# Process all files and build dataframe (existing code)
data_df = process_all_files(DATA_PATHS, DEBUG_MAX)

Processing ../../deCIFer/data/crystallm/full/serialized/train.h5:   0%|          | 0/2169177 [00:00<?, ?it/s]

In [37]:
# Configuration
DATA_PATHS = [
    '../../deCIFer/data/crystallm/full/serialized/train.h5',
    # '../../deCIFer/data/crystallm/full/serialized/val.h5',
    # '../../deCIFer/data/crystallm/full/serialized/test.h5',
]
DEBUG_MAX = None

# Process all files and build dataframe (existing code)
data_df_train = process_all_files(DATA_PATHS, DEBUG_MAX)

Processing ../../deCIFer/data/crystallm/full/serialized/train.h5:   0%|          | 0/2169177 [00:00<?, ?it/s]

In [10]:
from decifer.utility import pxrd_from_cif, extract_space_group_symbol

In [11]:
for string in data_df[["FeO" == s for s in data_df.reduced_formula]]["cif_string"]:
    spg = extract_space_group_symbol(string)
    print(spg)

P-62c
Fm-3m
Immm
Fm-3m
Cmmm
Cmcm
P3_121
Pm-3n
R3m
P6_3/mmc
Cc
P1
C2/m
P4_2/mmc
Fm-3m
Amm2
I4/mmm
Fm-3m
P-6m2
P4/nmm
P4_2/m
P6_3/mmc
P6_3mc
Pm-3m


In [51]:
data_df_train[["Si" == s for s in data_df_train.reduced_formula]]

Unnamed: 0,cif_name,cif_string,reduced_formula,spacegroup_num,species,atomic_numbers,num_elements,dataset,cif_token_len
12308,NOMAD_VPKVqQk4HAbtWk1DLWKNWTLyOGGQ,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,198,[Si],[14.0],1,NOMAD,270
35550,NOMAD_nYJL6WulmbzrFC0_rKovws5jFl25,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,141,[Si],[14.0],1,NOMAD,195
54231,OQMD_22407,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,139,[Si],[14.0],1,OQMD,196
121250,OQMD_3484,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,206,[Si],[14.0],1,OQMD,200
131767,MP_mp-1204627,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,63,[Si],[14.0],1,MP,485
190466,OQMD_1214837,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,217,[Si],[14.0],1,OQMD,310
339479,MP_mp-1204046,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,137,[Si],[14.0],1,MP,525
465645,OQMD_1215104,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,191,[Si],[14.0],1,OQMD,196
507168,OQMD_1215461,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,194,[Si],[14.0],1,OQMD,231
579438,OQMD_10215,# generated using pymatgen\ndata_Si\nloop_\n _...,Si,227,[Si],[14.0],1,OQMD,278


In [35]:
# Example: Search for entries with composition "Fe2O3"
query_comp = "CeO2"
matches = search_by_composition(data_df, query_comp)

# Print or process matching cif_strings
print(f"Entries matching {query_comp}:")
for cif_str in matches:
    print(cif_str)

Entries matching CeO2:
# generated using pymatgen
data_CeO2
loop_
 _atom_type_symbol
 _atom_type_electronegativity
 _atom_type_radius
 _atom_type_ionic_radius
  Ce  1.1200  1.8500  1.0800
  O  3.4400  0.6000  1.2600
_symmetry_space_group_name_H-M   P-3m1
_cell_length_a   3.7990
_cell_length_b   3.7990
_cell_length_c   4.0065
_cell_angle_alpha   90.0000
_cell_angle_beta   90.0000
_cell_angle_gamma   120.0000
_symmetry_Int_Tables_number   164
_chemical_formula_structural   CeO2
_chemical_formula_sum   'Ce1 O2'
_cell_volume   50.0766
_cell_formula_units_Z   1
loop_
 _symmetry_equiv_pos_site_id
 _symmetry_equiv_pos_as_xyz
  1  'x, y, z'
  2  '-x, -y, -z'
  3  '-y, x-y, z'
  4  'y, -x+y, -z'
  5  '-x+y, -x, z'
  6  'x-y, x, -z'
  7  'y, x, -z'
  8  '-y, -x, z'
  9  'x-y, -y, -z'
  10  '-x+y, y, z'
  11  '-x, -x+y, -z'
  12  'x, x-y, z'
loop_
 _atom_site_type_symbol
 _atom_site_label
 _atom_site_symmetry_multiplicity
 _atom_site_fract_x
 _atom_site_fract_y
 _atom_site_fract_z
 _atom_site_occ