In [None]:
from __future__ import annotations

import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from tqdm.auto import tqdm
from synkit.Chem.Reaction.standardize import Standardize
from synkit.Chem.Reaction.canon_rsmi import CanonRSMI

log = logging.getLogger(__name__)

def process_aam(
    entries: Iterable[Dict[str, Any]],
    *,
    rxn_fn: Optional[Callable[[Any], Any]] = None,
    canon_fn: Optional[Callable[[Any], Any]] = None,
    std: Optional[Standardize] = None,
    canon: Optional[CanonRSMI] = None,
    reactions_key: str = "reactions",
    rxn_key: str = "rxn",
    gt_key: str = "ground_truth",
    inplace: bool = False,
    swallow_exceptions: bool = True,
    return_failures: bool = False,
    progress: bool = True,
    progress_desc: Optional[str] = None,
    progress_disable: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]]:
    if inplace:
        working: List[Dict[str, Any]] = entries  # type: ignore[assignment]
    else:
        working = [dict(e) for e in entries]

    if rxn_fn is None:
        if std is None:
            std = Standardize()
        rxn_fn = lambda reactions: std.fit(reactions)

    if canon_fn is None:
        if canon is None:
            canon = CanonRSMI(backend="wl", wl_iterations=3)
        def _canon_fn(gt_val: Any) -> Any:
            if gt_val is None:
                return None
            return canon.canonicalise(gt_val).canonical_rsmi
        canon_fn = _canon_fn

    processed: List[Dict[str, Any]] = []
    failures: List[Dict[str, Any]] = []

    iterator = working
    if progress:
        total = None
        try:
            total = len(working)  # type: ignore[arg-type]
        except Exception:
            total = None
        iterator = tqdm(working, desc=(progress_desc or "process_aam"), disable=progress_disable, total=total)

    for idx, entry in enumerate(iterator):
        try:
            raw_reactions = entry.get(reactions_key)
            entry[rxn_key] = rxn_fn(raw_reactions)
            if gt_key in entry and entry.get(gt_key) is not None:
                entry[gt_key] = canon_fn(entry[gt_key])
        except Exception as exc:
            log.debug("process_aam: error processing entry %s: %s", idx, exc, exc_info=True)
            if swallow_exceptions:
                entry[rxn_key] = None
                entry[gt_key] = None
                failures.append({"index": idx, "entry": dict(entry), "error": exc})
            else:
                raise

        if entry.get(rxn_key):
            entry.pop(reactions_key, None)
            processed.append(entry)

    if progress:
        try:
            iterator.close()
        except Exception:
            pass

    if inplace:
        try:
            entries.clear()  # type: ignore[attr-defined]
            entries.extend(processed)  # type: ignore[arg-type]
        except Exception:
            log.debug("Could not replace original 'entries' in-place; returning processed list.")

    if return_failures:
        return processed, failures

    return processed


In [None]:
import sys
sys.path.append('../')
from synrxn.io.io import load_json_from_raw_github, load_df_gz, save_df_gz
from synkit.IO import configure_warnings_and_logs
configure_warnings_and_logs(True, True)
ecoli = "https://raw.githubusercontent.com/TieuLongPhan/SynTemp/main/Data/AAM/results_benchmark/ecoli/ecoli_aam_reactions.json.gz"
recond3d = "https://raw.githubusercontent.com/TieuLongPhan/SynTemp/main/Data/AAM/results_benchmark/recon3d/recon3d_aam_reactions.json.gz"
golden = "https://raw.githubusercontent.com/TieuLongPhan/SynTemp/main/Data/AAM/results_benchmark/golden/golden_aam_reactions.json.gz"
natcomm = "https://raw.githubusercontent.com/TieuLongPhan/SynTemp/main/Data/AAM/results_benchmark/natcomm/natcomm_aam_reactions.json.gz"
uspto3k = "https://raw.githubusercontent.com/TieuLongPhan/SynTemp/main/Data/AAM/results_benchmark/uspto_3k/uspto_3k_aam_reactions.json.gz"


In [None]:
import pandas as pd
ecoli = load_json_from_raw_github(ecoli, as_frame=False)
ecoli = process_aam(ecoli, progress=True, rxn_key='rxn')
print(len(ecoli))
save_df_gz(pd.DataFrame(ecoli), '../Data/aam/ecoli.csv.gz')

In [None]:
recond3d = load_json_from_raw_github(recond3d, as_frame=False)
recond3d = process_aam(recond3d)
print(len(recond3d))
save_df_gz(pd.DataFrame(recond3d), '../Data/aam/ecoli.csv.gz')

In [None]:
golden = load_json_from_raw_github(golden, as_frame=False)
golden = process_aam(golden)
print(len(golden))
save_df_gz(pd.DataFrame(golden), '../Data/aam/ecoli.csv.gz')

In [None]:
natcomm = load_json_from_raw_github(natcomm, as_frame=False)
natcomm = process_aam(natcomm)
print(len(natcomm))
save_df_gz(pd.DataFrame(natcomm), '../Data/aam/ecoli.csv.gz')

In [None]:
uspto3k = load_json_from_raw_github(uspto3k, as_frame=False)
uspto3k = process_aam(uspto3k)
print(len(uspto3k))
save_df_gz(pd.DataFrame(uspto3k), '../Data/aam/ecoli.csv.gz')