In [1]:
import sys
sys.path.append('..')

In [2]:
from logics_pack import global_settings, chemistry, reinvent, predictor, reward_functions
from logics_pack import analysis, smiles_vocab, smiles_lstm
import pandas as pd
import numpy as np
import json
import torch

project_paths = global_settings.build_project_paths(project_dir='../')
expset_obj = global_settings.ExperimentSettings(project_paths['EXPERIMENT_SETTINGS_JSON'])

Perform REINVENT fine-tuning to build agent generator

In [3]:
# REINVENT fine-tuning config
config = global_settings.Object()
config.tokens_path = project_paths['SMILES_TOKENS_PATH']
config.pretrain_setting_path = project_paths['PRETRAIN_SETTING_JSON']
config.pretrained_model_path = project_paths['PROJECT_DIR'] + 'model-prior/prior_e10.ckpt'
config.featurizer = predictor.featurizer
config.predictor_path = project_paths['PROJECT_DIR'] + "model-kor/predictor/kor_rfr_cv%s.pkl"%expset_obj.get_setting("kor-pred-best-cv")

config.max_epoch = 30000  # "epoch" is actually the training batches for reinforcement learning models
config.save_period = 600
config.save_size = 20000
config.save_ckpt_fmt = project_paths['PROJECT_DIR'] + 'model-kor/reinvent/kor_reinv_e%d.ckpt'
config.sample_fmt = project_paths['PROJECT_DIR'] + 'model-kor/reinvent/kor_reinv_e%d.txt'
config.sigma = 10
config.rewarding = reward_functions.pAff_to_reward_t1
config.train_batch_size = 128
config.finetune_lr = 0.0002
config.sampling_bs = 256

config.device_name = 'cpu'

In [None]:
# perform fine-tuning
reinvent.REINVENT_training(config)

Load REINVENT agent generator and sample some examples

In [4]:
vocab_obj = smiles_vocab.Vocabulary(init_from_file=config.tokens_path)
smtk = smiles_vocab.SmilesTokenizer(vocab_obj)

with open(config.pretrain_setting_path, 'r') as f:
    model_setting = json.load(f)
    
# load agent model (epoch=3000)
agent_ckpt = torch.load(config.save_ckpt_fmt%3000, map_location='cpu')
lstm_agent = smiles_lstm.SmilesLSTMGenerator(vocab_obj, model_setting['emb_size'], model_setting['hidden_units'], device_name='cpu')
lstm_agent.lstm.load_state_dict(agent_ckpt['model_state_dict'])

<All keys matched successfully>

In [5]:
# sampling
ssplr = analysis.SafeSampler(lstm_agent, batch_size=16)
generated_smiles = ssplr.sample_clean(50, maxlen=150)
display(generated_smiles)

['O=C(Cn1ncc2c1-c1ccccc1OC2)N1CCN(c2ccccc2)CC1',
 'O=S(=O)(NCC(c1ccco1)N1CCCCC1)c1cccc2ccccc12',
 'Cc1nc(C(=O)N2CCN(C(=O)c3cccn3-c3ccccc3F)CC2)cs1',
 'CC1CN(S(=O)(=O)c2cccc(C(=O)N3CCN(c4ccccc4F)CC3)c2)CCN1c1cccc(Cl)c1',
 'CCOC(=O)C1CCN(C(=O)c2cc(=O)c3ccccc3o2)CC1',
 'Cc1cc(C(=O)N2CCC2(C)C(=O)Nc2ccc3sc(C)nc3c2)c(C)o1',
 'COCC1CN=C(Nc2ncco2)N1Cc1ccc(Cl)cc1',
 'Cc1cnn(C2CN3CC(COCc4ccc(Cl)nc4)CC(C2)O3)c1',
 'CCCN(CC(O)(Cn1cncn1)c1ccc(F)cc1F)C(C)CC',
 'CC1=C(C(=O)Nc2ccc3ncncc3c2)C(c2ccc(F)c(C(N)=O)c2)C2C(=O)N=C(S)N2C1',
 'CCOC(=O)CC1c2ccccc2C2CCN(C(=O)OCc3ccccc3)C(C(=O)OC)C21',
 'C=C(C)C1CCC2(C(=O)O)CCC3(C)C(CCC4C5(C)CCC(=O)C(C)(C)C5CCC43C)C12',
 'Cc1ccccc1C(=O)N1CC(N2CCN(S(=O)(=O)N3CCCCCCC3)CC2)C1',
 'Cc1cccc(NC(=O)N2CCCC2c2cnc(-c3ccccc3F)o2)c1',
 'CC(=O)NCCNC(=O)CCc1cc(C)c(O)c(Cl)c1',
 'Cc1ccccc1C(=O)N1CCC(C(=O)c2ccccc2)CC1',
 'CCCCCCCCCCCCCC(O)C(N)Cc1ccccc1',
 'CC1CCCN(C(=O)CSCc2ccccc2Cl)C1',
 'CCCCOC(=O)N1CCN(C(=O)C(CC)NC(=O)C2CCNCC2)CC1',
 'C=C(C)C1CCC2(CCOC)CCC3(C)C(CCC4C5(C)CCC(OC(C)