From 04acc62f86b8b4fc5bd9018b11e1b64988ae4fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Tue, 8 Nov 2022 13:30:58 +0100 Subject: [PATCH 1/2] Update layers.py --- deeptrack/models/gnns/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deeptrack/models/gnns/layers.py b/deeptrack/models/gnns/layers.py index da03b1f2b..d90d04d71 100644 --- a/deeptrack/models/gnns/layers.py +++ b/deeptrack/models/gnns/layers.py @@ -215,7 +215,8 @@ def __init__( # node update layer self.update_layer = layers.GRU(filters, time_major=True) - + self.update_norm = tf.keras.layers.Layer() + def update_node_features(self, nodes, aggregated, learnable_embs, edges): Combined = tf.reshape( tf.stack([nodes, aggregated], axis=0), (2, -1, nodes.shape[-1]) From 19bd1a38e98ece40680fa32ef665f4a826966d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Tue, 8 Nov 2022 13:37:23 +0100 Subject: [PATCH 2/2] Update test_layers.py --- deeptrack/test/test_layers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/deeptrack/test/test_layers.py b/deeptrack/test/test_layers.py index b74859672..d06c0b715 100644 --- a/deeptrack/test/test_layers.py +++ b/deeptrack/test/test_layers.py @@ -282,6 +282,20 @@ def test_Masked_FGNN_layer(self): ), ) self.assertTrue(model.layers[-1], layers.MaskedFGNN) + + def test_GRUMPN_layer(self): + block = layers.GRUMPNLayer() + model = makeMinimalModel( + block(96), + input_layer=( + k_layers.Input(shape=(None, 96)), + k_layers.Input(shape=(None, 10)), + k_layers.Input(shape=(None, 2), dtype=tf.int32), + k_layers.Input(shape=(None, 1)), + k_layers.Input(shape=(None, 2)), + ), + ) + self.assertTrue(model.layers[-1], layers.GRUMPN) def test_GraphTransformer(self): block = layers.GraphTransformerLayer()