Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MATH-1558] Fix MidPointIntegrator incremental implementation #161

Merged
merged 1 commit into from Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -25,7 +25,7 @@
import org.apache.commons.math4.util.FastMath;

/**
* Implements the <a href="http://en.wikipedia.org/wiki/Midpoint_method">
* Implements the <a href="https://en.wikipedia.org/wiki/Riemann_sum#Midpoint_rule">
* Midpoint Rule</a> for integration of real univariate functions. For
* reference, see <b>Numerical Mathematics</b>, ISBN 0387989595,
* chapter 9.2.
Expand All @@ -36,8 +36,10 @@
*/
public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {

/** Maximum number of iterations for midpoint. */
private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 63;
/** Maximum number of iterations for midpoint. 39 = floor(log_3(2^63)), the
* maximum number of triplings allowed before exceeding 64-bit bounds.
*/
private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 39;

/**
* Build a midpoint integrator with given accuracies and iterations counts.
Expand All @@ -50,7 +52,7 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
* @exception NumberIsTooSmallException if maximal number of iterations
* is lesser than or equal to the minimal number of iterations
* @exception NumberIsTooLargeException if maximal number of iterations
* is greater than 63.
* is greater than 39.
*/
public MidPointIntegrator(final double relativeAccuracy,
final double absoluteAccuracy,
Expand All @@ -73,7 +75,7 @@ public MidPointIntegrator(final double relativeAccuracy,
* @exception NumberIsTooSmallException if maximal number of iterations
* is lesser than or equal to the minimal number of iterations
* @exception NumberIsTooLargeException if maximal number of iterations
* is greater than 63.
* is greater than 39.
*/
public MidPointIntegrator(final int minimalIterationCount,
final int maximalIterationCount)
Expand All @@ -98,11 +100,11 @@ public MidPointIntegrator() {
* This function should only be called by API <code>integrate()</code> in the package.
* To save time it does not verify arguments - caller does.
* <p>
* The interval is divided equally into 2^n sections rather than an
* The interval is divided equally into 3^n sections rather than an
* arbitrary m sections because this configuration can best utilize the
* already computed values.</p>
*
* @param n the stage of 1/2 refinement. Must be larger than 0.
* @param n the stage of 1/3 refinement. Must be larger than 0.
* @param previousStageResult Result from the previous call to the
* {@code stage} method.
* @param min Lower bound of the integration interval.
Expand All @@ -118,21 +120,29 @@ private double stage(final int n,
double diffMaxMin)
throws TooManyEvaluationsException {

// number of new points in this stage
final long np = 1L << (n - 1);
// number of points in the previous stage. This stage will contribute
// 2*3^{n-1} more points.
final long np = (long) FastMath.pow(3, n - 1);
double sum = 0;

// spacing between adjacent new points
final double spacing = diffMaxMin / np;
final double leftOffset = spacing / 6;
final double rightOffset = 5 * leftOffset;

// the first new point
double x = min + 0.5 * spacing;
double x = min;
for (long i = 0; i < np; i++) {
sum += computeObjectiveValue(x);
// The first and second new points are located at the new midpoints
// generated when each previous integration slice is split into 3.
//
// |--------x--------|
// |--x--|--x--|--x–-|
sum += computeObjectiveValue(x + leftOffset);
sum += computeObjectiveValue(x + rightOffset);
x += spacing;
}
// add the new sum to previously calculated result
return 0.5 * (previousStageResult + sum * spacing);
return (previousStageResult + sum * spacing) / 3.0;
}


Expand Down
Expand Up @@ -35,6 +35,25 @@
public final class MidPointIntegratorTest {
private static final int NUM_ITER = 30;

/**
* The initial iteration contributes 1 evaluation. Each successive iteration
* contributes 2 points to each previous slice.
*
* The total evaluation count == 1 + 2*3^0 + 2*3^1 + ... 2*3^n
*
* the series 3^0 + 3^1 + ... + 3^n sums to 3^(n-1) / (3-1), so the total
* expected evaluations == 1 + 2*(3^(n-1) - 1)/2 == 3^(n-1).
*
* The n in the series above is offset by 1 from the MidPointIntegrator
* iteration count so the actual result == 3^n.
*
* Without the incremental implementation, the same result would require
* (3^(n + 1) - 1) / 2 evaluations; just under 50% more.
*/
private long expectedEvaluations(int iterations) {
return (long) FastMath.pow(3, iterations);
}

/**
* Test of integrator for the sine function.
*/
Expand All @@ -48,8 +67,9 @@ public void testLowAccuracy() {
double expected = -3697001.0 / 48.0;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test might be redundant now that we know exactly how many evaluations we should have, as a function of iterations.

Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);

}
Expand All @@ -67,17 +87,19 @@ public void testSinFunction() {
double expected = 2;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);

min = -FastMath.PI/3;
max = 0;
expected = -0.5;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);

}
Expand All @@ -95,16 +117,17 @@ public void testQuinticFunction() {
double expected = -1.0 / 48;
double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);

min = 0;
max = 0.5;
expected = 11.0 / 768;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expected, result, tolerance);

Expand All @@ -113,8 +136,9 @@ public void testQuinticFunction() {
expected = 2048 / 3.0 - 78 + 1.0 / 48;
tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
Assert.assertTrue(integrator.getIterations() < NUM_ITER);
Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations());
Assert.assertEquals(expected, result, tolerance);

}
Expand Down