In [1]:
import sys
sys.path.append('..')

In [2]:
from logics_pack import global_settings, chemistry, predictor, reward_functions, augmem
from logics_pack import analysis, smiles_vocab, smiles_lstm
import pandas as pd
import numpy as np
import json
import torch

project_paths = global_settings.build_project_paths(project_dir='../')
expset_obj = global_settings.ExperimentSettings(project_paths['EXPERIMENT_SETTINGS_JSON'])

Perform Augmented Memory fine-tuning to build agent generator

In [3]:
# AugMem fine-tuning config
config = global_settings.Object()
config.tokens_path = project_paths['SMILES_TOKENS_PATH']
config.pretrain_setting_path = project_paths['PRETRAIN_SETTING_JSON']
config.pretrained_model_path = project_paths['PROJECT_DIR'] + 'model-prior/prior_e10.ckpt'
config.featurizer = predictor.featurizer
config.predictor_path = project_paths['PROJECT_DIR'] + "model-kor/predictor/kor_rfr_cv%s.pkl"%expset_obj.get_setting("kor-pred-best-cv")

config.max_epoch = 2000  # "epoch" is actually the training batches for reinforcement learning models
config.save_period = 20
config.save_size = 20000
config.save_ckpt_fmt = project_paths['PROJECT_DIR'] + 'model-kor/augmem/kor_augmem_e%d.ckpt'
config.sample_fmt = project_paths['PROJECT_DIR'] + 'model-kor/augmem/kor_augmem_e%d.txt'
config.sigma = 20
config.memory_size = 200  ## AugMem
config.aug_rounds = 2  ## AugMem
config.nbmax = 25  ## DF
config.minscore = 0.5  ## we are using -1.0 ~ 1.0 range rewards
config.dfmode = "binary"  ## DF
config.rewarding = reward_functions.pAff_to_reward_t1
config.train_batch_size = 100
config.finetune_lr = 0.0004
config.sampling_bs = 256

config.device_name = 'cuda:1'  ####

In [4]:
augmem.AugmentedMemory_training(config)

--- 0 ---
---
uniq valid count:  93
avg pkx:  6.039153953746372
avg filtered scores  -0.0473272438907125
size _scaffolds  0
mem score avg:  -0.047327243890712495
size mem:  93
---
uniq valid count:  98
avg pkx:  6.078028324829931
avg filtered scores  -0.038378463481088726
size _scaffolds  4
mem score avg:  0.1352968975304069
size mem:  200
---
uniq valid count:  94
avg pkx:  6.129725008865247
avg filtered scores  -0.021052195631330452
size _scaffolds  11
mem score avg:  0.2286075966446143
size mem:  200
---
uniq valid count:  99
avg pkx:  6.173630132275132
avg filtered scores  -0.008855604475336232
size _scaffolds  16
mem score avg:  0.2751283045911838
size mem:  200
--- 20 ---
---
uniq valid count:  94
avg pkx:  6.152590732016209
avg filtered scores  -0.014499881790216634
size _scaffolds  18
mem score avg:  0.32190479625076857
size mem:  200
---
uniq valid count:  93
avg pkx:  6.2188747311827965
avg filtered scores  0.004733298771583871
size _scaffolds  24
mem score avg:  0.3659890555

---
uniq valid count:  93
avg pkx:  6.872132060931898
avg filtered scores  0.17302269262209993
size _scaffolds  2238
mem score avg:  0.6806396897555408
size mem:  200
--- 300 ---
---
uniq valid count:  93
avg pkx:  6.84034219406042
avg filtered scores  0.17474581977919884
size _scaffolds  2335
mem score avg:  0.6822597343602027
size mem:  200
---
uniq valid count:  92
avg pkx:  6.674557454710146
avg filtered scores  0.10780025137557157
size _scaffolds  2422
mem score avg:  0.6835469640898567
size mem:  200
---
uniq valid count:  95
avg pkx:  6.593886877192983
avg filtered scores  0.10607989381830045
size _scaffolds  2510
mem score avg:  0.6838376640904553
size mem:  200
---
uniq valid count:  96
avg pkx:  6.839599045138889
avg filtered scores  0.16388101281491269
size _scaffolds  2594
mem score avg:  0.6846748966143895
size mem:  200
--- 320 ---
---
uniq valid count:  95
avg pkx:  6.805908621553884
avg filtered scores  0.14352706125973802
size _scaffolds  2687
mem score avg:  0.6857927

