# 0. Get data

In [None]:
from synkit.IO import load_database

data = load_database('./data_aam.json.gz')
print(len(data))

# 1. Chem package

## 1.1. Standardization and Canonicalization

In [None]:
from joblib import Parallel, delayed
from tqdm import tqdm
from typing import Any, Dict, List, Tuple, Optional
from synkit.IO.data_process import TqdmJoblib
from synkit.Chem.Reaction.standardize import Standardize
from synkit.Chem.Reaction.canon_rsmi import CanonRSMI

def safe_std_canon(aam: Any) -> Optional[str]:
    if aam is None:
        return None
    try:
        std = Standardize()
        canon = CanonRSMI(wl_iterations=4)
        aam_std = std.fit(aam, remove_aam=False)
        canon_out = canon.canonicalise(aam_std)
        return getattr(canon_out, "canonical_rsmi", None)
    except Exception:
        return None

def process_item(pair: Tuple[int, Dict[str, Any]], keys: Tuple[str, ...]) -> Dict[str, Any]:
    idx, item = pair
    out = item.copy()
    for k in keys:
        val = item.get(k)
        out[k] = safe_std_canon(val)
    return out

def parallel_std_canon_tqdm(
    data: List[Dict[str, Any]],
    n_jobs: int = -1,
    keys: Tuple[str, ...] = ("rxn_mapper", "graphormer", "local_mapper"),
    prefer_backend: str = "loky",
) -> List[Dict[str, Any]]:
    enumerated = list(enumerate(data))
    total = len(enumerated)

    def _proc(pair):
        return process_item(pair, keys)

    try:
        with TqdmJoblib(tqdm(total=total, desc="std+canon", unit="it")):
            result = Parallel(n_jobs=n_jobs, backend=prefer_backend)(
                delayed(_proc)(pair) for pair in enumerated
            )
    except Exception:
        # retry with threads if processes/pickling fail
        with TqdmJoblib(tqdm(total=total, desc="std+canon", unit="it")):
            result = Parallel(n_jobs=n_jobs, backend="threading")(
                delayed(_proc)(pair) for pair in enumerated
            )
    return result

In [None]:
result = parallel_std_canon_tqdm(data, n_jobs=4)
data[:] = result

## 1. 2. Ensemble AAM

In [None]:
from typing import List, Dict, Tuple, Any, Optional, Union

def extract_aam(
    data: List[Dict[str, Any]],
    keys: Tuple[str, str, str] = ("rxn_mapper", "graphormer", "local_mapper"),
    strip: bool = True,
    require_not_none: bool = True,
    return_indices: bool = False,
    inplace: bool = False,
) -> Union[
    Tuple[List[Dict[str, Any]], List[Dict[str, Any]]],
    Tuple[List[Dict[str, Any]], List[int], List[Dict[str, Any]], List[int]],
]:
    out: List[Dict[str, Any]] = []
    idxs: List[int] = []
    non_out: List[Dict[str, Any]] = []
    non_idxs: List[int] = []

    for i, item in enumerate(data):
        v0 = item.get(keys[0])
        v1 = item.get(keys[1])
        v2 = item.get(keys[2])

        if strip:
            v0 = v0.strip() if isinstance(v0, str) else v0
            v1 = v1.strip() if isinstance(v1, str) else v1
            v2 = v2.strip() if isinstance(v2, str) else v2

        if require_not_none and (v0 is None or v1 is None or v2 is None):
            non_out.append(item)
            non_idxs.append(i)
            continue

        try:
            equal = (v0 == v1 == v2)
        except Exception:
            equal = False

        if equal:
            aam_val = v0
            new_item = item.copy()
            for k in keys:
                new_item.pop(k, None)
            new_item["aam"] = aam_val
            out.append(new_item)
            idxs.append(i)
        else:
            non_out.append(item)
            non_idxs.append(i)

    if inplace:
        data[:] = out

    if return_indices:
        return out, idxs, non_out, non_idxs
    return out, non_out

In [None]:
matches, non_matches = extract_aam(data)
print(len(matches), len(non_matches))

