In [30]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
from tensorflow.keras import models
import pandas as pd
import numpy as np
from tqdm import tqdm
from functools import lru_cache
from smiles_tools import return_tokens
from c_wrapper import seqOneHot

In [2]:
dopa_model = models.load_model(f'{os.getcwd()}//dopa_rnn_model.h5')
sero_model = models.load_model(f'{os.getcwd()}//sero_rnn_model.h5')

In [3]:
dopa_model.compile()
sero_model.compile()

In [4]:
rnn_models = [dopa_model, sero_model]

In [5]:
test_string = 'COC1=CC=C(C=C1)NC(=O)CC2=NC3=CC=CC=C3N2'

In [6]:
vocab = pd.read_csv('../preprocessor/vocab.csv')['tokens'].to_list()

In [7]:
tokenizer = {i : n for n, i in enumerate(vocab)}

In [8]:
reverse_tokenizer = {value: key for key, value in tokenizer.items()}
convert_back = lambda x: ''.join(reverse_tokenizer.get(np.argmax(i)-1, '') for i in x)

In [9]:
max_len = 190
seq_shape = np.array([max_len, max([i+1 for i in tokenizer.values()])+1], dtype=np.int32)

In [10]:
initial_seq = np.array([tokenizer[i]+1 for i in return_tokens(test_string, vocab)[0]])
full_seq = np.hstack([np.zeros(max_len-len(initial_seq)), initial_seq])
full_seq = seqOneHot(np.array(full_seq, dtype=np.int32), seq_shape).reshape(1, *seq_shape)

In [11]:
full_seq[0]

