In [None]:
import sys
sys.path.append(..)
import gentrl
import torch
import pandas as pd
import numpy as np
torch.cuda.set_device(0)

from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol
from moses.utils import disable_rdkit_log
disable_rdkit_log()
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors

In [None]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();

In [None]:
model.load('pretrained_model/')
model.cuda();

### Reward 1: Baseline Training - only penalized_logP

In [None]:
def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

### Reward 2: Promoting Diversity: Penalized logP + Diversity

In [None]:
def calculate_diversity(mol, reference_mol_fps):
    mol_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024) 
    max_similarity = max(DataStructs.BulkTanimotoSimilarity(mol_fp, reference_mol_fps))
    novelty = 1 - max_similarity
    return novelty

def penalized_logP_diversity(mol_or_smiles, reference_mol_fps, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    diversity = calculate_diversity(mol, reference_mol_fps)
    if masked and not mol_passes_filters(mol):
        return default
    total_reward = reward + diversity
    return total_reward

In [None]:
ref = pd.read_csv('../data/train_plogp_plogpm.csv') #SMILES training set molecules
reference_mols = [Chem.MolFromSmiles(smiles) for smiles in ref['smiles'].iloc[:5000]]
reference_mol_fps = [AllChem.GetMorganFingerprintAsBitVect(ref_mol, 2, nBits=1024) for ref_mol in reference_mols]

### Reward 3: Targeted Molecule Design: Penalized logP + Diversity + Substructure Constraint

In [None]:
def substructure(mol):
    melatonin_mol = Chem.MolFromSmiles('CC(=O)NCCC1=CNC2=C1C=C(C=C2)OC')
    return mol.HasSubstructMatch(melatonin_mol)
    # true if yes false if no

def penalized_logP_diversity_substructure(mol_or_smiles, reference_mol_fps, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None or not substructure(mol):
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    diversity = calculate_diversity(mol, reference_mol_fps)
    if masked and not mol_passes_filters(mol):
        return default
    total_reward = reward + diversity
    return total_reward

In [None]:
model.train_as_rl(lambda mol: penalized_logP_diversity_substructure(mol, reference_mol_fps=reference_mol_fps))

In [None]:
! mkdir -p rl_model

In [None]:
model.save('./rl_model/')