# Scaffold-based reinforcement learning and molecule generation

In some cases you might have idea of a scaffold or fragments that your new molecules should contain. In this case, it usefull to do the reinforcement learning and the molecule generation with preselected fragments (single fragment or combination of multiple fragments).

In this tutorial, we show an example with a single pyrazine and a combination of a pyrazine and a thiophene for the graph-based transformer.

To understand this tutorial and to have all necessary files, we expect the user to be familiar with the basic tutorial on [graph transformer](../Graph-Transformer.ipynb).

In [12]:
import sys
sys.path.append('../')
from utils import smilesToGrid

In [13]:
frags = ['c1cnccn1', 'c1cnccn1.c1ccsc1' ]  

smilesToGrid(frags)

MolGridWidget()

# Building the environment

In this example, we build both the graph- and smiles-based models with the same vocabularies, pretrained and finetuned generators and QSAR models as in the [graph transformer](../Graph-Transformer.ipynb) tutorial.

First we build the environment which is identique for both models,

In [14]:
from drugex.training.scorers.properties import Property
from drugex.training.scorers.modifiers import ClippedScore
from drugex.training.environment import DrugExEnvironment
from drugex.training.rewards import WeightedSum
from drugex.training.scorers.qsprpred import QSPRPredScorer

from qsprpred.models.models import QSPRsklearn

scorers, thresholds = [], []

# # QSAR model for A2A - active target
# scorer_a2a = QSPRPredScorer(QSPRsklearn(name='A2AR_RandomForestClassifier', base_dir='../data/models/qsar'))
# scorer_a2a.setModifier(ClippedScore(lower_x=6.5-2, upper_x=6.5))
# scorers.append(scorer_a2a)
# thresholds.append(0.99)

# SAscore
sascore = Property("SA", modifier=ClippedScore(lower_x=7, upper_x=3))
scorers.append(sascore)
thresholds.append(0.5)

# QED
qed = Property("QED", modifier=ClippedScore(lower_x=0.2, upper_x=0.8))
scorers.append(qed)
thresholds.append(0.5)

# Create environment
environment = DrugExEnvironment(scorers, thresholds, reward_scheme=WeightedSum())

# Graph-based Transformer
## Data Preprocessing

We use the same encoder as in [graph transformer](../Graph-Transformer.ipynb) tutorial to create molecules from the fragments and encode fragment-molecule pairs, with a small modifications:
1. Instead of using a `fragmenter` we create dummy molecules from the fragments with `dummyMolsFromFragments` 
2. Set `splitter` to `None`, `n_proc` and `chunk_size` to 1 

In [15]:
import os
from drugex.data.datasets import GraphFragDataSet
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.data.fragments import FragmentCorpusEncoder, GraphFragmentEncoder
from drugex.data.corpus.vocabulary import VocGraph

fragmenter = dummyMolsFromFragments()
splitter = None

encoder = FragmentCorpusEncoder(
    fragmenter=fragmenter, 
    encoder=GraphFragmentEncoder(
        VocGraph(n_frags=4) 
    ),
    pairs_splitter=splitter, 
    n_proc=1,
    chunk_size=1
)

graph_input_folder = "datasets/encoded/graph"
if not os.path.exists(graph_input_folder):
    os.makedirs(graph_input_folder)
    
dataset = GraphFragDataSet(f"{graph_input_folder}/scaffolds.tsv", rewrite=True)

In [16]:
encoder.apply(list(frags), encodingCollectors=[dataset])

Creating fragment-molecule pairs (batch processing):   0%|          | 0/2 [00:00<?, ?it/s]

Encoding fragment-molecule pairs. (batch processing):   0%|          | 0/2 [00:00<?, ?it/s]

## Reinforcement learning

Then we can build the explorer composed of the agent, the prior and the enviroment.

In [17]:
from drugex.training.explorers import FragGraphExplorer
from drugex.training.generators import GraphTransformer
from drugex.data.corpus.vocabulary import VocGraph

GPUS = gpus=(1,)

vocabulary = VocGraph.fromFile('../data/models/finetuned/graph/A2AR_FT.vocab')
agent = GraphTransformer(voc_trg=vocabulary, use_gpus=GPUS)
agent.loadStatesFromFile('../data/models/finetuned/graph/A2AR_FT.pkg')
prior = GraphTransformer(voc_trg=vocabulary, use_gpus=GPUS)
prior.loadStatesFromFile('../data/models/finetuned/graph/A2AR_FT.pkg')

explorer = FragGraphExplorer(agent=agent, env=environment, mutate=prior, epsilon=0.1, use_gpus=GPUS)

But used only the selected scaffolds as input fragments for training and validation. As the initial set only contains two inputs, they are sampled 100 times to create the training set and 100*0.2=20 to create the test set.

In [18]:
from drugex.data.datasets import GraphFragDataSet

data_path = f'{graph_input_folder}/scaffolds.tsv'
train_loader = GraphFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100)
test_loader = GraphFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100, n_samples_ratio=0.2)

After that we can finally start the training loop:

In [19]:
from drugex.training.monitors import FileMonitor

monitor = FileMonitor("models/reinforced/graph/scaffolds", save_smiles=True) 
explorer.fit(train_loader, test_loader, monitor=monitor, epochs=3)

Fitting graph explorer:   0%|          | 0/3 [00:00<?, ?it/s]

Iterating over training batches:   0%|          | 0/1 [00:00<?, ?it/s]

Fitting graph explorer:  33%|███▎      | 1/3 [01:03<02:06, 63.15s/it]

Iterating over training batches:   0%|          | 0/1 [00:00<?, ?it/s]

Fitting graph explorer:  67%|██████▋   | 2/3 [02:01<01:00, 60.62s/it]

Iterating over training batches:   0%|          | 0/1 [00:00<?, ?it/s]

Fitting graph explorer: 100%|██████████| 3/3 [03:10<00:00, 63.65s/it]


We look that all created molecules include either a pyrazine or a pyrazine and a thiophene.

In [20]:
import pandas as pd 

df_smiles = pd.read_csv('models/reinforced/graph/scaffolds_smiles.tsv', sep='\t')
smilesToGrid(df_smiles.SMILES.tolist())

MolGridWidget()

## *de novo* generation

Once we have the optimized model (not the case in tutorial as for speed is set to 3 instead of 1000), it can be used to sample *novel*  mocules.

In [21]:
reinforced = GraphTransformer(voc_trg=VocGraph(), use_gpus=GPUS)
reinforced.loadStatesFromFile('models/reinforced/graph/scaffolds.pkg')

denovo = reinforced.generate(input_frags=frags, num_samples=10, evaluator=environment)

Initialized empty dataset. The data set file does not exist (yet): /tmp/tmpchtw301s. You can add data by calling this instance with the appropriate parameters.


Creating fragment-molecule pairs (batch processing):   0%|          | 0/1 [00:00<?, ?it/s]

Encoding fragment-molecule pairs. (batch processing):   0%|          | 0/1 [00:00<?, ?it/s]

Generating molecules:   0%|          | 0/10 [00:00<?, ?it/s]

or without applying the modifiers to better evaluate the predicted properties.