In [2]:
using PyCall

In [3]:
py"""
from smiles_tools import return_tokens, SmilesEnumerator
from c_wrapper import seqOneHot
import pandas as pd
import numpy as np
import os

vocab = pd.read_csv(f'{os.path.dirname(os.getcwd())}/preprocessor/vocab.csv')['tokens'].to_list()

tokenizer = {i : n for n, i in enumerate(vocab)}
reverse_tokenizer = {value: key for key, value in tokenizer.items()}
convert_back = lambda x: ''.join(reverse_tokenizer.get(np.argmax(i)-1, '') for i in x)

def augment_smiles(string, n):
    sme = SmilesEnumerator()
    output = []
    for i in range(n):
        output.append(sme.randomize_smiles(string))
    
    return output
"""

In [4]:
tokenizer = py"tokenizer"
reverse_tokenizer = py"reverse_tokenizer"
convert_back = py"convert_back"
return_tokens = py"return_tokens"

augment_smiles(str, n) = py"augment_smiles"(str, n)

augment_smiles (generic function with 1 method)

In [174]:
function standardizeCase(str::String)
    str = titlecase(str)
    str = replace(str, "h" => "H")
    return str
end

function return_augmented_list(str::String, n::Int64=5)
    current_augmentation = Matrix{String}(undef, 0, 1)
    counter = 0

    while length(current_augmentation) < n && counter < n * 2
        new_string = augment_smiles(str, 1)[begin]
        current_tokens = standardizeCase.(return_tokens(new_string, py"vocab")[begin])
        if issubset(Set(current_tokens), keys(tokenizer))
            current_augmentation = vcat(current_augmentation, join(current_tokens))
        end
        counter += 1
    end

    return current_augmentation
end

function return_augmented_tokens(str::String, n::Int64=5)
    current_augmentation = []
    counter = 0

    while length(current_augmentation) < n && counter < n * 2
        new_string = augment_smiles(str, 1)[begin]
        current_tokens = standardizeCase.(return_tokens(new_string, py"vocab")[begin])
        if issubset(Set(current_tokens), keys(tokenizer))
            push!(current_augmentation, current_tokens)
        end
        counter += 1
    end

    return current_augmentation
end
    
function tokenize_and_pad(tokens_vector; max_len::Int64=190)
    len = max_len - length(tokens_vector)
    return vcat(zeros(Int, len), [tokenizer[i]+1 for i in tokens_vector])
end

tokenize_and_pad (generic function with 2 methods)

In [270]:
function predict(model, sample, shape; verbosity=0)
    return py"$(model).predict(np.array($(sample)).reshape($(shape)), verbose=0)"
end

predict (generic function with 1 method)

In [250]:
seqs

1-element Vector{Any}:
 Matrix{Int32}[[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]]

In [255]:
predict(py"model", seqs, [5, 190, 72])



5×2 Matrix{Float32}:
 0.0175267   0.982473
 0.651568    0.348432
 0.971154    0.0288456
 0.00101979  0.99898
 0.00186065  0.998139

In [269]:
function augment_and_predict(strings, shape, models; n::Int64=5)
    tokens = return_augmented_tokens.(strings, n)
    indicies = [i for i in 1:length(tokens) if length(tokens[i]) == n]
    
    encoded_seqs = []
    for i in 1:length(strings)
        if i in indicies
            push!(encoded_seqs, [py"seqOneHot"(tokenize_and_pad(j, max_len=shape[begin]), shape) for j in tokens[i]])
        else
            push!(encoded_seqs, [py"seqOneHot"(tokenize_and_pad(return_tokens(strings[i], py"vocab")[begin], max_len=shape[begin]), shape)])
        end
    end
    
    preds = [[] for i in 1:length(models)]
    for (n, model) in enumerate(models)
        for i in encoded_seqs
            pred = predict(model, i, [length(i), shape...])
            push!(preds[n], pred)
        end
    end
    
    return [[argmax(sum(eachrow(i)))-1 for i in j] for j in preds]
end

augment_and_predict (generic function with 2 methods)

In [92]:
py"""
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras import models

model = models.load_model(f'{os.getcwd()}//dopa_rnn_model.h5')
model.compile()
"""

In [98]:
preds = py"model.predict(np.array($(seqs)).reshape(5, 190, 72))"



5×2 Matrix{Float32}:
 0.202561  0.797439
 0.26447   0.73553
 0.808269  0.191731
 0.557924  0.442076
 0.26121   0.73879

In [271]:
augment_and_predict(["CC1CCN(CC1)CCCN2C(=O)C3=CC=CC=C3N=C2SCC(=O)NC4=C(C=C(C=C4)OC)OC"], [190, 72], [py"model"], 6)

1-element Vector{Vector{Int64}}:
 [1]