In [None]:
import sys
sys.path.append('..')

In [None]:
from logics_pack import global_settings, chemistry, segler, predictor, reward_functions, analysis, smiles_vocab, smiles_lstm
import pandas as pd
import numpy as np
import json
import torch

project_paths = global_settings.build_project_paths(project_dir='../')
expset_obj = global_settings.ExperimentSettings(project_paths['EXPERIMENT_SETTINGS_JSON'])

Perform Segler fine-tuning to build agent generator

In [None]:
config = global_settings.Object()
config.tokens_path = project_paths['SMILES_TOKENS_PATH']
config.pretrain_setting_path = project_paths['PRETRAIN_SETTING_JSON']
config.pretrained_model_path = project_paths['PROJECT_DIR'] + 'model-prior/prior_e10.ckpt'
config.featurizer = predictor.featurizer
config.predictor_path = project_paths['PROJECT_DIR'] + \
                            "model-pik3ca/predictor/pik3ca_rfr_cv%s.pkl"%expset_obj.get_setting("pik3ca-pred-best-cv")

config.score_thrs = global_settings.PIK3CA_ACT_THRS
config.max_epoch = 150
config.save_period = 3
config.save_size = 20000
config.save_ckpt_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_e%d.ckpt'
config.sample_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_e%d.txt'
config.init_gen_size = 100000
config.ssz_per_epoch = 10000
config.ft_period = 5
config.finetune_lr = 0.0001
config.finetune_bs = 128
config.sampling_bs = 256
config.record_actives_size = 100000
config.record_actives_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_recacts_e%d.smi'

config.device_name = 'cpu'

In [None]:
# perform fine-tuning
segler.Segler_training(config)

Load Segler agent generator and sample some examples

In [None]:
vocab_obj = smiles_vocab.Vocabulary(init_from_file=config.tokens_path)
smtk = smiles_vocab.SmilesTokenizer(vocab_obj)

with open(config.pretrain_setting_path, 'r') as f:
    model_setting = json.load(f)
    
# load agent model (epoch=150)
agent_ckpt = torch.load(config.save_ckpt_fmt%150, map_location='cpu')
lstm_agent = smiles_lstm.SmilesLSTMGenerator(vocab_obj, model_setting['emb_size'], model_setting['hidden_units'], device_name='cpu')
lstm_agent.lstm.load_state_dict(agent_ckpt['model_state_dict'])

In [None]:
# sampling
ssplr = analysis.SafeSampler(lstm_agent, batch_size=16)
generated_smiles = ssplr.sample_clean(50, maxlen=150)
display(generated_smiles)

Subsidiary files building for evaluation phase

In [None]:
config.vc_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_vc_e%d.smi'  # save valid & canonical smiles
config.npfps_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_npfps_e%d.npy'  # save fingerprint in npy
config.fcvec_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/segler/pik3ca_segler_fcvec_e%d.npy'  # save Frechet ChemNet vectors

# epochs = list(range(0, config.max_epoch+1, config.save_period))
epochs = [150]

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # use tensorflow cpu

import fcd
from logics_pack import frechet_chemnet
fc_ref_model = fcd.load_ref_model()

In [None]:
for epo in epochs:
    print(epo)
    with open(config.sample_fmt%epo, 'r') as f:
        gens = [line.strip() for line in f.readlines()]
    vcs, invids = chemistry.get_valid_canons(gens)
    print("- count invalids: ", len(invids))
    with open(config.vc_fmt%epo, 'w') as f:
        f.writelines([line+'\n' for line in vcs])
    fps = chemistry.get_fps_from_smilist(vcs)
    np.save(config.npfps_fmt%epo, chemistry.rdk2npfps(fps))
    fcvecs = fcd.get_predictions(fc_ref_model, vcs)  # ChemNet vectors
    np.save(config.fcvec_fmt%epo, fcvecs)

Evaluate FCD and OTD on validation set, and pick the best epoch

In [None]:
# loading validation dataset
with open(project_paths['PIK3CA_FOLD_JSON'], 'r') as f:
    pik3_folds = json.load(f)
data_npfps = np.load(project_paths['PIK3CA_DATA_FP'])
data_fcvecs = np.load(project_paths['PIK3CA_DATA_FCVEC'])

val_fold_id = expset_obj.get_setting('pik3ca-pred-best-cv')
val_npfps = data_npfps[pik3_folds[val_fold_id]]
val_rdkfps = chemistry.np2rdkfps(val_npfps)
val_fcvecs = data_fcvecs[pik3_folds[val_fold_id]]

dsize = len(val_rdkfps)  # demand size for OT
ssize = dsize*global_settings.OT_CALC_REPEATS  # supply size for repeated OT   

In [None]:
val_fcd_list = []
val_otd_list = []
for epo in epochs:
    print(epo)
    # load fc vectors of generation
    gen_fcvecs = np.load(config.fcvec_fmt%epo)
    fcdval = frechet_chemnet.fcd_calculation(val_fcvecs, gen_fcvecs)
    val_fcd_list.append(fcdval)
    
    gen_npfps = np.load(config.npfps_fmt%epo)[:ssize]  # only need this amount
    gen_rdkfps = chemistry.np2rdkfps(gen_npfps)
    simmat = analysis.calculate_simmat(gen_rdkfps, val_rdkfps)  # row:gen, col:data
    distmat = analysis.transport_distmat(analysis.tansim_to_dist, simmat, global_settings.OT_CALC_REPEATS)
    _, _, motds = analysis.repeated_optimal_transport(distmat, repeat=global_settings.OT_CALC_REPEATS)
    val_otd_list.append(np.mean(motds))

In [None]:
# validation FCDxOTD
val_FCDxOTD = np.array(val_fcd_list)*np.array(val_otd_list)
# find the best epoch
best_epoch = epochs[np.argmin(val_FCDxOTD)]
# register the best epoch
expset_obj.update_setting('pik3ca-segler-best-epoch', best_epoch)

In [None]:
print(expset_obj.get_setting('pik3ca-segler-best-epoch'))