## 1.3. Filter by isomorphism

In [None]:
from typing import List, Dict, Any, Tuple, Optional
from synkit.Chem.Reaction.aam_validator import AAMValidator

def resolve_non_matches(
    non_matches: List[Dict[str, Any]],
    validator: Optional[AAMValidator] = None,
    keys: Tuple[str, str, str] = ("rxn_mapper", "graphormer", "local_mapper"),
    strip: bool = True,
    allow_interrupt: bool = True,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    if validator is None:
        validator = AAMValidator()

    resolved: List[Dict[str, Any]] = []
    remaining: List[Dict[str, Any]] = []

    key_rxn, key_graph, key_local = keys

    for item in non_matches:
        try:
            v_rxn = item.get(key_rxn)
            v_graph = item.get(key_graph)
            v_local = item.get(key_local)

            if strip:
                v_rxn = v_rxn.strip() if isinstance(v_rxn, str) else v_rxn
                v_graph = v_graph.strip() if isinstance(v_graph, str) else v_graph
                v_local = v_local.strip() if isinstance(v_local, str) else v_local

            if v_rxn is None or v_graph is None or v_local is None:
                remaining.append(item)
                continue

            try:
                ok_local = bool(validator.smiles_check(v_rxn, v_local))
            except KeyboardInterrupt:
                if allow_interrupt:
                    raise
                ok_local = False
            except Exception:
                ok_local = False

            try:
                ok_graph = bool(validator.smiles_check(v_rxn, v_graph))
            except KeyboardInterrupt:
                if allow_interrupt:
                    raise
                ok_graph = False
            except Exception:
                ok_graph = False

            if ok_local and ok_graph:
                new_item = item.copy()
                new_item.pop(key_rxn, None)
                new_item.pop(key_graph, None)
                new_item.pop(key_local, None)
                new_item["aam"] = v_rxn
                resolved.append(new_item)
            else:
                remaining.append(item)

        except KeyboardInterrupt:
            raise
        except Exception:
            remaining.append(item)

    return resolved, remaining


In [None]:
resolved, still_unresolved = resolve_non_matches(non_matches)
matches.extend(resolved)
print(len(matches))

## 1.4. Drop duplicates

In [None]:
import pandas as pd
data = pd.DataFrame(matches)
data.drop_duplicates(subset='aam', inplace=True)
data = data.to_dict('records')
print(len(data))

# 2. IO
Convert to ITS and reaction center

In [None]:
from synkit.IO import rsmi_to_its
for value in data:
    value['ITS'] = rsmi_to_its(value['aam'], core=False)
    value['RC'] = rsmi_to_its(value['aam'], core=True)

# 3. Graph module

In [None]:
from synkit.Graph.Feature.wl_hash import WLHash
from synkit.Graph.Matcher.graph_cluster import GraphCluster
for value in data:
    value['wl'] = WLHash().weisfeiler_lehman_graph_hash(value['RC'])
cls = GraphCluster()
result = cls.fit(data, rule_key='RC', attribute_key='wl') # quick filter

# 4. Rule

In [None]:
from synkit.Utils.utils import stratified_random_sample
from synkit.Rule.syn_rule import SynRule
rule = stratified_random_sample(result, 'class', 1)
rule = [value['ITS'] for value in rule]

In [None]:
len(rule)

# 5. Rule application

In [None]:
df = result[:]

In [None]:
from __future__ import annotations
from typing import Dict, Any, List, Optional, Union
import copy
import logging
import traceback
import os

from joblib import Parallel, delayed
import concurrent.futures

try:
    from tqdm import tqdm  # optional
except Exception:
    tqdm = None  # type: ignore

from synkit.IO.chem_converter import rsmi_to_graph, rsmi_to_its
from synkit.Graph.ITS.its_decompose import get_rc
from synkit.Chem.Reaction.standardize import Standardize
from synkit.Synthesis.Reactor.syn_reactor import SynReactor

logger = logging.getLogger(__name__)


def process_reaction_entry(
    entry: Dict[str, Any],
    smart_key: str = "aam",
    explicit_h: bool = False,
    invert: bool = False,
    implicit_temp: bool = True,
    strategy: str = "all",
    embed_pre_filter: bool = False,
    return_reactor: bool = False,
) -> Union[List[str], Dict[str, Any]]:
    if smart_key not in entry:
        raise KeyError(f"Entry missing required key '{smart_key}'")
    smi = entry[smart_key]
    if not isinstance(smi, str) or not smi.strip():
        raise ValueError(f"Value for '{smart_key}' must be a non-empty SMILES string")
    try:
        rsmi = Standardize().fit(smi, remove_aam=True)
    except Exception as e:
        logger.exception("Standardization failed for SMILES: %s", smi)
        raise RuntimeError("Standardize().fit(...) failed") from e
    try:
        left, right = rsmi_to_graph(rsmi, drop_non_aam=False, use_index_as_atom_map=False)
    except Exception as e:
        logger.exception("rsmi_to_graph failed for RSMS: %s", rsmi)
        raise RuntimeError("rsmi_to_graph(...) failed") from e
    try:
        its = rsmi_to_its(smi)
        rc = get_rc(its)
    except Exception as e:
        logger.exception("ITS decomposition or get_rc failed for SMILES: %s", smi)
        raise RuntimeError("rsmi_to_its(...) or get_rc(...) failed") from e
    if rc is None:
        raise ValueError("Reaction center (rc) could not be determined from ITS")
    substrate = right if invert else left
    try:
        reactor = SynReactor(
            substrate,
            rc,
            explicit_h=explicit_h,
            invert=invert,
            strategy=strategy,
            embed_pre_filter=embed_pre_filter,
            implicit_temp=implicit_temp,
            automorphism=True,
        )
    except Exception as e:
        logger.exception("SynReactor initialization failed")
        raise RuntimeError("SynReactor(...) failed") from e
    smarts = getattr(reactor, "smarts", None)
    if smarts is None:
        try:
            smarts = list(reactor.get_smarts())
        except Exception:
            logger.error("SynReactor returned no 'smarts' attribute or get_smarts method")
            raise RuntimeError("Could not obtain SMARTS from SynReactor instance")
    if return_reactor:
        return {"smarts": smarts, "reactor": reactor}
    return smarts


def dict_process(*args, **kwargs):
    if "strat" in kwargs and "strategy" not in kwargs:
        kwargs["strategy"] = kwargs.pop("strat")
    return process_reaction_entry(*args, **kwargs)


def _safe_process_pair(entry: Dict[str, Any], *, process_fn, catch_traceback: bool = True) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    try:
        fw = process_fn(entry, invert=False)
        out["fw"] = fw
    except Exception as e:
        logger.exception("Forward processing failed for entry (preview): %r", entry.get("id") or entry.get("aam"))
        out["fw_error"] = str(e)
        if catch_traceback:
            out["fw_traceback"] = traceback.format_exc()
    try:
        bw = process_fn(entry, invert=True)
        out["bw"] = bw
    except Exception as e:
        logger.exception("Backward processing failed for entry (preview): %r", entry.get("id") or entry.get("aam"))
        out["bw_error"] = str(e)
        if catch_traceback:
            out["bw_traceback"] = traceback.format_exc()
    return out


def process_entries_parallel(
    entries,
    *,
    process_fn,
    n_jobs: int = -1,
    backend: str = "loky",
    batch_size: Optional[int] = None,
    mutate_inplace: bool = True,
    show_progress: bool = False,
    use_threads_for_progress: bool = True,
):
    entries_list = list(entries)
    if not entries_list:
        return []
    if n_jobs is None or n_jobs == 0:
        n_jobs = 1
    if n_jobs < 0:
        n_workers = os.cpu_count() or 1
    else:
        n_workers = n_jobs
    def _worker(e: Dict[str, Any]) -> Dict[str, Any]:
        base = e if mutate_inplace else copy.deepcopy(e)
        result = _safe_process_pair(e, process_fn=process_fn)
        base.update(result)
        return base
    if show_progress and tqdm is not None:
        executor_cls = concurrent.futures.ThreadPoolExecutor if use_threads_for_progress else concurrent.futures.ProcessPoolExecutor
        results = []
        try:
            with executor_cls(max_workers=n_workers) as ex:
                futures = [ex.submit(_worker, e) for e in entries_list]
                for fut in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="process_entries"):
                    try:
                        results.append(fut.result())
                    except Exception:
                        tb = traceback.format_exc()
                        logger.exception("Worker raised unexpectedly: %s", tb)
                        results.append({"_worker_error": tb})
        except KeyboardInterrupt:
            logger.warning("KeyboardInterrupt received â€” attempting to cancel running futures")
            raise
        return results
    try:
        parallel = Parallel(n_jobs=n_jobs, backend=backend, batch_size=batch_size)
        results = parallel(delayed(_worker)(e) for e in entries_list)
        return results
    except KeyboardInterrupt:
        logger.warning("KeyboardInterrupt propagated from joblib.Parallel")
        raise
    except Exception:
        logger.exception("joblib.Parallel failed; falling back to sequential processing")
        out = []
        for e in entries_list:
            out.append(_worker(e))
        return out

