In [1]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
import tensorflow_gnn as tfgnn

In [3]:
from utils import *

In [4]:
# Create a graph of protein interactions
graphToTensor = GraphToTensor()

In [5]:
graph_tensor = graphToTensor.graph_tensor

In [6]:
train_graph, _ = graphToTensor.split_graph(train_size=0.8)

In [7]:
import tensorflow as tf

In [8]:
from model import *

In [9]:
dataset = create_dataset(graph_tensor, edge_batch_merge)
train_dataset = create_dataset(train_graph, edge_batch_merge)

In [10]:
graph_spec = dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)

In [11]:
graph = tfgnn.keras.layers.MapFeatures(
    node_sets_fn=set_initial_node_state,
    edge_sets_fn=set_initial_edge_state,
)(input_graph)



In [12]:
def dense_layer(self, units=64, l2_reg=0.1, dropout=0.25, activation='relu'):
    regularizer = tf.keras.regularizers.l2(l2_reg)
    return tf.keras.Sequential([
        tf.keras.layers.Dense(units,
                              kernel_regularizer=regularizer,
                              bias_regularizer=regularizer),
        tf.keras.layers.Dropout(dropout)])

In [13]:
graph = tfgnn.keras.layers.GraphUpdate(
    edge_sets = {'Interactions': tfgnn.keras.layers.EdgeSetUpdate(
        next_state = tfgnn.keras.layers.NextStateFromConcat(
            dense_layer(64,activation='relu')))},
    node_sets = {
        'Proteins': tfgnn.keras.layers.NodeSetUpdate({
            'Interactions': tfgnn.keras.layers.Pool(
                tag=tfgnn.TARGET,
                reduce_type="sum",
                feature_name = tfgnn.HIDDEN_STATE)},
            tfgnn.keras.layers.NextStateFromConcat(
                dense_layer(64)))})(graph)

logits = tf.keras.layers.Dense(1, activation='sigmoid')(graph.edge_sets['Interactions'][tfgnn.HIDDEN_STATE])

model = tf.keras.Model(input_graph, logits)


In [14]:
model.compile(
    tf.keras.optimizers.Adam(learning_rate=0.001),
    loss = 'binary_crossentropy',
    metrics = ['Accuracy']
)

model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [()]                      0         
                                                                 
 map_features (MapFeatures)  ()                        1120      
                                                                 
 graph_update (GraphUpdate)  ()                        28800     
                                                                 
 input.edge_sets (InstancePr  {'Interactions': ()}     0         
 operty)                                                         
                                                                 
 input._get_features_ref_4 (  {'hidden_state': (None,   0        
 InstanceProperty)           64)}                                
                                                                 
 dense_2 (Dense)             (None, 1)                 65    

In [15]:
model.fit(
    train_dataset.repeat(),
    epochs=1000,
    steps_per_epoch=100,
    validation_data=dataset.repeat(),
    validation_steps=10
)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000

KeyboardInterrupt: 