Skip to content

Commit

Permalink
BUG: transform_to_displacement_field_filter type inference
Browse files Browse the repository at this point in the history
Customize implementation since we do not have a image arg that can used
to infer the filter type.

Output type is itk.Imgae[itk.Vector[itk.F, output_dim], output_dim]
because that is what is wrapped.

Closes #3860
  • Loading branch information
thewtex committed Jan 13, 2023
1 parent f187319 commit f130702
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
@@ -1,4 +1,6 @@
if(ITK_WRAP_PYTHON)
itk_python_add_test(NAME itkDisplacementFieldTransformPythonTest
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkDisplacementFieldTransformTest.py)
itk_python_add_test(NAME itkTransformToDisplacementFieldPythonTest
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkTransformToDisplacementFieldTest.py)
endif()
@@ -0,0 +1,29 @@
# ==========================================================================
#
# Copyright NumFOCUS
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==========================================================================*/
import numpy as np
import itk

itk.auto_progress(2)

img = itk.image_from_array(np.zeros((10, 10, 10)))
transform = itk.IdentityTransform[itk.D, 3].New()
field = itk.transform_to_displacement_field_filter(
transform,
reference_image=img,
use_reference_image=True)
print(field)
31 changes: 29 additions & 2 deletions Wrapping/Generators/SwigInterface/igenerator.py
Expand Up @@ -1210,6 +1210,34 @@ def generate_process_object_snake_case_functions(self, typedefs):
elif any([b.startswith("PathSource") for b in bases]):
return_typehint = "-> itkt.PathSourceReturn"

instantiation = f"""
instance = itk.{process_object}.New(*args, **kwargs)
"""
if snake_case == 'transform_to_displacement_field_filter':
instantiation = f"""
transform = None
if len(args):
transform = args[0]
elif 'transform' in kwargs:
transform = kwargs.pop('transform')
elif 'transform_input' in kwargs:
transform = kwargs.pop('transform_input')
else:
raise ValueError('A transform is required. Pass as the first argument.')
input_dim = transform.GetInputSpaceDimension()
output_dim = transform.GetOutputSpaceDimension()
ParametersType = itk.template(transform)[1][0]
decorator = itk.DataObjectDecorator[itk.Transform[ParametersType, input_dim, output_dim]].New()
decorator.Set(transform)
FieldType = itk.Image[itk.Vector[itk.F, output_dim], output_dim]
args = (decorator,)
instance = itk.TransformToDisplacementFieldFilter[FieldType, ParametersType].New(*args, **kwargs)
"""
# print(args_typehint, kwargs_typehints, return_typehint)
self.outputFile.write(
f"""from itk.support import helpers
Expand All @@ -1224,8 +1252,7 @@ def {snake_case}(*args{args_typehint}, {kwargs_typehints}**kwargs){return_typehi
kwarg_typehints = {{ {kwarg_dict} }}
specified_kwarg_typehints = {{ k:v for (k,v) in kwarg_typehints.items() if kwarg_typehints[k] is not ... }}
kwargs.update(specified_kwarg_typehints)
instance = itk.{process_object}.New(*args, **kwargs)
{instantiation}
return instance.__internal_call__()
def {snake_case}_init_docstring():
Expand Down

0 comments on commit f130702

Please sign in to comment.