In [None]:
import gentrl
import torch
import os
import minisom
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol

import pickle
import joblib

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [None]:
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
torch.cuda.set_device(0)

MODEL_PATH = "model/Bayes_model_iter599/"
RL_PATH = "ddr1_bayes_regression_1201.pkl"

LR = 1e-3
Iteration = 3000
BATCH_SIZE = 500

SAVE_PATH = f"DDR1_BAYES_lr{LR}"

In [None]:
def init_model():
    enc = gentrl.RNNEncoder(latent_size=50,hidden_size=128)
    dec = gentrl.DilConvDecoder(latent_input_size=50)
    model = gentrl.GENTRL(enc, dec, 50 * [('c', 10)], [('c', 10)], tt_int=30,beta=0.001)
    return model

In [None]:
# reward function 
def reward_fn(mol_or_smiles,cur_iteration=0,bayes_regression = bayes_regression, default=0):

    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default

    xx = AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(mol_or_smiles), 2, nBits=2048)
    xx = np.vstack(xx).reshape(1,-1)
    
    mfp_sum = np.array(xx).sum()
    if mfp_sum<20:
            return default
    
    mfp = 1 / (1+np.exp(-(mfp_sum-60)/10))
    
    bayes_regression = np.exp(-bayes_regression.predict(np.array(xx).reshape(1,-1))[0])
    
    reward = mfp * bayes_regression
    
    return reward

In [None]:
# model and reinforcement model load
model = init_model()
model.load(MODEL_PATH)
model.cuda()
bayes_regression = joblib.load(open(RL_PATH,'rb'))
os.mkdir(SAVE_PATH)

In [None]:

#record the reward value and valid molecule percentage every 100 epoch
records_mean = model.train_as_rl(reward_fn,num_iterations= Iteration,batch_size=BATCH_SIZE,lr_lp=LR,lr_dec=1e-8,exploration_ratio=0.1, file_path=SAVE_PATH,topN=1)

global_stats, record_mean_reward, record_valid_perc = records_mean

pd.DataFrame(record_mean_reward).to_csv("./"+file_path+"/mean_reward.csv")
pd.DataFrame(record_valid_perc).to_csv("./"+file_path+"/valid_perc.csv")
