Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deeptrack/models/gnns/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def GetEdge(
end: int,
radius: int,
parenthood: pd.DataFrame,
columns = [],
**kwargs,
):
"""
Expand Down Expand Up @@ -79,7 +80,7 @@ def GetEdge(
edges.append(combdf)
# Concatenate the dataframes in a single
# dataframe for the whole set of edges
edgedf = pd.concat(edges)
edgedf = pd.concat(edges) if len(edges) > 0 else pd.DataFrame(columns=columns)

# Merge columns contaning the labels into a single column
# of numpy arrays, i.e., label = [label_x, label_y]
Expand Down Expand Up @@ -120,6 +121,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs):
"""
# Create a copy of the dataframe to avoid overwriting
df = nodesdf.copy()
columns = df.columns

edgedfs = []
sets = np.unique(df["set"])
Expand All @@ -140,7 +142,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs):
window = [elem for elem in window if elem <= df_set["frame"].max()]

# Compute the edges for each frames window
edgedf = GetEdge(df_set, start=window[0], end=window[-1], **kwargs)
edgedf = GetEdge(df_set, start=window[0], end=window[-1], columns=columns, **kwargs)
edgedf["set"] = setid
edgedfs.append(edgedf)

Expand Down
33 changes: 32 additions & 1 deletion deeptrack/test/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .. import generators
from ..optics import Fluorescence
from ..scatterers import PointParticle
from ..models import gnns
import numpy as np

import pandas as pd

class TestGenerators(unittest.TestCase):
def test_Generator(self):
Expand Down Expand Up @@ -154,7 +155,37 @@ def get_particle_position(result):
# a = generator[idx]

# [self.assertLess(d[-1], 8) for d in generator.data]


def test_GraphGenerator(self):
frame = np.arange(10)
centroid = np.random.normal(0.5, 0.1, (10, 2))

df = pd.DataFrame(
{
'frame': frame,
'centroid-0': centroid[:, 0],
'centroid-1': centroid[:, 1],
'label': 0,
'set': 0,
'solution': 0.0
}
)
# remove consecutive frames
df = df[~df["frame"].isin([3, 4, 5])]

generator = gnns.generators.GraphGenerator(
nodesdf=df,
properties=["centroid"],
min_data_size=8,
max_data_size=9,
batch_size=8,
feature_function=gnns.augmentations.GetGlobalFeature,
radius=0.2,
nofframes=3,
output_type="edges"
)
self.assertIsInstance(generator, gnns.generators.ContinuousGraphGenerator)

if __name__ == "__main__":
unittest.main()