In [30]:
import gc
import time
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import pandas as pd
import smact
import torch
from datasets import load_dataset
from pymatgen.core import Structure
from sklearn.metrics import r2_score
from smact.screening import smact_filter, smact_validity
from tqdm import tqdm

from lemat_genbench.benchmarks.validity_benchmark import ValidityBenchmark
from lemat_genbench.preprocess.validity_preprocess import ValidityPreprocessor
from lemat_genbench.utils.logging import logger

In [18]:
def lematbulk_item_to_structure(item: dict):
    sites = item["species_at_sites"]
    coords = item["cartesian_site_positions"]
    cell = item["lattice_vectors"]

    structure = Structure(
        species=sites, coords=coords, lattice=cell, coords_are_cartesian=True
    )

    return structure

In [23]:
dataset_name = "Lematerial/LeMat-Bulk"
name = "compatible_pbe"
split = "train"
dataset = load_dataset(dataset_name, name=name, split=split, streaming=False)

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [24]:
np.random.seed(2)
indicies = np.random.randint(0, len(dataset), 1000)

In [25]:
len(pd.DataFrame(indicies, columns = ["index_col"]).index_col.unique())

1000

In [27]:
# row = []
structures = []
for i in tqdm(range(len(indicies))):
    # print(index)
    index = int(indicies[i])
    strut = lematbulk_item_to_structure(dataset[index])
    name = dataset[index]["immutable_id"]
    # strut.to(filename="lematbulk_cifs/"+name+".cif")
    structures.append(strut)

100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 544.25it/s]


In [11]:
config = {
    "validity": True,  # ALWAYS run validity preprocessing
    "distribution": False,
    "stability": False,
    "embeddings": False,
    "fingerprint": False,
}

In [None]:
def run_validity_preprocessing_and_filtering(
    structures, config: Dict[str, Any], monitor_memory: bool = False
):
    """Run validity benchmark and preprocessing, then filter to valid structures only.

    Returns
    -------
    tuple
        (validity_benchmark_result, valid_structures, validity_filtering_metadata)
    """
    # Log initial memory usage
    log_memory_usage("before validity processing", force_log=monitor_memory)

    n_total_structures = len(structures)
    logger.info(
        f"🔍 Starting MANDATORY validity processing for {n_total_structures} structures..."
    )

    # Step 1: Run validity benchmark on ALL structures
    logger.info("🔍 Running MANDATORY validity benchmark on ALL structures...")
    start_time = time.time()

    validity_settings = config.get("validity_settings", {})
    validity_benchmark = ValidityBenchmark(
        charge_tolerance=validity_settings.get("charge_tolerance", 0.1),
        distance_scaling=validity_settings.get("distance_scaling", 0.5),
        min_density=validity_settings.get("min_density", 0.01),
        max_density=validity_settings.get("max_density", 25.0),
        check_format=validity_settings.get("check_format", True),
        check_symmetry=validity_settings.get("check_symmetry", True),
    )

    validity_benchmark_result = validity_benchmark.evaluate(structures)

    elapsed_time = time.time() - start_time
    logger.info(
        f"✅ MANDATORY validity benchmark complete for {n_total_structures} structures in {elapsed_time:.1f}s"
    )

    # Clean up after validity benchmark
    cleanup_after_benchmark("validity", monitor_memory)

    # Step 2: Run validity preprocessor on ALL structures
    logger.info("🔍 Running MANDATORY validity preprocessor on ALL structures...")
    start_time = time.time()

    charge_tolerance = validity_settings.get("charge_tolerance", 0.1)
    distance_scaling = validity_settings.get("distance_scaling", 0.5)
    min_density = validity_settings.get("min_density", 0.01)
    max_density = validity_settings.get("max_density", 25.0)
    check_format = validity_settings.get("check_format", True)
    check_symmetry = validity_settings.get("check_symmetry", True)

    validity_preprocessor = ValidityPreprocessor(
        charge_tolerance=charge_tolerance,
        distance_scaling_factor=distance_scaling,
        plausibility_min_density=min_density,
        plausibility_max_density=max_density,
        plausibility_check_format=check_format,
        plausibility_check_symmetry=check_symmetry,
    )

    # Create source IDs for tracking
    structure_sources = [f"structure_{i}" for i in range(len(structures))]
    validity_preprocessor_result = validity_preprocessor.run(
        structures, structure_sources=structure_sources
    )
    processed_structures = validity_preprocessor_result.processed_structures

    elapsed_time = time.time() - start_time
    logger.info(
        f"✅ MANDATORY validity preprocessing complete for {len(processed_structures)} structures in {elapsed_time:.1f}s"
    )

    # Clean up after validity preprocessor
    cleanup_after_preprocessor("validity", monitor_memory)

    # Step 3: Filter to only valid structures
    logger.info("🔍 Filtering to valid structures only...")

    valid_structures = []
    valid_structure_ids = []
    valid_structure_sources = []

    for structure in processed_structures:
        is_valid = structure.properties.get("overall_valid", False)
        if is_valid:
            valid_structures.append(structure)
            valid_structure_ids.append(
                structure.properties.get("structure_id", "unknown")
            )
            valid_structure_sources.append(
                structure.properties.get("original_source", "unknown")
            )

    n_valid_structures = len(valid_structures)
    n_invalid_structures = n_total_structures - n_valid_structures

    # Log filtering results
    logger.info(
        f"✅ Filtering complete: {n_valid_structures} valid structures out of {n_total_structures} total"
    )
    logger.info(f"📊 Valid: {n_valid_structures}, Invalid: {n_invalid_structures}")

    if n_valid_structures == 0:
        logger.warning(
            "⚠️  No valid structures found! All subsequent benchmarks will be skipped."
        )

    # Create filtering metadata
    validity_filtering_metadata = {
        "total_input_structures": n_total_structures,
        "valid_structures": n_valid_structures,
        "invalid_structures": n_invalid_structures,
        "validity_rate": n_valid_structures / n_total_structures
        if n_total_structures > 0
        else 0.0,
        "valid_structure_ids": valid_structure_ids,
        "valid_structure_sources": valid_structure_sources,
    }

    # Log final memory usage
    log_memory_usage("after validity filtering", force_log=monitor_memory)

    return validity_benchmark_result, valid_structures, validity_filtering_metadata


