From e225214213e321ccb75c3cec9bd16743a055cf14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 21 Dec 2022 09:56:44 +0100 Subject: [PATCH 1/4] Update graphs.py --- deeptrack/models/gnns/graphs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deeptrack/models/gnns/graphs.py b/deeptrack/models/gnns/graphs.py index 4b3f15822..a52507443 100644 --- a/deeptrack/models/gnns/graphs.py +++ b/deeptrack/models/gnns/graphs.py @@ -16,6 +16,7 @@ def GetEdge( end: int, radius: int, parenthood: pd.DataFrame, + columns = [], **kwargs, ): """ @@ -79,7 +80,7 @@ def GetEdge( edges.append(combdf) # Concatenate the dataframes in a single # dataframe for the whole set of edges - edgedf = pd.concat(edges) + edgedf = pd.concat(edges) if len(edges) > 0 else pd.DataFrame(columns=columns) # Merge columns contaning the labels into a single column # of numpy arrays, i.e., label = [label_x, label_y] @@ -120,6 +121,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs): """ # Create a copy of the dataframe to avoid overwriting df = nodesdf.copy() + columns = df.columns edgedfs = [] sets = np.unique(df["set"]) @@ -140,7 +142,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs): window = [elem for elem in window if elem <= df_set["frame"].max()] # Compute the edges for each frames window - edgedf = GetEdge(df_set, start=window[0], end=window[-1], **kwargs) + edgedf = GetEdge(df_set, start=window[0], end=window[-1], columns=columns, **kwargs) edgedf["set"] = setid edgedfs.append(edgedf) From 2ada318b962a4cea4feb15c5fb39f53a7576b71b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 21 Dec 2022 10:31:26 +0100 Subject: [PATCH 2/4] Update test_generators.py --- deeptrack/test/test_generators.py | 37 ++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/deeptrack/test/test_generators.py b/deeptrack/test/test_generators.py index 8cd8d88b8..04d169ae0 100644 --- a/deeptrack/test/test_generators.py +++ b/deeptrack/test/test_generators.py @@ -8,8 +8,9 @@ from .. import generators from ..optics import Fluorescence from ..scatterers import PointParticle +from ..models import gnns import numpy as np - +import pandas as pd class TestGenerators(unittest.TestCase): def test_Generator(self): @@ -154,7 +155,41 @@ def get_particle_position(result): # a = generator[idx] # [self.assertLess(d[-1], 8) for d in generator.data] + + + def test_GraphGenerator(self): + frame = np.arange(10) + centroid = np.random.normal(0.5, 0.1, (10, 2)) + + df = pd.DataFrame( + { + 'frame': frame, + 'centroid-0': centroid[:, 0], + 'centroid-1': centroid[:, 1], + 'label': 0, + 'set': 0, + 'solution': 0.0 + } + ) + # remove consecutive frames + df = df[~df["frame"].isin([3, 4, 5])] + + generator = gnns.generators.GraphGenerator( + nodesdf=df, + properties=["centroid"], + min_data_size=8, + max_data_size=9, + batch_size=8, + feature_function=gnns.augmentations.GetGlobalFeature, + radius=0.2, + nofframes=3, + output_type="edges" + ) + with generator: + graphs, _ = generator[0] + self.assertEqual(graphs[0].shape[0], 8) + self.assertEqual(graphs[0].shape[2], 2) if __name__ == "__main__": unittest.main() \ No newline at end of file From 2a3a194a0db963fb0374931b0c471c4d6d74ba62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 21 Dec 2022 10:49:15 +0100 Subject: [PATCH 3/4] Update generators.py --- deeptrack/generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/generators.py b/deeptrack/generators.py index 6766c4bf1..84454328f 100644 --- a/deeptrack/generators.py +++ b/deeptrack/generators.py @@ -363,7 +363,7 @@ def __getitem__(self, idx): np.array(labels), ) else: - return np.array(data), np.array(labels) + return np.array(data, dtype="object"), np.array(labels, dtype="object") def __len__(self): steps = int((self.min_data_size // self._batch_size)) From 10b850e7510112360bf3239a8f3680b1a1c479f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Pineda?= Date: Wed, 21 Dec 2022 11:00:47 +0100 Subject: [PATCH 4/4] fix test --- deeptrack/generators.py | 2 +- deeptrack/test/test_generators.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/deeptrack/generators.py b/deeptrack/generators.py index 84454328f..6766c4bf1 100644 --- a/deeptrack/generators.py +++ b/deeptrack/generators.py @@ -363,7 +363,7 @@ def __getitem__(self, idx): np.array(labels), ) else: - return np.array(data, dtype="object"), np.array(labels, dtype="object") + return np.array(data), np.array(labels) def __len__(self): steps = int((self.min_data_size // self._batch_size)) diff --git a/deeptrack/test/test_generators.py b/deeptrack/test/test_generators.py index 04d169ae0..cc14ba327 100644 --- a/deeptrack/test/test_generators.py +++ b/deeptrack/test/test_generators.py @@ -185,11 +185,7 @@ def test_GraphGenerator(self): nofframes=3, output_type="edges" ) - - with generator: - graphs, _ = generator[0] - self.assertEqual(graphs[0].shape[0], 8) - self.assertEqual(graphs[0].shape[2], 2) + self.assertIsInstance(generator, gnns.generators.ContinuousGraphGenerator) if __name__ == "__main__": unittest.main() \ No newline at end of file