In [None]:
import sys
sys.path.insert(0, '..')

from keras.models import load_model
import keras.backend as K
import tensorflow as tf
# disable eager execution
tf.compat.v1.disable_eager_execution()

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw

from chemception.featurizer import ChemCeptionizer

from gradio import Interface

import numpy as np
import matplotlib.pyplot as plt
import cv2

graph = tf.Graph()
with graph.as_default():
    session = tf.compat.v1.Session()
    K.set_session(session)
    
    model = load_model('../models/chemception_transfer.h5')
    featurizer = ChemCeptionizer(embed=20, fuse=True)


    def plot_kernels(x):
        flower_output = model.output[:, 0]
        last_conv_layer = model.get_layer('mixed2')

        with graph.as_default():
            K.set_session(session)
            grads = K.gradients(flower_output, last_conv_layer.output)[0]                             
            pooled_grads = K.mean(grads, axis=(0, 1, 2))                                              
            iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
            pooled_grads_value, conv_layer_output_value = iterate([x])

        for i in range(288):                                                                     
                conv_layer_output_value[:, :, i] *= pooled_grads_value[i]

        heatmap = np.mean(conv_layer_output_value, axis=-1)
        heatmap = np.maximum(heatmap, 0)                                                  
        heatmap /= np.max(heatmap)
        
        return heatmap


    def predict_smiles(smiles):
        print(smiles)
        mol = Chem.MolFromSmiles(smiles)
        mol_chemceptionized = featurizer.featurize(mol)
        mol_chemceptionized = np.expand_dims(mol_chemceptionized, axis=0)
        print(mol_chemceptionized.shape)
        
        with graph.as_default():
            K.set_session(session)
            prediction = model.predict(mol_chemceptionized)
            
        heatmap = plot_kernels(mol_chemceptionized)
        
        img = Draw.MolToImage(mol, size=(300, 300))
        img = np.array(img)
        img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
        
        extent = 0, 300, 0, 300
        fig, ax = plt.subplots(frameon=False)

        ax.imshow(img, extent=extent)
        ax.imshow(heatmap, cmap='viridis', alpha=0.4, extent=extent)

        return 'Active' if prediction[0][0] > 0.5 else 'Inactive', fig


    iface = Interface(
        fn=predict_smiles,
        inputs="text",
        outputs=["text", "plot"],
        examples=['O=c1c2cc(N3CCOCC3)ccc2[nH]c(=S)n1Cc1ccc(Cl)cc1','O=C(C1CCN(c2ncnc3c2nc2n3CCCCC2)CC1)N1CCN(c2ccccc2)CC1']
        
    )
    
    iface.launch()