In [None]:
import logging
from typing import Dict, Tuple, List

_prev_logging_state: Dict = {}

def force_disable_logging() -> None:
    """
    Forcefully silence Python logging (best-effort). Call restore_logging() to bring things back.
    """
    global _prev_logging_state
    if _prev_logging_state:
        return  # already disabled

    manager = logging.root.manager
    _prev_logging_state["disabled"] = manager.disable
    _prev_logging_state["root_level"] = logging.getLogger().level
    _prev_logging_state["loggers"] = {}

    # raise level and remove handlers for all known loggers (best-effort)
    for name, obj in list(manager.loggerDict.items()):
        if isinstance(obj, logging.Logger):
            logger: logging.Logger = obj
            _prev_logging_state["loggers"][name] = (
                logger.level,
                logger.propagate,
                list(logger.handlers),
            )
            try:
                logger.handlers = []
            except Exception:
                # some loggers may not allow direct reassignment; ignore
                pass
            try:
                logger.setLevel(logging.CRITICAL + 10)
            except Exception:
                pass
            try:
                logger.propagate = False
            except Exception:
                pass

    # also silence the root logger
    try:
        root = logging.getLogger()
        _prev_logging_state["root_handlers"] = list(root.handlers)
        root.handlers = []
        root.setLevel(logging.CRITICAL + 10)
    except Exception:
        pass

    # globally disable logging up to and including CRITICAL (effectively all levels)
    logging.disable(logging.CRITICAL)


