In [15]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import multiprocessing 
from collections import defaultdict
from itertools import product

# tf.config.set_visible_devices([], 'GPU')

from rdkit import Chem
from rdkit.Chem import Lipinski
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.Draw import SimilarityMaps
from rdkit.Chem import Draw
from rdkit import DataStructs

import sys
sys.path.append('../src/')

from ops import transform_ops
from gcn.datasets import GCNDataset
from gcn.models import GCNModel
from gcn import saliency

## 1. Define function for serving trained model later

In [20]:
def serve_model(model):

    @tf.function(input_signature=[
        [tf.TensorSpec([None, None], dtype='float32', name='A'),
         tf.TensorSpec([None, None], dtype='float32', name='H')]
    ])
    def serve(inputs):
        return {
            'prediction': model.call(
                inputs=[inputs[0][tf.newaxis], inputs[1][tf.newaxis]],
                training=False
            )
        }
    return serve

## 2. Train on Fiehn HILIC

In [None]:
batch_size = 32
num_epochs = 200

# Train on all data
train_dataset = GCNDataset(
    [f'../input/tfrecords/Fiehn_HILIC/train.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/valid.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/test.tfrec'], 
    batch_size, training=True)

# build model (with default hyper-parameters)
gcn_model = GCNModel()

# fit model for {num_epochs} with a batch_size of {batch_size}
gcn_model.fit(
    train_dataset.get_iterator(), 
    epochs=num_epochs, verbose=1
)

# remcompile with vanilla gradient descent (SGD)
gcn_model.compile(optimizer=tf.keras.optimizers.SGD())

# save model
tf.saved_model.save(gcn_model, f'../output/models/gcn_model_Fiehn_HILIC', serve_model(gcn_model))

In [None]:
sal = saliency.Saliency(import_dir=f'../output/models/gcn_model_Fiehn_HILIC')

# define new dataset (with training=False)
dataset = GCNDataset(
    [f'../input/tfrecords/Fiehn_HILIC/train.tfrec',
     f'../input/tfrecords/Fiehn_HILIC/valid.tfrec', 
     f'../input/tfrecords/Fiehn_HILIC/test.tfrec'], 
    batch_size, training=False)

# obtain dataset as a numpy iterator
dataset = dataset.get_iterator()
dataset = dataset.as_numpy_iterator()

# loop over dataset in batches
for batch in tqdm(dataset):
    
    # loop over each example in batch and compute its saliency map
    # and finally save to file
    for index in range(batch['label'].shape[0]):
        A = batch['adjacency_matrix'][index]
        H = batch['feature_matrix'][index]
        y = batch['label'][index]
        s = batch['string'][index][0]
        i = batch['index'][index][0]
        
        saliency_map = sal.atom_importance(A, H, y)
        
        # build RDKit mol object from string (SMILES)
        mol = transform_ops.mol_from_string(s.decode('utf-8'))
        
        # draw saliency map on 2-d representation of mol object and save to file
        sal.draw_atom_saliency_on_mol(
            mol, saliency_map, f'../output/saliency/mol_Fiehn_HILIC_{i}.png')
    

## 3. Train on RIKEN

In [28]:
batch_size = 32
num_epochs = 200

# Train on all data
train_dataset = GCNDataset(
    [f'../input/tfrecords/RIKEN/train.tfrec',
     f'../input/tfrecords/RIKEN/valid.tfrec',
     f'../input/tfrecords/RIKEN/test.tfrec'], 
    batch_size, training=True)

# build model (with default hyper-parameters)
gcn_model = GCNModel()

# fit model for {num_epochs} with a batch_size of {batch_size}
gcn_model.fit(
    train_dataset.get_iterator(), 
    epochs=num_epochs, verbose=1
)

# remcompile with vanilla gradient descent (SGD)
gcn_model.compile(optimizer=tf.keras.optimizers.SGD())

# save model
tf.saved_model.save(gcn_model, f'../output/models/gcn_model_RIKEN', serve_model(gcn_model))

In [29]:
sal = saliency.Saliency(import_dir=f'../output/models/gcn_model_RIKEN')

# define new dataset (with training=False)
dataset = GCNDataset(
    [f'../input/tfrecords/RIKEN/train.tfrec', 
     f'../input/tfrecords/RIKEN/valid.tfrec', 
     f'../input/tfrecords/RIKEN/test.tfrec'], 
    batch_size, training=False)

# obtain dataset as a numpy iterator
dataset = dataset.get_iterator()
dataset = dataset.as_numpy_iterator()

# loop over dataset in batches
for batch in tqdm(dataset):
    
    # loop over each example in batch and compute its saliency map
    # and finally save to file
    for index in range(batch['label'].shape[0]):
        A = batch['adjacency_matrix'][index]
        H = batch['feature_matrix'][index]
        y = batch['label'][index]
        s = batch['string'][index][0]
        i = batch['index'][index][0]
        
        saliency_map = sal.atom_importance(A, H, y)
        
        # build RDKit mol object from string (SMILES)
        mol = transform_ops.mol_from_string(s.decode('utf-8'))
        
        # draw saliency map on 2-d representation of mol object and save to file
        sal.draw_atom_saliency_on_mol(
            mol, saliency_map, f'../output/saliency/mol_RIKEN_{i}.png')
    