### Packages

In [None]:
!pip install spektral
!pip install lime
!pip install rdkit
!pip install py3Dmol
!pip install Bio

Collecting spektral
  Downloading spektral-1.3.1-py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting lxml (from spektral)
  Downloading lxml-5.2.2-cp310-cp310-manylinux_2_28_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lxml, spektral
Successfully installed lxml-5.2.2 spektral-1.3.1
Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283835 sha256=95c77f9a883e1b511738fb63bfb0f656c4f6ba5a2031af05459fd7eef3

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

from keras.layers import concatenate
from tensorflow.keras import backend as K

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Embedding, LSTM, Dropout, BatchNormalization, Dense, concatenate, Reshape, Flatten, Attention, Bidirectional, MultiHeadAttention
from tensorflow.keras.optimizers import Adam, AdamW
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

from tensorflow.keras.metrics import RootMeanSquaredError
from tensorflow.keras.regularizers import l2
from tensorflow.keras.initializers import GlorotUniform

from spektral.layers import GCNConv, GlobalSumPool, GATConv
from scipy.sparse import csr_matrix

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D

import lime
from lime.lime_tabular import LimeTabularExplainer

from IPython.display import display, Image

from collections import defaultdict

import py3Dmol

from Bio import SeqIO
from Bio.pairwise2 import align



### Data Preprocessing

#### Data Optimization

In [None]:
df_lipinski = pd.read_csv('/content/drive/MyDrive/Dataframes/df_lipinski.csv')
df_lipinski.head(2)

