diff --git a/deeptrack/models/gnns/generators.py b/deeptrack/models/gnns/generators.py index 73a08b1ba..952412e88 100644 --- a/deeptrack/models/gnns/generators.py +++ b/deeptrack/models/gnns/generators.py @@ -116,7 +116,7 @@ class ContinuousGraphGenerator(ContinuousGenerator): the speed gained from reusing images. The generator will continuously create new trainingdata during training, until `max_data_size` is reached, at which point the oldest data point is replaced. - + Parameters ---------- feature : dt.Feature @@ -153,9 +153,9 @@ def __getitem__(self, idx): batch, labels = super().__getitem__(idx) # Extracts minimum number of nodes in the batch - cropNodesTo = np.min( - list(map(lambda _batch: np.shape(_batch[0])[0], batch)) - ) + numofnodes = list(map(lambda _batch: np.shape(_batch[0])[0], batch)) + bgraph_idx = np.argmin(numofnodes) + cropTo = int(numofnodes[bgraph_idx]) inputs = [[], [], [], []] outputs = [[], [], []] @@ -166,26 +166,28 @@ def __getitem__(self, idx): # Clip node features to the minimum number of nodes # in the batch - nodef = batch[i][0][:cropNodesTo, :] - - last_node_idx = 0 - # Extracts index of the last node in the adjacency matrix - try: - last_node_idx = int( - np.where(batch[i][2][:, 1] <= cropNodesTo - 1)[0][-1] + 1 + nodef = batch[i][0][:cropTo, :] + + edge_dropouts = ( + np.any(batch[i][2] > cropTo - 1, axis=-1) + if i != bgraph_idx + else np.array( + [ + False, + ] + * np.shape(batch[i][2])[0] ) - except IndexError: - continue + ) # Clips edge features and adjacency matrix to the index # of the last node - edgef = batch[i][1][:last_node_idx] - adjmx = batch[i][2][:last_node_idx] - wghts = batch[i][3][:last_node_idx] + edgef = batch[i][1][~edge_dropouts] + adjmx = batch[i][2][~edge_dropouts] + wghts = batch[i][3][~edge_dropouts] # Clips node and edge solutions - nodesol = labels[i][0][:cropNodesTo] - edgesol = labels[i][1][:last_node_idx] + nodesol = labels[i][0][:cropTo] + edgesol = labels[i][1][~edge_dropouts] globsol = labels[i][2].astype(np.float) inputs[0].append(nodef)