In [1]:
import pickle
import numpy as np
import pandas as pd
from collections import Counter

from scipy.stats import pearsonr
from synplan.utils.loading import load_reaction_rules

from CGRtools import smiles
from CGRtools.reactor.reactor import Reactor
from synplan.chem.precursor import Precursor
from synplan.chem.reaction import apply_reaction_rule
from synplan.utils.config import PolicyNetworkConfig
from synplan.mcts.expansion import PolicyNetworkFunction

import seaborn as sns
import matplotlib.pyplot as plt

  from pkg_resources import resource_string


In [12]:
def get_applied_stat(mol, policy_function, reaction_rules):
    # 1. Get Precursor
    prec = Precursor(mol)

    # 2.Get predictions
    pred = list(policy_function.predict_reaction_rules(prec, reaction_rules))
    
    # 3. Calc n applied rules
    n_applied = 0
    for (prob, rule, idx) in pred:
        rule = Reactor(rule)
        prod = list(apply_reaction_rule(prec.molecule, rule))
        if prod:
            n_applied += 1

    # 4. Calc max applied
    n_max = 0
    for rule in reaction_rules:
        rule = Reactor(rule)
        prod = list(apply_reaction_rule(prec.molecule, rule))
        if prod:
            n_max += 1
    
    return n_applied, n_max

### 1. One step expansion rate

In [13]:
with open("training_hybrid/reaction_rules.pickle", "rb") as f:
    reaction_rules = pickle.load(f)
reaction_rules = [i[0] for i in reaction_rules]

policy_config = PolicyNetworkConfig(weights_path="training_hybrid/ranking_policy_network/policy_network.ckpt", top_rules=50)
policy_function = PolicyNetworkFunction(policy_config=policy_config)

In [14]:
smi_list = pd.read_csv("synplan_data/chembl/molecules_for_filtering_policy_training_all.smi", header=None)[0].to_list()
smi_list = smi_list[:100]

In [15]:
res = []
for n, smi in enumerate(smi_list):

    # 1. Read molecule
    mol = smiles(smi)
    mol.clean2d()

    # 2. Calc applied
    n_applied, n_max = get_applied_stat(mol, policy_function, reaction_rules)
    
    # 3. Calc metrics
    obs_exp_rate = 100 * n_applied / 50
    act_exp_rate = 100 * n_applied / min(50, max(n_max, 1))

    res.append({"N_APPLIED":n_applied,
                "MAX_APPLIED": n_max,
                "OBS_EXP_RATE": obs_exp_rate,
                "TRUE_EXP_RATE": act_exp_rate
               })
    n += 1
    print(f"{n} / {len(smi_list)}", end="\r")
#
res = pd.DataFrame(res)

100 / 100

In [20]:
res = res.round(1)
res.sort_values(by="MAX_APPLIED", ascending=True).head(10)

Unnamed: 0,N_APPLIED,MAX_APPLIED,OBS_EXP_RATE,TRUE_EXP_RATE
34,0,16,0.0,0.0
49,1,25,2.0,4.0
10,0,32,0.0,0.0
99,3,42,6.0,7.1
92,4,47,8.0,8.5
78,7,47,14.0,14.9
73,2,49,4.0,4.1
19,1,49,2.0,2.0
43,0,52,0.0,0.0
63,0,54,0.0,0.0


In [22]:
res.mean().round(1)

N_APPLIED          7.8
MAX_APPLIED      150.9
OBS_EXP_RATE      15.6
TRUE_EXP_RATE     15.6
dtype: float64