# Step 3: Policy Network Training

This tutorial demonstrates how to train ranking and filtering policy netwirk in ``SynPlanner``

## Basic recommendations

**1. Prefer ranking policy network over filtering policy network**

The filtering policy network in its current implementation requires a lot of computational resources and its training is practically feasible with many CPUs and several dozen GB of RAM in case of large training sets. The bottleneck of the current implementation is the preparation of the training dataset, particularly the generation of binary vectors if successfully applied reaction rules to each training molecule. Thus, with limited computational
resources, it is recommended to use a ranking policy network.

**2. Use a filtering policy network for the portability of reaction rules between different tools**

Filtering policy networks can be trained with any set of reaction rules, including those generated with other software because filtering network training does not depend on the original reaction dataset from which the reaction rules were extracted. In this case, the filtering policy network can be used for comparison of reaction rules extracted with different software/tools.

**3. Reduce the size of the training molecules for filtering policy network**

The problem of computational resources for filtering policy networks can be partially solved by a drastic reduction of the training set of molecules.

## 1. Set up input and output data locations

The ``SynPlanner`` input data will be downloaded from the ``HuggingFace`` repository to the specified directory.

In [2]:
import os
import pickle
import shutil
from pathlib import Path
from synplan.utils.loading import download_all_data

# download SynPlanner data
data_folder = Path("synplan_data").resolve()
download_all_data(save_to=data_folder)

# results folder
results_folder = Path("tutorial_results").resolve()
results_folder.mkdir(exist_ok=True)

# input data
# use default filtered data from tutorial folder or replace with custom data prepared with data curation tutorial
# be sure that you use the same reaction dataset from which the reaction rules were extracted 

reaction_rules_path = results_folder.joinpath("uspto_reaction_rules.pickle") # needed for both ranking and filtering policy network training

filtered_data_path = results_folder.joinpath("uspto_filtered.smi") # needed for ranking policy network training
molecules_data_path = data_folder.joinpath("synplan_data/chembl/molecules_for_filtering_policy_training.smi") # needed for filtering policy network training

# output data
ranking_policy_network_folder = results_folder.joinpath("ranking_policy_network")
filtering_policy_network_folder = results_folder.joinpath("filtering_policy_network")

# output data
ranking_policy_dataset_path = ranking_policy_network_folder.joinpath("ranking_policy_dataset.pt") # the generated training set for ranking network
filtering_policy_dataset_path = filtering_policy_network_folder.joinpath("filtering_policy_dataset.pt") # the generated training set for ranking network

Fetching 25 files:   0%|          | 0/25 [00:00<?, ?it/s]

## 2. Ranking policy training

### Ranking network configuration

In [2]:
from synplan.utils.config import PolicyNetworkConfig
from synplan.ml.training.supervised import create_policy_dataset, run_policy_training

training_config = PolicyNetworkConfig(
    policy_type="ranking",  # the type of policy network
    num_conv_layers=5,  # the number of graph convolutional layers in the network
    vector_dim=512,  # the dimensionality of the final embedding vector
    learning_rate=0.0008,  # the learning rate for the training process
    dropout=0.4,  # the dropout rate
    num_epoch=100,  # the number of epochs for training
    batch_size=100,
)  # the size of training batch of input data

### Creating ranking network training set

Next, we create the policy dataset using the `create_policy_dataset` function. This involves specifying paths to the reaction rules and the reaction data:

In [3]:
datamodule = create_policy_dataset(
    dataset_type="ranking",
    reaction_rules_path=reaction_rules_path,
    molecules_or_reactions_path=filtered_data_path,
    output_path=ranking_policy_dataset_path,
    batch_size=training_config.batch_size,
    num_cpus=4,
)

Number of reactions processed: 1019304 [2:36:15]


Training set size: 616841, validation set size: 154211


### Running ranking policy network training

Finally, we train the policy network using the `run_policy_training` function. This step involves feeding the dataset and the training configuration into the network:

In [4]:
run_policy_training(
    datamodule,  # the prepared data module for training
    config=training_config,  # the training configuration
    results_path=ranking_policy_network_folder,
)  # path to save the training results

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type           | Params | Mode 
-------------------------------------------------------
0 | embedder    | GraphEmbedding | 1.3 M  | train
1 | y_predictor | Linear         | 17.9 M | train
-------------------------------------------------------
19.2 M    Trainable params
0         Non-trainable params
19.2 M    Total params
76.944    Total estimated model params size (MB)


Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


Sanity Checking: |      | 0/? [00:00<?, ?it/s]

Training: |             | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

Validation: |           | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


Policy network balanced accuracy: 0.88


## 3. Filtering policy training

### Filtering network configuration

In [5]:
from synplan.utils.config import PolicyNetworkConfig
from synplan.ml.training.supervised import create_policy_dataset, run_policy_training

training_config = PolicyNetworkConfig(
    policy_type="filtering",  # the type of policy network
    num_conv_layers=5,  # the number of graph convolutional layers in the network
    vector_dim=512,  # the dimensionality of the final embedding vector
    learning_rate=0.0008,  # the learning rate for the training process
    dropout=0.4,  # the dropout rate
    num_epoch=100,  # the number of epochs for training
    batch_size=100,
)  # the size of training batch of input data

### Creating filtering network training set

Next, we create the policy dataset using the `create_policy_dataset` function. This involves specifying paths to the reaction rules and the molecules dataset:

In [1]:
datamodule = create_policy_dataset(
    dataset_type="filtering",
    reaction_rules_path=reaction_rules_path,
    molecules_or_reactions_path=molecules_data_path,
    output_path=filtering_policy_dataset_path,
    batch_size=training_config.batch_size,
    num_cpus=4,
)

### Running filtering policy network training

Finally, we train the policy network using the `run_policy_training` function. This step involves feeding the dataset and the training configuration into the network:

In [None]:
run_policy_training(
    datamodule,  # the prepared data module for training
    config=training_config,  # the training configuration
    results_path=filtering_policy_network_folder,
)  # path to save the training results

## Results

If the tutorial is executed successfully, you will get in the results folder three reaction data files (from reaction curation tutorial), corresponding extracted reaction rules (from reaction rules extraction tutorial) and trained ranking and filtering policy network:
- original reaction data
- standardized reaction data
- filtered reaction data
- extracted reaction rules
- ranking policy network folder (the training set and trained network)
- filtering policy network folder (the training set and trained network)

In [3]:
sorted(Path(results_folder).iterdir(), key=os.path.getmtime, reverse=False)

[PosixPath('/home1/dima/synplanner/tutorials/tutorial_results/uspto_original.smi'),
 PosixPath('/home1/dima/synplanner/tutorials/tutorial_results/uspto_standardized.smi'),
 PosixPath('/home1/dima/synplanner/tutorials/tutorial_results/uspto_filtered.smi'),
 PosixPath('/home1/dima/synplanner/tutorials/tutorial_results/uspto_reaction_rules.pickle'),
 PosixPath('/home1/dima/synplanner/tutorials/tutorial_results/ranking_policy_network')]