In [31]:
import pickle
import numpy as np
import pandas as pd
from collections import Counter

from scipy.stats import pearsonr
from synplan.utils.loading import load_reaction_rules

from CGRtools import smiles
from CGRtools.reactor.reactor import Reactor
from synplan.chem.precursor import Precursor
from synplan.chem.reaction import apply_reaction_rule
from synplan.utils.config import PolicyNetworkConfig
from synplan.mcts.expansion import PolicyNetworkFunction

import seaborn as sns
import matplotlib.pyplot as plt

In [32]:
from collections import defaultdict

def get_applied_stat(mol, reaction_rules):
    """
    For a given molecule, return the indices of all rules that applied,
    grouped by rule label.
    """
    prec = Precursor(mol)
    applied_rule_indices = []
    applied_by_label = defaultdict(list)

    for i, (rule_obj, _) in enumerate(reaction_rules):
        rule = Reactor(rule_obj)
        prod = list(apply_reaction_rule(prec.molecule, rule))
        if prod:
            applied_rule_indices.append(i)
            label = rule_obj.meta.get("label", "unknown")
            applied_by_label[label].append(i)

    return {
        "APPLIED_RULES": applied_rule_indices,
        "BY_LABEL": dict(applied_by_label)
    }

### 1. Calculate true rule application rate

In [33]:
with open("training_hybrid/reaction_rules.pickle", "rb") as f:
    reaction_rules = pickle.load(f)

In [40]:
smi_list = pd.read_csv("synplan_data/chembl/molecules_for_filtering_policy_training_all.smi", header=None)[0].to_list()
smi_list = smi_list[:]

In [None]:
import pickle
from collections import defaultdict

save_every = 50  # save progress every N molecules
save_path = "rule_profile.pkl"

res = defaultdict(int)                # global scalar stats
rule_counts = defaultdict(int)        # per-rule application counts
res_by_label = defaultdict(int)       # total count per label

for n, smi in enumerate(smi_list, start=1):

    # 1. Read molecule
    mol = smiles(smi)
    mol.clean2d()

    # 2. Calculate applied stats
    n_applied = get_applied_stat(mol, reaction_rules)

    # 3. Accumulate results
    res["NUM_MOLS_TRIED"] += 1
    res["TOTAL_APPLIED"] += len(n_applied["APPLIED_RULES"])

    # Per-label accumulation
    for label, rule_indices in n_applied["BY_LABEL"].items():
        res_by_label[label] += len(rule_indices)

    # Count per-rule applications
    for ri in n_applied["APPLIED_RULES"]:
        rule_counts[ri] += 1

    # 4. Periodic saving
    if n % save_every == 0:
        res["BY_LABEL"] = dict(res_by_label)
        res["RULE_COUNTS"] = dict(rule_counts)
        with open(save_path, "wb") as f:
            pickle.dump(dict(res), f)
        print(f"[{n}] Progress saved to {save_path}", end="\r")

# --- Final save ---
res["BY_LABEL"] = dict(res_by_label)
res["RULE_COUNTS"] = dict(rule_counts)
with open(save_path, "wb") as f:
    pickle.dump(dict(res), f)