diff --git a/deeptrack/generators.py b/deeptrack/generators.py index 7d1cda5cf..4c7b1311e 100644 --- a/deeptrack/generators.py +++ b/deeptrack/generators.py @@ -215,6 +215,7 @@ def __init__( shuffle_batch=True, ndim=4, max_epochs_per_sample=np.inf, + use_multi_inputs=False, verbose=1, ): if label_function is None and batch_function is None: @@ -242,6 +243,7 @@ def __init__( self.batch_size = batch_size self.shuffle_batch = shuffle_batch self.max_epochs_per_sample = max_epochs_per_sample + self.use_multi_inputs = use_multi_inputs self.ndim = ndim self.augmentation = augmentation @@ -352,7 +354,20 @@ def __getitem__(self, idx): data = [self.batch_function(d["data"]) for d in subset] labels = [self.label_function(d["data"]) for d in subset] - return np.array(data), np.array(labels) + if self.use_multi_inputs: + return ( + tuple( + [ + np.stack(list(map(np.array, _data)), axis=0) + for _data in list(zip(*data)) + ] + ), + np.array(labels), + ) + else: + return np.array(data), np.array(labels) + + def __len__(self): steps = int((self.min_data_size // self._batch_size)) diff --git a/deeptrack/test/test_generators.py b/deeptrack/test/test_generators.py index 61f7e7825..8cd8d88b8 100644 --- a/deeptrack/test/test_generators.py +++ b/deeptrack/test/test_generators.py @@ -4,6 +4,7 @@ import unittest +from .. import features from .. import generators from ..optics import Fluorescence from ..scatterers import PointParticle @@ -68,6 +69,50 @@ def get_particle_position(result): with generator: self.assertGreater(len(generator.data), 10) self.assertLess(len(generator.data), 21) + + + def test_MultiInputs_ContinuousGenerator(self): + optics = Fluorescence( + NA=0.7, + wavelength=680e-9, + resolution=1e-6, + magnification=10, + output_region=(0, 0, 128, 128), + ) + scatterer_A = PointParticle( + intensity=100, + position_unit="pixel", + position=lambda: np.random.rand(2) * 128, + ) + scatterer_B = PointParticle( + intensity=10, + position_unit="pixel", + position=lambda: np.random.rand(2) * 128, + ) + imaged_scatterer_A = optics(scatterer_A) + imaged_scatterer_B = optics(scatterer_B) + + def get_particle_position(result): + result = result[0] + for property in result.properties: + if "position" in property: + return property["position"] + + imaged_scatterers = imaged_scatterer_A & imaged_scatterer_B + + generator = generators.ContinuousGenerator( + imaged_scatterers, + get_particle_position, + batch_size=8, + min_data_size=10, + max_data_size=20, + use_multi_inputs=True, + ) + + with generator: + data, _ = generator[0] + self.assertEqual(data[0].shape, (8, 128, 128, 1)) + self.assertEqual(data[1].shape, (8, 128, 128, 1)) def test_CappedContinuousGenerator(self):