Skip to content

Commit

Permalink
FIX: compose multi type transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 23, 2024
1 parent 34536fc commit a930d72
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/antsTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,19 +344,32 @@ AntsImage<ImageType> transformImage( AntsTransform<TransformType> & myTx,
}


// need nb:list instead of std::vector<AntsTransform<TransformBaseType>> in order
// to support a mix of standard itk transform and itk displacementfieldtransform types
template <typename TransformBaseType, typename PrecisionType, unsigned int Dimension>
AntsTransform<TransformBaseType> composeTransforms( std::vector<AntsTransform<TransformBaseType>> tformlist,
std::string precision, unsigned int dimension)
AntsTransform<TransformBaseType> composeTransforms( nb::list tformlist,
std::string precision,
unsigned int dimension)
{
typedef typename itk::DisplacementFieldTransform<PrecisionType, Dimension> DisplacementTransformType;
typedef typename DisplacementTransformType::Pointer DisplacementTransformPointerType;
typedef typename TransformBaseType::Pointer TransformBasePointerType;
typedef typename itk::CompositeTransform<PrecisionType, Dimension> CompositeTransformType;

typename CompositeTransformType::Pointer comp_transform = CompositeTransformType::New();

for ( unsigned int i = 0; i < tformlist.size(); i++ )
{
TransformBasePointerType t = tformlist[i].ptr;
comp_transform->AddTransform( t );
for ( nb::handle_t<AntsTransform<TransformBaseType>> h: tformlist )
{
PyObject * a_py = h.ptr();
AntsTransform<TransformBaseType> mytx;
bool res = nb::try_cast<AntsTransform<TransformBaseType> &>(h, mytx);
if (res == false) {
// failed cast means its a displacement field transform
AntsTransform<DisplacementTransformType> &mytx = nb::cast<AntsTransform<DisplacementTransformType> &>(h);
comp_transform->AddTransform( mytx.ptr );
} else {
comp_transform->AddTransform( mytx.ptr );
}
}
AntsTransform<TransformBaseType> outTransform = { comp_transform.GetPointer() };
return outTransform;
Expand Down
17 changes: 17 additions & 0 deletions tests/test_bugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def test_resample_returns_NaNs(self):

self.assertTrue(np.sum(np.isnan(img3dr.numpy())) == 0)

def test_compose_multi_type_transforms(self):
image = ants.image_read(ants.get_ants_data("r16"))

linear_transform = ants.create_ants_transform(transform_type=
"AffineTransform", precision='float', dimension=image.dimension)

displacement_field = ants.simulate_displacement_field(image,
field_type="bspline", number_of_random_points=1000,
sd_noise=10.0, enforce_stationary_boundary=True,
number_of_fitting_levels=4, mesh_size=1,
sd_smoothing=4.0)
displacement_field_xfrm = ants.transform_from_displacement_field(displacement_field)

xfrm = ants.compose_ants_transforms([linear_transform, displacement_field_xfrm])
xfrm = ants.compose_ants_transforms([linear_transform, linear_transform])
xfrm = ants.compose_ants_transforms([displacement_field_xfrm, linear_transform])


if __name__ == '__main__':
run_tests()

0 comments on commit a930d72

Please sign in to comment.