In [2]:
###############################################################################
# This notebook shows application of the RhNet2 model to aqueous solubility
# prediction.
#
# Presented here, RhNet2 models learned to predict intrinsic aqueous solubility
# from 1-el density matrices (originating from DFT calculations) on a set of
# 94 drug-like molecules as supposed by the Solubility Challenge (2008) terms.
#
# The details of the model architecture are described in the corresponding paper:
# https://doi.org/10.26434/chemrxiv-2024-k2k3l
#
# Implementation:
# https://github.com/Shorku/rhnet2
#
# Solubility Challenge:
# https://10.1021/ci800058v
# https://10.1021/ci800436c
###############################################################################

In [None]:
import os
import json
import requests

import tensorflow as tf

from data_utils import graph_from_orca
from data_utils import read_schema
from data_utils import GraphToTensorsLayer
from data_utils import graph2rest

###############################################################################
# path info
###############################################################################
orca_out_path = 'data_example/dft'
tfrecords_path = 'data_example/tfrecords'
overlap_threshold = 0.035
schema_path = 'rhnet2.pbtxt'
models_path = 'models/LOO_G8D1W1Hd0Hw0_GKANTrueWKANFalseHKANFalseL2H0.08L2W0.08_batch8_lr0.0001'

In [3]:
###############################################################################
# As an example, we will now predict water solubility of, say, Benzocaine
# (test compound, not seen by the models) from its electronic structure.
# The results of DFT calculation are available in the data_example/dft folder.
# First, we will convert the output of quantum chemical package (ORCA) into
# a TF-GNN graph:

benzocaine_orca_out = os.path.join(orca_out_path, "Benzocaine.zip")
benzocaine_graph = graph_from_orca(
    out_file=benzocaine_orca_out,
    overlap_thresh=overlap_threshold,
    target_features={f"dummy{i}": 0.0 for i in range(6)})[0]

In [None]:
###############################################################################
# Note, the resulting graph can be saved for example in .tfrecord format:
#
# record_options = tf.io.TFRecordOptions(compression_type='GZIP')
# with tf.io.TFRecordWriter(os.path.join(tfrecords_path, "Benzocaine.tfrecord"),
#                           options=record_options) as writer:
#         example = tfgnn.write_example(benzocaine_graph)
#         writer.write(example.SerializeToString())
#
# And later restored:
#
# from data_utils import get_decode_fn
#
# graph_schema = read_schema(schema_path)
# decode_fn = get_decode_fn(graph_spec=graph_schema)
# benzocaine_data = tf.data.TFRecordDataset(
#     [os.path.join(tfrecords_path, "Benzocaine.tfrecord")],
#     compression_type="GZIP").map(decode_fn)
#
# The data_example/tfrecords folder contains pre-built molecular graphs
# for all compounds from the Solubility Challenge (2008) test set

In [4]:
###############################################################################
# Now lets convert the benzocaine graph into a format, compatible with saved
# Tensorflow models i.e a bunch of Tensors.
#
# As a side note, while the saved TF models (serving especially) generally
# prefer the most basic data structures as inputs, it might seem redundant
# to convert quantum chemical data to graphs first. Indeed, all the input
# tensors such as the ones containing nodes embeddings or adjacency matrices
# can be obtained directly. However, I still prefer to use graphs as an
# intermediate step to guarantee that the order of data is preserved.

schema = read_schema(schema_path=schema_path)
mapping_layer = GraphToTensorsLayer(graph_schema=schema)

benzocaine_graph_mapped = mapping_layer(graph=benzocaine_graph)

In [56]:
###############################################################################
# Finally, we can predict the solubility of benzocaine. Let's pick arbitrary
# model from models folder. The folder contains 94 models saved after
# leave-one-out validation. The subfolder names indicate CAS numbers of
# training compounds excluded from the training set in a particular fitting run.

model_path = os.path.join(models_path, "113-59-7/1")
model = tf.saved_model.load(model_path)
infer = model.signatures['serving_default']
benzocaine_prediction = infer(**benzocaine_graph_mapped)
benzocaine_sol = benzocaine_prediction['target'][0].numpy()[0]

print(f'Benzocaine intrinsic aqueous solubility (log \u03BCM):\n'
      f'Predicted: {benzocaine_sol:.2f}, '
      f'Reported: 3.81, '
      f'Error: {abs(benzocaine_sol - 3.81):.2f}')

Benzocaine intrinsic aqueous solubility (log μM):
Predicted: 3.89, Reported: 3.81, Error: 0.08


