diff --git a/deeptrack/models/gnns/layers.py b/deeptrack/models/gnns/layers.py index da03b1f2b..d90d04d71 100644 --- a/deeptrack/models/gnns/layers.py +++ b/deeptrack/models/gnns/layers.py @@ -215,7 +215,8 @@ def __init__( # node update layer self.update_layer = layers.GRU(filters, time_major=True) - + self.update_norm = tf.keras.layers.Layer() + def update_node_features(self, nodes, aggregated, learnable_embs, edges): Combined = tf.reshape( tf.stack([nodes, aggregated], axis=0), (2, -1, nodes.shape[-1]) diff --git a/deeptrack/test/test_layers.py b/deeptrack/test/test_layers.py index b74859672..d06c0b715 100644 --- a/deeptrack/test/test_layers.py +++ b/deeptrack/test/test_layers.py @@ -282,6 +282,20 @@ def test_Masked_FGNN_layer(self): ), ) self.assertTrue(model.layers[-1], layers.MaskedFGNN) + + def test_GRUMPN_layer(self): + block = layers.GRUMPNLayer() + model = makeMinimalModel( + block(96), + input_layer=( + k_layers.Input(shape=(None, 96)), + k_layers.Input(shape=(None, 10)), + k_layers.Input(shape=(None, 2), dtype=tf.int32), + k_layers.Input(shape=(None, 1)), + k_layers.Input(shape=(None, 2)), + ), + ) + self.assertTrue(model.layers[-1], layers.GRUMPN) def test_GraphTransformer(self): block = layers.GraphTransformerLayer()