Skip to content

Commit 4f3876f

Browse files
authoredNov 14, 2022
Merge pull request #3728 from PranjalSahu/distancethreshold
ENH: Add DistanceThreshold parameter in EuclideanDistance Metricv4
2 parents 5e726de + c510229 commit 4f3876f

4 files changed

+308
-3
lines changed
 

‎Modules/Registration/Metricsv4/include/itkEuclideanDistancePointSetToPointSetMetricv4.h

+17
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ class ITK_TEMPLATE_EXPORT EuclideanDistancePointSetToPointSetMetricv4
6767
using typename Superclass::PixelType;
6868
using typename Superclass::PointIdentifier;
6969

70+
using RealType = MeasureType;
71+
/**
72+
* Distance threshold to be used to calculate the metric value.
73+
* Only point pairs that have distance lesser than this threshold
74+
* contribute to the metric. Default is -1 to include all the pairs.
75+
*/
76+
itkSetMacro(DistanceThreshold, RealType);
77+
78+
/**
79+
* Get the Distance threshold to be used to calculate the metric value
80+
* Default = -1.
81+
*/
82+
itkGetConstMacro(DistanceThreshold, RealType);
83+
7084
/**
7185
* Calculates the local metric value for a single point.
7286
*/
@@ -95,6 +109,9 @@ class ITK_TEMPLATE_EXPORT EuclideanDistancePointSetToPointSetMetricv4
95109
/** PrintSelf function */
96110
void
97111
PrintSelf(std::ostream & os, Indent indent) const override;
112+
113+
private:
114+
RealType m_DistanceThreshold = -1.0;
98115
};
99116
} // end namespace itk
100117

‎Modules/Registration/Metricsv4/include/itkEuclideanDistancePointSetToPointSetMetricv4.hxx

+22-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@ typename EuclideanDistancePointSetToPointSetMetricv4<TFixedPointSet, TMovingPoin
3535
closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId);
3636

3737
const MeasureType distance = point.EuclideanDistanceTo(closestPoint);
38-
return distance;
38+
if (this->m_DistanceThreshold <= 0 || distance < this->m_DistanceThreshold)
39+
{
40+
return distance;
41+
}
42+
else
43+
{
44+
return 0;
45+
}
3946
}
4047

4148
template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
@@ -52,8 +59,20 @@ EuclideanDistancePointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TIn
5259
PointIdentifier pointId = this->m_MovingTransformedPointsLocator->FindClosestPoint(point);
5360
closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId);
5461

55-
measure = point.EuclideanDistanceTo(closestPoint);
56-
localDerivative = closestPoint - point;
62+
auto distance = point.EuclideanDistanceTo(closestPoint);
63+
64+
if (this->m_DistanceThreshold <= 0 || distance < this->m_DistanceThreshold)
65+
{
66+
measure = distance;
67+
localDerivative = closestPoint - point;
68+
}
69+
else
70+
{
71+
// Skip the points that are beyond the threshold by making value and derivative as 0.
72+
measure = 0;
73+
closestPoint.Fill(0.0);
74+
localDerivative = closestPoint;
75+
}
5776
}
5877

5978
/** PrintSelf method */

