In [1]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
import tensorflow_gnn as tfgnn

In [3]:
from utils import ProteinInteractionGraph

In [4]:
# Create a graph of protein interactions
interacation_graph = ProteinInteractionGraph(negative_interaction_method='most_distant')

In [5]:
graph_tensor = interacation_graph.get_graph_tensor()

In [6]:
import tensorflow as tf

In [7]:
graph_tensor

GraphTensor(
  context=Context(features={}, sizes=[1], shape=(), indices_dtype=tf.int32),
  node_set_names=['Proteins'],
  edge_set_names=['Interactions'])

In [8]:
def edge_batch_merge(graph):
    graph = graph.merge_batch_to_components()
    node_features = graph.node_sets['Proteins'].get_features_dict()
    edge_features = graph.edge_sets['Interactions'].get_features_dict()
    
    labels = edge_features.pop('labels')
    
    new_graph = graph.replace_features(
        node_sets = { 'Proteins': node_features },
        edge_sets = { 'Interactions': edge_features }
    )
    
    return new_graph, labels
    

In [9]:
def create_dataset(graph, function):
    dataset = tf.data.Dataset.from_tensors(graph)
    dataset = dataset.batch(32)
    return dataset.map(function)

dataset = create_dataset(graph_tensor, edge_batch_merge)

In [10]:
graph_spec = dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)

def set_initial_node_state(node_set, node_set_name):
    features = [
        tf.keras.layers.Dense(32,activation="relu")(node_set['basic_protein_properties']    ),
        tf.keras.layers.Dense(32,activation="relu")(node_set['secondary_structure_content']),
        tf.keras.layers.Dense(32,activation="relu")(node_set['other_properties']),
        tf.keras.layers.Dense(32,activation="relu")(node_set['amino_acid_composition'])
    ]
    return tf.keras.layers.Concatenate()(features)

def set_initial_edge_state(edge_set, edge_set_name):
    
    return tfgnn.keras.layers.MakeEmptyFeature()(edge_set)

graph = tfgnn.keras.layers.MapFeatures(
    node_sets_fn=set_initial_node_state,
    edge_sets_fn=set_initial_edge_state,
)(input_graph)



In [11]:
graph_spec

GraphTensorSpec({'context': ContextSpec({'features': {}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, None), 'node_sets': {'Proteins': NodeSetSpec({'features': {'basic_protein_properties': TensorSpec(shape=(None, 5), dtype=tf.float64, name=None), 'secondary_structure_content': TensorSpec(shape=(None, 3), dtype=tf.float64, name=None), 'other_properties': TensorSpec(shape=(None, 3), dtype=tf.float64, name=None), 'amino_acid_composition': TensorSpec(shape=(None, 20), dtype=tf.float64, name=None)}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, None)}, 'edge_sets': {'Interactions': EdgeSetSpec({'features': {}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(None,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, {'#index.0': 'Proteins', '#index.1': 'Protein

In [12]:
def dense_layer(self,units=64,l2_reg=0.1,dropout=0.25, activation='relu'):
    regularizer = tf.keras.regularizers.l2(l2_reg)
    return tf.keras.Sequential([
        tf.keras.layers.Dense(units,
                              kernel_regularizer=regularizer,
                              bias_regularizer=regularizer),
        tf.keras.layers.Dropout(dropout)])

In [13]:
graph_updates = 3
for i in range(graph_updates):
    graph = tfgnn.keras.layers.GraphUpdate(
        edge_sets = {'Interactions': tfgnn.keras.layers.EdgeSetUpdate(
            next_state = tfgnn.keras.layers.NextStateFromConcat(
                dense_layer(64,activation='relu')))},
        node_sets = {
            'Proteins': tfgnn.keras.layers.NodeSetUpdate({
                'Interactions': tfgnn.keras.layers.Pool(
                    tag=tfgnn.TARGET,
                    reduce_type="sum",
                    feature_name = tfgnn.HIDDEN_STATE)},
                tfgnn.keras.layers.NextStateFromConcat(
                    dense_layer(64)))})(graph)

    logits = tf.keras.layers.Dense(1,activation='sigmoid')(graph.edge_sets['Interactions'][tfgnn.HIDDEN_STATE])

edge_model = tf.keras.Model(input_graph, logits)

In [14]:
edge_model.compile(
    tf.keras.optimizers.Adam(learning_rate=0.01),
    loss = 'binary_crossentropy',
    metrics = ['Accuracy']
)

edge_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [()]                      0         
                                                                 
 map_features (MapFeatures)  ()                        1120      
                                                                 
 graph_update (GraphUpdate)  ()                        28800     
                                                                 
 graph_update_1 (GraphUpdate  ()                       20608     
 )                                                               
                                                                 
 graph_update_2 (GraphUpdate  ()                       20608     
 )                                                               
                                                                 
 input.edge_sets_2 (Instance  {'Interactions': ()}     0     

In [15]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min',
    verbose=1,
    restore_best_weights=True
)

In [16]:
edge_model.fit(
    dataset.repeat(),
    epochs=1000,
    steps_per_epoch=10,
    callbacks=[early_stopping],
)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

KeyboardInterrupt: 