In [1]:
import json
import pickle
import random

import pandas as pd
 
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
from pathlib import Path

from CGRtools import smiles

### 5. Ranking policy training

In [2]:
from synplan.utils.config import PolicyNetworkConfig
from synplan.ml.training.supervised import create_policy_dataset, run_policy_training

/storage/dmitry/miniforge3/envs/synplan/lib/python3.10/site-packages/lightning_fabric/__init__.py:41: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [3]:
training_config = PolicyNetworkConfig(
    policy_type="ranking",  # the type of policy network
    num_conv_layers=5,  # the number of graph convolutional layers in the network
    vector_dim=512,  # the dimensionality of the final embedding vector
    learning_rate=0.0008,  # the learning rate for the training process
    dropout=0.4,  # the dropout rate
    num_epoch=1000,  # the number of epochs for training
    batch_size=100,
) 

In [4]:
reaction_data_path = "passerini_reactions_stand.smi"
reaction_rules_path = "passerini_reaction_rules.pickle"
ranking_policy_network_folder = Path("ranking_policy_network")

ranking_policy_dataset_path = ranking_policy_network_folder.joinpath("ranking_policy_dataset.pt") # the generated training set

In [5]:
datamodule = create_policy_dataset(
    dataset_type="ranking",
    reaction_rules_path=reaction_rules_path,
    molecules_or_reactions_path=reaction_data_path,
    output_path=ranking_policy_dataset_path,
    batch_size=training_config.batch_size,
    num_cpus=4,
)

Training set size: 1489, validation set size: 373


In [6]:
run_policy_training(
    datamodule,  # the prepared data module for training
    config=training_config,  # the training configuration
    results_path=ranking_policy_network_folder,
    accelerator="cpu",
    silent=True
)  # path to save the training results

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Policy network balanced accuracy: 0.675


### 6. Expansion rate estimation

In [7]:
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.utils.loading import load_reaction_rules, load_building_blocks

from synplan.chem.precursor import Precursor
from synplan.chem.reaction import Reaction, apply_reaction_rule

from collections import Counter

In [8]:
# initialize the policy function
ranking_policy_network = ranking_policy_network_folder.joinpath("policy_network.ckpt")

policy_config = PolicyNetworkConfig(weights_path=ranking_policy_network)
policy_function = PolicyNetworkFunction(policy_config=policy_config)
reaction_rules = load_reaction_rules(reaction_rules_path)

In [9]:
mol = smiles("c1cc(ccc1)NC(C2=CCCCC2)(OC(c3ccccc3)=O)O")
precursor = Precursor(mol)
precursor

c1cc(ccc1)NC(C2=CCCCC2)(OC(c3ccccc3)=O)O

In [10]:
rule_prob_list = []
for prob, rule, rule_id in policy_function.predict_reaction_rules(precursor, reaction_rules):
    # Predict reaction rules
    rule_prob_list.append((rule_id, prob, rule))
sorted(rule_prob_list, key=lambda x: x[1], reverse=True)[:10]

[(1,
  0.012197372503578663,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d31817fd0>),
 (1251,
  0.009791867807507515,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d3140baf0>),
 (6,
  0.009302973747253418,
  <CGRtools.reactor.reactor.Reactor at 0x7f4e893ac940>),
 (7,
  0.008433681912720203,
  <CGRtools.reactor.reactor.Reactor at 0x7f4e893acb50>),
 (41,
  0.008429284207522869,
  <CGRtools.reactor.reactor.Reactor at 0x7f4e893a06a0>),
 (1784,
  0.00827113538980484,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d3122f910>),
 (1665,
  0.008125271648168564,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d313a0580>),
 (1258,
  0.007484037429094315,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d31419360>),
 (135,
  0.007333867717534304,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d3185f070>),
 (623,
  0.007105713710188866,
  <CGRtools.reactor.reactor.Reactor at 0x7f4d317c06a0>)]

In [11]:
n = 0
for rule_id, prob, rule in rule_prob_list:
    # 2. Apply reaction rules
    for products in apply_reaction_rule(precursor.molecule, rule):
        if products:
            n += 1
print(f"Expansion rate {n}/{len(rule_prob_list)}")

Expansion rate 0/50


### 7. Expansion rate benchmark

In [12]:
with open("passerini_reactions_stand.smi", "r") as f:
    reacts = f.readlines()
reacts = [smiles(i) for i in reacts]
# Take random precursors
cand_list = []
for r in reacts:
    cand_list.extend(r.reactants)
    cand_list.extend(r.products)

In [13]:
n_list = []
for mol in cand_list[:]:
    precursor = Precursor(mol)
    try:
        list(policy_function.predict_reaction_rules(precursor, reaction_rules))
    except:
        continue
    
    n = 0
    for prob, rule, rule_id in policy_function.predict_reaction_rules(precursor, reaction_rules):
        for products in apply_reaction_rule(precursor.molecule, rule):
            if products:
                n += 1
    n_list.append(n)

In [14]:
Counter(n_list)

Counter({0: 5926, 1: 902, 2: 502})