In [None]:
import sys
sys.path.append('..')

In [None]:
from logics_pack import global_settings, chemistry, predictor, reward_functions, augmem
from logics_pack import analysis, smiles_vocab, smiles_lstm
import pandas as pd
import numpy as np
import json
import torch

project_paths = global_settings.build_project_paths(project_dir='../')
expset_obj = global_settings.ExperimentSettings(project_paths['EXPERIMENT_SETTINGS_JSON'])

Perform Augmented Memory fine-tuning to build agent generator

In [None]:
# AugMem fine-tuning config
config = global_settings.Object()
config.tokens_path = project_paths['SMILES_TOKENS_PATH']
config.pretrain_setting_path = project_paths['PRETRAIN_SETTING_JSON']
config.pretrained_model_path = project_paths['PROJECT_DIR'] + 'model-prior/prior_e10.ckpt'
config.featurizer = predictor.featurizer
config.predictor_path = project_paths['PROJECT_DIR'] + \
                "model-pik3ca/predictor/pik3ca_rfr_cv%s.pkl"%expset_obj.get_setting("pik3ca-pred-best-cv")
config.max_epoch = 2000  # "epoch" is actually the training batches for reinforcement learning models
config.save_period = 20
config.save_size = 20000
config.save_ckpt_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/augmem/pik3ca_augmem_e%d.ckpt'
config.sample_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/augmem/pik3ca_augmem_e%d.txt'
config.sigma = 20
config.memory_size = 200  ## AugMem
config.aug_rounds = 2  ## AugMem
config.nbmax = 25  ## DF
config.minscore = 0.5  ## we are using -1.0 ~ 1.0 range rewards
config.dfmode = "binary"  ## DF
config.rewarding = reward_functions.pAff_to_reward_t2
config.train_batch_size = 100
config.finetune_lr = 0.0004
config.sampling_bs = 256

config.device_name = 'cuda'  ####

In [None]:
augmem.AugmentedMemory_training(config)

Subsidiary files building for evaluation phase

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # use tensorflow cpu

import fcd
import pickle
from logics_pack import frechet_chemnet
fc_ref_model = fcd.load_ref_model()

config.vc_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/augmem/pik3ca_augmem_vc_e%d.smi'  # save valid & canonical smiles
config.npfps_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/augmem/pik3ca_augmem_npfps_e%d.npy'  # save fingerprint in npy
config.fcvec_fmt = project_paths['PROJECT_DIR'] + 'model-pik3ca/augmem/pik3ca_augmem_fcvec_e%d.npy'  # save Frechet ChemNet vectors

epochs = list(range(0, config.max_epoch+1, config.save_period))
np.array(epochs)

In [None]:
for epo in epochs:
    print(epo)
    with open(config.sample_fmt%epo, 'r') as f:
        gens = [line.strip() for line in f.readlines()]
    vcs, invids = chemistry.get_valid_canons(gens)
    print("- count invalids: ", len(invids))
    with open(config.vc_fmt%epo, 'w') as f:
        f.writelines([line+'\n' for line in vcs])
    fps = chemistry.get_fps_from_smilist(vcs)
    np.save(config.npfps_fmt%epo, chemistry.rdk2npfps(fps))
    fcvecs = fcd.get_predictions(fc_ref_model, vcs)  # ChemNet vectors
    np.save(config.fcvec_fmt%epo, fcvecs)

Evaluate FCD and OTD on validation set, and pick the best epoch

In [None]:
# which validation fold recorded
vfold = expset_obj.get_setting("pik3ca-pred-best-cv")
vfold

In [None]:
affinity_data = pd.read_csv(project_paths['PIK3CA_DATA_PATH'])

# data split info
with open(project_paths['PIK3CA_FOLD_JSON'], 'r') as f:
    folds = json.load(f)

# retrieve validation set
val_ids = folds[vfold]
val_data = affinity_data.iloc[val_ids]

# get validation set activate (vsa)
vsa_data = val_data[val_data['affinity']>global_settings.PIK3CA_ACT_THRS]  # active among validation set
len(vsa_data)

vsa_smis = vsa_data['smiles'].tolist()
vsa_rdkfps = chemistry.get_fps_from_smilist(vsa_smis)
vsa_fc_vecs = fcd.get_predictions(fc_ref_model, vsa_smis)

dsize = len(vsa_rdkfps)  # demand size for OT
ssize = dsize*global_settings.OT_CALC_REPEATS  # supply size for repeated OT

# load predictor for PredAct (avg. predicted activity) calculation
with open(config.predictor_path, 'rb') as f:
    predictor = pickle.load(f)

In [None]:
val_fcd_list = []
val_otd_list = []
predact_list = []

for epo in epochs:
    print(epo)
    # load fc vectors of generation
    gen_fcvecs = np.load(config.fcvec_fmt%epo)
    fcdval = frechet_chemnet.fcd_calculation(gen_fcvecs, vsa_fc_vecs)
    val_fcd_list.append(fcdval)
    
    gen_npfps = np.load(config.npfps_fmt%epo)[:ssize]  # only need this amount
    gen_rdkfps = chemistry.np2rdkfps(gen_npfps)
    simmat = analysis.calculate_simmat(gen_rdkfps, vsa_rdkfps)  # row:gen, col:data
    distmat = analysis.transport_distmat(analysis.tansim_to_dist, simmat, global_settings.OT_CALC_REPEATS)
    _, _, motds = analysis.repeated_optimal_transport(distmat, repeat=global_settings.OT_CALC_REPEATS)
    val_otd_list.append(np.mean(motds))

    # record PredAct
    predact_list.append(np.mean(predictor.predict(gen_npfps)))

In [None]:
# validation FCDxOTD
val_FCDxOTD = np.array(val_fcd_list)*np.array(val_otd_list)
# dataframe for validation performance
v_perf = pd.DataFrame(epochs, columns=['epoch'])
v_perf['v-OTDxFCD'] = val_FCDxOTD
v_perf['v-OTD'] = val_otd_list
v_perf['v-FCD'] = val_fcd_list
v_perf['PredAct'] = predact_list
v_perf

In [None]:
# we are only interested in epochs that achieved PredAct > (activity threshold)
subv = v_perf[v_perf['PredAct']>global_settings.PIK3CA_ACT_THRS].copy()

# find the best epoch
vbest = subv.loc[subv['v-OTDxFCD'].idxmin()]
print(vbest)

# register the best epoch
expset_obj.update_setting('pik3ca-augmem-best-epoch', int(vbest['epoch']))