In [89]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdmolops
import numpy as np
from scipy.sparse import block_diag
from spektral.data import Graph, Dataset, Loader, DisjointLoader
from spektral.data import SingleLoader
from spektral.layers import GCNConv, GlobalSumPool
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
import tensorflow as tf

## Scratch 

In [274]:
df = pd.read_csv('../data/train_subset.tsv').head(10000)
df

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name,binds
0,0,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C#CCOc1ccc(CN)cc1.Cl,Br.Br.NCC1CCCN1c1cccnn1,C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...,BRD4,0
1,1,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C#CCOc1ccc(CN)cc1.Cl,Br.Br.NCC1CCCN1c1cccnn1,C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...,HSA,0
2,2,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C#CCOc1ccc(CN)cc1.Cl,Br.Br.NCC1CCCN1c1cccnn1,C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...,sEH,0
3,3,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C#CCOc1ccc(CN)cc1.Cl,Br.NCc1cccc(Br)n1,C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC...,BRD4,0
4,4,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C#CCOc1ccc(CN)cc1.Cl,Br.NCc1cccc(Br)n1,C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC...,HSA,0
...,...,...,...,...,...,...,...
9995,9995,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C=C(C)COCCN.Cl,NCCNC(=O)c1cccnc1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCNC(=O)c2cccnc2...,sEH,0
9996,9996,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C=C(C)COCCN.Cl,NCCOc1cccnc1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCOCC(=C)C)nc(NC...,BRD4,0
9997,9997,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C=C(C)COCCN.Cl,NCCOc1cccnc1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCOCC(=C)C)nc(NC...,HSA,0
9998,9998,C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21,C=C(C)COCCN.Cl,NCCOc1cccnc1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCOCC(=C)C)nc(NC...,sEH,0


In [275]:
def mol_to_graph(mol):
    atoms = mol.GetAtoms()
    edges = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    a = rdmolops.GetAdjacencyMatrix(mol)
    x = np.array([atom.GetAtomicNum() for atom in atoms]).reshape(-1, 1).astype(np.float32)  # Convert to float32 here
    return Graph(x=x, a=a)

# Convert SMILES to Graph
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol_to_graph(mol) if mol else None

def pad_features(features, max_features):
    padded_length = max_features - features.shape[1]
    padded_features = np.pad(features, ((0, 0), (0, padded_length)), 'constant', constant_values=0)
    return padded_features.astype(np.float32)  # Ensure the padded array is float32

def process_df(df):
    smile_columns = ['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles', 'molecule_smiles']
    unified_graph_list = []
    for _, row in df.iterrows():
        graphs = [smiles_to_graph(row[col]) for col in smile_columns]
        graphs = [graph for graph in graphs if graph is not None]
        if len(graphs) == 4:
            max_features = max(graph.x.shape[1] for graph in graphs)
            combined_x = np.vstack([pad_features(graph.x, max_features) for graph in graphs])
            combined_a = block_diag([graph.a for graph in graphs]).tocoo()  # Keep adjacency matrix in COO format
            label = row['binds']  # Extract the label
            unified_graph_list.append(Graph(x=combined_x, a=combined_a, y=label))
    return unified_graph_list, max_features




In [276]:
import numpy as np
from scipy.sparse import csr_matrix
from spektral.data import Graph, Dataset
from tensorflow.keras.utils import to_categorical


# Function to ensure adjacency matrices are in the correct format
def ensure_sparse_format(adj):
    if isinstance(adj, tf.sparse.SparseTensor):
        # Convert TensorFlow SparseTensor to a scipy.sparse matrix
        adj = sp.csr_matrix((adj.values.numpy(), adj.indices.numpy().T), shape=adj.dense_shape)
    return adj

# Function to process your dataset and ensure all graphs are in the correct format
def process_graph_data(graphs):
    for g in graphs:
        g.a = ensure_sparse_format(g.a)  # Ensure the adjacency matrix is in scipy.sparse format
    return graphs

class MoleculeDataset(Dataset):
    def __init__(self, graph_list, **kwargs):
        self.graph_list = graph_list
        super().__init__(**kwargs)
    
    def read(self):
        return self.graph_list

    def collate(batch):
        features = [item.x for item in batch]
        adj_matrices = [item.a for item in batch]
        labels = [item.y for item in batch]  # Assuming labels are included in the graph objects
    
        features = np.vstack(features)
        adj_matrices = [adj.tocoo() for adj in adj_matrices]  # Convert to COO format if not already
        adj_indices = [np.column_stack((adj.row, adj.col)) for adj in adj_matrices]
        adj_values = [adj.data for adj in adj_matrices]
        adj_shape = adj_matrices[0].shape if adj_matrices else (0, 0)
    
        # Convert to TensorFlow SparseTensor
        adj_matrices = [tf.SparseTensor(indices, values, adj_shape) for indices, values in zip(adj_indices, adj_values)]
        adj_matrices = tf.sparse.concat(axis=0, sp_inputs=adj_matrices)
    
        labels = np.array(labels)  # Ensure labels are a numpy array
    
        return [features, adj_matrices], labels
        
graphs, mf = process_df(df)
dataset = MoleculeDataset(graphs)
loader = DisjointLoader(dataset, node_level=True, batch_size=32)


In [272]:
import spektral
import tensorflow as tf
from spektral.layers import GraphSageConv, GlobalAvgPool
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense



class GNN(tf.keras.Model):
    def __init__(self):
        super(GNN, self).__init__()
        self.graph_conv1 = GraphSageConv(32, activation='relu')
        self.graph_conv2 = GraphSageConv(32, activation='relu')
        self.pool = GlobalAvgPool()
        self.classifier = Dense(1, activation='sigmoid')

    def call(self, inputs):
        x, a, i = inputs
        # Pass only x and a to the GraphSageConv layers
        x = self.graph_conv1([x, a])
        x = self.graph_conv2([x, a])
        # Use i in the global pooling layer if necessary
        x = self.pool([x, i])
        return self.classifier(x)

model = GNN()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])



In [268]:
# Creating manual test data with exact expected dimensions
test_x = tf.random.normal([2884, 1])
test_a_indices = tf.convert_to_tensor([[i, i] for i in range(2884)], dtype=tf.int64)
test_a_values = tf.ones([2884], dtype=tf.float32)
test_a_shape = [2884, 2884]
test_a = tf.SparseTensor(indices=test_a_indices, values=test_a_values, dense_shape=test_a_shape)
test_i = tf.range(2884)

# Manually calling the model
test_output = model([test_x, test_a, test_i])
print("Test output:", test_output)



Test output: tf.Tensor(
[[0.5465763 ]
 [0.5465763 ]
 [0.5465763 ]
 ...
 [0.5465763 ]
 [0.48460877]
 [0.48460877]], shape=(2884, 1), dtype=float32)


In [277]:
model = GNN()  # Assuming this is already defined as per previous discussions
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Assuming 'loader' is your DisjointLoader instance
model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch, epochs=10)

Epoch 1/10
[1m 32/313[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 5ms/step - accuracy: 0.4597 - loss: 0.6948

  np.random.shuffle(a)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.8931 - loss: 0.3647
Epoch 2/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9990 - loss: 0.0786
Epoch 3/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9995 - loss: 0.0409
Epoch 4/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9984 - loss: 0.0299
Epoch 5/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9990 - loss: 0.0197
Epoch 6/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9988 - loss: 0.0154
Epoch 7/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9988 - loss: 0.0129
Epoch 8/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.9992 - loss: 0.0095
Epoch 9/10
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x3322ebf90>