In [7]:
###############################################################################
# The predicted solubility deviates from the reported by only 20% in real values.
# But we can still improve it by averaging predictions of an ensemble of models.
# It will also help to reduce the impact of models instability, which can show
# off in deep GNNs.
# Let's serve all the models using Tensorflow Model Server. Pull docker image
# and run it.
#
# docker pull tensorflow/serving
# docker run -p 8501:8501 --name=rhnet2 --mount type=bind,source=$(pwd)/models,target=/models --mount type=bind,source=$(pwd)/models.config,target=/models.config -t tensorflow/serving --model_config_file=/models.config
#
# Now that the service is up, we need to convert data into a json-compatible format:

benzocaine_graph_rest = graph2rest(graph=benzocaine_graph,
                                   mapping_layer=mapping_layer)

# Define a function to make requests:
def predict_request(graph_json, model_name):
    headers = {"content-type": "application/json"}
    data = {"signature_name": "serving_default",
            "inputs": graph_json}
    json_response = requests.post(f'http://localhost:8501/v1/models/{model_name}:predict',
                                  json=data, headers=headers)
    return {'S0': json.loads(json_response.text)['outputs']['target'][0][0],
            'melt': json.loads(json_response.text)['outputs']['melt'][0][0],
            'logK': json.loads(json_response.text)['outputs']['logk'][0][0]}

# And test the service:
benzocaine_sol = predict_request(benzocaine_graph_rest, 'model_0')['S0']

print(f'Benzocaine intrinsic aqueous solubility (log \u03BCM):\n'
      f'Predicted: {benzocaine_sol:.2f}, '
      f'Reported: 3.81, '
      f'Error: {abs(benzocaine_sol - 3.81):.2f}')

Benzocaine intrinsic aqueous solubility (log μM):
Predicted: 3.89, Reported: 3.81, Error: 0.08


In [8]:
###############################################################################
# Finally, make ensemble prediction:

predictions = [predict_request(benzocaine_graph_rest, f'model_{i}') for i in range(94)]
benzocaine_sol = sum([prediction['S0'] for prediction in predictions]) / 94

print(f'Benzocaine intrinsic aqueous solubility (log \u03BCM):\n'
      f'Predicted (ensemble): {benzocaine_sol:.2f}, '
      f'Reported: 3.81, '
      f'Error: {abs(benzocaine_sol - 3.81):.2f}')

Benzocaine intrinsic aqueous solubility (log μM):
Predicted (ensemble): 3.83, Reported: 3.81, Error: 0.02


In [10]:
###############################################################################
# Simultaneously with solubility, properties like melting point and Kow are
# predicted. These were introduced primarily for regularization purpose.
# The models were not optimized to perform well for anything but solubility.
# However, let's look at the predicted values:

benzocaine_melt = sum([prediction['melt'] for prediction in predictions]) / 94 * 100
benzocaine_kow = sum([prediction['logK'] for prediction in predictions]) / 94
print(f'Benzocaine melting point (\u2070C):\n'
      f'Predicted: {benzocaine_melt:.1f}, '
      f'Reported: 89.9, '
      f'Error: {abs(benzocaine_melt - 89.9):.1f}\n\n'
      f'Benzocaine logK_ow:\n'
      f'Predicted: {benzocaine_kow:.2f}, '
      f'Reported: 1.86, '
      f'Error: {abs(benzocaine_kow - 1.86):.2f}')

Benzocaine melting point (⁰C):
Predicted: 132.1, Reported: 89.9, Error: 42.2

Benzocaine logK_ow:
Predicted: 2.00, Reported: 1.86, Error: 0.14


In [18]:
###############################################################################
# Available pre-calculated compounds (SC1 test set)

test_compounds = [
  "Acebutolol",          "Amoxicillin",       "Bendroflumethiazide", "Benzocaine",
  "Benzthiazide",        "Clozapine",         "Dibucaine",           "Diethylstilbestrol",
  "Diflunisal",          "Dipyridamole",      "Folic_acid",          "Furosemide",
  "Hydrochlorothiazide", "Imipramine",        "Indomethacin",        "Ketoprofen",
  "Lidocaine",           "Meclofenamic_acid", "Naphthoic_acid",      "Probenecid",
  "Pyrimethamine",       "Salicylic_acid",    "Sulfamerazine",       "Sulfamethizole",
  "Terfenadine",         "Thiabendazole",     "Tolbutamide",         "Trazodone"]