From e73d5ed5e47ca564492215690d11881708379a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Tue, 2 Aug 2022 19:32:55 +0200 Subject: [PATCH] Fix indexing problems related to tf.gather --- deeptrack/models/gnns/generators.py | 38 +++++++++++++++-------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/deeptrack/models/gnns/generators.py b/deeptrack/models/gnns/generators.py index 73a08b1ba..952412e88 100644 --- a/deeptrack/models/gnns/generators.py +++ b/deeptrack/models/gnns/generators.py @@ -116,7 +116,7 @@ class ContinuousGraphGenerator(ContinuousGenerator): the speed gained from reusing images. The generator will continuously create new trainingdata during training, until `max_data_size` is reached, at which point the oldest data point is replaced. - + Parameters ---------- feature : dt.Feature @@ -153,9 +153,9 @@ def __getitem__(self, idx): batch, labels = super().__getitem__(idx) # Extracts minimum number of nodes in the batch - cropNodesTo = np.min( - list(map(lambda _batch: np.shape(_batch[0])[0], batch)) - ) + numofnodes = list(map(lambda _batch: np.shape(_batch[0])[0], batch)) + bgraph_idx = np.argmin(numofnodes) + cropTo = int(numofnodes[bgraph_idx]) inputs = [[], [], [], []] outputs = [[], [], []] @@ -166,26 +166,28 @@ def __getitem__(self, idx): # Clip node features to the minimum number of nodes # in the batch - nodef = batch[i][0][:cropNodesTo, :] - - last_node_idx = 0 - # Extracts index of the last node in the adjacency matrix - try: - last_node_idx = int( - np.where(batch[i][2][:, 1] <= cropNodesTo - 1)[0][-1] + 1 + nodef = batch[i][0][:cropTo, :] + + edge_dropouts = ( + np.any(batch[i][2] > cropTo - 1, axis=-1) + if i != bgraph_idx + else np.array( + [ + False, + ] + * np.shape(batch[i][2])[0] ) - except IndexError: - continue + ) # Clips edge features and adjacency matrix to the index # of the last node - edgef = batch[i][1][:last_node_idx] - adjmx = batch[i][2][:last_node_idx] - wghts = batch[i][3][:last_node_idx] + edgef = batch[i][1][~edge_dropouts] + adjmx = batch[i][2][~edge_dropouts] + wghts = batch[i][3][~edge_dropouts] # Clips node and edge solutions - nodesol = labels[i][0][:cropNodesTo] - edgesol = labels[i][1][:last_node_idx] + nodesol = labels[i][0][:cropTo] + edgesol = labels[i][1][~edge_dropouts] globsol = labels[i][2].astype(np.float) inputs[0].append(nodef)