Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions deeptrack/models/gnns/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ContinuousGraphGenerator(ContinuousGenerator):
the speed gained from reusing images. The generator will continuously
create new trainingdata during training, until `max_data_size` is reached,
at which point the oldest data point is replaced.

Parameters
----------
feature : dt.Feature
Expand Down Expand Up @@ -153,9 +153,9 @@ def __getitem__(self, idx):
batch, labels = super().__getitem__(idx)

# Extracts minimum number of nodes in the batch
cropNodesTo = np.min(
list(map(lambda _batch: np.shape(_batch[0])[0], batch))
)
numofnodes = list(map(lambda _batch: np.shape(_batch[0])[0], batch))
bgraph_idx = np.argmin(numofnodes)
cropTo = int(numofnodes[bgraph_idx])

inputs = [[], [], [], []]
outputs = [[], [], []]
Expand All @@ -166,26 +166,28 @@ def __getitem__(self, idx):

# Clip node features to the minimum number of nodes
# in the batch
nodef = batch[i][0][:cropNodesTo, :]

last_node_idx = 0
# Extracts index of the last node in the adjacency matrix
try:
last_node_idx = int(
np.where(batch[i][2][:, 1] <= cropNodesTo - 1)[0][-1] + 1
nodef = batch[i][0][:cropTo, :]

edge_dropouts = (
np.any(batch[i][2] > cropTo - 1, axis=-1)
if i != bgraph_idx
else np.array(
[
False,
]
* np.shape(batch[i][2])[0]
)
except IndexError:
continue
)

# Clips edge features and adjacency matrix to the index
# of the last node
edgef = batch[i][1][:last_node_idx]
adjmx = batch[i][2][:last_node_idx]
wghts = batch[i][3][:last_node_idx]
edgef = batch[i][1][~edge_dropouts]
adjmx = batch[i][2][~edge_dropouts]
wghts = batch[i][3][~edge_dropouts]

# Clips node and edge solutions
nodesol = labels[i][0][:cropNodesTo]
edgesol = labels[i][1][:last_node_idx]
nodesol = labels[i][0][:cropTo]
edgesol = labels[i][1][~edge_dropouts]
globsol = labels[i][2].astype(np.float)

inputs[0].append(nodef)
Expand Down