From 8a648f178fd9f03cf5e06ba388c3db5f9c8c0738 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 29 Jul 2025 15:41:31 +0200 Subject: [PATCH 1/5] adding torch compatibility --- deeptrack/features.py | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index a4bf2c709..6398e07b3 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -218,7 +218,7 @@ def propagate_data_to_dependencies( "OneOfDict", "LoadImage", # TODO ***MG*** "SampleToMasks", # TODO ***MG*** - "AsType", # TODO ***MG*** + "AsType", "ChannelFirst2d", "Upscale", # TODO ***AL*** "NonOverlapping", # TODO ***AL*** @@ -7751,9 +7751,9 @@ def _process_and_get( class AsType(Feature): """Convert the data type of images. - This feature changes the data type (`dtype`) of input images to a specified - type. The accepted types are the same as those used by NumPy arrays, such - as `float64`, `int32`, `uint16`, `int16`, `uint8`, and `int8`. + This feature changes the data type (`dtype`) of input images to a specified + type. The accepted types are standard NumPy or PyTorch data types (e.g., + 'float64', 'int32', `uint8`, `int8`, and 'torch.float32'). Parameters ---------- @@ -7833,7 +7833,39 @@ def get( """ - return image.astype(dtype) + if apc.is_torch_array(image): + # Mapping from string to torch dtype + torch_dtypes = { + "float64": torch.float64, + "double": torch.float64, + "float32": torch.float32, + "float": torch.float32, + "float16": torch.float16, + "half": torch.float16, + "int64": torch.int64, + "int32": torch.int32, + "int16": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + + # Ensure 'torch.float32' and 'float32' are treated the same by + # normalizing the string + dtype_str = str(dtype).replace("torch.", "") + torch_dtype = torch_dtypes.get(dtype_str) + + if torch_dtype is None: + raise ValueError( + f"Unsupported dtype for torch.Tensor: {dtype}" + ) + + return image.to(dtype=torch_dtype) + + else: + return image.astype(dtype) class ChannelFirst2d(Feature): # DEPRECATED From dcc0f55d229f06ae8fc93e6955197f8dff1fd66f Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 29 Jul 2025 15:53:32 +0200 Subject: [PATCH 2/5] added unittesting for torch --- deeptrack/tests/test_features.py | 45 +++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/deeptrack/tests/test_features.py b/deeptrack/tests/test_features.py index 591547a68..b3a2252ec 100644 --- a/deeptrack/tests/test_features.py +++ b/deeptrack/tests/test_features.py @@ -1949,11 +1949,48 @@ def test_AsType(self): np.all(output_image == np.array([1, 2, 3], dtype=dtype)) ) - # Test for Image. - #TODO - # Test for PyTorch tensors. - #TODO + if TORCH_AVAILABLE: + input_image_torch = torch.tensor([1.5, 2.5, 3.5]) + + data_types_torch = [ + "float64", + "int32", + "int16", + "uint8", + "int8", + "torch.float64", + "torch.int32", + ] + + torch_dtypes_map = { + "float64": torch.float64, + "int32": torch.int32, + "int16": torch.int16, + "uint8": torch.uint8, + "int8": torch.int8, + "torch.float64": torch.float64, + "torch.int32": torch.int32, + } + + for dtype in data_types_torch: + astype_feature = features.AsType(dtype=dtype) + output_image = astype_feature.get( + input_image_torch, dtype=dtype + ) + expected_dtype = torch_dtypes_map[dtype] + self.assertEqual(output_image.dtype, expected_dtype) + + # Additional check for specific behavior of integers. + if expected_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.uint8, + ]: + # Verify that fractional parts are truncated + expected = torch.tensor([1, 2, 3], dtype=expected_dtype) + self.assertTrue(torch.equal(output_image, expected)) def test_ChannelFirst2d(self): From a236ae381280af7e5359e45ac5e0fc351f384cca Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 30 Jul 2025 11:12:27 +0200 Subject: [PATCH 3/5] minor change --- deeptrack/tests/test_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/tests/test_features.py b/deeptrack/tests/test_features.py index b3a2252ec..d46aa101e 100644 --- a/deeptrack/tests/test_features.py +++ b/deeptrack/tests/test_features.py @@ -1949,7 +1949,7 @@ def test_AsType(self): np.all(output_image == np.array([1, 2, 3], dtype=dtype)) ) - # Test for PyTorch tensors. + ### Test with PyTorch tensor (if available) if TORCH_AVAILABLE: input_image_torch = torch.tensor([1.5, 2.5, 3.5]) From fc638dc9c9d7e26ce9710787c189323f9de3f276 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 12 Aug 2025 13:56:58 +0200 Subject: [PATCH 4/5] update astype --- deeptrack/features.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 6398e07b3..6bc3fcfc5 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -7753,7 +7753,7 @@ class AsType(Feature): This feature changes the data type (`dtype`) of input images to a specified type. The accepted types are standard NumPy or PyTorch data types (e.g., - 'float64', 'int32', `uint8`, `int8`, and 'torch.float32'). + `"float64"`, `"int32"`, `"uint8"`, `"int8"`, and `"torch.float32"`). Parameters ---------- @@ -7776,7 +7776,7 @@ class AsType(Feature): >>> >>> input_image = np.array([1.5, 2.5, 3.5]) - Apply an AsType feature to convert to `int32`: + Apply an AsType feature to convert to "`int32"`: >>> astype_feature = dt.AsType(dtype="int32") >>> output_image = astype_feature.get(input_image, dtype="int32") >>> output_image @@ -7852,8 +7852,8 @@ def get( "complex128": torch.complex128, } - # Ensure 'torch.float32' and 'float32' are treated the same by - # normalizing the string + # Ensure `"torch.float32"` and `"float32"` are treated the same by + # removing the `torch.` prefix if present dtype_str = str(dtype).replace("torch.", "") torch_dtype = torch_dtypes.get(dtype_str) From c54b6868323049cb05d3ae09027795d43dbba2be Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Mon, 8 Sep 2025 09:36:00 +0200 Subject: [PATCH 5/5] update astype --- deeptrack/features.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 6bc3fcfc5..29232eecc 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -7751,7 +7751,7 @@ def _process_and_get( class AsType(Feature): """Convert the data type of images. - This feature changes the data type (`dtype`) of input images to a specified + `Astype` changes the data type (`dtype`) of input images to a specified type. The accepted types are standard NumPy or PyTorch data types (e.g., `"float64"`, `"int32"`, `"uint8"`, `"int8"`, and `"torch.float32"`). @@ -7793,8 +7793,7 @@ def __init__( dtype: PropertyLike[str] = "float64", **kwargs: Any, ): - """ - Initialize the AsType feature. + """Initialize the AsType feature. Parameters ----------