In [None]:
import io
import json
import gzip
import requests
import pandas as pd
from typing import Optional, Union


def load_json_from_raw_github(
    url: str,
    as_frame: bool = True,
    lines: Optional[bool] = None,
) -> Union[pd.DataFrame, list, dict]:
    # Fetch remote content
    r = requests.get(url, stream=True)
    r.raise_for_status()

    buf = io.BytesIO(r.content)

    # Try decompressing as gzip; fall back to plain text if not gzipped
    try:
        with gzip.GzipFile(fileobj=buf) as fh:
            text = fh.read().decode("utf-8")
    except (OSError, gzip.BadGzipFile):
        buf.seek(0)
        text = buf.read().decode("utf-8")

    # Determine NDJSON (lines) mode if user didn't force it
    if lines is None:
        # Heuristic: multiple lines and many lines beginning with '{'
        stripped = text.lstrip()
        is_ndjson = ("\n" in text) and stripped.startswith("{") and ("\n{" in text)
        lines_mode = is_ndjson
    else:
        lines_mode = bool(lines)

    # Parse and return
    if as_frame:
        # Try to produce a DataFrame where possible
        try:
            if lines_mode:
                return pd.read_json(io.StringIO(text), lines=True)
            else:
                return pd.read_json(io.StringIO(text))
        except ValueError:
            # pd.read_json couldn't parse (maybe nested, or not table-like) -> fallback
            return json.loads(text)
    else:
        if lines_mode:
            return [json.loads(line) for line in text.splitlines() if line.strip()]
        else:
            return json.loads(text)


In [None]:
url = "https://github.com/TieuLongPhan/SynKit/raw/refs/heads/att_graph/Data/Benchmark/data_aam.json.gz"
df = load_json_from_raw_github(url, as_frame=False)

In [None]:
from joblib import Parallel, delayed
from synkit.Chem.Reaction.canon_rsmi import CanonRSMI
from synkit.IO.debug import configure_warnings_and_logs

configure_warnings_and_logs(False, True)

maps = ['rxn_mapper', 'graphormer', 'local_mapper']

def _canon_task(value, mapper):
    smi = value.get(mapper)
    if not smi:
        return None
    try:
        
        return CanonRSMI().canonicalise(smi).canonical_rsmi
    except Exception as exc:
        return None

tasks = [(row, m) for row in df for m in maps]

results = Parallel(n_jobs=4, verbose=2, backend="loky")(
    delayed(_canon_task)(row, m) for row, m in tasks
)

# keep only successful canonical rsmi strings
final = [r for r in results if r is not None]


In [None]:
final = list(set(final))
print(len(final))

In [None]:
from synkit.Chem.Reaction.standardize import Standardize
from synkit.Chem.Reaction.balance_check import BalanceReactionCheck
std = Standardize()
check = BalanceReactionCheck()

In [None]:

def _process(idx, value):
    try:
        rxn = std.fit(value, True)
        balance = check.rsmi_balance_check(rxn)
        new = {
            'R-id': f'R-{idx}',
            'rxn': rxn,
            'aam': value,
            'balance': balance
        }
        return idx, new
    except Exception as exc:
        
        return idx, None

tasks = list(enumerate(final))  
try:
    results = Parallel(n_jobs=4, backend='loky', verbose=2)(
        delayed(_process)(idx, val) for idx, val in tasks
    )
except Exception:
    # fallback if objects aren't picklable
    backend = "threading"
    results = Parallel(n_jobs=4, backend='loky', verbose=2)(
        delayed(_process)(idx, val) for idx, val in tasks
    )

results.sort(key=lambda t: t[0])
data = [item for (_, item) in results if item is not None]
data = [value for value in data if value['balance']]

In [None]:
from joblib import Parallel, delayed
from synkit.IO import rsmi_to_its

