# Scaffold-based reinforcement learning and molecule generation

In some cases you might have idea of a scaffold or fragments that your new molecules should contain. In this case, it usefull to do the [reinforcement learning](rl_optimization.ipynb) and the [molecule generation](generation.ipynb) with preselected fragments (single fragment or combination of multiple fragments).

In this tutorial, we show an example with a single pyrazine and a combination of two pyrazines for the smiles- and graph-based transformer.

To understand this tutorial we expect the user to be familiar with the basic tutorials on [data preprocessing](../datasets.ipynb), [reinforcement learning](rl_optimization.ipynb) and the [molecule generation](generation.ipynb).

## Data Preprocessing

In [1]:
import sys
sys.path.append('..')
sys.path.append('/zfsdata/data/sohvi/DrugEx/')

from utils import smilesToGrid

frags = ['c1cnccn1', 'c1cnccn1.c1cnccn1' ]  

smilesToGrid(frags)

  from .autonotebook import tqdm as notebook_tqdm


### For Graph-based Transformer

We use the same encoder as in [Preparing Data for the Graph-Based Transformer](../datasets.ipynb) to create molecules from the fragments and encode fragment-molecule pairs, with a small modifications:
1. Instead of using a `fragmenter` we create dummy molecules from the fragments with `dummyMolsFromFragments` 
2. Set `splitter` to `None`, `n_proc` and `chunk_size` to 1 

In [2]:
import os
from drugex.data.datasets import GraphFragDataSet
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.data.fragments import FragmentCorpusEncoder, GraphFragmentEncoder
from drugex.data.corpus.vocabulary import VocGraph

fragmenter = dummyMolsFromFragments()
splitter = None

encoder = FragmentCorpusEncoder(
    fragmenter=fragmenter, 
    encoder=GraphFragmentEncoder(
        VocGraph(n_frags=4) 
    ),
    pairs_splitter=splitter, 
    n_proc=1,
    chunk_size=1
)

graph_input_folder = "data/sets/graph/"
if not os.path.exists(graph_input_folder):
    os.makedirs(graph_input_folder)
    
dataset = GraphFragDataSet(f"{graph_input_folder}/scaffold_graph.tsv", rewrite=True)

In [4]:
encoder.apply(list(frags), encodingCollectors=[dataset])

Creating fragment-molecule pairs (batch processing): 100%|██████████| 2/2 [00:00<00:00, 143.44it/s]
Encoding fragment-molecule pairs. (batch processing):   0%|          | 0/2 [00:00<?, ?it/s]The following exception occured while encoding fragment c1cnccn1.c1cnccn1 for molecule c1cn(-c2cnccn2)ccn1: 'NoneType' object has no attribute 'GetSubstructMatches'
Failed to convert item None to the new representation in <drugex.data.fragments.FragmentPairsEncodedSupplier object at 0x7f223c3a2af0>
	 Cause: FragmentEncodingException('Failed to encode fragment c1cnccn1.c1cnccn1 from molecule: c1cn(-c2cnccn2)ccn1')
Encoding fragment-molecule pairs. (batch processing): 100%|██████████| 2/2 [00:00<00:00, 49.02it/s]


### For SMILES-based Transformer

We use the same encoder as in [Preparing Data for the SMILES-Based Transformer](../datasets.ipynb) to create molecules from the fragments and encode fragment-molecule pairs, with a small modifications:
1. Instead of using a `fragmenter` we create dummy molecules from the fragments with `dummyMolsFromFragments` 
2. Set `splitter` to `None`, `min_len` to 2, `n_proc` and `chunk_size` to 1 

In [5]:
import os
from drugex.data.datasets import SmilesFragDataSet
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.data.fragments import FragmentCorpusEncoder, SequenceFragmentEncoder
from drugex.data.corpus.vocabulary import VocSmiles

fragmenter = dummyMolsFromFragments()
splitter = None

encoder = FragmentCorpusEncoder(
    fragmenter=fragmenter, 
    encoder=SequenceFragmentEncoder(
        VocSmiles(min_len=2) 
    ),
    pairs_splitter=splitter, 
    n_proc=1,
    chunk_size=1
)

smiles_input_folder = "data/sets/smiles/"
if not os.path.exists(smiles_input_folder):
    os.makedirs(smiles_input_folder)
    
dataset = SmilesFragDataSet(f"{smiles_input_folder}/scaffold_smi.tsv", rewrite=True)

Initialized empty dataset. The data set file does not exist (yet): data/sets/smiles//scaffold_smi.tsv. You can add data by calling this instance with the appropriate parameters.


In [6]:
encoder.apply(list(frags), encodingCollectors=[dataset])

Creating fragment-molecule pairs (batch processing): 100%|██████████| 2/2 [00:00<00:00, 150.66it/s]
Encoding fragment-molecule pairs. (batch processing): 100%|██████████| 2/2 [00:00<00:00, 132.32it/s]


## Reinforcement learning

In this example, we build the model with the same vocabulary, pretrained and finetuned generators and QSAR models as in the [general RL example](../rl_optimization.ipynb).

First we build the environment,

In [24]:
from drugex.training.scorers.properties import Property
from drugex.training.scorers.modifiers import ClippedScore
from drugex.training.environment import DrugExEnvironment
from drugex.training.rewards import ParetoCrowdingDistance
from drugex.training.scorers.predictors import Predictor
from drugex.training.interfaces import Scorer

