diff --git a/deeptrack/models/gnns/graphs.py b/deeptrack/models/gnns/graphs.py index 4b3f15822..a52507443 100644 --- a/deeptrack/models/gnns/graphs.py +++ b/deeptrack/models/gnns/graphs.py @@ -16,6 +16,7 @@ def GetEdge( end: int, radius: int, parenthood: pd.DataFrame, + columns = [], **kwargs, ): """ @@ -79,7 +80,7 @@ def GetEdge( edges.append(combdf) # Concatenate the dataframes in a single # dataframe for the whole set of edges - edgedf = pd.concat(edges) + edgedf = pd.concat(edges) if len(edges) > 0 else pd.DataFrame(columns=columns) # Merge columns contaning the labels into a single column # of numpy arrays, i.e., label = [label_x, label_y] @@ -120,6 +121,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs): """ # Create a copy of the dataframe to avoid overwriting df = nodesdf.copy() + columns = df.columns edgedfs = [] sets = np.unique(df["set"]) @@ -140,7 +142,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs): window = [elem for elem in window if elem <= df_set["frame"].max()] # Compute the edges for each frames window - edgedf = GetEdge(df_set, start=window[0], end=window[-1], **kwargs) + edgedf = GetEdge(df_set, start=window[0], end=window[-1], columns=columns, **kwargs) edgedf["set"] = setid edgedfs.append(edgedf) diff --git a/deeptrack/test/test_generators.py b/deeptrack/test/test_generators.py index 8cd8d88b8..cc14ba327 100644 --- a/deeptrack/test/test_generators.py +++ b/deeptrack/test/test_generators.py @@ -8,8 +8,9 @@ from .. import generators from ..optics import Fluorescence from ..scatterers import PointParticle +from ..models import gnns import numpy as np - +import pandas as pd class TestGenerators(unittest.TestCase): def test_Generator(self): @@ -154,7 +155,37 @@ def get_particle_position(result): # a = generator[idx] # [self.assertLess(d[-1], 8) for d in generator.data] + + def test_GraphGenerator(self): + frame = np.arange(10) + centroid = np.random.normal(0.5, 0.1, (10, 2)) + + df = pd.DataFrame( + { + 'frame': frame, + 'centroid-0': centroid[:, 0], + 'centroid-1': centroid[:, 1], + 'label': 0, + 'set': 0, + 'solution': 0.0 + } + ) + # remove consecutive frames + df = df[~df["frame"].isin([3, 4, 5])] + + generator = gnns.generators.GraphGenerator( + nodesdf=df, + properties=["centroid"], + min_data_size=8, + max_data_size=9, + batch_size=8, + feature_function=gnns.augmentations.GetGlobalFeature, + radius=0.2, + nofframes=3, + output_type="edges" + ) + self.assertIsInstance(generator, gnns.generators.ContinuousGraphGenerator) if __name__ == "__main__": unittest.main() \ No newline at end of file