Skip to content

Commit

Permalink
STATISTICS-52: Add a high precision PDF to the normal distribution
Browse files Browse the repository at this point in the history
Exploit information in the round-off from x*x to increase the precision
of exp(-0.5*x*x) when x is large.

Add a benchmark to demonstrate this has minor impact on the runtime
performance. Accuracy is increased to within 3 ULP (down from hundreds)
for large x values.
  • Loading branch information
aherbert committed Jan 21, 2022
1 parent 34f5373 commit f931dd5
Show file tree
Hide file tree
Showing 8 changed files with 2,454 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ final class ExtendedPrecision {
private static final double SQRT2PI_L;
/** Round-off from sqrt(2 pi) as a double. */
private static final double SQRT2PI_R;
/** X-value where {@code exp(-0.5*x*x)} cannot increase accuracy using the round-off
* from x squared. */
private static final int EXP_M_HALF_XX_MIN_VALUE = 2;
/** Approximate x-value where {@code exp(-0.5*x*x) == 0}. This is above
* {@code -2 * ln(2^-1074)} due to rounding performed within the exp function. */
private static final int EXP_M_HALF_XX_MAX_VALUE = 1491;

static {
// Initialise constants
Expand Down Expand Up @@ -179,6 +185,72 @@ private static double computeSqrt2aa(double a) {
return c + cc;
}

/**
* Compute {@code exp(-0.5*x*x)} with high accuracy. This is performed using information in the
* round-off from {@code x*x}.
*
* <p>This is accurate at large x to 1 ulp until exp(-0.5*x*x) is close to sub-normal. For very
* small exp(-0.5*x*x) the adjustment is sub-normal and bits can be lost in the adjustment for a
* max observed error of {@code < 2} ulp.
*
* <p>At small x the accuracy cannot be improved over using exp(-0.5*x*x). This occurs at
* {@code x <= sqrt(2)}.
*
* @param x Value
* @return exp(-0.5*x*x)
* @see <a href="https://issues.apache.org/jira/browse/STATISTICS-52">STATISTICS-52</a>
*/
static double expmhxx(double x) {
final double z = x * x;
if (z <= EXP_M_HALF_XX_MIN_VALUE) {
return Math.exp(-0.5 * z);
} else if (z >= EXP_M_HALF_XX_MAX_VALUE) {
// exp(-745.5) == 0
return 0;
}
// Split the number
final double hx = highPartUnscaled(x);
final double lx = x - hx;
// Compute the round-off
final double zz = squareLow(hx, lx, z);
return expxx(-0.5 * z, -0.5 * zz);
}

/**
* Compute {@code exp(a+b)} with high accuracy assuming {@code a+b = a}.
*
* <p>This is accurate at large positive a to 1 ulp. If a is negative and exp(a) is close to
* sub-normal a bit of precision may be lost when adjusting result as the adjustment is sub-normal
* (max observed error {@code < 2} ulp). For the use case of multiplication of a number less than
* 1 by exp(-x*x), a = -x*x, the result will be sub-normal and the rounding error is lost.
*
* <p>At small |a| the accuracy cannot be improved over using exp(a) as the round-off is too small
* to create terms that can adjust the standard result by more than 0.5 ulp. This occurs at
* {@code |a| <= 1}.
*
* @param a High bits of a split number
* @param b Low bits of a split number
* @return exp(a+b)
* @see <a href="https://issues.apache.org/jira/projects/NUMBERS/issues/NUMBERS-177">
* Numbers-177: Accurate scaling by exp(z*z)</a>
*/
private static double expxx(double a, double b) {
// exp(a+b) = exp(a) * exp(b)
// = exp(a) * (exp(b) - 1) + exp(a)
// Assuming:
// 1. -746 < a < 710 for no under/overflow of exp(a)
// 2. a+b = a
// As b -> 0 then exp(b) -> 1; expm1(b) -> b
// The round-off b is limited to ~ 0.5 * ulp(746) ~ 5.68e-14
// and we can use an approximation for expm1 (x/1! + x^2/2! + ...)
// The second term is required for the expm1 result but the
// bits are not significant to change the product with exp(a)

final double ea = Math.exp(a);
// b ~ expm1(b)
return ea * b + ea;
}

/**
* Implement Dekker's method to split a value into two parts. Multiplying by (2^s + 1) creates
* a big value from which to derive the two split parts.
Expand Down Expand Up @@ -238,4 +310,22 @@ private static double productLow(double hx, double lx, double hy, double ly, dou
// low = lx * ly - err3
return lx * ly - (((xy - hx * hy) - lx * hy) - hx * ly);
}

/**
* Compute the low part of the double length number {@code (z,zz)} for the exact
* square of {@code x} using Dekker's mult12 algorithm. The standard precision product
* {@code x*x} must be provided. The number {@code x} should already be split into low
* and high parts.
*
* <p>Note: This is a specialisation of
* {@link #productLow(double, double, double, double, double)}.
*
* @param hx High part of factor.
* @param lx Low part of factor.
* @param xx Square of the factor.
* @return <code>lx * lx - (((xx - hx * hx) - lx * hx) - hx * lx)</code>
*/
private static double squareLow(double hx, double lx, double xx) {
return lx * lx - ((xx - hx * hx) - 2 * lx * hx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ public double getStandardDeviation() {
/** {@inheritDoc} */
@Override
public double density(double x) {
final double x0 = x - mean;
final double x1 = x0 / standardDeviation;
return Math.exp(-0.5 * x1 * x1) / sdSqrt2pi;
final double z = (x - mean) / standardDeviation;
return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
}

/** {@inheritDoc} */
Expand All @@ -131,9 +130,8 @@ public double probability(double x0,
/** {@inheritDoc} */
@Override
public double logDensity(double x) {
final double x0 = x - mean;
final double x1 = x0 / standardDeviation;
return -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi;
final double z = (x - mean) / standardDeviation;
return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class ExtendedPrecisionTest {
private static final RMS SQRT2XX_RMS2 = new RMS();
/** The sum of the squared ULP error for the first computation for x * sqrt(2 pi). */
private static final RMS XSQRT2PI_RMS = new RMS();
/** The sum of the squared ULP error for the first computation for exp(-0.5*x*x). */
private static final RMS EXPMHXX_RMS1 = new RMS();
/** The sum of the squared ULP error for the second computation for exp(-0.5*x*x). */
private static final RMS EXPMHXX_RMS2 = new RMS();

/**
* Class to compute the root mean squared error (RMS).
Expand Down Expand Up @@ -186,6 +190,46 @@ void testXsqrt2piPrecision() {
assertPrecision(XSQRT2PI_RMS, 1.2, 0.6);
}

@ParameterizedTest
@ValueSource(doubles = {0, 0.5, 1, 2, 3, 4, 5, 38.5, Double.MAX_VALUE, Double.POSITIVE_INFINITY, Double.NaN})
void testExpmhxxEdgeCases(double x) {
final double expected = Math.exp(-0.5 * x * x);
Assertions.assertEquals(expected, ExtendedPrecision.expmhxx(x));
Assertions.assertEquals(expected, ExtendedPrecision.expmhxx(-x));
}

/**
* Test the extended precision {@code exp(-0.5 * x * x)}. The expected result
* is an extended precision computation. For comparison ulp errors are collected for
* the standard precision computation.
*
* @param x Value x
* @param expected Expected result of {@code exp(-0.5 * x * x)}.
*/
@ParameterizedTest
@Order(1)
@CsvFileSource(resources = "expmhxx.csv")
void testExpmhxx(double x, BigDecimal expected) {
final double e = expected.doubleValue();
final double actual = ExtendedPrecision.expmhxx(x);
Assertions.assertEquals(e, actual, Math.ulp(e) * 2);
// Compute errors
addError(actual, expected, e, EXPMHXX_RMS1);
addError(Math.exp(-0.5 * x * x), expected, e, EXPMHXX_RMS2);
}

@Test
void testExpmhxxHighPrecision() {
// Typical result: max 0.9727 rms 0.3481
assertPrecision(EXPMHXX_RMS1, 1.5, 0.5);
}

@Test
void testExpmhxxStandardPrecision() {
// Typical result: max 385.7193 rms 50.7769
assertPrecision(EXPMHXX_RMS2, 400, 60);
}

private static void assertPrecision(RMS rms, double maxError, double rmsError) {
Assertions.assertTrue(rms.getMax() < maxError, () -> "max error: " + rms.getMax());
Assertions.assertTrue(rms.getRMS() < rmsError, () -> "rms error: " + rms.getRMS());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,41 @@ void testInverseCDF() {
}
}

/**
* Test the PDF using high-accuracy uniform x data.
*
* <p>This dataset uses uniformly spaced machine representable x values that have no
* round-off component when squared. If the density is implemented using
* {@code exp(logDensity(x))} the test will fail. Using the log density requires a
* tolerance of approximately 53 ULP to pass the test of larger x values.
*/
@ParameterizedTest
@CsvFileSource(resources = "normpdf.csv")
void testPDF(double x, BigDecimal expected) {
assertPDF(x, expected, 2);
}

/**
* Test the PDF using high-accuracy random x data.
*
* <p>This dataset uses random x values with full usage of the 52-bit mantissa to ensure
* that there is a round-off component when squared. It requires a high precision exponential
* function using the round-off to compute {@code exp(-0.5*x*x)} accurately.
* Using a standard precision computation requires a tolerance of approximately 383 ULP
* to pass the test of larger x values.
*
* <p>See STATISTICS-52.
*/
@ParameterizedTest
@CsvFileSource(resources = "normpdf2.csv")
void testPDF2(double x, BigDecimal expected) {
assertPDF(x, expected, 3);
}

private static void assertPDF(double x, BigDecimal expected, int ulpTolerance) {
final double e = expected.doubleValue();
final double a = STANDARD_NORMAL.density(x);
// Require high precision. Currently this does not work at 1 ULP.
Assertions.assertEquals(e, a, Math.ulp(e) * 2,
Assertions.assertEquals(e, a, Math.ulp(e) * ulpTolerance,
() -> "ULP error: " + expected.subtract(new BigDecimal(a)).doubleValue() / Math.ulp(e));
}
}

0 comments on commit f931dd5

Please sign in to comment.