From f0968b46530f1254574d1fea5cf0269d0cbe2b7d Mon Sep 17 00:00:00 2001 From: aherbert Date: Fri, 8 Oct 2021 14:20:18 +0100 Subject: [PATCH] NUMBERS-167: Allow precomputation of regularized gamma arguments --- .../numbers/gamma/RegularizedGamma.java | 444 ++++++++++++++++-- .../numbers/gamma/RegularizedGammaTest.java | 171 +++++-- 2 files changed, 527 insertions(+), 88 deletions(-) diff --git a/commons-numbers-gamma/src/main/java/org/apache/commons/numbers/gamma/RegularizedGamma.java b/commons-numbers-gamma/src/main/java/org/apache/commons/numbers/gamma/RegularizedGamma.java index f607fb9b7..ab0d321d2 100644 --- a/commons-numbers-gamma/src/main/java/org/apache/commons/numbers/gamma/RegularizedGamma.java +++ b/commons-numbers-gamma/src/main/java/org/apache/commons/numbers/gamma/RegularizedGamma.java @@ -17,29 +17,175 @@ package org.apache.commons.numbers.gamma; import java.text.MessageFormat; - +import java.util.function.DoubleUnaryOperator; import org.apache.commons.numbers.fraction.ContinuedFraction; /** * * Regularized Gamma functions. - * - * Class is immutable. */ public final class RegularizedGamma { /** Maximum allowed numerical error. */ private static final double DEFAULT_EPSILON = 1e-15; + /** Maximum allowed iterations. */ + private static final int DEFAULT_ITERATIONS = Integer.MAX_VALUE; /** Private constructor. */ private RegularizedGamma() { // intentionally empty. } + /** + * Encapsulates values for argument {@code a} of the Regularized Gamma functions. + * + *

Class is immutable. + */ + public static final class ArgumentA { + /** Argument a. */ + private final double a; + /** logGamma(a). */ + private final double logGammaA; + + /** + * @param a Argument a + */ + private ArgumentA(double a) { + this.a = a; + this.logGammaA = LogGamma.value(a); + } + + /** + * Gets the value of the argument. + * + * @return a + */ + public double get() { + return a; + } + + /** + * Gets the value for logGamma(a). + * + *

Note: This method has an argument to allow it to be used as a method reference. + * It will not compute the log value on the input argument. + * + * @param ignore Value to ignore + * @return logGamma(a) + */ + double getLogGamma(double ignore) { + return logGammaA; + } + + /** + * Pre-compute values for argument {@code a} of the Regularized Gamma functions. + * + * @param a Argument a + * @return the argument + * @throws IllegalArgumentException if {@code a <= 0} or is NaN + */ + public static ArgumentA of(double a) { + if (invalid(a)) { + throw new IllegalArgumentException("Value is not strictly positive: " + a); + } + return new ArgumentA(a); + } + + /** + * Check if argument {@code a} is invalid. + * + * @param a Argument a + * @return true if {@code a <= 0} or is NaN + */ + static boolean invalid(double a) { + return Double.isNaN(a) || a <= 0; + } + } + + /** + * Encapsulates values for argument {@code x} of the Regularized Gamma functions. + * + *

Class is immutable. + */ + public static final class ArgumentX { + /** Argument x. */ + private final double x; + /** log(x). */ + private final double logX; + + /** + * @param x Argument x + */ + private ArgumentX(double x) { + this.x = x; + this.logX = Math.log(x); + } + + /** + * Gets the value of the argument. + * + * @return x + */ + public double get() { + return x; + } + + /** + * Gets the value for log(x). + * + *

Note: This method has an argument to allow it to be used as a method reference. + * It will not compute the log value on the input argument. + * + * @param ignore Value to ignore + * @return log(x) + */ + double getLog(double ignore) { + return logX; + } + + /** + * Pre-compute values for argument {@code x} of the Regularized Gamma functions. + * + * @param x Argument x + * @return the argument + * @throws IllegalArgumentException if {@code x < 0} or is NaN + */ + public static ArgumentX of(double x) { + if (invalid(x)) { + throw new IllegalArgumentException("Value is not positive: " + x); + } + return new ArgumentX(x); + } + + /** + * Check if argument {@code x} is invalid. + * + * @param x Argument x + * @return true if {@code x < 0} or is NaN + */ + static boolean invalid(double x) { + return Double.isNaN(x) || x < 0; + } + } + /** * \( P(a, x) \) * regularized Gamma function. * - * Class is immutable. + *

The implementation of this method is based on: + *

*/ public static final class P { /** Prevent instantiation. */ @@ -55,27 +201,45 @@ private P() {} */ public static double value(double a, double x) { - return value(a, x, DEFAULT_EPSILON, Integer.MAX_VALUE); + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); + } + + /** + * Computes the regularized gamma function \( P(a, x) \). + * + *

This is a specialization of the function \( P(a, x) \) that allows pre-computation + * of values required for argument \( a \). + * + * @param a Argument. + * @param x Argument. + * @return \( P(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentA + */ + public static double value(ArgumentA a, + double x) { + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); } /** * Computes the regularized gamma function \( P(a, x) \). * - * The implementation of this method is based on: - *

+ *

This is a specialization of the function \( P(a, x) \) that allows pre-computation + * of values required for argument \( x \). + * + * @param a Argument. + * @param x Argument. + * @return \( P(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentX + */ + public static double value(double a, + ArgumentX x) { + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); + } + + /** + * Computes the regularized gamma function \( P(a, x) \). * * @param a Argument. * @param x Argument. @@ -88,16 +252,90 @@ public static double value(double a, double x, double epsilon, int maxIterations) { - if (Double.isNaN(a) || - Double.isNaN(x) || - a <= 0 || - x < 0) { + if (ArgumentA.invalid(a) || ArgumentX.invalid(x)) { + return Double.NaN; + } + return compute(a, x, epsilon, maxIterations, Math::log, LogGamma::value); + } + + /** + * Computes the regularized gamma function \( P(a, x) \). + * + *

This is a specialization of the function \( P(a, x) \) that allows pre-computation + * of values required for argument \( a \). + * + * @param a Argument. + * @param x Argument. + * @param epsilon Tolerance in continued fraction evaluation. + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @return \( P(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentA + */ + public static double value(ArgumentA a, + double x, + double epsilon, + int maxIterations) { + if (ArgumentX.invalid(x)) { + return Double.NaN; + } + return compute(a.get(), x, epsilon, maxIterations, Math::log, a::getLogGamma); + } + + /** + * Computes the regularized gamma function \( P(a, x) \). + * + *

This is a specialization of the function \( P(a, x) \) that allows pre-computation + * of values required for argument \( x \). + * + * @param a Argument. + * @param x Argument. + * @param epsilon Tolerance in continued fraction evaluation. + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @return \( P(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentX + */ + public static double value(double a, + ArgumentX x, + double epsilon, + int maxIterations) { + if (ArgumentA.invalid(a)) { return Double.NaN; - } else if (x == 0) { + } + return compute(a, x.get(), epsilon, maxIterations, x::getLog, LogGamma::value); + } + + /** + * Computes the regularized gamma function \( P(a, x) \). + * + *

Note: Assumes argument validation has been performed. + * + * @param a Argument ({@code a > 0}). + * @param x Argument ({@code a >= 0}). + * @param epsilon Tolerance in continued fraction evaluation. + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @param fLogX Function to compute the log of x + * @param fLogGammaA Function to compute logGamma(a) + * @return \( P(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + */ + static double compute(double a, + double x, + double epsilon, + int maxIterations, + DoubleUnaryOperator fLogX, + DoubleUnaryOperator fLogGammaA) { + // Assume validation has been performed: + // a > 0 + // x >= 0 + // NaN is not allowed + + if (x == 0) { return 0; } else if (x >= a + 1) { // Q should converge faster in this case. - return 1 - RegularizedGamma.Q.value(a, x, epsilon, maxIterations); + return 1 - RegularizedGamma.Q.compute(a, x, epsilon, maxIterations, fLogX, fLogGammaA); } else { // Series. double n = 0; // current element index @@ -119,8 +357,10 @@ public static double value(double a, } else if (Double.isInfinite(sum)) { return 1; } else { + final double logX = fLogX.applyAsDouble(x); + final double logGammaA = fLogGammaA.applyAsDouble(a); + final double result = Math.exp(-x + (a * logX) - logGammaA) * sum; // Ensure result is in the range [0, 1] - final double result = Math.exp(-x + (a * Math.log(x)) - LogGamma.value(a)) * sum; return result > 1.0 ? 1.0 : result; } } @@ -131,7 +371,18 @@ public static double value(double a, * Creates the \( Q(a, x) \equiv 1 - P(a, x) \) * regularized Gamma function. * - * Class is immutable. + *

The implementation of this method is based on: + *

*/ public static final class Q { /** Prevent instantiation. */ @@ -147,24 +398,45 @@ private Q() {} */ public static double value(double a, double x) { - return value(a, x, DEFAULT_EPSILON, Integer.MAX_VALUE); + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); + } + + /** + * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). + * + *

This is a specialization of the function \( Q(a, x) \) that allows pre-computation + * of values required for argument \( a \). + * + * @param a Argument. + * @param x Argument. + * @return \( Q(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentA + */ + public static double value(ArgumentA a, + double x) { + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); } /** * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). * - * The implementation of this method is based on: - *

+ *

This is a specialization of the function \( Q(a, x) \) that allows pre-computation + * of values required for argument \( x \). + * + * @param a Argument. + * @param x Argument. + * @return \( Q(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentX + */ + public static double value(double a, + ArgumentX x) { + return value(a, x, DEFAULT_EPSILON, DEFAULT_ITERATIONS); + } + + /** + * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). * * @param a Argument. * @param x Argument. @@ -177,16 +449,90 @@ public static double value(final double a, double x, double epsilon, int maxIterations) { - if (Double.isNaN(a) || - Double.isNaN(x) || - a <= 0 || - x < 0) { + if (ArgumentA.invalid(a) || ArgumentX.invalid(x)) { + return Double.NaN; + } + return compute(a, x, epsilon, maxIterations, Math::log, LogGamma::value); + } + + /** + * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). + * + *

This is a specialization of the function \( Q(a, x) \) that allows pre-computation + * of values required for argument \( a \). + * + * @param a Argument. + * @param x Argument. + * @param epsilon Tolerance in continued fraction evaluation. + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @return \( Q(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentA + */ + public static double value(ArgumentA a, + double x, + double epsilon, + int maxIterations) { + if (ArgumentX.invalid(x)) { + return Double.NaN; + } + return compute(a.get(), x, epsilon, maxIterations, Math::log, a::getLogGamma); + } + + /** + * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). + * + *

This is a specialization of the function \( Q(a, x) \) that allows pre-computation + * of values required for argument \( x \). + * + * @param a Argument. + * @param x Argument. + * @param epsilon Tolerance in continued fraction evaluation. + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @return \( Q(a, x) \). + * @throws ArithmeticException if the continued fraction fails to converge. + * @see ArgumentX + */ + public static double value(double a, + ArgumentX x, + double epsilon, + int maxIterations) { + if (ArgumentA.invalid(a)) { return Double.NaN; - } else if (x == 0) { + } + return compute(a, x.get(), epsilon, maxIterations, x::getLog, LogGamma::value); + } + + /** + * Computes the regularized gamma function \( Q(a, x) = 1 - P(a, x) \). + * + *

Note: Assumes argument validation has been performed. + * + * @param a Argument ({@code a > 0}). + * @param x Argument ({@code a >= 0}). + * @param epsilon Tolerance in continued fraction evaluation. + * @param fLogX Function to compute the log of x + * @param fLogGammaA Function to compute logGamma(a) + * @param maxIterations Maximum number of iterations in continued fraction evaluation. + * @throws ArithmeticException if the continued fraction fails to converge. + * @return \( Q(a, x) \). + */ + static double compute(final double a, + double x, + double epsilon, + int maxIterations, + DoubleUnaryOperator fLogX, + DoubleUnaryOperator fLogGammaA) { + // Assume validation has been performed: + // a > 0 + // x >= 0 + // NaN is not allowed + + if (x == 0) { return 1; } else if (x < a + 1) { // P should converge faster in this case. - return 1 - RegularizedGamma.P.value(a, x, epsilon, maxIterations); + return 1 - RegularizedGamma.P.compute(a, x, epsilon, maxIterations, fLogX, fLogGammaA); } else { final ContinuedFraction cf = new ContinuedFraction() { /** {@inheritDoc} */ @@ -202,7 +548,9 @@ protected double getB(int n, double x) { } }; - return Math.exp(-x + (a * Math.log(x)) - LogGamma.value(a)) / + final double logX = fLogX.applyAsDouble(x); + final double logGammaA = fLogGammaA.applyAsDouble(a); + return Math.exp(-x + (a * logX) - logGammaA) / cf.evaluate(x, epsilon, maxIterations); } } diff --git a/commons-numbers-gamma/src/test/java/org/apache/commons/numbers/gamma/RegularizedGammaTest.java b/commons-numbers-gamma/src/test/java/org/apache/commons/numbers/gamma/RegularizedGammaTest.java index e461a1a46..30f110401 100644 --- a/commons-numbers-gamma/src/test/java/org/apache/commons/numbers/gamma/RegularizedGammaTest.java +++ b/commons-numbers-gamma/src/test/java/org/apache/commons/numbers/gamma/RegularizedGammaTest.java @@ -16,46 +16,42 @@ */ package org.apache.commons.numbers.gamma; +import org.apache.commons.numbers.gamma.RegularizedGamma.ArgumentA; +import org.apache.commons.numbers.gamma.RegularizedGamma.ArgumentX; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; /** * Tests for {@link RegularizedGamma}. */ class RegularizedGammaTest { - @Test - void testRegularizedGammaNanPositive() { - testRegularizedGamma(Double.NaN, Double.NaN, 1.0); + /** + * Test argument X cannot be NaN, negative or zero. + * + * @param a Argument a + */ + @ParameterizedTest + @ValueSource(doubles = {Double.NaN, 0, -1}) + void testInvalidArgumentA(double a) { + Assertions.assertThrows(IllegalArgumentException.class, () -> ArgumentA.of(a)); + // No exception thrown. The result is NaN. + testRegularizedGamma(Double.NaN, Double.NaN, a, 1.0); } - @Test - void testRegularizedGammaPositiveNan() { - testRegularizedGamma(Double.NaN, 1.0, Double.NaN); - } - - @Test - void testRegularizedGammaNegativePositive() { - testRegularizedGamma(Double.NaN, -1.5, 1.0); - } - - @Test - void testRegularizedGammaPositiveNegative() { - testRegularizedGamma(Double.NaN, 1.0, -1.0); - } - - @Test - void testRegularizedGammaZeroPositive() { - testRegularizedGamma(Double.NaN, 0.0, 1.0); - } - - @Test - void testRegularizedGammaPositiveZero() { - testRegularizedGamma(0.0, 1.0, 0.0); - } - - @Test - void testRegularizedGammaPositivePositive() { - testRegularizedGamma(0.632120558828558, 1.0, 1.0); + /** + * Test argument X cannot be NaN or negative. + * + * @param x Argument x + */ + @ParameterizedTest + @ValueSource(doubles = {Double.NaN, -1}) + void testInvalidArgumentX(double x) { + Assertions.assertThrows(IllegalArgumentException.class, () -> ArgumentX.of(x)); + // No exception thrown. The result is NaN. + testRegularizedGamma(Double.NaN, Double.NaN, 1.0, x); } @Test @@ -65,7 +61,7 @@ void testRegularizedGammaPWithACloseToZero() { final double a = 1e-18; // x must force use of the series in regularized gamma P using x < a + 1 final double x = 0.5; - testRegularizedGamma(1.0, a, x); + testRegularizedGamma(1.0, 0.0, a, x); } @Test @@ -74,25 +70,120 @@ void testRegularizedGammaPWithAVeryCloseToZero() { final double a = Double.MIN_VALUE; // x must force use of the series in regularized gamma P using x < a + 1 final double x = 0.5; - testRegularizedGamma(1.0, a, x); + testRegularizedGamma(1.0, 0.0, a, x); } - private void testRegularizedGamma(double expected, double a, double x) { + /** + * Test the regularized gamma P and Q functions. + * + *

Note that the identity P + Q = 1 is tested. It should not be used to generate + * data for P from Q (or vice versa) when the values approach 0 or 1 due to floating + * point error. + * + *

Tests the methods with double arguments and then the methods with pre-computed + * arguments which must be an exact match. + * + * @param expectedP Expected P(a, x) + * @param expectedQ Expected Q(a, x) + * @param a Argument a + * @param x Argument x + */ + @ParameterizedTest + @CsvSource(value = { + "0.0, 1.0, 1.0, 0.0", + // Values computed using Wolfram Mathematica + "0.63212055882855767840, 0.36787944117144232160, 1.0, 1.0", + "0.080301397071394196011, 0.91969860292860580399, 3, 1", + "0.877050191685244, 0.1229498083147559, 0.52, 1.23", + "0.01101451006216559, 0.988985489937834, 46.34, 32.18", + "1.0, 1.0922956375456871032e-43, 10, 130", + "7.6002090267819442301e-95, 1.0, 130, 10", + }) + void testRegularizedGamma(double expectedP, double expectedQ, double a, double x) { double actualP = RegularizedGamma.P.value(a, x); double actualQ = RegularizedGamma.Q.value(a, x); - Assertions.assertEquals(expected, actualP, 1e-15); - Assertions.assertEquals(actualP, 1 - actualQ, 1e-15); + Assertions.assertEquals(expectedP, actualP, 1e-15, "p"); + Assertions.assertEquals(expectedQ, actualQ, 1e-15, "q"); + + // Note: If the expected values are NaN then assume this is due to invalid parameters. + if (Double.isNaN(expectedP)) { + // Try to construct the arguments. + // If one is valid then the function should compute the same result (NaN) + try { + final ArgumentA argA = ArgumentA.of(a); + Assertions.assertEquals(actualP, RegularizedGamma.P.value(argA, x)); + Assertions.assertEquals(actualQ, RegularizedGamma.Q.value(argA, x)); + } catch (IllegalArgumentException ex) { + // argument a is invalid + } + try { + final ArgumentX argX = ArgumentX.of(x); + Assertions.assertEquals(actualP, RegularizedGamma.P.value(a, argX)); + Assertions.assertEquals(actualQ, RegularizedGamma.Q.value(a, argX)); + } catch (IllegalArgumentException ex) { + // argument x is invalid + } + return; + } + + // Test the identity P + Q = 1 + Assertions.assertEquals(1.0, actualP + actualQ, 1e-15, "p+q"); + + // Verify the versions with pre-computed arguments. + // The results must be binary equal so do not use a tolerance. + + final ArgumentA argA = ArgumentA.of(a); + Assertions.assertEquals(actualP, RegularizedGamma.P.value(argA, x)); + Assertions.assertEquals(actualQ, RegularizedGamma.Q.value(argA, x)); + + final ArgumentX argX = ArgumentX.of(x); + Assertions.assertEquals(actualP, RegularizedGamma.P.value(a, argX)); + Assertions.assertEquals(actualQ, RegularizedGamma.Q.value(a, argX)); } @Test - void testRegularizedGammaMaxIterationsExceededThrows() { - final double a = 1.0; - final double x = 1.0; + void testRegularizedGammaPMaxIterationsExceededThrows() { + // x < a + 1 + final double a = 13.0; + final double x = 10.0; // OK without - Assertions.assertEquals(0.632120558828558, RegularizedGamma.P.value(a, x), 1e-15); + final double actual = RegularizedGamma.P.value(a, x); + // mathematica: N[GammaRegularized[13, 0, 10], 20] + Assertions.assertEquals(0.20844352360512566106, actual, 1e-15); + final ArgumentA argA = ArgumentA.of(a); + final ArgumentX argX = ArgumentX.of(x); + Assertions.assertEquals(actual, RegularizedGamma.P.value(argA, x)); + Assertions.assertEquals(actual, RegularizedGamma.P.value(a, argX)); final int maxIterations = 3; Assertions.assertThrows(ArithmeticException.class, () -> RegularizedGamma.P.value(a, x, 1e-15, maxIterations)); + Assertions.assertThrows(ArithmeticException.class, () -> + RegularizedGamma.P.value(argA, x, 1e-15, maxIterations)); + Assertions.assertThrows(ArithmeticException.class, () -> + RegularizedGamma.P.value(a, argX, 1e-15, maxIterations)); + } + + @Test + void testRegularizedGammaQMaxIterationsExceededThrows() { + // x >= a + 1 + final double a = 10.0; + final double x = 13.0; + // OK without + final double actual = RegularizedGamma.Q.value(a, x); + // mathematica: N[GammaRegularized[10, 13], 20] + Assertions.assertEquals(0.16581187661729210469, actual, 1e-15); + final ArgumentA argA = ArgumentA.of(a); + final ArgumentX argX = ArgumentX.of(x); + Assertions.assertEquals(actual, RegularizedGamma.Q.value(argA, x)); + Assertions.assertEquals(actual, RegularizedGamma.Q.value(a, argX)); + + final int maxIterations = 3; + Assertions.assertThrows(ArithmeticException.class, () -> + RegularizedGamma.Q.value(a, x, 1e-15, maxIterations)); + Assertions.assertThrows(ArithmeticException.class, () -> + RegularizedGamma.Q.value(argA, x, 1e-15, maxIterations)); + Assertions.assertThrows(ArithmeticException.class, () -> + RegularizedGamma.Q.value(a, argX, 1e-15, maxIterations)); } }