array([[1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)

In [12]:
convert_back(full_seq[0])

'COC1=CC=C(C=C1)NC(=O)CC2=NC3=CC=CC=C3N2'

In [14]:
np.hstack([i.predict(full_seq, verbose=0) for i in rnn_models])

array([[0.08524679, 0.9147532 , 0.7503976 , 0.24960242]], dtype=float32)

In [31]:
@lru_cache(maxsize=256)
def ensemble_predict(string, models_array, vocab=None, tokenizer=None):
    if vocab is None:
        vocab = pd.read_csv(f'{os.getcwd()}//vocab.csv')['tokens'].to_list()
    
    if tokenizer is None:
        tokenizer = {i : n for n, i in enumerate(vocab)}
        
    initial_seq = np.array([tokenizer[i]+1 for i in return_tokens(test_string, vocab)[0]])
    full_seq = np.hstack([np.zeros(max_len-len(initial_seq)), initial_seq])
    full_seq = seqOneHot(np.array(full_seq, dtype=np.int32), seq_shape).reshape(1, *seq_shape)
    
    return np.hstack([i.predict(full_seq, verbose=0) for i in models_array])

In [16]:
ensemble_predict(test_string, rnn_models, vocab=vocab, tokenizer=tokenizer)

array([[0.08524679, 0.9147532 , 0.7503976 , 0.24960242]], dtype=float32)

In [17]:
filtered_dataset = pd.read_csv('../preprocessor/sero_filtered_dataset.csv')

In [18]:
filtered_dataset['PUBCHEM_EXT_DATASOURCE_SMILES']

0       CC[C@H]([C@@H]1[C@H](C[C@@](O1)(CC)[C@H]2CC[C@...
1         CC1=C(OC2=C1C=C(C=C2)OC)C(=O)NC3=NC4=CC=CC=C4N3
2       CCCN1C[C@@H](C[C@H]2[C@H]1CC3=CNC4=CC=CC2=C34)...
3       C1CN(CCN1CC2=CC=CC=C2)C(=O)CCCN3C(=O)CSC4=C3C=...
4       CC1=CC(=C(C(=C1)C2=NNC3=C2C(N(C3=O)CC4=CC=CO4)...
                              ...                        
4825    CC1=CC\2=C(C=C1)SCC/C2=N\OS(=O)(=O)C3=CC=C(C=C...
4826    COCCN(C1=C(N(C(=O)NC1=O)CC2=CC=CC=C2)N)C(=O)CS...
4827    C1CC(OC1)C(=O)NC2=CC=CC(=C2)C3=NN=C(O3)C4=CC=C...
4828    CC1=CC=CC=C1NC(=O)C2=CC=C(C=C2)CN(C3=CC=CC=C3)...
4829           CC1=CC(=O)OC2=C1C=CC(=C2)OCC3=CC=C(C=C3)OC
Name: PUBCHEM_EXT_DATASOURCE_SMILES, Length: 4830, dtype: object

In [20]:
preds = []
for i in tqdm(range(len(filtered_dataset['PUBCHEM_EXT_DATASOURCE_SMILES']))):
    string = filtered_dataset['PUBCHEM_EXT_DATASOURCE_SMILES'][i]
    initial_seq = np.array([tokenizer[i]+1 for i in return_tokens(string, vocab)[0]])
    if len(initial_seq) > max_len:
        preds.append(None)
        continue
    full_seq = np.hstack([np.zeros(max_len-len(initial_seq)), initial_seq])
    full_seq = seqOneHot(np.array(full_seq, dtype=np.int32), seq_shape).reshape(1, *seq_shape)
    
    preds.append(sero_model.predict(full_seq, verbose=0))

100%|███████████████████████████████████████████████████████████████████████████████| 4830/4830 [15:33<00:00,  5.17it/s]


In [21]:
preds

[array([[9.9979585e-01, 2.0419070e-04]], dtype=float32),
 array([[0.7724961 , 0.22750385]], dtype=float32),
 array([[9.9999475e-01, 5.2228493e-06]], dtype=float32),
 array([[0.8768784 , 0.12312165]], dtype=float32),
 array([[0.996293  , 0.00370693]], dtype=float32),
 array([[0.7503976 , 0.24960242]], dtype=float32),
 array([[9.999689e-01, 3.111645e-05]], dtype=float32),
 array([[9.9997509e-01, 2.4955663e-05]], dtype=float32),
 array([[0.97133183, 0.02866822]], dtype=float32),
 array([[0.9813276 , 0.01867238]], dtype=float32),
 array([[9.9999952e-01, 4.4124627e-07]], dtype=float32),
 array([[9.9999976e-01, 2.8698872e-07]], dtype=float32),
 array([[9.999949e-01, 5.090257e-06]], dtype=float32),
 array([[9.9991035e-01, 8.9585737e-05]], dtype=float32),
 array([[0.99872476, 0.00127522]], dtype=float32),
 array([[0.9878677, 0.0121323]], dtype=float32),
 array([[9.999975e-01, 2.454457e-06]], dtype=float32),
 array([[0.9970227 , 0.00297734]], dtype=float32),
 array([[0.4834987 , 0.51650137]], d

In [None]:
len(return_tokens(filtered_dataset['PUBCHEM_EXT_DATASOURCE_SMILES'][1702], vocab)[0])

In [22]:
sum([np.argmax(i) for i in preds if i is not None])/len(preds)

0.4968944099378882

In [23]:
dopa_data = pd.read_csv('../filtered_dataset.csv')

In [24]:
d_preds = []
for i in tqdm(range(len(dopa_data['PUBCHEM_EXT_DATASOURCE_SMILES']))):
    string = dopa_data['PUBCHEM_EXT_DATASOURCE_SMILES'][i]
    initial_seq = np.array([tokenizer[i]+1 for i in return_tokens(string, vocab)[0]])
    if len(initial_seq) > max_len:
        d_preds.append(None)
        continue
    full_seq = np.hstack([np.zeros(max_len-len(initial_seq)), initial_seq])
    full_seq = seqOneHot(np.array(full_seq, dtype=np.int32), seq_shape).reshape(1, *seq_shape)
    
    d_preds.append(dopa_model.predict(full_seq, verbose=0))

100%|█████████████████████████████████████████████████████████████████████████████| 18234/18234 [59:45<00:00,  5.09it/s]


In [25]:
sum([np.argmax(i) for i in d_preds if i is not None])/len(d_preds)

0.5034002413074476

In [26]:
d_preds

[array([[0.99882036, 0.00117956]], dtype=float32),
 array([[0.98718643, 0.01281351]], dtype=float32),
 array([[0.30386385, 0.6961362 ]], dtype=float32),
 array([[0.4633632, 0.5366368]], dtype=float32),
 array([[9.999201e-01, 7.981102e-05]], dtype=float32),
 array([[9.9997008e-01, 2.9911082e-05]], dtype=float32),
 array([[0.9984352 , 0.00156479]], dtype=float32),
 array([[0.2620687, 0.7379314]], dtype=float32),
 array([[0.9989845 , 0.00101541]], dtype=float32),
 array([[0.9850293 , 0.01497071]], dtype=float32),
 array([[0.2750773 , 0.72492266]], dtype=float32),
 array([[0.6383268 , 0.36167312]], dtype=float32),
 array([[9.999243e-01, 7.571239e-05]], dtype=float32),
 array([[9.9984694e-01, 1.5301631e-04]], dtype=float32),
 array([[0.4206064, 0.5793936]], dtype=float32),
 array([[0.97816885, 0.02183115]], dtype=float32),
 array([[0.91993475, 0.08006527]], dtype=float32),
 array([[9.9963558e-01, 3.6438365e-04]], dtype=float32),
 array([[0.9556512 , 0.04434881]], dtype=float32),
 array([[0.

TypeError: tuple indices must be integers or slices, not str