def run_remaining_preprocessors(
    valid_structures,
    preprocessor_config: Dict[str, Any],
    config: Dict[str, Any],
    monitor_memory: bool = False,
):
    """Run remaining preprocessors on valid structures only.

    Note: validity preprocessing is already complete.
    """
    processed_structures = valid_structures
    preprocessor_results = {}

    if len(valid_structures) == 0:
        logger.warning(
            "⚠️  No valid structures to preprocess. Skipping remaining preprocessors."
        )
        return processed_structures, preprocessor_results

    # Log initial memory usage
    log_memory_usage("before remaining preprocessing", force_log=monitor_memory)

    # Fingerprint preprocessor (for BAWL/short-BAWL methods only)
    if preprocessor_config["fingerprint"]:
        logger.info(
            f"Running fingerprint preprocessor on {len(processed_structures)} valid structures..."
        )
        start_time = time.time()

        fingerprint_method = config.get("fingerprint_method", "short-bawl")
        fingerprint_preprocessor = FingerprintPreprocessor(
            fingerprint_method=fingerprint_method
        )
        fingerprint_result = fingerprint_preprocessor(processed_structures)
        processed_structures = fingerprint_result.processed_structures
        preprocessor_results["fingerprint"] = fingerprint_result
        elapsed_time = time.time() - start_time
        logger.info(
            f"✅ Fingerprint preprocessing complete for {len(processed_structures)} valid structures in {elapsed_time:.1f}s"
        )

        # Clean up after fingerprint preprocessor
        cleanup_after_preprocessor("fingerprint", monitor_memory)

    # Distribution preprocessor (for MMD, JSDistance)
    if preprocessor_config["distribution"]:
        logger.info(
            f"Running distribution preprocessor on {len(processed_structures)} valid structures..."
        )
        start_time = time.time()
        dist_preprocessor = DistributionPreprocessor()
        dist_result = dist_preprocessor(processed_structures)
        processed_structures = dist_result.processed_structures
        preprocessor_results["distribution"] = dist_result
        elapsed_time = time.time() - start_time
        logger.info(
            f"✅ Distribution preprocessing complete for {len(processed_structures)} valid structures in {elapsed_time:.1f}s"
        )

        # Clean up after distribution preprocessor
        cleanup_after_preprocessor("distribution", monitor_memory)

    # Multi-MLIP preprocessor (for stability, embeddings)
    if preprocessor_config["stability"] or preprocessor_config["embeddings"]:
        logger.info(
            f"Running Multi-MLIP preprocessor on {len(processed_structures)} valid structures..."
        )
        start_time = time.time()

        # Configure MLIP models
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        mlip_configs = {
            "orb": {"model_type": "orb_v3_conservative_inf_omat", "device": device},
            "mace": {"model_type": "mp", "device": device},
            "uma": {"task": "omat", "device": device},
        }

        # Determine what to extract based on requirements
        extract_embeddings = preprocessor_config["embeddings"]
        relax_structures = preprocessor_config["stability"]

        # Show progress for MLIP model loading
        logger.info("🔥 Initializing MLIP models (this may take 1-2 minutes)...")

        mlip_preprocessor = MultiMLIPStabilityPreprocessor(
            mlip_names=["orb", "mace", "uma"],
            mlip_configs=mlip_configs,
            relax_structures=relax_structures,
            relaxation_config={"fmax": 0.02, "steps": 50},
            calculate_formation_energy=relax_structures,
            calculate_energy_above_hull=relax_structures,
            extract_embeddings=extract_embeddings,
            timeout=300,
        )

        # Add progress bar for structure processing
        logger.info(
            f"🔥 Processing {len(processed_structures)} valid structures with MLIP models..."
        )
        mlip_result = mlip_preprocessor(processed_structures)
        processed_structures = mlip_result.processed_structures
        preprocessor_results["multi_mlip"] = mlip_result
        elapsed_time = time.time() - start_time
        logger.info(
            f"✅ Multi-MLIP preprocessing complete for {len(processed_structures)} valid structures in {elapsed_time:.1f}s"
        )

        # Clean up after MLIP preprocessor (this is crucial for memory management)
        cleanup_after_preprocessor("multi_mlip", monitor_memory)

    # Log final memory usage
    log_memory_usage("after remaining preprocessing")

    return processed_structures, preprocessor_results