---
uniq valid count:  97
avg pkx:  7.153847979872361
avg filtered scores  0.21005094130705282
size _scaffolds  7119
mem score avg:  0.713811912238504
size mem:  200
---
uniq valid count:  96
avg pkx:  7.031666052414021
avg filtered scores  0.20311907463664744
size _scaffolds  7214
mem score avg:  0.7141922302000568
size mem:  200
--- 600 ---
---
uniq valid count:  99
avg pkx:  7.284932104377104
avg filtered scores  0.2641521279924064
size _scaffolds  7326
mem score avg:  0.7146041307394756
size mem:  200
---
uniq valid count:  97
avg pkx:  7.230927378497791
avg filtered scores  0.26424582794551915
size _scaffolds  7417
mem score avg:  0.7149359559672399
size mem:  200
---
uniq valid count:  98
avg pkx:  7.094393612730808
avg filtered scores  0.20423424288130615
size _scaffolds  7522
mem score avg:  0.7151628146445813
size mem:  200
---
uniq valid count:  97
avg pkx:  7.0494611634757005
avg filtered scores  0.17973806731642514
size _scaffolds  7611
mem score avg:  0.7145906151450272
si

--- 880 ---
---
uniq valid count:  97
avg pkx:  6.97341660652921
avg filtered scores  0.18957111361512932
size _scaffolds  11868
mem score avg:  0.7280801867739464
size mem:  200
---
uniq valid count:  95
avg pkx:  7.268826280701754
avg filtered scores  0.2706071771859943
size _scaffolds  11969
mem score avg:  0.7286653579142833
size mem:  200
---
uniq valid count:  98
avg pkx:  7.329215184645287
avg filtered scores  0.25729095040878364
size _scaffolds  12070
mem score avg:  0.7288337634695361
size mem:  200
--- 900 ---
---
uniq valid count:  99
avg pkx:  7.107653367003367
avg filtered scores  0.2165292770287332
size _scaffolds  12171
mem score avg:  0.7291748096215619
size mem:  200
---
uniq valid count:  98
avg pkx:  7.1848168707483
avg filtered scores  0.26101465418365793
size _scaffolds  12268
mem score avg:  0.728363958697404
size mem:  200
---
uniq valid count:  92
avg pkx:  7.282946394927536
avg filtered scores  0.27905663326120456
size _scaffolds  12372
mem score avg:  0.730423

---
uniq valid count:  97
avg pkx:  7.176245314187531
avg filtered scores  0.23722888529804786
size _scaffolds  16656
mem score avg:  0.7424128834796527
size mem:  198
---
uniq valid count:  94
avg pkx:  7.329793971631208
avg filtered scores  0.24301337480692287
size _scaffolds  16751
mem score avg:  0.742630650146811
size mem:  199
--- 1180 ---
---
uniq valid count:  96
avg pkx:  7.385291860119048
avg filtered scores  0.28640826129210517
size _scaffolds  16843
mem score avg:  0.742518716398024
size mem:  200
---
uniq valid count:  96
avg pkx:  7.475362425595239
avg filtered scores  0.2645132685686143
size _scaffolds  16929
mem score avg:  0.7427137614360546
size mem:  200
---
uniq valid count:  98
avg pkx:  7.458151275510205
avg filtered scores  0.32198159158386774
size _scaffolds  17039
mem score avg:  0.742757866996902
size mem:  200
--- 1200 ---
---
uniq valid count:  98
avg pkx:  7.2298566326530604
avg filtered scores  0.23346154743558392
size _scaffolds  17131
mem score avg:  0.7

---
uniq valid count:  93
avg pkx:  7.439955967741936
avg filtered scores  0.29583091720558913
size _scaffolds  21888
mem score avg:  0.7476585611691788
size mem:  200
--- 1460 ---
---
uniq valid count:  97
avg pkx:  7.340943917525774
avg filtered scores  0.2605906325464003
size _scaffolds  22001
mem score avg:  0.7483433798707168
size mem:  200
---
uniq valid count:  97
avg pkx:  7.801201757486502
avg filtered scores  0.36096196470594183
size _scaffolds  22125
mem score avg:  0.7483928913966339
size mem:  200
---
uniq valid count:  95
avg pkx:  7.1703778320802005
avg filtered scores  0.23207464384732873
size _scaffolds  22226
mem score avg:  0.7484341035644231
size mem:  200
--- 1480 ---
---
uniq valid count:  92
avg pkx:  7.476029420289855
avg filtered scores  0.28796175389376943
size _scaffolds  22326
mem score avg:  0.7482702338761159
size mem:  200
---
uniq valid count:  86
avg pkx:  7.531248903654484
avg filtered scores  0.3418504494586406
size _scaffolds  22432
mem score avg:  0

