Skip to content
Permalink
Browse files
Allow ziggurat sampling from only the overhangs in the performance test
Add additional ternary variant for testing.
  • Loading branch information
aherbert committed Sep 16, 2021
1 parent 9a25d0b commit 9f492d05b47c51427228ffce1e48ce5c2099accb
Showing 2 changed files with 193 additions and 13 deletions.
@@ -115,6 +115,8 @@ public class ZigguratSamplerPerformance {
static final String MOD_EXPONENTIAL_E_MAX_2 = "ModExponentialEmax2";
/** The name for the {@link ModifiedZigguratExponentialSamplerTernary}. */
static final String MOD_EXPONENTIAL_TERNARY = "ModExponentialTernary";
/** The name for the {@link ModifiedZigguratExponentialSamplerTernarySubtract}. */
static final String MOD_EXPONENTIAL_TERNARY_SUBTRACT = "ModExponentialTernarySubtract";
/** The name for the {@link ModifiedZigguratExponentialSampler512} using a table size of 512. */
static final String MOD_EXPONENTIAL_512 = "ModExponential512";

@@ -580,8 +582,57 @@ public abstract static class Sources {
MOD_EXPONENTIAL_LOOP, MOD_EXPONENTIAL_LOOP2,
MOD_EXPONENTIAL_RECURSION, MOD_EXPONENTIAL_INT_MAP,
MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2,
MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_512})
protected String type;
MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_TERNARY_SUBTRACT, MOD_EXPONENTIAL_512})
private String type;


/** Flag to indicate that the sample targets the overhangs.
* This is applicable to the McFarland Ziggurat sampler and
* requires manipulation of the final bits of the RNG. */
@Param({"true", "false"})
private boolean overhang;

/**
* Creates the sampler.
*
* <p>This may return a specialisation for the McFarland sampler that only samples
* from overhangs.
*
* @param rng RNG
* @return the sampler
*/
protected ContinuousSampler createSampler(UniformRandomProvider rng) {
if (!overhang) {
return createSampler(type, rng);
}
// For the Marsaglia Ziggurat sampler overhangs are only tested once then
// the method recurses the entire sample method. Overhang sampling cannot be forced
// for this sampler.
if (GAUSSIAN_128.equals(type) ||
GAUSSIAN_256.equals(type) ||
EXPONENTIAL.equals(type)) {
return createSampler(type, rng);
}
// Assume the sampler is a McFarland Ziggurat sampler.
// Manipulate the final bits of the long from the RNG to force sampling
// from the overhang. Assume most of the samplers use an 8-bit look-up table.
int numberOfBits = 8;
if (type.contains("512")) {
// 9-bit look-up table
numberOfBits = 9;
}
// Use an RNG that can set the lower bits of the long.
final ModifiedRNG modRNG = new ModifiedRNG(rng, numberOfBits);
final ContinuousSampler sampler = createSampler(type, modRNG);
// Create a sampler where each call should force overhangs/tail sampling
return new ContinuousSampler() {
@Override
public double sample() {
modRNG.modifyNextLong();
return sampler.sample();
}
};
}

/**
* Creates the sampler.
@@ -641,12 +692,85 @@ static ContinuousSampler createSampler(String type, UniformRandomProvider rng) {
return new ModifiedZigguratExponentialSamplerEMax2(rng);
} else if (MOD_EXPONENTIAL_TERNARY.equals(type)) {
return new ModifiedZigguratExponentialSamplerTernary(rng);
} else if (MOD_EXPONENTIAL_TERNARY_SUBTRACT.equals(type)) {
return new ModifiedZigguratExponentialSamplerTernarySubtract(rng);
} else if (MOD_EXPONENTIAL_512.equals(type)) {
return new ModifiedZigguratExponentialSampler512(rng);
} else {
throw new IllegalStateException("Unknown type: " + type);
}
}

/**
* A class that can modify the lower bits to be all set for the next invocation of
* {@link UniformRandomProvider#nextLong()}.
*/
private static class ModifiedRNG implements UniformRandomProvider {
/** Underlying source of randomness. */
private final UniformRandomProvider rng;
/** The bits to set in the output long using a bitwise or ('|'). */
private final long bits;
/** The next bits to set in the output long using a bitwise or ('|'). */
private long nextBits;

/**
* @param rng Underlying source of randomness
* @param numberOfBits Number of least significant bits to set for a call to nextLong()
*/
ModifiedRNG(UniformRandomProvider rng, int numberOfBits) {
this.rng = rng;
bits = (1L << numberOfBits) - 1;
}

/**
* Set the state to modify the lower bits on the next call to nextLong().
*/
void modifyNextLong() {
nextBits = bits;
}

@Override
public long nextLong() {
final long x = rng.nextLong() | nextBits;
nextBits = 0;
return x;
}

// The following methods should not be used.

@Override
public void nextBytes(byte[] bytes) {
throw new IllegalStateException();
}
@Override
public void nextBytes(byte[] bytes, int start, int len) {
throw new IllegalStateException();
}
@Override
public int nextInt() {
throw new IllegalStateException();
}
@Override
public int nextInt(int n) {
throw new IllegalStateException();
}
@Override
public long nextLong(long n) {
throw new IllegalStateException();
}
@Override
public boolean nextBoolean() {
throw new IllegalStateException();
}
@Override
public float nextFloat() {
throw new IllegalStateException();
}
@Override
public double nextDouble() {
throw new IllegalStateException();
}
}
}