In [12]:
result = run_validity_preprocessing_and_filtering(structures, config=config)

(({'Cs': 1.0, 'Tl': 1.0, 'Cu': 4.0, 'Cl': -1.0},), (), (np.float64(-0.4095133960720269),))
((), (), ())
Index Error
((), (), ())
Index Error
((), (), ())
Index Error
(({'Y': 1.0, 'Mg': 1.0, 'Sn': -4.0, 'O': 1.0}, {'Y': 1.0, 'Mg': 2.0, 'Sn': -4.0, 'O': -1.0}, {'Y': 2.0, 'Mg': 1.5, 'Sn': -4.0, 'O': -1.0}, {'Y': 2.0, 'Mg': 2.0, 'Sn': -4.0, 'O': -2.0}, {'Y': 3.0, 'Mg': 1.0, 'Sn': -4.0, 'O': -1.0}, {'Y': 3.0, 'Mg': 1.5, 'Sn': -4.0, 'O': -2.0}), (), (np.float64(-0.6217569384378161), np.float64(-0.6166988299940601), np.float64(-0.470800007959056), np.float64(-0.4608628897146129), np.float64(-0.4266840804849635), np.float64(0.014618848472203377)))
((), (), ())
Index Error
((), (), ())
Index Error
(({'Be': 2.0, 'Fe': -2.0, 'Cl': 2.0}, {'Be': 2.0, 'Fe': -1.5, 'Cl': 1.0}), (), (np.float64(0.08942612630364373), np.float64(0.3621567262755949)))
failed penalty
((), (), ())
Index Error
(({'Ag': 1.0, 'O': 2.0, 'F': -1.0}, {'Ag': 2.0, 'O': 1.0, 'F': -1.0}), (), (np.float64(-0.8974668327582289), np.floa



(({'Ru': -2.0, 'Ir': 2.0, 'H': -1.0}, {'Ru': 4.0, 'Ir': -3.0, 'H': 1.0}, {'Ru': 4.0, 'Ir': -1.0, 'H': -1.0}, {'Ru': 8.0, 'Ir': -3.0, 'H': -1.0}), (), (np.float64(nan), np.float64(nan), np.float64(nan), np.float64(nan)))
failed penalty
(({'B': 1.0, 'Co': 1.0, 'Mg': 2.0, 'Pb': -4.0}, {'B': 1.0, 'Co': 2.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 2.0, 'Co': 1.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 3.0, 'Co': -1.0, 'Mg': 2.0, 'Pb': -4.0}), (), (np.float64(-0.7973114762142738), np.float64(-0.6338339433214192), np.float64(-0.6041837486820981), np.float64(-0.5879455130357054)))
(({'Pr': 2.0, 'As': -3.0, 'Al': 1.0},), (), (np.float64(-0.9599349481604783),))
((), (), ())
Index Error
((), (), ())
Index Error
(({'Ac': 3.0, 'Pb': -4.0, 'N': 1.0},), (), (np.float64(-0.4210209997417763),))
(({'Ba': 2.0, 'Se': 2.0, 'W': -2.0},), (), (np.float64(-0.4067039076395485),))
(({'Mn': -3.0, 'Tl': 1.0, 'Br': -1.0},), (), (np.float64(0.04404577866459287),))
failed penalty
((), (), ())
Index Error
((), (), ())
Index Error
((), 



(({'Ru': -2.0, 'Ir': 2.0, 'H': -1.0}, {'Ru': 4.0, 'Ir': -3.0, 'H': 1.0}, {'Ru': 4.0, 'Ir': -1.0, 'H': -1.0}, {'Ru': 8.0, 'Ir': -3.0, 'H': -1.0}), (), (np.float64(nan), np.float64(nan), np.float64(nan), np.float64(nan)))
failed penalty
(({'B': 1.0, 'Co': 1.0, 'Mg': 2.0, 'Pb': -4.0}, {'B': 1.0, 'Co': 2.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 2.0, 'Co': 1.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 3.0, 'Co': -1.0, 'Mg': 2.0, 'Pb': -4.0}), (), (np.float64(-0.7973114762142738), np.float64(-0.6338339433214192), np.float64(-0.6041837486820981), np.float64(-0.5879455130357054)))
(({'Pr': 2.0, 'As': -3.0, 'Al': 1.0},), (), (np.float64(-0.9599349481604783),))
((), (), ())
Index Error
((), (), ())
Index Error
(({'Ac': 3.0, 'Pb': -4.0, 'N': 1.0},), (), (np.float64(-0.4210209997417763),))
(({'Ba': 2.0, 'Se': 2.0, 'W': -2.0},), (), (np.float64(-0.4067039076395485),))
(({'Mn': -3.0, 'Tl': 1.0, 'Br': -1.0},), (), (np.float64(0.04404577866459287),))
failed penalty
((), (), ())
Index Error
((), (), ())
Index Error
((), 

Processing ValidityPreprocessor:   3%|█▍                                            | 30/1000 [00:00<00:07, 128.44it/s]

(({'Cs': 1.0, 'Tl': 1.0, 'Cu': 4.0, 'Cl': -1.0},), (), (np.float64(-0.4095133960720269),))
((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:   6%|██▉                                           | 65/1000 [00:00<00:07, 123.31it/s]

((), (), ())
Index Error


Processing ValidityPreprocessor:  10%|████▌                                        | 100/1000 [00:00<00:06, 144.80it/s]

(({'Y': 1.0, 'Mg': 1.0, 'Sn': -4.0, 'O': 1.0}, {'Y': 1.0, 'Mg': 2.0, 'Sn': -4.0, 'O': -1.0}, {'Y': 2.0, 'Mg': 1.5, 'Sn': -4.0, 'O': -1.0}, {'Y': 2.0, 'Mg': 2.0, 'Sn': -4.0, 'O': -2.0}, {'Y': 3.0, 'Mg': 1.0, 'Sn': -4.0, 'O': -1.0}, {'Y': 3.0, 'Mg': 1.5, 'Sn': -4.0, 'O': -2.0}), (), (np.float64(-0.6217569384378161), np.float64(-0.6166988299940601), np.float64(-0.470800007959056), np.float64(-0.4608628897146129), np.float64(-0.4266840804849635), np.float64(0.014618848472203377)))
((), (), ())
Index Error
((), (), ())
Index Error
(({'Be': 2.0, 'Fe': -2.0, 'Cl': 2.0}, {'Be': 2.0, 'Fe': -1.5, 'Cl': 1.0}), (), (np.float64(0.08942612630364373), np.float64(0.3621567262755949)))
failed penalty
((), (), ())
Index Error


Processing ValidityPreprocessor:  13%|█████▊                                       | 130/1000 [00:00<00:06, 132.01it/s]

(({'Ag': 1.0, 'O': 2.0, 'F': -1.0}, {'Ag': 2.0, 'O': 1.0, 'F': -1.0}), (), (np.float64(-0.8974668327582289), np.float64(-0.4323038176100851)))
((), (), ())
Index Error
(({'I': -1.0, 'K': -1.0, 'Li': 1.0, 'N': 2.0}, {'I': 1.0, 'K': -1.0, 'Li': 1.0, 'N': -2.0}), (), (np.float64(-0.2703726053406836), np.float64(0.3328413661414192)))
(({'Ac': 3.0, 'Pb': -4.0, 'Se': -2.0, 'N': -1.0}, {'Ac': 3.0, 'Pb': -4.0, 'Se': 2.0, 'N': -3.0}), (), (np.float64(-0.7052659078341142), np.float64(-0.6360638521026352)))
(({'Si': -3.5, 'Ba': 2.0, 'F': -1.0},), (), (np.float64(-0.36972653234580716),))


Processing ValidityPreprocessor:  17%|███████▌                                     | 169/1000 [00:01<00:05, 158.34it/s]

((), (), ())
Index Error
((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  20%|█████████                                    | 201/1000 [00:01<00:06, 130.37it/s]

((), (), ())
Index Error
((), (), ())
Index Error
(({'Mg': 1.0, 'P': -3.0, 'Zn': 2.0}, {'Mg': 1.0, 'P': -2.0, 'Zn': 1.0}, {'Mg': 2.0, 'P': -3.0, 'Zn': 1.0}), (), (np.float64(-0.979465637726506), np.float64(-0.9237152324921261), np.float64(-0.8346750651930225)))


Processing ValidityPreprocessor:  23%|██████████▎                                  | 229/1000 [00:01<00:06, 116.39it/s]

(({'S': -2.0, 'N': -1.0, 'Sn': 2.0, 'Ti': -1.0}, {'S': -1.0, 'N': -2.0, 'Sn': 2.0, 'Ti': -1.0}, {'S': -1.0, 'N': 5.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 1.0, 'N': 3.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 1.0, 'N': 4.0, 'Sn': -4.0, 'Ti': 3.0}, {'S': 1.0, 'N': 5.0, 'Sn': -4.0, 'Ti': 2.0}, {'S': 2.0, 'N': 2.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 2.0, 'N': 3.0, 'Sn': -4.0, 'Ti': 3.0}, {'S': 2.0, 'N': 4.0, 'Sn': -4.0, 'Ti': 2.0}, {'S': 3.0, 'N': 1.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 3.0, 'N': 2.0, 'Sn': -4.0, 'Ti': 3.0}, {'S': 3.0, 'N': 3.0, 'Sn': -4.0, 'Ti': 2.0}, {'S': 4.0, 'N': 1.0, 'Sn': -4.0, 'Ti': 3.0}, {'S': 4.0, 'N': 2.0, 'Sn': -4.0, 'Ti': 2.0}, {'S': 4.0, 'N': 5.0, 'Sn': -4.0, 'Ti': -1.0}, {'S': 5.0, 'N': -1.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 5.0, 'N': 1.0, 'Sn': -4.0, 'Ti': 2.0}, {'S': 5.0, 'N': 4.0, 'Sn': -4.0, 'Ti': -1.0}, {'S': 6.0, 'N': -2.0, 'Sn': -4.0, 'Ti': 4.0}, {'S': 6.0, 'N': -1.0, 'Sn': -4.0, 'Ti': 3.0}, {'S': 6.0, 'N': 3.0, 'Sn': -4.0, 'Ti': -1.0}), (), (np.float64(-0.49947689022312486), np

Processing ValidityPreprocessor:  27%|████████████▎                                | 274/1000 [00:02<00:05, 131.57it/s]

(({'Ac': 3.0, 'Ir': -1.0, 'As': -3.0, 'Li': 1.0},), (), (np.float64(-0.8647894202689625),))
((), (), ())
Index Error


Processing ValidityPreprocessor:  32%|██████████████▌                              | 323/1000 [00:02<00:04, 149.21it/s]

(({'I': -1.0, 'Pd': 2.0, 'V': 1.0}, {'I': -1.0, 'Pd': 4.0, 'V': -1.0}), (), (np.float64(-0.6068914775777702), np.float64(0.06154189249905394)))
(({'K': -1.0, 'Cl': 4.5, 'F': -1.0}, {'K': 1.0, 'Cl': 3.5, 'F': -1.0}), (), (np.float64(-0.18753302899309088), np.float64(0.26758572884025533)))
((), (), ())
Index Error
((), (), ())
Index Error
(({'Cr': -2.0, 'Al': 1.0, 'P': 3.0}, {'Cr': -2.0, 'Al': 3.0, 'P': 1.0}, {'Cr': -1.5, 'Al': 1.0, 'P': 2.0}, {'Cr': -1.0, 'Al': 1.0, 'P': 1.0}, {'Cr': -1.0, 'Al': 3.0, 'P': -1.0}, {'Cr': 1.0, 'Al': 1.0, 'P': -3.0}), (), (np.float64(-0.9969702408703213), np.float64(-0.5658479745480199), np.float64(0.037089836668601675), np.float64(0.4311222663223012), np.float64(0.6644600709906311), np.float64(0.7541600122615626)))


Processing ValidityPreprocessor:  36%|████████████████                             | 356/1000 [00:02<00:04, 147.54it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  39%|█████████████████▍                           | 387/1000 [00:02<00:04, 149.37it/s]

((), (), ())
Index Error
(({'Na': -1.0, 'K': -1.0, 'Te': 4.0}, {'Na': -1.0, 'K': 1.0, 'Te': -2.0}, {'Na': 1.0, 'K': -1.0, 'Te': 2.0}), (), (np.float64(-0.8044097866128268), np.float64(0.8044097866128268), np.float64(0.9969899056466502)))
(({'Pb': -4.0, 'As': 2.0, 'Al': 1.0},), (), (np.float64(-0.536398944859522),))


Processing ValidityPreprocessor:  42%|██████████████████▋                          | 416/1000 [00:03<00:04, 136.88it/s]

((), (), ())
Index Error
(({'Fe': -1.0, 'Se': -2.0, 'Al': 1.0},), (), (np.float64(-0.8832327388159914),))
(({'Zn': 1.0, 'Cl': -1.0, 'Cd': 1.0},), (), (np.float64(-0.9997298514133506),))
((), (), ())
Index Error


Processing ValidityPreprocessor:  43%|███████████████████▍                         | 431/1000 [00:03<00:04, 134.25it/s]

((), (), ())
Index Error
(({'Co': -1.0, 'Mg': 1.0, 'Si': -1.0}, {'Co': -1.0, 'Mg': 1.5, 'Si': -2.0}, {'Co': -1.0, 'Mg': 2.0, 'Si': -3.0}, {'Co': 1.0, 'Mg': 1.0, 'Si': -3.0}, {'Co': 1.0, 'Mg': 1.5, 'Si': -4.0}, {'Co': 2.0, 'Mg': 1.0, 'Si': -4.0}), (), (np.float64(-0.9995544002749418), np.float64(-0.9686196045011366), np.float64(-0.9291150685589119), np.float64(-0.593777282050744), np.float64(-0.5256277449721678), np.float64(-0.3869079093827)))
((), (), ())
Index Error


Processing ValidityPreprocessor:  46%|████████████████████▌                        | 457/1000 [00:03<00:05, 105.74it/s]

(({'Ru': -2.0, 'Ir': 2.0, 'H': -1.0}, {'Ru': 4.0, 'Ir': -3.0, 'H': 1.0}, {'Ru': 4.0, 'Ir': -1.0, 'H': -1.0}, {'Ru': 8.0, 'Ir': -3.0, 'H': -1.0}), (), (np.float64(nan), np.float64(nan), np.float64(nan), np.float64(nan)))
failed penalty
(({'B': 1.0, 'Co': 1.0, 'Mg': 2.0, 'Pb': -4.0}, {'B': 1.0, 'Co': 2.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 2.0, 'Co': 1.0, 'Mg': 1.0, 'Pb': -4.0}, {'B': 3.0, 'Co': -1.0, 'Mg': 2.0, 'Pb': -4.0}), (), (np.float64(-0.7973114762142738), np.float64(-0.6338339433214192), np.float64(-0.6041837486820981), np.float64(-0.5879455130357054)))


Processing ValidityPreprocessor:  48%|██████████████████████                        | 479/1000 [00:03<00:06, 80.14it/s]

(({'Pr': 2.0, 'As': -3.0, 'Al': 1.0},), (), (np.float64(-0.9599349481604783),))
((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  50%|██████████████████████▊                       | 495/1000 [00:03<00:05, 95.92it/s]

(({'Ac': 3.0, 'Pb': -4.0, 'N': 1.0},), (), (np.float64(-0.4210209997417763),))
(({'Ba': 2.0, 'Se': 2.0, 'W': -2.0},), (), (np.float64(-0.4067039076395485),))


Processing ValidityPreprocessor:  53%|████████████████████████▌                     | 533/1000 [00:04<00:05, 82.38it/s]

(({'Mn': -3.0, 'Tl': 1.0, 'Br': -1.0},), (), (np.float64(0.04404577866459287),))
failed penalty


Processing ValidityPreprocessor:  56%|█████████████████████████▌                    | 556/1000 [00:05<00:08, 55.02it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  58%|██████████████████████████▉                   | 585/1000 [00:06<00:11, 37.54it/s]

((), (), ())
Index Error
((), (), ())
Index Error
(({'Th': 2.0, 'Sn': 2.0, 'N': -2.6666666666666665}, {'Th': 2.3333333333333335, 'Sn': -4.0, 'N': -1.0}, {'Th': 2.3333333333333335, 'Sn': 2.0, 'N': -3.0}, {'Th': 2.6666666666666665, 'Sn': -4.0, 'N': -1.3333333333333333}, {'Th': 3.0, 'Sn': -4.0, 'N': -1.6666666666666667}, {'Th': 3.3333333333333335, 'Sn': -4.0, 'N': -2.0}, {'Th': 3.6666666666666665, 'Sn': -4.0, 'N': -2.3333333333333335}, {'Th': 4.0, 'Sn': -4.0, 'N': -2.6666666666666665}), (), (np.float64(-0.9462638940667149), np.float64(-0.9267495329019914), np.float64(-0.6835038732699298), np.float64(-0.6446482814757386), np.float64(-0.5986180929726627), np.float64(-0.544049427946467), np.float64(-0.4795054886323516), np.float64(-0.40365447665463955)))
(({'Sm': 2.0, 'Ru': -2.0, 'Se': -2.0},), (), (np.float64(-0.969789467952218),))


Processing ValidityPreprocessor:  61%|████████████████████████████                  | 610/1000 [00:06<00:06, 61.98it/s]

(({'Br': 1.0, 'Mg': 1.0, 'Mo': -1.0}, {'Br': 1.0, 'Mg': 2.0, 'Mo': -1.5}, {'Br': 3.0, 'Mg': 1.0, 'Mo': -2.0}), (), (np.float64(-0.2941141855276844), np.float64(-0.017492785713533094), np.float64(0.3812464258315116)))


Processing ValidityPreprocessor:  63%|████████████████████████████▉                 | 629/1000 [00:06<00:06, 60.25it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  66%|██████████████████████████████▎               | 660/1000 [00:06<00:03, 86.98it/s]

(({'Ni': -1.0, 'H': -1.0, 'V': 3.0}, {'Ni': -1.0, 'H': 1.0, 'V': -1.0}, {'Ni': 1.0, 'H': -1.0, 'V': 1.0}, {'Ni': 3.0, 'H': -1.0, 'V': -1.0}), (), (np.float64(-0.8710451982429083), np.float64(-0.8609167657051999), np.float64(-0.010128432537708452), np.float64(0.8710451982429083)))
((), (), ())
Index Error
(({'Th': 2.0, 'Sn': -4.0, 'Te': 2.0},), (), (np.float64(-0.35135135135135126),))


Processing ValidityPreprocessor:  72%|████████████████████████████████▎            | 718/1000 [00:07<00:02, 109.27it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  74%|█████████████████████████████████▍           | 744/1000 [00:07<00:01, 144.35it/s]

(({'La': 3.0, 'Ru': -2.0, 'Cl': -1.0},), (), (np.float64(-0.7810148683505013),))
((), (), ())
Index Error


Processing ValidityPreprocessor:  78%|██████████████████████████████████▉          | 776/1000 [00:07<00:01, 130.61it/s]

(({'Ce': 2.0, 'H': -1.0, 'Pb': -4.0},), (), (np.float64(-0.9107962159130372),))
(({'Y': 1.0, 'Si': -4.0, 'Ir': 2.5}, {'Y': 1.0, 'Si': -3.75, 'Ir': 2.3333333333333335}, {'Y': 1.0, 'Si': -3.5, 'Ir': 2.1666666666666665}, {'Y': 1.0, 'Si': -3.25, 'Ir': 2.0}, {'Y': 1.0, 'Si': -3.0, 'Ir': 1.8333333333333333}, {'Y': 1.0, 'Si': -2.75, 'Ir': 1.6666666666666667}, {'Y': 1.0, 'Si': -2.5, 'Ir': 1.5}, {'Y': 1.0, 'Si': -2.25, 'Ir': 1.3333333333333333}, {'Y': 1.0, 'Si': -2.0, 'Ir': 1.1666666666666667}, {'Y': 1.0, 'Si': -1.75, 'Ir': 1.0}, {'Y': 1.0, 'Si': 1.25, 'Ir': -1.0}, {'Y': 2.0, 'Si': -4.0, 'Ir': 2.3333333333333335}, {'Y': 2.0, 'Si': -3.75, 'Ir': 2.1666666666666665}, {'Y': 2.0, 'Si': -3.5, 'Ir': 2.0}, {'Y': 2.0, 'Si': -3.25, 'Ir': 1.8333333333333333}, {'Y': 2.0, 'Si': -3.0, 'Ir': 1.6666666666666667}, {'Y': 2.0, 'Si': -2.75, 'Ir': 1.5}, {'Y': 2.0, 'Si': -2.5, 'Ir': 1.3333333333333333}, {'Y': 2.0, 'Si': -2.25, 'Ir': 1.1666666666666667}, {'Y': 2.0, 'Si': -2.0, 'Ir': 1.0}, {'Y': 2.0, 'Si': 1.0, 'Ir': 

Processing ValidityPreprocessor:  81%|████████████████████████████████████▍        | 810/1000 [00:08<00:01, 145.53it/s]

(({'Mg': 1.0, 'Hg': 1.0, 'Cl': -1.0},), (), (np.float64(-0.9294197023849987),))
(({'Mg': 1.0, 'P': -3.0, 'Zr': 1.0},), (), (np.float64(-0.999801882117917),))
((), (), ())
Index Error


Processing ValidityPreprocessor:  84%|██████████████████████████████████████       | 845/1000 [00:08<00:01, 124.79it/s]

((), (), ())
Index Error
((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  90%|█████████████████████████████████████████▍    | 902/1000 [00:09<00:00, 98.75it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor:  92%|██████████████████████████████████████████▌   | 925/1000 [00:09<00:00, 86.72it/s]

((), (), ())
Index Error


Processing ValidityPreprocessor:  95%|███████████████████████████████████████████▋  | 949/1000 [00:09<00:00, 96.98it/s]

((), (), ())
Index Error
((), (), ())
Index Error


Processing ValidityPreprocessor: 100%|█████████████████████████████████████████████| 1000/1000 [00:10<00:00, 99.67it/s]

((), (), ())
Index Error
((), (), ())
Index Error





In [14]:
validity_benchmark_result, valid_structures, validity_filtering_metadata, processed_structures = result

In [15]:
invalid_indicies = []
count = 0
for val in validity_benchmark_result.evaluator_results["charge_neutrality"]["metric_results"]["charge_neutrality"].individual_values:
    if val == 10.0:
        invalid_indicies.append(count)
    count += 1

In [16]:
validity_filtering_metadata["validity_rate"]

{'total_input_structures': 1000,
 'valid_structures': 953,
 'invalid_structures': 47,
 'validity_rate': 0.953,
 'valid_structure_ids': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  46,
  47,
  48,
  49,
  50,
  51,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  91,
  92,
  94,
  95,
  96,
  97,
  98,
  99,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  129,
  130,
  131,
  132,
  133,
  134,
  135,
  136,
  137,
  138,
  139,
  140,
  141,
  142,
  143,
  144,
  146