From e0b2efc2acc999ffe7df6dd6d2799de294dd811f Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 20 Oct 2020 16:24:29 -0600 Subject: [PATCH] MATH-1558: Fix MidPointIntegrator incremental implementation --- .../integration/MidPointIntegrator.java | 36 ++++++++++++------- .../integration/MidPointIntegratorTest.java | 36 +++++++++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java b/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java index bb8b76390a..de252a4a3e 100644 --- a/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java +++ b/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java @@ -25,7 +25,7 @@ import org.apache.commons.math4.util.FastMath; /** - * Implements the + * Implements the * Midpoint Rule for integration of real univariate functions. For * reference, see Numerical Mathematics, ISBN 0387989595, * chapter 9.2. @@ -36,8 +36,10 @@ */ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator { - /** Maximum number of iterations for midpoint. */ - private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 63; + /** Maximum number of iterations for midpoint. 39 = floor(log_3(2^63)), the + * maximum number of triplings allowed before exceeding 64-bit bounds. + */ + private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 39; /** * Build a midpoint integrator with given accuracies and iterations counts. @@ -50,7 +52,7 @@ public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator { * @exception NumberIsTooSmallException if maximal number of iterations * is lesser than or equal to the minimal number of iterations * @exception NumberIsTooLargeException if maximal number of iterations - * is greater than 63. + * is greater than 39. */ public MidPointIntegrator(final double relativeAccuracy, final double absoluteAccuracy, @@ -73,7 +75,7 @@ public MidPointIntegrator(final double relativeAccuracy, * @exception NumberIsTooSmallException if maximal number of iterations * is lesser than or equal to the minimal number of iterations * @exception NumberIsTooLargeException if maximal number of iterations - * is greater than 63. + * is greater than 39. */ public MidPointIntegrator(final int minimalIterationCount, final int maximalIterationCount) @@ -98,11 +100,11 @@ public MidPointIntegrator() { * This function should only be called by API integrate() in the package. * To save time it does not verify arguments - caller does. *

- * The interval is divided equally into 2^n sections rather than an + * The interval is divided equally into 3^n sections rather than an * arbitrary m sections because this configuration can best utilize the * already computed values.

* - * @param n the stage of 1/2 refinement. Must be larger than 0. + * @param n the stage of 1/3 refinement. Must be larger than 0. * @param previousStageResult Result from the previous call to the * {@code stage} method. * @param min Lower bound of the integration interval. @@ -118,21 +120,29 @@ private double stage(final int n, double diffMaxMin) throws TooManyEvaluationsException { - // number of new points in this stage - final long np = 1L << (n - 1); + // number of points in the previous stage. This stage will contribute + // 2*3^{n-1} more points. + final long np = (long) FastMath.pow(3, n - 1); double sum = 0; // spacing between adjacent new points final double spacing = diffMaxMin / np; + final double leftOffset = spacing / 6; + final double rightOffset = 5 * leftOffset; - // the first new point - double x = min + 0.5 * spacing; + double x = min; for (long i = 0; i < np; i++) { - sum += computeObjectiveValue(x); + // The first and second new points are located at the new midpoints + // generated when each previous integration slice is split into 3. + // + // |--------x--------| + // |--x--|--x--|--x–-| + sum += computeObjectiveValue(x + leftOffset); + sum += computeObjectiveValue(x + rightOffset); x += spacing; } // add the new sum to previously calculated result - return 0.5 * (previousStageResult + sum * spacing); + return (previousStageResult + sum * spacing) / 3.0; } diff --git a/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java b/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java index 0474d27bdc..1d227dda64 100644 --- a/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java +++ b/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java @@ -35,6 +35,25 @@ public final class MidPointIntegratorTest { private static final int NUM_ITER = 30; + /** + * The initial iteration contributes 1 evaluation. Each successive iteration + * contributes 2 points to each previous slice. + * + * The total evaluation count == 1 + 2*3^0 + 2*3^1 + ... 2*3^n + * + * the series 3^0 + 3^1 + ... + 3^n sums to 3^(n-1) / (3-1), so the total + * expected evaluations == 1 + 2*(3^(n-1) - 1)/2 == 3^(n-1). + * + * The n in the series above is offset by 1 from the MidPointIntegrator + * iteration count so the actual result == 3^n. + * + * Without the incremental implementation, the same result would require + * (3^(n + 1) - 1) / 2 evaluations; just under 50% more. + */ + private long expectedEvaluations(int iterations) { + return (long) FastMath.pow(3, iterations); + } + /** * Test of integrator for the sine function. */ @@ -48,8 +67,9 @@ public void testLowAccuracy() { double expected = -3697001.0 / 48.0; double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); double result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); + Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations()); Assert.assertEquals(expected, result, tolerance); } @@ -67,8 +87,9 @@ public void testSinFunction() { double expected = 2; double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); double result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); + Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations()); Assert.assertEquals(expected, result, tolerance); min = -FastMath.PI/3; @@ -76,8 +97,9 @@ public void testSinFunction() { expected = -0.5; tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); + Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations()); Assert.assertEquals(expected, result, tolerance); } @@ -95,8 +117,9 @@ public void testQuinticFunction() { double expected = -1.0 / 48; double tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); double result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); + Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations()); Assert.assertEquals(expected, result, tolerance); min = 0; @@ -104,7 +127,7 @@ public void testQuinticFunction() { expected = 11.0 / 768; tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); Assert.assertEquals(expected, result, tolerance); @@ -113,8 +136,9 @@ public void testQuinticFunction() { expected = 2048 / 3.0 - 78 + 1.0 / 48; tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy()); result = integrator.integrate(Integer.MAX_VALUE, f, min, max); - Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2); + Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3); Assert.assertTrue(integrator.getIterations() < NUM_ITER); + Assert.assertEquals(expectedEvaluations(integrator.getIterations()), integrator.getEvaluations()); Assert.assertEquals(expected, result, tolerance); }