/**
@@ -683,7 +807,7 @@ public ContinuousSampler getSampler() {
public void setup() {
final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
final UniformRandomProvider rng = randomSource.create();
sampler = createSampler(type, rng);
sampler = createSampler(rng);
}
}

@@ -736,7 +860,7 @@ public ContinuousSampler getSampler() {
public void setup() {
final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
final UniformRandomProvider rng = randomSource.create();
final ContinuousSampler s = createSampler(type, rng);
final ContinuousSampler s = createSampler(rng);
sampler = createSequentialSampler(size, s);
}

@@ -1194,9 +1318,9 @@ static class ModifiedZigguratNormalizedGaussianSampler implements ContinuousSamp
// Ziggurat volumes:
// Inside the layers = 98.8281% (253/256)
// Fraction outside the layers:
// concave overhangs = 76.1941%
// convex overhangs = 76.1941%
// inflection overhang = 0.1358%
// convex overhangs = 21.3072%
// concave overhangs = 21.3072%
// tail = 2.3629%

/** The number of layers in the ziggurat. Maximum i value for early exit. */
@@ -2267,9 +2391,9 @@ public double sample() {
// Ziggurat volumes:
// Inside the layers = 98.8281% (253/256)
// Fraction outside the layers:
// concave overhangs = 76.1941%
// convex overhangs = 76.1941%
// inflection overhang = 0.1358%
// convex overhangs = 21.3072%
// concave overhangs = 21.3072%
// tail = 2.3629%

// Separation of convex overhangs:
@@ -2479,9 +2603,11 @@ public double sample() {
// Concave overhang
for (;;) {
// If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
// Create a second uniform deviate (as u1 is recycled).
final long ua = u1;
final long ub = randomInt63();
// Sort u1 < u2 to sample the lower-left triangle
// Sort u1 < u2 to sample the lower-left triangle.
// Use conditional ternary to avoid a 50/50 branch statement to swap the pair.
u1 = ua < ub ? ua : ub;
final long u2 = ua < ub ? ub : ua;
x = sampleX(X, j, u1);
@@ -2518,9 +2644,9 @@ static class ModifiedZigguratNormalizedGaussianSampler512 implements ContinuousS
// Ziggurat volumes:
// Inside the layers = 99.4141% (509/512)
// Fraction outside the layers:
// concave overhangs = 75.5775%
// convex overhangs = 75.5775%
// inflection overhang = 0.0675%
// convex overhangs = 22.2196%
// concave overhangs = 22.2196%
// tail = 2.1354%

/** The number of layers in the ziggurat. Maximum i value for early exit. */
@@ -4121,7 +4247,7 @@ private double sampleOverhang(int j, long eMax) {
* <p>Uses the algorithm from McFarland, C.D. (2016).
*
* <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
* a ternary operator to sort the two random long values.
* two ternary operators to sort the two random long values.
*/
static class ModifiedZigguratExponentialSamplerTernary
extends ModifiedZigguratExponentialSampler {
@@ -4149,7 +4275,8 @@ protected double sampleOverhang(int j) {
// If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
final long ua = randomInt63();
final long ub = randomInt63();
// Sort u1 < u2 to sample the lower-left triangle
// Sort u1 < u2 to sample the lower-left triangle.
// Use conditional ternary to avoid a 50/50 branch statement to swap the pair.
final long u1 = ua < ub ? ua : ub;
final long u2 = ua < ub ? ub : ua;
final double x = sampleX(X, j, u1);
@@ -4162,6 +4289,58 @@ protected double sampleOverhang(int j) {
}
}

/**
* Modified Ziggurat method for sampling from an exponential distribution.
*
* <p>Uses the algorithm from McFarland, C.D. (2016).
*
* <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
* a ternary operator to sort the two random long values and a subtraction
* to get the difference.
*/
static class ModifiedZigguratExponentialSamplerTernarySubtract
extends ModifiedZigguratExponentialSampler {

/**
* @param rng Generator of uniformly distributed random numbers.
*/
ModifiedZigguratExponentialSamplerTernarySubtract(UniformRandomProvider rng) {
super(rng);
}

@Override
protected double sampleOverhang(int j) {
// Sample from the triangle:
// X[j],Y[j]
// |\-->u1
// | \ |
// | \ |
// | \| Overhang j (with hypotenuse not pdf(x))
// | \
// | |\
// | | \
// | u2 \
// +-------- X[j-1],Y[j-1]
// If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
final long ua = randomInt63();
final long ub = randomInt63();
// Sort u1 < u2 to sample the lower-left triangle.
// Use conditional ternary to avoid a 50/50 branch statement to swap the pair.
final long u1 = ua < ub ? ua : ub;
final double x = sampleX(X, j, u1);
// u2 = ua + ub - u1
// uDistance = ua + ub - u1 - u1
final long uDistance = ua + ub - (u1 << 1);
if (uDistance >= E_MAX) {
// Early Exit: x < y - epsilon
return x;
}

// u2 = u1 + uDistance
return sampleY(Y, j, u1 + uDistance) <= Math.exp(-x) ? x : sampleOverhang(j);
}
}

/**
* Modified Ziggurat method for sampling from an exponential distribution.
*
@@ -117,6 +117,7 @@ private static Stream<Arguments> exponentialSamplers() {
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY_SUBTRACT),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
}

0 comments on commit 9f492d0

Please sign in to comment.