Unnamed: 0.1,Unnamed: 0,interval,subset,docking_score,pdb_id,zinc_id,smiles,molecular_weight,logP,numH_donors,numH_acceptors
0,0,"(-15.0, -14.0]",validation,-14.881847,6IIU,ZINC001129722346,C#C[C@@H](NC(=O)[C@@H]1CCCN(c2nc3ccccc3s2)C1)c...,409.942,4.6568,1,4
1,1,"(-15.0, -14.0]",validation,-14.196672,6IIU,ZINC001600492567,Cc1ccc2c(CN3[C@@H]4C[C@H](C(=O)O)O[C@H]4CC[C@H...,357.406,2.69642,1,5


In [None]:
df_fingerprints_opt = pd.read_pickle('/content/drive/MyDrive/Dataframes/df_fingerprints_extended_connectivity_opt.pkl')
df_fingerprints_opt.head(2)

Unnamed: 0,docking_score,extended_connectivity_fps,encoded_seq
0,-14.881847,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."
1,-14.196672,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."


In [None]:
graphs = pd.read_pickle('/content/drive/MyDrive/Dataframes/df_graphs_opt.pkl')
graphs.head(2)

Unnamed: 0,docking_score,node_features,edge_features,adjacency_matrix,encoded_seq
0,-14.881847,"[6, 6, 6, 7, 6, 8, 6, 6, 6, 6, 7, 6, 7, 6, 6, ...","[3.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."
1,-14.196672,"[6, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 8, 8, 8, ...","[1.0, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."


In [None]:
df_receptors = pd.read_pickle('/content/drive/MyDrive/Dataframes/df_receptors.pkl')
df_receptors.head(2)

Unnamed: 0,interval,subset,docking_score,pdb_id,zinc_id,smiles,sequence,encoded_seq
0,"(-15.0, -14.0]",validation,-14.881847,6IIU,ZINC001129722346,C#C[C@@H](NC(=O)[C@@H]1CCCN(c2nc3ccccc3s2)C1)c...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."
1,"(-15.0, -14.0]",validation,-14.196672,6IIU,ZINC001600492567,Cc1ccc2c(CN3[C@@H]4C[C@H](C(=O)O)O[C@H]4CC[C@H...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ..."


In [None]:
df_fusion = pd. DataFrame()

df_fusion['pdb_id'] = df_receptors['pdb_id']
df_fusion['zinc_id'] = df_receptors['zinc_id']
df_fusion['docking_score'] = df_receptors['docking_score']

df_fusion['smiles'] = df_receptors['smiles']
df_fusion['sequence'] = df_receptors['sequence']
df_fusion['encoded_seq'] = df_receptors['encoded_seq']

df_fusion['molecular_weight'] = df_lipinski['molecular_weight']
df_fusion['logP'] = df_lipinski['logP']
df_fusion['numH_donors'] = df_lipinski['numH_donors']
df_fusion['numH_acceptors'] = df_lipinski['numH_acceptors']

df_fusion['extended_connectivity_fps'] = df_fingerprints_opt['extended_connectivity_fps']

df_fusion['node_features'] = graphs['node_features']
df_fusion['edge_features'] = graphs['edge_features']
df_fusion['adjacency_matrix'] = graphs['adjacency_matrix']

df_fusion.head()

Unnamed: 0,pdb_id,zinc_id,docking_score,smiles,sequence,encoded_seq,molecular_weight,logP,numH_donors,numH_acceptors,extended_connectivity_fps,node_features,edge_features,adjacency_matrix
0,6IIU,ZINC001129722346,-14.881847,C#C[C@@H](NC(=O)[C@@H]1CCCN(c2nc3ccccc3s2)C1)c...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",409.942,4.6568,1,4,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 6, 7, 6, 8, 6, 6, 6, 6, 7, 6, 7, 6, 6, ...","[3.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,6IIU,ZINC001600492567,-14.196672,Cc1ccc2c(CN3[C@@H]4C[C@H](C(=O)O)O[C@H]4CC[C@H...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",357.406,2.69642,1,5,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[6, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 8, 8, 8, ...","[1.0, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
2,6IIU,ZINC001320490156,-14.227941,C[C@H](CCNCC1CCCC1)NC(=O)[C@@H]1COc2ccc(F)cc2C1,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",348.462,3.0514,2,3,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[6, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6, 7, 6, 8, 6, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,6IIU,ZINC001702816479,-14.018111,C[C@H](C#N)C(=O)NC[C@@H]1CN(C(=O)c2c(F)c(F)cc(...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",371.334,2.22698,1,3,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 7, 6, 8, 6, 6, ...","[1.0, 1.0, 3.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,6IIU,ZINC000546135608,-14.656972,CCN1c2ccc(NC(=O)NCc3c(C)cc(C)cc3C)cc2CCC1=O,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",365.477,4.23266,2,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 7, 6, 6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 6, ...","[1.0, 1.0, 1.0, 1.5, 1.5, 1.5, 1.0, 1.0, 2.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [None]:
df_fusion.to_pickle("/content/drive/MyDrive/Dataframes/df_fusion.pkl")

#### Data Processing

In [None]:
df_fusion = pd.read_pickle('/content/drive/MyDrive/Dataframes/df_fusion.pkl')
df_fusion.head(2)

Unnamed: 0,pdb_id,zinc_id,docking_score,smiles,sequence,encoded_seq,molecular_weight,logP,numH_donors,numH_acceptors,extended_connectivity_fps,node_features,edge_features,adjacency_matrix
0,6IIU,ZINC001129722346,-14.881847,C#C[C@@H](NC(=O)[C@@H]1CCCN(c2nc3ccccc3s2)C1)c...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",409.942,4.6568,1,4,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 6, 7, 6, 8, 6, 6, 6, 6, 7, 6, 7, 6, 6, ...","[3.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,6IIU,ZINC001600492567,-14.196672,Cc1ccc2c(CN3[C@@H]4C[C@H](C(=O)O)O[C@H]4CC[C@H...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",357.406,2.69642,1,5,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[6, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 8, 8, 8, ...","[1.0, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [None]:
len(df_fusion)

1200000

In [None]:
node_features_lengths = df_fusion['node_features'].apply(len)
edge_features_lengths = df_fusion['edge_features'].apply(len)
adjacency_matrix_lengths = df_fusion['adjacency_matrix'].apply(len)

filtered_df = df_fusion[
    (node_features_lengths >= 20) &
    (node_features_lengths <= 35) &
    (edge_features_lengths >= 20) &
    (edge_features_lengths <= 35) &
    (adjacency_matrix_lengths >= 20) &
    (adjacency_matrix_lengths <= 35)
]

filtered_df.head()

Unnamed: 0,pdb_id,zinc_id,docking_score,smiles,sequence,encoded_seq,molecular_weight,logP,numH_donors,numH_acceptors,extended_connectivity_fps,node_features,edge_features,adjacency_matrix
0,6IIU,ZINC001129722346,-14.881847,C#C[C@@H](NC(=O)[C@@H]1CCCN(c2nc3ccccc3s2)C1)c...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",409.942,4.6568,1,4,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 6, 7, 6, 8, 6, 6, 6, 6, 7, 6, 7, 6, 6, ...","[3.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,6IIU,ZINC001600492567,-14.196672,Cc1ccc2c(CN3[C@@H]4C[C@H](C(=O)O)O[C@H]4CC[C@H...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",357.406,2.69642,1,5,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[6, 6, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 8, 8, 8, ...","[1.0, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
2,6IIU,ZINC001320490156,-14.227941,C[C@H](CCNCC1CCCC1)NC(=O)[C@@H]1COc2ccc(F)cc2C1,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",348.462,3.0514,2,3,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[6, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6, 7, 6, 8, 6, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,6IIU,ZINC001702816479,-14.018111,C[C@H](C#N)C(=O)NC[C@@H]1CN(C(=O)c2c(F)c(F)cc(...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",371.334,2.22698,1,3,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 7, 6, 8, 6, 6, ...","[1.0, 1.0, 3.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,6IIU,ZINC000546135608,-14.656972,CCN1c2ccc(NC(=O)NCc3c(C)cc(C)cc3C)cc2CCC1=O,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",365.477,4.23266,2,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 7, 6, 6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 6, ...","[1.0, 1.0, 1.0, 1.5, 1.5, 1.5, 1.0, 1.0, 2.0, ...","[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [None]:
len(filtered_df)

1054642

In [None]:
pdb_id = np.array(filtered_df['pdb_id'].tolist())
zinc_id = np.array(filtered_df['zinc_id'].tolist())
smiles =  np.array(filtered_df['smiles'].tolist())

sequence = np.array(filtered_df['sequence'].tolist())
X_text = np.array(filtered_df['encoded_seq'].tolist())

X_numerical = filtered_df[['molecular_weight', 'logP', 'numH_donors', 'numH_acceptors']].values
scaler = StandardScaler()
X_numerical = scaler.fit_transform(X_numerical)

X_extended_connectivity = np.array(filtered_df['extended_connectivity_fps'].tolist())

y = filtered_df['docking_score'].values

In [None]:
padding_length = max([len(seq) for seq in filtered_df['adjacency_matrix']])
filtered_matrices = []

for row in filtered_df['adjacency_matrix']:
    padded_row = np.pad(row[:padding_length], (0, max(0, padding_length - len(row))), 'constant', constant_values=0)
    filtered_matrices.append(padded_row)

X_adjacency_matrix = np.array(filtered_matrices, dtype='float32')
print("Shape of filtered adjacency matrix:", X_adjacency_matrix.shape)

Shape of filtered adjacency matrix: (1054642, 35, 35)


In [None]:
X_node_features = pad_sequences(filtered_df['node_features'], padding='post', dtype='float32')
X_edge_features = pad_sequences(filtered_df['edge_features'], padding='post', dtype='float32')

In [None]:
pdb_id_train, pdb_id_rest, zinc_id_train, zinc_id_rest, smiles_train, smiles_rest, sequence_train, sequence_rest, X_text_train, X_text_rest, X_numerical_train, X_numerical_rest, X_extended_connectivity_train, X_extended_connectivity_rest, X_node_train, X_node_rest, X_edge_train, X_edge_rest, X_adjacency_train, X_adjacency_rest, y_train, y_rest = train_test_split(
    pdb_id, zinc_id, smiles, sequence, X_text, X_numerical, X_extended_connectivity, X_node_features, X_edge_features, X_adjacency_matrix, y, test_size=0.3, random_state=42
)

pdb_id_val, pdb_id_test, zinc_id_val, zinc_id_test, smiles_val, smiles_test, sequence_val, sequence_test, X_text_val, X_text_test, X_numerical_val, X_numerical_test, X_extended_connectivity_val, X_extended_connectivity_test, X_node_val, X_node_test, X_edge_val, X_edge_test, X_adjacency_val, X_adjacency_test, y_val, y_test = train_test_split(
     pdb_id_rest, zinc_id_rest, smiles_rest, sequence_rest, X_text_rest, X_numerical_rest, X_extended_connectivity_rest, X_node_rest, X_edge_rest, X_adjacency_rest, y_rest, test_size=0.5, random_state=42
)

In [None]:
all_tokens = set()
for seq in filtered_df['encoded_seq']:
    all_tokens.update(seq)

vocab_size = len(all_tokens) + 1
embedding_dim = int(np.ceil(np.log2(vocab_size)))
vocab_size, embedding_dim

(22, 5)

### TPU Parallelism

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
    print("Running on TPU")
except ValueError:
    strategy = tf.distribute.get_strategy()
    print("Running on CPU/GPU")

print("Number of accelerators: ", strategy.num_replicas_in_sync)

Running on TPU
Number of accelerators:  8


### AI Models

#### Descriptors

In [None]:
def r2_score(y_true, y_pred):
    SS_res =  K.sum(K.square(y_true - y_pred))
    SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
    return (1 - SS_res/(SS_tot + K.epsilon()))

In [None]:
print("Number of accelerators: ", strategy.num_replicas_in_sync)

with strategy.scope():
    # Inputs
    text_input = Input(shape=(X_text.shape[1],))
    numerical_input = Input(shape=(X_numerical.shape[1],))

    # Text embedding and processing
    embedded_text = Embedding(input_dim=vocab_size, output_dim=embedding_dim, embeddings_initializer=GlorotUniform(seed=42))(text_input)
    lstm_output = LSTM(units=128)(embedded_text)
    lstm_output = Dropout(0.5)(lstm_output)
    lstm_output = BatchNormalization()(lstm_output)

    # Concatenate LSTM output with numerical input
    concatenated_text_numerical = concatenate([lstm_output, numerical_input])
    #concatenated_text_numerical = Dropout(0.5)(concatenated_text_numerical)
    #concatenated_text_numerical = BatchNormalization()(concatenated_text_numerical)
    output = Dense(1, kernel_regularizer=l2(0.01))(concatenated_text_numerical)

    # Model definition
    descriptors_model = Model(inputs=[text_input, numerical_input], outputs=output)

    # Optimizer and compilation
    optimizer = AdamW(learning_rate=0.0001, clipnorm=1.0)
    descriptors_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=[RootMeanSquaredError(), r2_score])

# Callbacks
early_stopping_loss = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
early_stopping_r2 = EarlyStopping(monitor='val_r2_score', mode='max', patience=3, restore_best_weights=True)
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
model_checkpoint = ModelCheckpoint('/content/drive/MyDrive/Saved_Models/descriptors_best_model.h5', save_best_only=True, monitor='val_loss', mode='min')

# Training
descriptors_model.fit([X_text_train, X_numerical_train], y_train,
                validation_data=([X_text_val, X_numerical_val], y_val),
                epochs=100, batch_size=256, callbacks=[early_stopping_loss, early_stopping_r2, reduce_lr_loss, model_checkpoint])

# Evaluation
evaluation_results = descriptors_model.evaluate([X_text_val, X_numerical_val], y_val)
descriptors_loss, descriptors_rmse, descriptors_r2 = evaluation_results

print("Descriptors Model Val Loss:", descriptors_loss)
print("Descriptors Model Val RMSE:", descriptors_rmse)
print("Descriptors Model Val R2 Score:", descriptors_r2)

Number of accelerators:  8
Epoch 1/100

  saving_api.save_model(


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Descriptors Model Val Loss: 4.094377040863037
Descriptors Model Val RMSE: 1.996248483657837
Descriptors Model Val R2 Score: -0.11734286695718765


In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
y_pred = descriptors_model.predict([X_text_test, X_numerical_test])

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error (MSE):", mse)

mae = mean_absolute_error(y_test, y_pred)
print("Mean Absolute Error (MAE):", mae)

r2 = r2_score(y_test, y_pred)
print("R-squared (R2):", r2)

Mean Squared Error (MSE): 4.000508282406109
Mean Absolute Error (MAE): 1.6475300583407293
R-squared (R2): 0.6426270484969454


#### Fingerprints

In [None]:
def r2_score(y_true, y_pred):
    SS_res =  K.sum(K.square(y_true - y_pred))
    SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
    return (1 - SS_res/(SS_tot + K.epsilon()))

In [None]:
print("Number of accelerators: ", strategy.num_replicas_in_sync)

with strategy.scope():
    # Inputs
    text_input = Input(shape=(X_text.shape[1],))
    extended_connectivity_input = Input(shape=(X_extended_connectivity.shape[1],))

    # Text embedding and processing
    embedded_text = Embedding(input_dim=vocab_size, output_dim=embedding_dim, embeddings_initializer=GlorotUniform(seed=42))(text_input)
    lstm_text = Bidirectional(LSTM(units=128, return_sequences=True))(embedded_text)
    attention = MultiHeadAttention(num_heads=4, key_dim=128)(query=lstm_text, key=lstm_text, value=lstm_text)
    lstm_text = Flatten()(attention)
    lstm_text = Dropout(0.5)(lstm_text)
    lstm_text = BatchNormalization()(lstm_text)

    # Extended connectivity processing
    dense_extended_connectivity = Dense(128, activation='relu', kernel_regularizer=l2(0.01))(extended_connectivity_input)
    dense_extended_connectivity = Dropout(0.5)(dense_extended_connectivity)
    dense_extended_connectivity = BatchNormalization()(dense_extended_connectivity)

    # Concatenate processed text and extended connectivity features
    concatenated_text_extended = concatenate([lstm_text, dense_extended_connectivity])
    concatenated_text_extended = Dropout(0.5)(concatenated_text_extended)
    concatenated_text_extended = BatchNormalization()(concatenated_text_extended)
    output = Dense(1, kernel_regularizer=l2(0.01))(concatenated_text_extended)

    # Model definition
    fingerprints_model = Model(inputs=[text_input, extended_connectivity_input], outputs=output)

    # Optimizer and compilation
    optimizer = AdamW(learning_rate=0.0001, clipnorm=1.0)
    fingerprints_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=[RootMeanSquaredError(), r2_score])


# Callbacks
early_stopping_loss = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
early_stopping_r2 = EarlyStopping(monitor='val_r2_score', mode='max', patience=3, restore_best_weights=True)
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
model_checkpoint = ModelCheckpoint('/content/drive/MyDrive/Saved_Models/fingerprints_best_model.h5', save_best_only=True, monitor='val_loss', mode='min')

# Training
fingerprints_model.fit([X_text_train, X_extended_connectivity_train], y_train,
                validation_data=([X_text_val, X_extended_connectivity_val], y_val),
                epochs=100, batch_size=256, callbacks=[early_stopping_loss, early_stopping_r2, reduce_lr_loss, model_checkpoint])

# Evaluation
evaluation_results = fingerprints_model.evaluate([X_text_val, X_extended_connectivity_val], y_val)
fingerprints_loss, fingerprints_rmse, fingerprints_r2 = evaluation_results

print("Fingerprints Model Val Loss:", fingerprints_loss)
print("Fingerprints Model Val RMSE:", fingerprints_rmse)
print("Fingerprints Model Val R2 Score:", fingerprints_r2)

Number of accelerators:  8
Epoch 1/100

  saving_api.save_model(


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Fingerprints Model Val Loss: 1.1787914037704468
Fingerprints Model Val RMSE: 1.0752952098846436
Fingerprints Model Val R2 Score: 0.6563764810562134


In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
y_pred = fingerprints_model.predict([X_text_test, X_extended_connectivity_test])

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error (MSE):", mse)

mae = mean_absolute_error(y_test, y_pred)
print("Mean Absolute Error (MAE):", mae)

r2 = r2_score(y_test, y_pred)
print("R-squared (R2):", r2)

Mean Squared Error (MSE): 1.160435495311947
Mean Absolute Error (MAE): 0.898628012459866
R-squared (R2): 0.8963361081359609


#### Graph

In [None]:
def r2_score(y_true, y_pred):
    SS_res =  K.sum(K.square(y_true - y_pred))
    SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
    return (1 - SS_res/(SS_tot + K.epsilon()))

In [None]:
print("Number of accelerators: ", strategy.num_replicas_in_sync)

with strategy.scope():
    # Inputs
    text_input = Input(shape=(X_text.shape[1],))
    node_input = Input(shape=(X_node_features.shape[1],))
    edge_input = Input(shape=(X_edge_features.shape[1],))
    adjacency_input = Input(shape=(X_adjacency_matrix.shape[1], X_adjacency_matrix.shape[2]))

    # Text embedding and processing
    embedded_text = Embedding(input_dim=vocab_size, output_dim=embedding_dim, embeddings_initializer=GlorotUniform(seed=42))(text_input)
    lstm_text = Bidirectional(LSTM(units=128, return_sequences=True))(embedded_text)
    attention = MultiHeadAttention(num_heads=4, key_dim=128)(query=lstm_text, key=lstm_text, value=lstm_text)
    lstm_text = Flatten()(attention)
    lstm_text = Dropout(0.5)(lstm_text)
    lstm_text = BatchNormalization()(lstm_text)

    # Node input reshaping
    node_input_reshaped = Reshape((X_adjacency_matrix.shape[1], 1))(node_input)

    # Graph Attention Network (GAT)
    gat_layer = GATConv(channels=128, attn_heads=4, dropout_rate=0.5)([node_input_reshaped, adjacency_input])
    gat_layer = Flatten()(gat_layer)
    gat_layer = BatchNormalization()(gat_layer)

    # Graph Convolutional Network (GCN)
    gcn_layer = GCNConv(channels=128)([node_input_reshaped, adjacency_input])
    gcn_layer = Flatten()(gcn_layer)
    gcn_layer = Dropout(0.5)(gcn_layer)
    gcn_layer = BatchNormalization()(gcn_layer)

    # Combine GAT and GCN outputs with edge features and text features
    combined_gat_gcn = concatenate([gat_layer, gcn_layer, edge_input, lstm_text])

    # Final combination and output
    concatenated_node_edge_text = concatenate([combined_gat_gcn, edge_input, lstm_text])
    concatenated_node_edge_text = Dropout(0.5)(concatenated_node_edge_text)  # Add another dropout layer
    concatenated_node_edge_text = BatchNormalization()(concatenated_node_edge_text)  # Add another batch normalization layer
    output = Dense(1, kernel_regularizer=l2(0.01))(concatenated_node_edge_text)  # Increase L2 regularization

    # Model definition
    graph_model = Model(inputs=[text_input, node_input, edge_input, adjacency_input], outputs=output)

    # Optimizer and compilation
    optimizer = AdamW(learning_rate=0.0001, clipnorm=1.0)
    graph_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=[RootMeanSquaredError(), r2_score])

# Callbacks
early_stopping_loss = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
early_stopping_r2 = EarlyStopping(monitor='val_r2_score', mode='max', patience=3, restore_best_weights=True)
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
model_checkpoint = ModelCheckpoint('/content/drive/MyDrive/Saved_Models/graph_best_model.h5', save_best_only=True, monitor='val_loss', mode='min')

# Training
graph_model.fit([X_text_train, X_node_train, X_edge_train, X_adjacency_train], y_train,
                validation_data=([X_text_val, X_node_val, X_edge_val, X_adjacency_val], y_val),
                epochs=100, batch_size=256, callbacks=[early_stopping_loss, early_stopping_r2, reduce_lr_loss, model_checkpoint])

# Evaluation
evaluation_results = graph_model.evaluate([X_text_val, X_node_val, X_edge_val, X_adjacency_val], y_val)
graph_loss, graph_rmse, graph_r2 = evaluation_results

print("Graph Model Val Loss:", graph_loss)
print("Graph Model Val RMSE:", graph_rmse)
print("Graph Model Val R2 Score:", graph_r2)

Number of accelerators:  8
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Graph Model Val Loss: 0.7748922109603882
Graph Model Val RMSE: 0.8662426471710205
Graph Model Val R2 Score: 0.7904540300369263


In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
y_pred = graph_model.predict([X_text_test, X_node_test, X_edge_test, X_adjacency_test])

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error (MSE):", mse)

mae = mean_absolute_error(y_test, y_pred)
print("Mean Absolute Error (MAE):", mae)

r2 = r2_score(y_test, y_pred)
print("R-squared (R2):", r2)

Mean Squared Error (MSE): 0.7693927145610397
Mean Absolute Error (MAE): 0.6661486760419947
R-squared (R2): 0.9312686974110572


#### Early Fusion

In [None]:
def r2_score(y_true, y_pred):
    SS_res =  K.sum(K.square(y_true - y_pred))
    SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
    return (1 - SS_res/(SS_tot + K.epsilon()))

In [None]:
print("Number of accelerators: ", strategy.num_replicas_in_sync)

with strategy.scope():
    # Inputs
    text_input = Input(shape=(X_text.shape[1],))
    node_input = Input(shape=(X_node_features.shape[1],))
    edge_input = Input(shape=(X_edge_features.shape[1],))
    adjacency_input = Input(shape=(X_adjacency_matrix.shape[1], X_adjacency_matrix.shape[2]))
    extended_connectivity_input = Input(shape=(X_extended_connectivity.shape[1],))
    numerical_input = Input(shape=(X_numerical.shape[1],))

    # Text embedding and processing
    embedded_text = Embedding(input_dim=vocab_size, output_dim=embedding_dim, embeddings_initializer=GlorotUniform(seed=42))(text_input)
    lstm_text = Bidirectional(LSTM(units=128, return_sequences=True))(embedded_text)
    attention = MultiHeadAttention(num_heads=4, key_dim=128)(query=lstm_text, key=lstm_text, value=lstm_text)
    lstm_text = Flatten()(attention)
    lstm_text = Dropout(0.5)(lstm_text)
    lstm_text = BatchNormalization()(lstm_text)

    # Node input reshaping
    node_input_reshaped = Reshape((X_adjacency_matrix.shape[1], 1))(node_input)

    # Graph Attention Network (GAT)
    gat_layer = GATConv(channels=128, attn_heads=4)([node_input_reshaped, adjacency_input])
    gat_layer = Flatten()(gat_layer)
    gat_layer = Dropout(0.5)(gat_layer)
    gat_layer = BatchNormalization()(gat_layer)

    # Graph Convolutional Network (GCN)
    gcn_layer = GCNConv(channels=128)([node_input_reshaped, adjacency_input])
    gcn_layer = Flatten()(gcn_layer)
    gcn_layer = Dropout(0.5)(gcn_layer)
    gcn_layer = BatchNormalization()(gcn_layer)

    combined_gat_gcn = concatenate([gat_layer, gcn_layer, edge_input])

    # Extended connectivity processing
    dense_extended_connectivity = Dense(128, activation='relu', kernel_regularizer=l2(0.01))(extended_connectivity_input)

    dense_extended_connectivity = Dropout(0.5)(dense_extended_connectivity)
    dense_extended_connectivity = BatchNormalization()(dense_extended_connectivity)

    # Concatenate processed inputs from all modalities
    concatenated_inputs = concatenate([lstm_text, combined_gat_gcn, dense_extended_connectivity, numerical_input])
    concatenated_inputs = Dense(128, activation='relu')(concatenated_inputs)
    concatenated_inputs = Dropout(0.5)(concatenated_inputs)
    concatenated_inputs = BatchNormalization()(concatenated_inputs)

    # Final output layer
    final_output = Dense(1, kernel_regularizer=l2(0.01))(concatenated_inputs)

    # Define the model
    early_fusion_model = Model(inputs=[text_input, node_input, edge_input, adjacency_input, extended_connectivity_input, numerical_input], outputs=final_output)

    # Optimizer and compilation
    optimizer = AdamW(learning_rate=0.0001, clipnorm=1.0)
    early_fusion_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=[RootMeanSquaredError(), r2_score])

# Callbacks
early_stopping_loss = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
early_stopping_r2 = EarlyStopping(monitor='val_r2_score', mode='max', patience=3, restore_best_weights=True)
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
model_checkpoint = ModelCheckpoint('/content/drive/MyDrive/Saved_Models/early_fusion_best_model.h5', save_best_only=True, monitor='val_loss', mode='min')

# Training
early_fusion_model.fit([X_text_train, X_node_train, X_edge_train, X_adjacency_train, X_extended_connectivity_train, X_numerical_train], y_train,
                       validation_data=([X_text_val, X_node_val, X_edge_val, X_adjacency_val, X_extended_connectivity_val, X_numerical_val], y_val),
                       epochs=100, batch_size=256, callbacks=[early_stopping_loss, early_stopping_r2, reduce_lr_loss, model_checkpoint])

# Evaluation
evaluation_results = early_fusion_model.evaluate([X_text_val, X_node_val, X_edge_val, X_adjacency_val, X_extended_connectivity_val, X_numerical_val], y_val)
early_fusion_loss, early_fusion_rmse, early_fusion_r2 = evaluation_results

print("Early Fusion Model Val Loss:", early_fusion_loss)
print("Early Fusion Model Val RMSE:", early_fusion_rmse)
print("Early Fusion Model Val R2 Score:", early_fusion_r2)

Number of accelerators:  8




Epoch 1/100

  saving_api.save_model(


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Early Fusion Model Val Loss: 0.6943427324295044
Early Fusion Model Val RMSE: 0.8140344619750977
Early Fusion Model Val R2 Score: 0.8091449737548828


In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
y_pred = early_fusion_model.predict([X_text_test, X_node_test, X_edge_test, X_adjacency_test, X_extended_connectivity_test, X_numerical_test])

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error (MSE):", mse)

mae = mean_absolute_error(y_test, y_pred)
print("Mean Absolute Error (MAE):", mae)

r2 = r2_score(y_test, y_pred)
print("R-squared (R2):", r2)

Mean Squared Error (MSE): 0.6643379402861843
Mean Absolute Error (MAE): 0.6474200329574276
R-squared (R2): 0.9406534385743754


#### Late Fusion

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Saved models
model_1 = load_model('/content/drive/MyDrive/Saved_Models/graph_best_model.h5', custom_objects={'r2_score': r2_score, 'GCNConv': GCNConv, 'GATConv': GATConv})
model_2 = load_model('/content/drive/MyDrive/Saved_Models/fingerprints_best_model.h5', custom_objects={'r2_score': r2_score})
model_3 = load_model('/content/drive/MyDrive/Saved_Models/descriptors_best_model.h5', custom_objects={'r2_score': r2_score})

# Predictions
preds_1_test = model_1.predict([X_text_test, X_node_test, X_edge_test, X_adjacency_test])
preds_2_test = model_2.predict([X_text_test, X_extended_connectivity_test])
preds_3_test = model_3.predict([X_text_test, X_numerical_test])

# Combine predictions (Averaging)
final_preds_test = (preds_1_test + preds_2_test + preds_3_test) / 3

# Evaluation
mse = mean_squared_error(y_test, final_preds_test)
rmse = np.sqrt(mse)
r2 = r2_score(y_test, final_preds_test)

print("Late Fusion Model Val MSE:", mse)
print("Late Fusion Model Val RMSE:", rmse)
print("Late Fusion Model Val R2 Score:", r2)



Late Fusion Model Val MSE: 1.1053051114446097
Late Fusion Model Val RMSE: 1.0513349187792678
Late Fusion Model Val R2 Score: 0.901261009325846


In [None]:
# Saved models
#model_1 = load_model('/content/drive/MyDrive/Saved_Models/graph_best_model.h5', custom_objects={'r2_score': r2_score, 'GCNConv': GCNConv, 'GATConv': GATConv})
#model_2 = load_model('/content/drive/MyDrive/Saved_Models/fingerprints_best_model.h5', custom_objects={'r2_score': r2_score})
#model_3 = load_model('/content/drive/MyDrive/Saved_Models/descriptors_best_model.h5', custom_objects={'r2_score': r2_score})

# Predictions
#preds_1_test = model_1.predict([X_text_test, X_node_test, X_edge_test, X_adjacency_test])
#preds_2_test = model_2.predict([X_text_test, X_extended_connectivity_test])
#preds_3_test = model_3.predict([X_text_test, X_numerical_test])

# Coefficients based on R^2 scores
r2_scores = np.array([0.9312, 0.8963, 0.6426])

# Normalize R^2 scores to get coefficients
coefficients = r2_scores / np.sum(r2_scores)

# Combine predictions using coefficients
final_preds_test = (preds_1_test * coefficients[0] + preds_2_test * coefficients[1] + preds_3_test * coefficients[2])

# Evaluation
mse = mean_squared_error(y_test, final_preds_test)
rmse = np.sqrt(mse)
r2 = r2_score(y_test, final_preds_test)

print("Late Fusion Model Val MSE:", mse)
print("Late Fusion Model Val RMSE:", rmse)
print("Late Fusion Model Val R2 Score:", r2)

Late Fusion Model Val MSE: 0.9761214457987977
Late Fusion Model Val RMSE: 0.9879885858646332
Late Fusion Model Val R2 Score: 0.9128012298725362


### Predictions

In [None]:
train_df = pd.DataFrame({
    'pdb_id': pdb_id_train,
    'zinc_id': zinc_id_train,
    'docking_score': y_train,
    'smiles': smiles_train,
    'sequence': sequence_train,
    'encoded_seq': [list(seq) for seq in X_text_train],
    'molecular_weight': [x[0] for x in X_numerical_train],
    'logP': [x[1] for x in X_numerical_train],
    'numH_donors': [x[2] for x in X_numerical_train],
    'numH_acceptors': [x[3] for x in X_numerical_train],
    'extended_connectivity_fps': list(X_extended_connectivity_train),
    'node_features': [list(f) for f in X_node_train],
    'edge_features': [list(f) for f in X_edge_train],
    'adjacency_matrix': [list(f) for f in X_adjacency_train]
})
train_df.to_pickle("/content/drive/MyDrive/Dataframes/train_df.pkl")

In [None]:
test_df = pd.DataFrame({
    'pdb_id': pdb_id_test,
    'zinc_id': zinc_id_test,
    'docking_score': y_test,
    'smiles': smiles_test,
    'sequence': sequence_test,
    'encoded_seq': [list(seq) for seq in X_text_test],
    'molecular_weight': [x[0] for x in X_numerical_test],
    'logP': [x[1] for x in X_numerical_test],
    'numH_donors': [x[2] for x in X_numerical_test],
    'numH_acceptors': [x[3] for x in X_numerical_test],
    'extended_connectivity_fps': list(X_extended_connectivity_test),
    'node_features': [list(f) for f in X_node_test],
    'edge_features': [list(f) for f in X_edge_test],
    'adjacency_matrix': [list(f) for f in X_adjacency_test]
})
test_df.to_pickle("/content/drive/MyDrive/Dataframes/test_df.pkl")

In [None]:
test_df = pd.read_pickle('/content/drive/MyDrive/Dataframes/test_df.pkl')
test_df.head()

Unnamed: 0,pdb_id,zinc_id,docking_score,smiles,sequence,encoded_seq,molecular_weight,logP,numH_donors,numH_acceptors,extended_connectivity_fps,node_features,edge_features,adjacency_matrix
0,1T7R,ZINC001100776028,-4.472657,CCCCOCCNC1CC(CNC(=O)c2cccc3nc[nH]c32)C1,GSPGISGGGGGSHIEGYECQPIFLNVLEAIEPGVVCAGHDNNQPDS...,"[6, 16, 13, 6, 8, 16, 6, 6, 6, 6, 6, 16, 7, 8,...",-0.559848,-0.108683,1.745065,-0.48246,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, ...","[6.0, 6.0, 6.0, 6.0, 8.0, 6.0, 6.0, 7.0, 6.0, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
1,4F8H,ZINC000737231383,-8.28065,CCN(CC)S(=O)(=O)c1cc(NC(=O)C(=O)N[C@H](C)CCSC)...,GQDMVSPPPPIADEPLTVNTGIYLIECYSLDDKAETFKVNAFLSLS...,"[6, 14, 3, 11, 18, 16, 13, 13, 13, 13, 8, 1, 3...",0.949198,-0.301354,0.687259,0.134442,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[6.0, 6.0, 7.0, 6.0, 6.0, 16.0, 8.0, 8.0, 6.0,...","[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.5, ...","[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
2,4F8H,ZINC001723483319,-5.810636,CC(C)CN1CCO[C@H](Cn2c(C(C)C)nnc2N2Cc3cc(Br)ccc...,GQDMVSPPPPIADEPLTVNTGIYLIECYSLDDKAETFKVNAFLSLS...,"[6, 14, 3, 11, 18, 16, 13, 13, 13, 13, 8, 1, 3...",2.538595,1.523571,-1.428353,0.751344,"[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6.0, 6.0, 6.0, 6.0, 7.0, 6.0, 6.0, 8.0, 6.0, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
3,4F8H,ZINC001709265465,-5.239971,CC[C@@H]1CCCN(c2nnc(CC(C)C)n2C[C@@H](C2CCCCC2)...,GQDMVSPPPPIADEPLTVNTGIYLIECYSLDDKAETFKVNAFLSLS...,"[6, 14, 3, 11, 18, 16, 13, 13, 13, 13, 8, 1, 3...",0.398619,1.50022,-1.428353,0.134442,"[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 6.0, 7.0, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.5, ...","[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
4,6IIU,ZINC000890624425,-12.731217,O=C(c1cc2c(cc1F)NC(=O)CC2)N1Cc2ncccc2N2CCC[C@H...,DYKDDDDGAPADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAA...,"[3, 20, 9, 3, 3, 3, 3, 6, 1, 13, 1, 3, 10, 4, ...",0.203226,0.081592,-0.370547,-0.48246,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 9.0, ...","[2.0, 1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, ...","[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."


In [None]:
test_dfs_by_pdb_id = {pdb_id: test_df_group for pdb_id, test_df_group in test_df.groupby('pdb_id')}

for pdb_id in test_dfs_by_pdb_id.keys():
    globals()[f"test_df_{pdb_id}"] = test_dfs_by_pdb_id[pdb_id]

print([f"test_df_{pdb_id}" for pdb_id in test_dfs_by_pdb_id.keys()])

['test_df_1T7R', 'test_df_2ZV2', 'test_df_4F8H', 'test_df_5EK0', 'test_df_6D6T', 'test_df_6IIU']


In [None]:
fingerprints_model = load_model('/content/drive/MyDrive/Saved_Models/fingerprints_best_model.h5', custom_objects={'r2_score': r2_score, 'GCNConv': GCNConv, 'GATConv': GATConv})

In [None]:
for pdb_id in test_dfs_by_pdb_id.keys():
    df = test_dfs_by_pdb_id[pdb_id]
    # Inputs
    X_text_input = np.array([np.array(x) for x in df['encoded_seq'].tolist()])
    X_extended_connectivity_input = np.array(df['extended_connectivity_fps'].tolist())

    # Predictions
    predictions = fingerprints_model.predict([X_text_input, X_extended_connectivity_input])
    df['predicted_scores'] = predictions
    test_dfs_by_pdb_id[pdb_id] = df

print(test_dfs_by_pdb_id.keys())

dict_keys(['1T7R', '2ZV2', '4F8H', '5EK0', '6D6T', '6IIU'])


In [None]:
threshold = 0.25
well_predicted_dfs_by_pdb_id = {}

for pdb_id, df in test_dfs_by_pdb_id.items():
    # Absolute difference between ground truth and predicted scores
    df['difference'] = abs(df['docking_score'] - df['predicted_scores'])

    well_predicted_df = df[df['difference'] <= threshold]
    well_predicted_dfs_by_pdb_id[pdb_id] = well_predicted_df

for pdb_id, df in well_predicted_dfs_by_pdb_id.items():
    print(f"Number of well-predicted data points for pdb_id {pdb_id}: {len(df)}")

Number of well-predicted data points for pdb_id 1T7R: 3893
Number of well-predicted data points for pdb_id 2ZV2: 17
Number of well-predicted data points for pdb_id 4F8H: 6797
Number of well-predicted data points for pdb_id 5EK0: 2005
Number of well-predicted data points for pdb_id 6D6T: 1134
Number of well-predicted data points for pdb_id 6IIU: 6170


In [None]:
for pdb_id, df in well_predicted_dfs_by_pdb_id.items():
    filename = f'/content/drive/MyDrive/Dataframes/well_predicted_df_pdb_id_{pdb_id}.csv'
    df.to_csv(filename, index=False)