Skip to content

Commit

Permalink
Use a single instance of the standard normal distribution
Browse files Browse the repository at this point in the history
The distribution is immutable. The underlying computations use static
methods and can be used across threads.

Note: A micro-optimisation to eliminate the mean=0 and SD=1 from the
computations used from NormalDistribution(0, 1) has not been performed
to aid maintenance. It is presumed that multiplication by 1 or addition
of 0 is insignificant compared to the calls to:

- Math.exp (density)
- Erf.value (cumulativeProbability)
- InverseErf.value (inverseCumulativeProbability)

Only the logDensity function would be expected to improve performance as
it would eliminate 1 subtraction and 1 devision to leave the computation
as 2 multiplications and a subtraction:

return -0.5 * x * x - halfLog2Pi
  • Loading branch information
aherbert committed Sep 14, 2021
1 parent da71b36 commit 81cbacd
Showing 1 changed file with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
* Truncated normal distribution (Wikipedia)</a>
*/
public class TruncatedNormalDistribution extends AbstractContinuousDistribution {
/** A standard normal distribution used for calculations.
* This is immutable and thread-safe and can be used across instances. */
private static final NormalDistribution STANDARD_NORMAL = new NormalDistribution(0, 1);

/** Mean of parent normal distribution. */
private final double parentMean;
/** Standard deviation of parent normal distribution. */
Expand All @@ -37,8 +41,6 @@ public class TruncatedNormalDistribution extends AbstractContinuousDistribution
/** Upper bound of this distribution. */
private final double upper;

/** A standard normal distribution used for calculations. */
private final NormalDistribution standardNormal;
/** Stored value of @{code standardNormal.cumulativeProbability((lower - mean) / sd)} for faster computations. */
private final double cdfAlpha;
/**
Expand Down Expand Up @@ -75,21 +77,20 @@ public TruncatedNormalDistribution(double mean, double sd, double lower, double

parentMean = mean;
parentSd = sd;
standardNormal = new NormalDistribution(0, 1);

final double alpha = (lower - mean) / sd;
final double beta = (upper - mean) / sd;

final double cdfBeta = standardNormal.cumulativeProbability(beta);
cdfAlpha = standardNormal.cumulativeProbability(alpha);
final double cdfBeta = STANDARD_NORMAL.cumulativeProbability(beta);
cdfAlpha = STANDARD_NORMAL.cumulativeProbability(alpha);
cdfDelta = cdfBeta - cdfAlpha;

parentSdByCdfDelta = parentSd * cdfDelta;
logParentSdByCdfDelta = Math.log(parentSdByCdfDelta);

// Calculation of variance and mean.
final double pdfAlpha = standardNormal.density(alpha);
final double pdfBeta = standardNormal.density(beta);
final double pdfAlpha = STANDARD_NORMAL.density(alpha);
final double pdfBeta = STANDARD_NORMAL.density(beta);
final double pdfCdfDelta = (pdfAlpha - pdfBeta) / cdfDelta;
final double alphaBetaDelta = (alpha * pdfAlpha - beta * pdfBeta) / cdfDelta;

Expand Down Expand Up @@ -124,7 +125,7 @@ public double density(double x) {
if (x < lower || x > upper) {
return 0;
}
return standardNormal.density((x - parentMean) / parentSd) / parentSdByCdfDelta;
return STANDARD_NORMAL.density((x - parentMean) / parentSd) / parentSdByCdfDelta;
}

/** {@inheritDoc} */
Expand All @@ -133,7 +134,7 @@ public double logDensity(double x) {
if (x < lower || x > upper) {
return Double.NEGATIVE_INFINITY;
}
return standardNormal.logDensity((x - parentMean) / parentSd) - logParentSdByCdfDelta;
return STANDARD_NORMAL.logDensity((x - parentMean) / parentSd) - logParentSdByCdfDelta;
}

/** {@inheritDoc} */
Expand All @@ -144,7 +145,7 @@ public double cumulativeProbability(double x) {
} else if (x >= upper) {
return 1;
}
return (standardNormal.cumulativeProbability((x - parentMean) / parentSd) - cdfAlpha) / cdfDelta;
return (STANDARD_NORMAL.cumulativeProbability((x - parentMean) / parentSd) - cdfAlpha) / cdfDelta;
}

/** {@inheritDoc} */
Expand All @@ -159,7 +160,7 @@ public double inverseCumulativeProbability(double p) {
} else if (p == 1) {
return upper;
}
final double x = standardNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta) * parentSd + parentMean;
final double x = STANDARD_NORMAL.inverseCumulativeProbability(cdfAlpha + p * cdfDelta) * parentSd + parentMean;
// Clip to support to handle floating-point error at the support bound
if (x <= lower) {
return lower;
Expand Down

0 comments on commit 81cbacd

Please sign in to comment.