def restore_logging() -> None:
    """
    Restore logging state saved by force_disable_logging(). Safe to call even if nothing was saved.
    """
    global _prev_logging_state
    if not _prev_logging_state:
        return

    manager = logging.root.manager

    # restore global disable
    try:
        manager.disable = _prev_logging_state.get("disabled", logging.NOTSET)
        logging.disable(manager.disable)
    except Exception:
        logging.disable(logging.NOTSET)

    # restore root logger
    try:
        root = logging.getLogger()
        root.handlers = _prev_logging_state.get("root_handlers", [])
        root.setLevel(_prev_logging_state.get("root_level", logging.NOTSET))
    except Exception:
        pass

    # restore individual loggers
    for name, (level, propagate, handlers) in _prev_logging_state.get("loggers", {}).items():
        try:
            logger = logging.getLogger(name)
            logger.handlers = handlers
            logger.setLevel(level)
            logger.propagate = propagate
        except Exception:
            pass

    _prev_logging_state.clear()


In [None]:
force_disable_logging()
processed = process_entries_parallel(
    df[:],
    process_fn=process_reaction_entry,
    n_jobs=4,
    backend="loky",
    mutate_inplace=True,        # original dicts updated
        show_progress=True,        # set True to use concurrent.futures + tqdm if you'd like a progress bar
)