‎Modules/Registration/Metricsv4/test/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(ITKMetricsv4Tests
2828
itkEuclideanDistancePointSetMetricRegistrationTest.cxx
2929
itkExpectationBasedPointSetMetricRegistrationTest.cxx
3030
itkEuclideanDistancePointSetMetricTest2.cxx
31+
itkEuclideanDistancePointSetMetricTest3.cxx
3132
itkObjectToObjectMultiMetricv4Test.cxx
3233
itkObjectToObjectMultiMetricv4RegistrationTest.cxx
3334
itkMeanSquaresImageToImageMetricv4SpeedTest.cxx
@@ -46,6 +47,9 @@ itk_add_test(NAME itkEuclideanDistancePointSetMetricTest
4647
itk_add_test(NAME itkEuclideanDistancePointSetMetricTest2
4748
COMMAND ITKMetricsv4TestDriver itkEuclideanDistancePointSetMetricTest2)
4849

50+
itk_add_test(NAME itkEuclideanDistancePointSetMetricTest3
51+
COMMAND ITKMetricsv4TestDriver itkEuclideanDistancePointSetMetricTest3)
52+
4953
itk_add_test(NAME itkExpectationBasedPointSetMetricTest
5054
COMMAND ITKMetricsv4TestDriver itkExpectationBasedPointSetMetricTest)
5155

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/*=========================================================================
2+
*
3+
* Copyright NumFOCUS
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0.txt
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*=========================================================================*/
18+
19+
#include "itkEuclideanDistancePointSetToPointSetMetricv4.h"
20+
#include "itkTranslationTransform.h"
21+
#include "itkTestingMacros.h"
22+
23+
#include <fstream>
24+
#include "itkMath.h"
25+
26+
/*
27+
* Test with a translation transform
28+
*/
29+
30+
template <unsigned int Dimension>
31+
int
32+
itkEuclideanDistancePointSetMetricTest3Run(double distanceThreshold)
33+
{
34+
using PointSetType = itk::PointSet<float, Dimension>;
35+
using PointType = typename PointSetType::PointType;
36+
using IdentifierType = itk::IdentifierType;
37+
using PointsContainerType = itk::VectorContainer<IdentifierType, PointType>;
38+
using PointsLocatorType = itk::PointsLocator<PointsContainerType>;
39+
auto pointsLocator = PointsLocatorType::New();
40+
41+
auto fixedPoints = PointSetType::New();
42+
fixedPoints->Initialize();
43+
44+
auto movingPoints = PointSetType::New();
45+
movingPoints->Initialize();
46+
47+
// Create a few points and apply a small offset to make the moving points
48+
auto pointMax = static_cast<float>(1.0);
49+
PointType fixedPoint;
50+
fixedPoint.Fill(0.0);
51+
fixedPoint[0] = 0.0;
52+
fixedPoint[1] = 0.0;
53+
fixedPoints->SetPoint(0, fixedPoint);
54+
fixedPoint[0] = pointMax;
55+
fixedPoint[1] = 0.0;
56+
fixedPoints->SetPoint(1, fixedPoint);
57+
fixedPoint[0] = 0.0;
58+
fixedPoint[1] = pointMax;
59+
fixedPoints->SetPoint(2, fixedPoint);
60+
if (Dimension == 3)
61+
{
62+
fixedPoint[0] = 0.0;
63+
fixedPoint[1] = 0.0;
64+
fixedPoint[2] = pointMax;
65+
fixedPoints->SetPoint(3, fixedPoint);
66+
}
67+
unsigned int numberOfPoints = fixedPoints->GetNumberOfPoints();
68+
69+
PointType movingPoint;
70+
for (unsigned int n = 0; n < numberOfPoints; ++n)
71+
{
72+
movingPoint.Fill(0);
73+
fixedPoint = fixedPoints->GetPoint(n);
74+
if (n == 0)
75+
{
76+
movingPoint[0] = fixedPoint[0] + 0.5;
77+
movingPoint[1] = fixedPoint[1] + 0.75;
78+
}
79+
else if (n == 1)
80+
{
81+
movingPoint[0] = fixedPoint[0];
82+
movingPoint[1] = fixedPoint[1] + 0.25;
83+
}
84+
else if (n == 2)
85+
{
86+
movingPoint[0] = fixedPoint[0] + 0.25;
87+
movingPoint[1] = fixedPoint[1];
88+
}
89+
if (Dimension == 3)
90+
{
91+
movingPoint[2] = fixedPoint[2] + 0.75;
92+
}
93+
movingPoints->SetPoint(n, movingPoint);
94+
}
95+
96+
pointsLocator->SetPoints(movingPoints->GetPoints());
97+
pointsLocator->Initialize();
98+
99+
// Calculate distance between nearest points (correspondence points)
100+
std::vector<double> distanceArray;
101+
for (unsigned int n = 0; n < numberOfPoints; ++n)
102+
{
103+
auto tempFixedPoint = fixedPoints->GetPoint(n);
104+
auto pointId = pointsLocator->FindClosestPoint(tempFixedPoint);
105+
auto closestPoint = movingPoints->GetPoint(pointId);
106+
distanceArray.push_back(closestPoint.EuclideanDistanceTo(tempFixedPoint));
107+
std::cout << n << " " << tempFixedPoint << " , " << closestPoint << " "
108+
<< closestPoint.EuclideanDistanceTo(tempFixedPoint) << std::endl;
109+
}
110+
111+
// Test with Translation transform
112+
std::cout << "Testing with Translation Transform." << std::endl;
113+
using TranslationTransformType = itk::TranslationTransform<double, Dimension>;
114+
auto translationTransform = TranslationTransformType::New();
115+
116+
// Instantiate the metric
117+
using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4<PointSetType>;
118+
auto metric = PointSetMetricType::New();
119+
metric->SetFixedPointSet(fixedPoints);
120+
metric->SetMovingPointSet(movingPoints);
121+
metric->SetMovingTransform(translationTransform);
122+
metric->SetDistanceThreshold(distanceThreshold);
123+
124+
ITK_TEST_SET_GET_VALUE(distanceThreshold, metric->GetDistanceThreshold());
125+
126+
metric->Initialize();
127+
128+
// test
129+
typename PointSetMetricType::MeasureType value = metric->GetValue(), value2;
130+
typename PointSetMetricType::DerivativeType derivative, derivative2;
131+
metric->GetDerivative(derivative);
132+
metric->GetValueAndDerivative(value2, derivative2);
133+
134+
std::cout << "value: " << value << std::endl;
135+
136+
// Check for the same results from different methods
137+
if (itk::Math::NotExactlyEquals(value, value2))
138+
{
139+
std::cerr << "value does not match between calls to different methods: "
140+
<< "value: " << value << " value2: " << value2 << std::endl;
141+
}
142+
if (derivative != derivative2)
143+
{
144+
std::cerr << "derivative does not match between calls to different methods: "
145+
<< "derivative: " << derivative << " derivative2: " << derivative2 << std::endl;
146+
}
147+
148+
// Check if the points outside threshold are skipped in metric calculation
149+
double distanceSum = 0.0;
150+
for (unsigned int n = 0; n < numberOfPoints; ++n)
151+
{
152+
if (distanceThreshold <= 0 || distanceArray[n] < distanceThreshold)
153+
{
154+
distanceSum = distanceSum + distanceArray[n];
155+
}
156+
}
157+
158+
double valueTest = distanceSum / numberOfPoints;
159+
if (itk::Math::NotExactlyEquals(valueTest, value2))
160+
{
161+
std::cerr << "Value calculation is wrong when used threshold : " << distanceThreshold << "valueTest: " << valueTest
162+
<< " value2: " << value2 << std::endl;
163+
return EXIT_FAILURE;
164+
}
165+
166+
167+
// Check if the point outside threshold is skipped in derivative calculation
168+
typename PointSetMetricType::DerivativeType derivativeTest;
169+
derivativeTest.SetSize(Dimension);
170+
derivativeTest.Fill(0);
171+
for (unsigned int n = 0; n < numberOfPoints; ++n)
172+
{
173+
auto tempFixedPoint = fixedPoints->GetPoint(n);
174+
auto tempMovingPoint = movingPoints->GetPoint(n);
175+
auto tempDerivative = tempMovingPoint - tempFixedPoint;
176+
177+
if (distanceThreshold <= 0 || distanceArray[n] < distanceThreshold)
178+
{
179+
derivativeTest[0] = derivativeTest[0] + tempDerivative[0];
180+
derivativeTest[1] = derivativeTest[1] + tempDerivative[1];
181+
if (Dimension == 3)
182+
{
183+
derivativeTest[2] = derivativeTest[2] + tempDerivative[2];
184+
}
185+
}
186+
}
187+
188+
auto derivativeTestMean = derivativeTest / numberOfPoints;
189+
std::cout << "Derivative is [ " << derivativeTestMean << " ]" << std::endl;
190+
191+
for (unsigned int i = 0; i << Dimension; ++i)
192+
{
193+
if (itk::Math::NotExactlyEquals(derivativeTestMean[i], derivative2[i]))
194+
{
195+
std::cerr << "Derivative calculation is wrong when used threshold : " << distanceThreshold
196+
<< "derivativeTestMean: " << derivativeTestMean << " derivative2: " << derivative2 << std::endl;
197+
return EXIT_FAILURE;
198+
}
199+
}
200+
201+
return EXIT_SUCCESS;
202+
}
203+
204+
int
205+
itkEuclideanDistancePointSetMetricTest3(int, char *[])
206+
{
207+
int result = EXIT_SUCCESS;
208+
209+
double distanceThresholdPositive = 0.5;
210+
double distanceThresholdZero = 0.0;
211+
double distanceThresholdNegative = -8.0;
212+
213+
const unsigned int dimension2D = 2;
214+
const unsigned int dimension3D = 3;
215+
216+
// Test for positive distance threshold
217+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension2D>(distanceThresholdPositive) == EXIT_FAILURE)
218+
{
219+
std::cerr << "Failed for Dimension " << dimension2D << " for distanceThrehold = " << distanceThresholdPositive
220+
<< std::endl;
221+
result = EXIT_FAILURE;
222+
}
223+
224+
// Test for zero distance threshold
225+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension2D>(distanceThresholdZero) == EXIT_FAILURE)
226+
{
227+
std::cerr << "Failed for Dimension " << dimension2D << " for distanceThrehold = " << distanceThresholdZero
228+
<< std::endl;
229+
result = EXIT_FAILURE;
230+
}
231+
232+
// Test for negative distance threshold
233+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension2D>(distanceThresholdNegative) == EXIT_FAILURE)
234+
{
235+
std::cerr << "Failed for Dimension " << dimension2D << " for distanceThrehold = " << distanceThresholdNegative
236+
<< std::endl;
237+
result = EXIT_FAILURE;
238+
}
239+
240+
// Test for positive distance threshold
241+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension3D>(distanceThresholdPositive) == EXIT_FAILURE)
242+
{
243+
std::cerr << "Failed for Dimension " << dimension3D << " for distanceThrehold = " << distanceThresholdPositive
244+
<< std::endl;
245+
result = EXIT_FAILURE;
246+
}
247+
248+
// Test for zero distance threshold
249+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension3D>(distanceThresholdZero) == EXIT_FAILURE)
250+
{
251+
std::cerr << "Failed for Dimension " << dimension3D << " for distanceThrehold = " << distanceThresholdZero
252+
<< std::endl;
253+
result = EXIT_FAILURE;
254+
}
255+
256+
// Test for negative distance threshold
257+
if (itkEuclideanDistancePointSetMetricTest3Run<dimension3D>(distanceThresholdNegative) == EXIT_FAILURE)
258+
{
259+
std::cerr << "Failed for Dimension " << dimension3D << " for distanceThrehold = " << distanceThresholdNegative
260+
<< std::endl;
261+
result = EXIT_FAILURE;
262+
}
263+
264+
return result;
265+
}

0 commit comments

Comments
 (0)
Please sign in to comment.