---
uniq valid count:  83
avg pkx:  7.575530020080322
avg filtered scores  0.3571217763938866
size _scaffolds  26973
mem score avg:  0.7546680500992291
size mem:  200
---
uniq valid count:  87
avg pkx:  7.640322318007667
avg filtered scores  0.3521154988803618
size _scaffolds  27072
mem score avg:  0.7546680500992291
size mem:  200
---
uniq valid count:  87
avg pkx:  7.62274990421456
avg filtered scores  0.3474895068567709
size _scaffolds  27170
mem score avg:  0.7545500079683564
size mem:  200
--- 1760 ---
---
uniq valid count:  82
avg pkx:  7.428860618466899
avg filtered scores  0.3161844712447383
size _scaffolds  27282
mem score avg:  0.7546042928810951
size mem:  200
---
uniq valid count:  85
avg pkx:  7.6222993529411776
avg filtered scores  0.32816725258077756
size _scaffolds  27383
mem score avg:  0.7546042928810951
size mem:  200
---
uniq valid count:  83
avg pkx:  7.775257013769364
avg filtered scores  0.3864359598162962
size _scaffolds  27491
mem score avg:  0.7547743666930886

Subsidiary files building for evaluation phase

In [4]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # use tensorflow cpu

import fcd
import pickle
from logics_pack import frechet_chemnet
fc_ref_model = fcd.load_ref_model()

config.vc_fmt = project_paths['PROJECT_DIR'] + 'model-kor/augmem/kor_augmem_vc_e%d.smi'  # save valid & canonical smiles
config.npfps_fmt = project_paths['PROJECT_DIR'] + 'model-kor/augmem/kor_augmem_npfps_e%d.npy'  # save fingerprint in npy
config.fcvec_fmt = project_paths['PROJECT_DIR'] + 'model-kor/augmem/kor_augmem_fcvec_e%d.npy'  # save Frechet ChemNet vectors

epochs = list(range(0, config.max_epoch+1, config.save_period))
np.array(epochs)

Using TensorFlow backend.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


2023-07-06 10:16:44.188080: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2023-07-06 10:16:44.223327: E tensorflow/stream_executor/cuda/cuda_driver.cc:318] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-07-06 10:16:44.223359: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: shepherd5
2023-07-06 10:16:44.223367: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: shepherd5
2023-07-06 10:16:44.223451: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 465.19.1
2023-07-06 10:16:44.223483: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 465.19.1
2023-07-06 10:16:44.223490: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 465.19.1
2023-07-06 10:16:44.224420: I tensorflow/core/platform/cpu_featu




array([   0,   20,   40,   60,   80,  100,  120,  140,  160,  180,  200,
        220,  240,  260,  280,  300,  320,  340,  360,  380,  400,  420,
        440,  460,  480,  500,  520,  540,  560,  580,  600,  620,  640,
        660,  680,  700,  720,  740,  760,  780,  800,  820,  840,  860,
        880,  900,  920,  940,  960,  980, 1000, 1020, 1040, 1060, 1080,
       1100, 1120, 1140, 1160, 1180, 1200, 1220, 1240, 1260, 1280, 1300,
       1320, 1340, 1360, 1380, 1400, 1420, 1440, 1460, 1480, 1500, 1520,
       1540, 1560, 1580, 1600, 1620, 1640, 1660, 1680, 1700, 1720, 1740,
       1760, 1780, 1800, 1820, 1840, 1860, 1880, 1900, 1920, 1940, 1960,
       1980, 2000])

In [5]:
for epo in epochs:
    print(epo)
    with open(config.sample_fmt%epo, 'r') as f:
        gens = [line.strip() for line in f.readlines()]
    vcs, invids = chemistry.get_valid_canons(gens)
    print("- count invalids: ", len(invids))
    with open(config.vc_fmt%epo, 'w') as f:
        f.writelines([line+'\n' for line in vcs])
    fps = chemistry.get_fps_from_smilist(vcs)
    np.save(config.npfps_fmt%epo, chemistry.rdk2npfps(fps))
    fcvecs = fcd.get_predictions(fc_ref_model, vcs)  # ChemNet vectors
    np.save(config.fcvec_fmt%epo, fcvecs)

0
- count invalids:  952
20
- count invalids:  1059
40
- count invalids:  1009
60
- count invalids:  1167
80
- count invalids:  1195
100
- count invalids:  1165
120
- count invalids:  1165
140
- count invalids:  1064
160
- count invalids:  1079
180
- count invalids:  986
200
- count invalids:  831
220
- count invalids:  890
240
- count invalids:  905
260
- count invalids:  785
280
- count invalids:  776
300
- count invalids:  849
320
- count invalids:  781
340
- count invalids:  817
360
- count invalids:  802
380
- count invalids:  788
400
- count invalids:  755
420
- count invalids:  681
440
- count invalids:  690
460
- count invalids:  619
480
- count invalids:  635
500
- count invalids:  642
520
- count invalids:  584
540
- count invalids:  561
560
- count invalids:  525
580
- count invalids:  570
600
- count invalids:  522
620
- count invalids:  500
640
- count invalids:  583
660
- count invalids:  550
680
- count invalids:  547
700
- count invalids:  535
720
- count invalids:  608

