From f13070288b9088b16b35e0a378c237e811ffc3bc Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Thu, 12 Jan 2023 22:41:20 -0500 Subject: [PATCH] BUG: transform_to_displacement_field_filter type inference Customize implementation since we do not have a image arg that can used to infer the filter type. Output type is itk.Imgae[itk.Vector[itk.F, output_dim], output_dim] because that is what is wrapped. Closes #3860 --- .../wrapping/test/CMakeLists.txt | 2 ++ .../itkTransformToDisplacementFieldTest.py | 29 +++++++++++++++++ .../Generators/SwigInterface/igenerator.py | 31 +++++++++++++++++-- 3 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 Modules/Filtering/DisplacementField/wrapping/test/itkTransformToDisplacementFieldTest.py diff --git a/Modules/Filtering/DisplacementField/wrapping/test/CMakeLists.txt b/Modules/Filtering/DisplacementField/wrapping/test/CMakeLists.txt index dd1b938e977..2d1e7a97a80 100644 --- a/Modules/Filtering/DisplacementField/wrapping/test/CMakeLists.txt +++ b/Modules/Filtering/DisplacementField/wrapping/test/CMakeLists.txt @@ -1,4 +1,6 @@ if(ITK_WRAP_PYTHON) itk_python_add_test(NAME itkDisplacementFieldTransformPythonTest COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkDisplacementFieldTransformTest.py) + itk_python_add_test(NAME itkTransformToDisplacementFieldPythonTest + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkTransformToDisplacementFieldTest.py) endif() diff --git a/Modules/Filtering/DisplacementField/wrapping/test/itkTransformToDisplacementFieldTest.py b/Modules/Filtering/DisplacementField/wrapping/test/itkTransformToDisplacementFieldTest.py new file mode 100644 index 00000000000..916bbc7a0ca --- /dev/null +++ b/Modules/Filtering/DisplacementField/wrapping/test/itkTransformToDisplacementFieldTest.py @@ -0,0 +1,29 @@ +# ========================================================================== +# +# Copyright NumFOCUS +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ==========================================================================*/ +import numpy as np +import itk + +itk.auto_progress(2) + +img = itk.image_from_array(np.zeros((10, 10, 10))) +transform = itk.IdentityTransform[itk.D, 3].New() +field = itk.transform_to_displacement_field_filter( + transform, + reference_image=img, + use_reference_image=True) +print(field) diff --git a/Wrapping/Generators/SwigInterface/igenerator.py b/Wrapping/Generators/SwigInterface/igenerator.py index 988a2f32333..62d293c0cba 100755 --- a/Wrapping/Generators/SwigInterface/igenerator.py +++ b/Wrapping/Generators/SwigInterface/igenerator.py @@ -1210,6 +1210,34 @@ def generate_process_object_snake_case_functions(self, typedefs): elif any([b.startswith("PathSource") for b in bases]): return_typehint = "-> itkt.PathSourceReturn" + instantiation = f""" + + instance = itk.{process_object}.New(*args, **kwargs) +""" + if snake_case == 'transform_to_displacement_field_filter': + instantiation = f""" + transform = None + if len(args): + transform = args[0] + elif 'transform' in kwargs: + transform = kwargs.pop('transform') + elif 'transform_input' in kwargs: + transform = kwargs.pop('transform_input') + else: + raise ValueError('A transform is required. Pass as the first argument.') + + input_dim = transform.GetInputSpaceDimension() + output_dim = transform.GetOutputSpaceDimension() + ParametersType = itk.template(transform)[1][0] + + decorator = itk.DataObjectDecorator[itk.Transform[ParametersType, input_dim, output_dim]].New() + decorator.Set(transform) + + FieldType = itk.Image[itk.Vector[itk.F, output_dim], output_dim] + + args = (decorator,) + instance = itk.TransformToDisplacementFieldFilter[FieldType, ParametersType].New(*args, **kwargs) +""" # print(args_typehint, kwargs_typehints, return_typehint) self.outputFile.write( f"""from itk.support import helpers @@ -1224,8 +1252,7 @@ def {snake_case}(*args{args_typehint}, {kwargs_typehints}**kwargs){return_typehi kwarg_typehints = {{ {kwarg_dict} }} specified_kwarg_typehints = {{ k:v for (k,v) in kwarg_typehints.items() if kwarg_typehints[k] is not ... }} kwargs.update(specified_kwarg_typehints) - - instance = itk.{process_object}.New(*args, **kwargs) +{instantiation} return instance.__internal_call__() def {snake_case}_init_docstring():