From 81cbacd3f43fb3ec2c188143fb641c36e22a3853 Mon Sep 17 00:00:00 2001 From: aherbert Date: Tue, 14 Sep 2021 13:07:36 +0100 Subject: [PATCH] Use a single instance of the standard normal distribution The distribution is immutable. The underlying computations use static methods and can be used across threads. Note: A micro-optimisation to eliminate the mean=0 and SD=1 from the computations used from NormalDistribution(0, 1) has not been performed to aid maintenance. It is presumed that multiplication by 1 or addition of 0 is insignificant compared to the calls to: - Math.exp (density) - Erf.value (cumulativeProbability) - InverseErf.value (inverseCumulativeProbability) Only the logDensity function would be expected to improve performance as it would eliminate 1 subtraction and 1 devision to leave the computation as 2 multiplications and a subtraction: return -0.5 * x * x - halfLog2Pi --- .../TruncatedNormalDistribution.java | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java index 4f17d8b20..08751eaf2 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java @@ -24,6 +24,10 @@ * Truncated normal distribution (Wikipedia) */ public class TruncatedNormalDistribution extends AbstractContinuousDistribution { + /** A standard normal distribution used for calculations. + * This is immutable and thread-safe and can be used across instances. */ + private static final NormalDistribution STANDARD_NORMAL = new NormalDistribution(0, 1); + /** Mean of parent normal distribution. */ private final double parentMean; /** Standard deviation of parent normal distribution. */ @@ -37,8 +41,6 @@ public class TruncatedNormalDistribution extends AbstractContinuousDistribution /** Upper bound of this distribution. */ private final double upper; - /** A standard normal distribution used for calculations. */ - private final NormalDistribution standardNormal; /** Stored value of @{code standardNormal.cumulativeProbability((lower - mean) / sd)} for faster computations. */ private final double cdfAlpha; /** @@ -75,21 +77,20 @@ public TruncatedNormalDistribution(double mean, double sd, double lower, double parentMean = mean; parentSd = sd; - standardNormal = new NormalDistribution(0, 1); final double alpha = (lower - mean) / sd; final double beta = (upper - mean) / sd; - final double cdfBeta = standardNormal.cumulativeProbability(beta); - cdfAlpha = standardNormal.cumulativeProbability(alpha); + final double cdfBeta = STANDARD_NORMAL.cumulativeProbability(beta); + cdfAlpha = STANDARD_NORMAL.cumulativeProbability(alpha); cdfDelta = cdfBeta - cdfAlpha; parentSdByCdfDelta = parentSd * cdfDelta; logParentSdByCdfDelta = Math.log(parentSdByCdfDelta); // Calculation of variance and mean. - final double pdfAlpha = standardNormal.density(alpha); - final double pdfBeta = standardNormal.density(beta); + final double pdfAlpha = STANDARD_NORMAL.density(alpha); + final double pdfBeta = STANDARD_NORMAL.density(beta); final double pdfCdfDelta = (pdfAlpha - pdfBeta) / cdfDelta; final double alphaBetaDelta = (alpha * pdfAlpha - beta * pdfBeta) / cdfDelta; @@ -124,7 +125,7 @@ public double density(double x) { if (x < lower || x > upper) { return 0; } - return standardNormal.density((x - parentMean) / parentSd) / parentSdByCdfDelta; + return STANDARD_NORMAL.density((x - parentMean) / parentSd) / parentSdByCdfDelta; } /** {@inheritDoc} */ @@ -133,7 +134,7 @@ public double logDensity(double x) { if (x < lower || x > upper) { return Double.NEGATIVE_INFINITY; } - return standardNormal.logDensity((x - parentMean) / parentSd) - logParentSdByCdfDelta; + return STANDARD_NORMAL.logDensity((x - parentMean) / parentSd) - logParentSdByCdfDelta; } /** {@inheritDoc} */ @@ -144,7 +145,7 @@ public double cumulativeProbability(double x) { } else if (x >= upper) { return 1; } - return (standardNormal.cumulativeProbability((x - parentMean) / parentSd) - cdfAlpha) / cdfDelta; + return (STANDARD_NORMAL.cumulativeProbability((x - parentMean) / parentSd) - cdfAlpha) / cdfDelta; } /** {@inheritDoc} */ @@ -159,7 +160,7 @@ public double inverseCumulativeProbability(double p) { } else if (p == 1) { return upper; } - final double x = standardNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta) * parentSd + parentMean; + final double x = STANDARD_NORMAL.inverseCumulativeProbability(cdfAlpha + p * cdfDelta) * parentSd + parentMean; // Clip to support to handle floating-point error at the support bound if (x <= lower) { return lower;