Evaluate FCD and OTD on validation set, and pick the best epoch

In [5]:
# which validation fold recorded
vfold = expset_obj.get_setting("kor-pred-best-cv")
vfold

'3'

In [6]:
affinity_data = pd.read_csv(project_paths['KOR_DATA_PATH'])

# data split info
with open(project_paths['KOR_FOLD_JSON'], 'r') as f:
    kor_folds = json.load(f)

# retrieve validation set
val_ids = kor_folds[vfold]
val_data = affinity_data.iloc[val_ids]

# get validation set activate (vsa)
vsa_data = val_data[val_data['affinity']>global_settings.KOR_ACT_THRS]  # active among validation set
len(vsa_data)

vsa_smis = vsa_data['smiles'].tolist()
vsa_rdkfps = chemistry.get_fps_from_smilist(vsa_smis)
vsa_fc_vecs = fcd.get_predictions(fc_ref_model, vsa_smis)

dsize = len(vsa_rdkfps)  # demand size for OT
ssize = dsize*global_settings.OT_CALC_REPEATS  # supply size for repeated OT

# load predictor for PredAct (avg. predicted activity) calculation
with open(config.predictor_path, 'rb') as f:
    predictor = pickle.load(f)

In [7]:
val_fcd_list = []
val_otd_list = []
predact_list = []

for epo in epochs:
    print(epo)
    # load fc vectors of generation
    gen_fcvecs = np.load(config.fcvec_fmt%epo)
    fcdval = frechet_chemnet.fcd_calculation(gen_fcvecs, vsa_fc_vecs)
    val_fcd_list.append(fcdval)
    
    gen_npfps = np.load(config.npfps_fmt%epo)[:ssize]  # only need this amount
    gen_rdkfps = chemistry.np2rdkfps(gen_npfps)
    simmat = analysis.calculate_simmat(gen_rdkfps, vsa_rdkfps)  # row:gen, col:data
    distmat = analysis.transport_distmat(analysis.tansim_to_dist, simmat, global_settings.OT_CALC_REPEATS)
    _, _, motds = analysis.repeated_optimal_transport(distmat, repeat=global_settings.OT_CALC_REPEATS)
    val_otd_list.append(np.mean(motds))

    # record PredAct
    predact_list.append(np.mean(predictor.predict(gen_npfps)))

0
20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400
420
440
460
480
500
520
540
560
580
600
620
640
660
680
700
720
740
760
780
800
820
840
860
880
900
920
940
960
980
1000
1020
1040
1060
1080
1100
1120
1140
1160
1180
1200
1220
1240
1260
1280
1300
1320
1340
1360
1380
1400
1420
1440
1460
1480
1500
1520
1540
1560
1580
1600
1620
1640
1660
1680
1700
1720
1740
1760
1780
1800
1820
1840
1860
1880
1900
1920
1940
1960
1980
2000


In [8]:
# validation FCDxOTD
val_FCDxOTD = np.array(val_fcd_list)*np.array(val_otd_list)
# dataframe for validation performance
v_perf = pd.DataFrame(epochs, columns=['epoch'])
v_perf['v-OTDxFCD'] = val_FCDxOTD
v_perf['v-OTD'] = val_otd_list
v_perf['v-FCD'] = val_fcd_list
v_perf['PredAct'] = predact_list
v_perf

Unnamed: 0,epoch,v-OTDxFCD,v-OTD,v-FCD,PredAct
0,0,151.870015,5.373718,28.261627,5.950143
1,20,134.184715,5.291485,25.358612,6.152877
2,40,124.779109,5.247167,23.780282,6.208436
3,60,123.527241,5.246429,23.545012,6.243472
4,80,124.012755,5.281934,23.478665,6.283508
...,...,...,...,...,...
96,1920,193.126009,5.591267,34.540654,7.698707
97,1940,209.781165,5.772098,36.344008,7.690865
98,1960,205.133997,5.735355,35.766572,7.694452
99,1980,208.908194,5.825910,35.858467,7.702056


In [9]:
# we are only interested in epochs that achieved PredAct > (activity threshold)
subv = v_perf[v_perf['PredAct']>global_settings.KOR_ACT_THRS].copy()

# find the best epoch
vbest = subv.loc[subv['v-OTDxFCD'].idxmin()]
print(vbest)

# register the best epoch
expset_obj.update_setting('kor-augmem-best-epoch', int(vbest['epoch']))

epoch        400.000000
v-OTDxFCD    143.180629
v-OTD          5.285925
v-FCD         27.087145
PredAct        7.002252
Name: 20, dtype: float64