def _compute_its(idx, item):
    try:
        its = rsmi_to_its(item['aam'], core=False)
        rc  = rsmi_to_its(item['aam'], core=True)
        return idx, its, rc
    except Exception:
        return idx, None, None


results = Parallel(n_jobs=4, backend="loky", verbose=2)(
    delayed(_compute_its)(i, itm) for i, itm in enumerate(data)
)


for idx, its, rc in results:
    data[idx]['ITS'] = its
    data[idx]['RC']  = rc


In [None]:
from synkit.Graph.Hyrogen.hcomplete import HComplete

comp = HComplete()
complete = comp.process_graph_data_parallel(data,its_key='ITS', rc_key = 'RC', n_jobs=4, verbose=2)
amb_hydrogen = [value for value in complete if value['ITS'] is None]

In [None]:
amb_hydrogen = [{'R-id': value['R-id'], 
                 'rxn': value['rxn'],
                 'aam': value['aam']} for value in amb_hydrogen]

In [None]:
results = Parallel(n_jobs=4, backend="loky", verbose=2)(
    delayed(_compute_its)(i, itm) for i, itm in enumerate(amb_hydrogen)
)


for idx, its, rc in results:
    amb_hydrogen[idx]['ITS'] = its
    amb_hydrogen[idx]['RC']  = rc


In [None]:
from synkit.Graph.Matcher.graph_cluster import GraphCluster
from synkit.Utils.utils import stratified_random_sample
cls = GraphCluster()
result = cls.fit(amb_hydrogen, 'RC', None)
data = stratified_random_sample(result, 'class', 1, 42)



In [None]:
import re
from typing import Iterable, Tuple, List, Pattern

def _compile_bracketed_elements_pattern(elements: Tuple[str, ...]) -> Pattern:
    elems_sorted = sorted(elements, key=lambda s: -len(s))
    alt = "|".join(re.escape(e) for e in elems_sorted)
    pattern = rf'\[(?:\d+)?(?:{alt})(?![A-Za-z])(?:[:@+\-\d]*)?\]'
    return re.compile(pattern)

def is_single_bracketed_elements_collapse_dots(smiles: str, elements: Iterable[str] = ('H','O')) -> bool:
    if not smiles:
        return False
    elems = tuple(elements)
    token_re = _compile_bracketed_elements_pattern(elems)

    # find tokens with spans
    tokens: List[Tuple[int,int,str]] = [(m.start(), m.end(), m.group()) for m in token_re.finditer(smiles)]
    if not tokens:
        return False

    effective_count = 0
    prev_token = None
    prev_end = -1

    for start, end, text in tokens:
        if prev_token is not None and text == prev_token:
            # check that the chars between prev_end and start are only dots or whitespace
            between = smiles[prev_end:start]
            if re.fullmatch(r'[\s\.]*', between):
                # same run -> skip counting this token
                prev_end = end
                continue

        # new effective token
        effective_count += 1
        prev_token = text
        prev_end = end

    return effective_count == 1


In [None]:
bug = []
for key, value in enumerate(data):
    if is_single_bracketed_elements_collapse_dots(value['rxn']):
        bug.append(key)

data = [value for key, value in enumerate(data) if key not in bug]
easier = [value for key, value in enumerate(data) if key != 72]


In [None]:
from synkit.IO import save_to_pickle


save_to_pickle(data, './hydrogen.pkl.gz')


# Baseline

In [16]:
from synkit.IO import load_from_pickle
from synkit.Graph.Hyrogen.hextend import HExtend
data = load_from_pickle('./hydrogen.pkl.gz')

# result = HExtend().fit(data[50:60], 'ITS', 'RC', n_jobs=4, verbose=2)

In [None]:
easier = [value for key, value in enumerate(data) if key not in bug]
easier = [value for key, value in enumerate(easier) if key != 72]


In [None]:
result = HExtend().fit(easier[0:], 'ITS', 'RC', n_jobs=1, verbose=2)

In [None]:
result