class ModelScorer(Scorer):
    
    def __init__(self, model, prefix):
        super().__init__()
        self.model = model
        self.prefix = prefix
    
    def getScores(self, mols, frags=None):
        X = Predictor.calculateDescriptors(mols)
        return self.model.predict(X)
    
    def getKey(self):
        return f"{self.prefix}_{type(self.model)}"

qed = Property("QED", modifier=ClippedScore(lower_x=0, upper_x=1.0))
sascore = Property("SA", modifier=ClippedScore(lower_x=4.5, upper_x=0))
scorers = [qed, sascore]
thresholds = [0.5, 0.5]

# scorer_a1 = joblib.load('data/models/reinforced/graph/scorer_a1.pkg') #pickle.load(open('data/models/reinforced/graph/scorer_a1.pkg', 'rb'))
# scorer_a3 = joblib.load('data/models/reinforced/graph/scorer_a3.pkg') #pickle.load(open('data/models/reinforced/graph/scorer_a3.pkg', 'rb'))

# scorer_a1 = Predictor.fromFile()

# qed = Property("QED", modifier=ClippedScore(lower_x=0, upper_x=1.0))
# sascore = Property("SA", modifier=ClippedScore(lower_x=4.5, upper_x=0))
# scorers = [scorer_a1, scorer_a3, qed, sascore]
# thresholds = [0.99, 0.99, 0.0, 0.0]

environment = DrugExEnvironment(scorers, thresholds, reward_scheme=ParetoCrowdingDistance())

### Graph-based Transformer

and the explorer composed of the agent, the prior and the enviroment.

In [25]:
from drugex.training.models.explorer import GraphExplorer
from drugex.training.models.transform import GraphModel
from drugex.data.corpus.vocabulary import VocGraph

GPUS = gpus=(0,1)

vocabulary = VocGraph.fromFile('../data/models/finetuned/graph/ligand_finetuned.vocab')
finetuned = GraphModel(voc_trg=vocabulary, use_gpus=GPUS)
finetuned.loadStatesFromFile('../data/models/finetuned/graph/chembl_ligand.pkg')
pretrained = GraphModel(voc_trg=vocabulary, use_gpus=GPUS)
pretrained.loadStatesFromFile('../jupyter/models/pretrained/graph/chembl27/chembl27_graph.pkg')

explorer = GraphExplorer(agent=pretrained, env=environment, mutate=finetuned, epsilon=0.1, use_gpus=GPUS)

But used only the selected scaffolds as input fragments for training and validation. As the initial set only contains two inputs, they are sampled 100 times to create the training set and 100*0.2=20 to create the test set.

In [26]:
from drugex.data.datasets import GraphFragDataSet

data_path = 'data/sets/graph/scaffold_graph.tsv'
train_loader = GraphFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100)
test_loader = GraphFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100, n_samples_ratio=0.2)

After that we can finally start the training loop:

In [27]:
from drugex.training.monitors import FileMonitor

monitor = FileMonitor("data/models/reinforced/graph/scaffold_rl", verbose=True) 
explorer.fit(train_loader, test_loader, monitor=monitor, epochs=3)

Batch: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
Batch: 100%|██████████| 1/1 [00:01<00:00,  1.71s/it]
Batch: 100%|██████████| 1/1 [00:01<00:00,  1.46s/it]
100%|██████████| 3/3 [00:10<00:00,  3.47s/it]


We look that all created molecules include either on or two pyrazine rings:

In [28]:
import pandas as pd 

df_smiles = pd.read_csv('data/models/reinforced/graph/scaffold_rl_smiles.tsv', sep='\t')
smilesToGrid(df_smiles.Smiles.tolist())

### SMILES-based Transformer

!!! Does not work yet as we do not have pretrained/finetuned SMILES-based transformer models available !!!!

and the explorer composed of the agent, the prior and the enviroment.

In [None]:
from drugex.training.models.explorer import SmilesExplorer
from drugex.training.models.transform import GPT2Model
from drugex.data.corpus.vocabulary import VocSmiles

GPUS = gpus=(0,1)

vocabulary = VocSmiles.fromFile('../data/models/finetuned/smiles/ligand_finetuned.vocab')
finetuned = GPT2Model(voc_trg=vocabulary, use_gpus=GPUS)
finetuned.loadStatesFromFile('../data/models/finetuned/smiles/chembl_ligand.pkg')
pretrained = GPT2Model(voc_trg=vocabulary, use_gpus=GPUS)
pretrained.loadStatesFromFile('../jupyter/models/pretrained/smiles/chembl27/chembl27_graph.pkg')

explorer = SmilesExplorer(agent=pretrained, env=environment, mutate=finetuned, epsilon=0.1, use_gpus=GPUS)

But used only the selected scaffolds as input fragments for training and validation. As the initial set only contains two inputs, they are sampled 100 times to create the training set and 100*0.2=20 to create the test set.

In [30]:
from drugex.data.datasets import SmilesFragDataSet

data_path = 'data/sets/smiles/scaffold_smi.tsv'
train_loader = SmilesFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100)
test_loader = SmilesFragDataSet(data_path).asDataLoader(batch_size=1024, n_samples=100, n_samples_ratio=0.2)

After that we can finally start the training loop:

In [None]:
from drugex.training.monitors import FileMonitor

monitor = FileMonitor("data/models/reinforced/smiles/scaffold_rl", verbose=True) 
explorer.fit(train_loader, test_loader, monitor=monitor, epochs=3)

Batch: 100%|██████████| 1/1 [00:04<00:00,  4.75s/it]
Batch: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it]
Batch: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]
100%|██████████| 3/3 [00:14<00:00,  4.75s/it]


# *de novo* Generation