In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="CGRtools")

import os
import pandas as pd
import pickle
import shutil
from pathlib import Path
from synplan.utils.loading import download_all_data

In [None]:
results_folder = Path("training_hybrid").resolve()
results_folder.mkdir(exist_ok=True)

In [None]:
reaction_data_path = "radical_data/uspto_radical_filtered.csv"
reaction_rules_path = results_folder.joinpath("reaction_rules.pickle")

ranking_policy_network_folder = results_folder.joinpath("ranking_policy_network")
ranking_policy_dataset_path = ranking_policy_network_folder.joinpath("ranking_policy_dataset.pt") # the generated training set

building_blocks_path = data_folder.joinpath("synplan_data/building_blocks/building_blocks_em_sa_ln.smi").resolve(strict=True)

### Reaction rules extraction

In [None]:
from synplan.utils.config import RuleExtractionConfig
from synplan.chem.reaction_rules.extraction import extract_rules_from_reactions

### Rule extraction configuration

In [None]:
extraction_config = RuleExtractionConfig(
    min_popularity={"uspto":3, "radical":1},
    environment_atom_count=1,
    multicenter_rules=True,
    include_rings=False,
    keep_leaving_groups=True,
    keep_incoming_groups=False,
    keep_reagents=False,
    include_func_groups=False,
    func_groups_list=[],
    atom_info_retention={
        "reaction_center": {
            "neighbors": True,  
            "hybridization": True,  
            "implicit_hydrogens": False, 
            "ring_sizes": False,  
        },
        "environment": {
            "neighbors": False,  
            "hybridization": False,  
            "implicit_hydrogens": False,  
            "ring_sizes": False,  
        },
    },
)

### Running rule extraction

In [None]:
extract_rules_from_reactions(
    config=extraction_config,  
    reaction_data_path=reaction_data_path,  
    reaction_rules_path=reaction_rules_path,
    num_cpus=4,
    batch_size=100,
)

### Ranking policy config

In [None]:
from synplan.utils.config import PolicyNetworkConfig
from synplan.ml.training.supervised import create_policy_dataset, run_policy_training

In [None]:
training_config = PolicyNetworkConfig(
    policy_type="ranking",  
    num_conv_layers=5, 
    vector_dim=512,  
    learning_rate=0.0008,  
    dropout=0.4,  
    num_epoch=100,  
    batch_size=100,
)  

### Policy training dataset

In [None]:
datamodule = create_policy_dataset(
    dataset_type="ranking",  
    reaction_rules_path=reaction_rules_path,
    molecules_or_reactions_path=reaction_data_path,
    output_path=ranking_policy_dataset_path,
    batch_size=training_config.batch_size,
    num_cpus=20,
)

### Running policy training

In [None]:
run_policy_training(
    datamodule, 
    config=training_config, 
    results_path=ranking_policy_network_folder,
    accelerator="gpu"
)  