In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import os
from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem import Lipinski
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.Draw import SimilarityMaps
from rdkit.Chem import Draw
from rdkit import DataStructs

import sys
sys.path.append('../src/')

import logging
tf.get_logger().setLevel(logging.ERROR)

from gcn.datasets import GCNDataset
from gcn.models import GCNModel

### 1. Define functions

In [None]:
def vanilla_atom_saliency(model, A, H, y):

    # remove potential padding
    keep_idx = np.where(A.sum(axis=1) != 0)[0]
    H = tf.convert_to_tensor(H[keep_idx])[tf.newaxis]
    A = tf.convert_to_tensor(A[keep_idx][:, keep_idx])[tf.newaxis]
    y = tf.convert_to_tensor(y)[tf.newaxis]

    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(H)
        y_pred = model([A, H], training=False)
        loss = tf.compat.v1.losses.huber_loss(y, y_pred)
    
    gradients = tape.gradient(loss, H)
    gradients = tf.abs(gradients)
    return tf.reduce_sum(gradients[0], axis=1).numpy()

def draw_atom_saliency_on_mol(mol, saliency, path, size=(1000, 1000)):

    if not os.path.isdir('/'.join(path.split('/')[:-1])):
        os.makedirs('/'.join(path.split('/')[:-1]))

    drawer = Draw.MolDraw2DCairo(*size)
    drawer.drawOptions().bondLineWidth = 3

    saliency = saliency / saliency.max()
    
    Draw.SimilarityMaps.GetSimilarityMapFromWeights(
        mol=mol,
        weights=[float(s) for s in saliency],
        size=size,
        coordScale=1.0,
        colors='g',
        alpha=0.4,
        contourLines=10,
        draw2d=drawer);

    drawer.FinishDrawing()
    drawer.WriteDrawingText(path)

### 2. Obtain identical/similar compounds between the data sets

In [None]:
hilic_dataset = GCNDataset(
    [f'../input/tfrecords/Fiehn_HILIC/train.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/valid.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/test_1.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/test_2.tfrec'], 
    1, False).get_iterator()

rplc_dataset = GCNDataset(
    [f'../input/tfrecords/RIKEN/train.tfrec',
     f'../input/tfrecords/RIKEN/valid.tfrec',
     f'../input/tfrecords/RIKEN/test_1.tfrec',
     f'../input/tfrecords/RIKEN/test_2.tfrec'], 
    1, False
).get_iterator()


hilic_data = []
rplc_data = []

for example_1 in tqdm(hilic_dataset):
    for example_2 in rplc_dataset:

        mol_a = example_1['string'].numpy()[0][0].decode('utf-8')
        mol_b = example_2['string'].numpy()[0][0].decode('utf-8')
        mol_a = Chem.MolFromSmiles(mol_a)
        mol_b = Chem.MolFromSmiles(mol_b)
        
        if (mol_a.HasSubstructMatch(mol_b) and mol_b.HasSubstructMatch(mol_a)):
            
            A = example_1['adjacency_matrix'].numpy()[0]
            H = example_1['feature_matrix'].numpy()[0]
            y = example_1['label'].numpy()[0]
            i = example_1['index'].numpy()[0][0]
            
            hilic_data.append((A, H, y, mol_a, i))
            
            A = example_2['adjacency_matrix'].numpy()[0]
            H = example_2['feature_matrix'].numpy()[0]
            y = example_2['label'].numpy()[0]
            j = example_2['index'].numpy()[0][0]
            
            rplc_data.append((A, H, y, mol_b, j))
            
        

### 3. Train models 

In [None]:
batch_size = 32
num_epochs = 200

# Fiehn HILIC
# Train on all data
train_dataset_hilic = GCNDataset(
    [f'../input/tfrecords/Fiehn_HILIC/train.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/valid.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/test_1.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/test_2.tfrec'], 
    batch_size, training=True)

# build model (with default hyper-parameters)
gcn_model_hilic = GCNModel()

# fit model for {num_epochs} with a batch_size of {batch_size}
gcn_model_hilic.fit(
    train_dataset_hilic.get_iterator(), 
    epochs=num_epochs, verbose=1
)

# RIKEN
# Train on all data
train_dataset_rplc = GCNDataset(
    [f'../input/tfrecords/RIKEN/train.tfrec',
     f'../input/tfrecords/RIKEN/valid.tfrec',
     f'../input/tfrecords/RIKEN/test_1.tfrec',
     f'../input/tfrecords/RIKEN/test_2.tfrec'], 
    batch_size, training=True)

# build model (with default hyper-parameters)
gcn_model_rplc = GCNModel()

# fit model for {num_epochs} with a batch_size of {batch_size}
gcn_model_rplc.fit(
    train_dataset_rplc.get_iterator(), 
    epochs=num_epochs, verbose=1
)

### 4. Obtain saliences for compounds obtained in (2)

In [None]:
pair = 0

for example_1, example_2 in zip(hilic_data, rplc_data):

    (A, H, y, mol, i) = example_1

    saliency_map = vanilla_atom_saliency(gcn_model_hilic, A, H, y)

    draw_atom_saliency_on_mol(
        mol, saliency_map, f'../output/saliency/pair-{pair}_Fiehn-HILIC-index-{i}.png')

    (A, H, y, mol, i) = example_2

    saliency_map = vanilla_atom_saliency(gcn_model_rplc, A, H, y)

    draw_atom_saliency_on_mol(
        mol, saliency_map, f'../output/saliency/pair-{pair}_RIKEN-index-{i}.png')

    pair += 1