Skip to content
Permalink
Browse files
Replace benchmarking interfaces with java.util.function
  • Loading branch information
aherbert committed Aug 13, 2021
1 parent d7be418 commit da20fc81d265df172d595e70d5debf1f9d694a17
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 159 deletions.
@@ -18,6 +18,7 @@
package org.apache.commons.rng.examples.jmh.sampling;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ObjectSampler;
import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.rng.simple.RandomSource;
@@ -55,18 +56,6 @@ public class UnitSphereSamplerBenchmark {
/** Error message for an unknown sampler type. */
private static final String UNKNOWN_SAMPLER = "Unknown sampler type: ";

/**
* The sampler.
*/
private interface Sampler {
/**
* Gets the next sample.
*
* @return the sample
*/
double[] sample();
}

/**
* Base class for the sampler data.
* Contains the source of randomness and the number of samples.
@@ -75,7 +64,7 @@ private interface Sampler {
@State(Scope.Benchmark)
public abstract static class SamplerData {
/** The sampler. */
private Sampler sampler;
private ObjectSampler<double[]> sampler;

/** The number of samples. */
@Param({"100"})
@@ -95,7 +84,7 @@ public int getSize() {
*
* @return the sampler
*/
public Sampler getSampler() {
public ObjectSampler<double[]> getSampler() {
return sampler;
}

@@ -115,7 +104,7 @@ public void setup() {
* @param rng the source of randomness
* @return the sampler
*/
protected abstract Sampler createSampler(UniformRandomProvider rng);
protected abstract ObjectSampler<double[]> createSampler(UniformRandomProvider rng);
}

/**
@@ -142,7 +131,7 @@ public static class Sampler1D extends SamplerData {

/** {@inheritDoc} */
@Override
protected Sampler createSampler(final UniformRandomProvider rng) {
protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
if (BASELINE.equals(type)) {
return () -> {
return new double[] {1.0};
@@ -185,11 +174,9 @@ public static class Sampler2D extends SamplerData {

/** {@inheritDoc} */
@Override
protected Sampler createSampler(final UniformRandomProvider rng) {
protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
if (BASELINE.equals(type)) {
return () -> {
return new double[] {1.0, 0.0};
};
return () -> new double[] {1.0, 0.0};
} else if (ARRAY.equals(type)) {
return new ArrayBasedUnitSphereSampler(2, rng);
} else if (NON_ARRAY.equals(type)) {
@@ -201,7 +188,7 @@ protected Sampler createSampler(final UniformRandomProvider rng) {
/**
* Sample from a 2D unit sphere.
*/
private static class UnitSphereSampler2D implements Sampler {
private static class UnitSphereSampler2D implements ObjectSampler<double[]> {
/** Sampler used for generating the individual components of the vectors. */
private final NormalizedGaussianSampler sampler;

@@ -240,11 +227,9 @@ public static class Sampler3D extends SamplerData {

/** {@inheritDoc} */
@Override
protected Sampler createSampler(final UniformRandomProvider rng) {
protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
if (BASELINE.equals(type)) {
return () -> {
return new double[] {1.0, 0.0, 0.0};
};
return () -> new double[] {1.0, 0.0, 0.0};
} else if (ARRAY.equals(type)) {
return new ArrayBasedUnitSphereSampler(3, rng);
} else if (NON_ARRAY.equals(type)) {
@@ -256,7 +241,7 @@ protected Sampler createSampler(final UniformRandomProvider rng) {
/**
* Sample from a 3D unit sphere.
*/
private static class UnitSphereSampler3D implements Sampler {
private static class UnitSphereSampler3D implements ObjectSampler<double[]> {
/** Sampler used for generating the individual components of the vectors. */
private final NormalizedGaussianSampler sampler;

@@ -296,11 +281,9 @@ public static class Sampler4D extends SamplerData {

/** {@inheritDoc} */
@Override
protected Sampler createSampler(final UniformRandomProvider rng) {
protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
if (BASELINE.equals(type)) {
return () -> {
return new double[] {1.0, 0.0, 0.0, 0.0};
};
return () -> new double[] {1.0, 0.0, 0.0, 0.0};
} else if (ARRAY.equals(type)) {
return new ArrayBasedUnitSphereSampler(4, rng);
} else if (NON_ARRAY.equals(type)) {
@@ -312,7 +295,7 @@ protected Sampler createSampler(final UniformRandomProvider rng) {
/**
* Sample from a 4D unit hypersphere.
*/
private static class UnitSphereSampler4D implements Sampler {
private static class UnitSphereSampler4D implements ObjectSampler<double[]> {
/** Sampler used for generating the individual components of the vectors. */
private final NormalizedGaussianSampler sampler;

@@ -345,7 +328,7 @@ public double[] sample() {
/**
* Sample from a unit sphere using an array based method.
*/
private static class ArrayBasedUnitSphereSampler implements Sampler {
private static class ArrayBasedUnitSphereSampler implements ObjectSampler<double[]> {
/** Space dimension. */
private final int dimension;
/** Sampler used for generating the individual components of the vectors. */
@@ -394,7 +377,7 @@ public double[] sample() {
* @param data Input data.
*/
private static void runSampler(Blackhole bh, SamplerData data) {
final Sampler sampler = data.getSampler();
final ObjectSampler<double[]> sampler = data.getSampler();
for (int i = data.getSize() - 1; i >= 0; i--) {
bh.consume(sampler.sample());
}
@@ -43,6 +43,7 @@
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

/**
* Executes benchmark to compare the speed of generation of random numbers from an enumerated
@@ -128,23 +129,11 @@ public abstract static class SamplerSources extends LocalRandomSources {
private String samplerType;

/** The factory. */
private DiscreteSamplerFactory factory;
private Supplier<DiscreteSampler> factory;

/** The sampler. */
private DiscreteSampler sampler;

/**
* A factory for creating DiscreteSampler objects.
*/
interface DiscreteSamplerFactory {
/**
* Creates the sampler.
*
* @return the sampler
*/
DiscreteSampler create();
}

/**
* Gets the sampler.
*
@@ -162,7 +151,7 @@ public void setup() {

final double[] probabilities = createProbabilities();
createSamplerFactory(getGenerator(), probabilities);
sampler = factory.create();
sampler = factory.get();
}

/**
@@ -210,7 +199,7 @@ private void createSamplerFactory(final UniformRandomProvider rng,
* @return The sampler.
*/
public DiscreteSampler createSampler() {
return factory.create();
return factory.get();
}
}

@@ -18,7 +18,7 @@
package org.apache.commons.rng.examples.jmh.sampling.distribution;

import java.util.concurrent.TimeUnit;

import java.util.function.DoubleFunction;
import org.apache.commons.rng.RandomProviderState;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
@@ -267,31 +267,18 @@ public double getMax() {
}
}

/**
* A factory for creating Poisson sampler objects.
*/
private interface PoissonSamplerFactory {
/**
* Creates a new Poisson sampler object.
*
* @param mean the mean
* @return The sampler
*/
DiscreteSampler createPoissonSampler(double mean);
}

/**
* Exercises a poisson sampler created for a single use with a range of means.
*
* @param factory The factory.
* @param range The range of means.
* @param bh Data sink.
*/
private static void runSample(PoissonSamplerFactory factory,
private static void runSample(DoubleFunction<DiscreteSampler> factory,
MeanRange range,
Blackhole bh) {
for (int i = 0; i < NUM_SAMPLES; i++) {
bh.consume(factory.createPoissonSampler(range.getMean(i)).sample());
bh.consume(factory.apply(range.getMean(i)).sample());
}
}

@@ -307,7 +294,7 @@ public void runPoissonSampler(Sources sources,
MeanRange range,
Blackhole bh) {
final UniformRandomProvider r = sources.getGenerator();
final PoissonSamplerFactory factory = mean -> PoissonSampler.of(r, mean);
final DoubleFunction<DiscreteSampler> factory = mean -> PoissonSampler.of(r, mean);
runSample(factory, range, bh);
}

@@ -322,7 +309,7 @@ public void runPoissonSamplerCacheWhenEmpty(Sources sources,
Blackhole bh) {
final UniformRandomProvider r = sources.getGenerator();
final PoissonSamplerCache cache = new PoissonSamplerCache(0, 0);
final PoissonSamplerFactory factory = mean -> cache.createSharedStateSampler(r, mean);
final DoubleFunction<DiscreteSampler> factory = mean -> cache.createSharedStateSampler(r, mean);
runSample(factory, range, bh);
}

@@ -338,7 +325,7 @@ public void runPoissonSamplerCache(Sources sources,
final UniformRandomProvider r = sources.getGenerator();
final PoissonSamplerCache cache = new PoissonSamplerCache(
range.getMin(), range.getMax());
final PoissonSamplerFactory factory = mean -> cache.createSharedStateSampler(r, mean);
final DoubleFunction<DiscreteSampler> factory = mean -> cache.createSharedStateSampler(r, mean);
runSample(factory, range, bh);
}
}
@@ -37,6 +37,7 @@
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

/**
* Executes benchmark to compare the speed of generation of Poisson distributed random numbers.
@@ -136,23 +137,11 @@ public static class Sources {
private UniformRandomProvider generator;

/** The factory. */
private DiscreteSamplerFactory factory;
private Supplier<DiscreteSampler> factory;

/** The sampler. */
private DiscreteSampler sampler;

/**
* A factory for creating DiscreteSampler objects.
*/
interface DiscreteSamplerFactory {
/**
* Creates the sampler.
*
* @return the sampler
*/
DiscreteSampler create();
}

/**
* @return The RNG.
*/
@@ -193,7 +182,7 @@ public void setup() {
} else if ("TinyMeanPoissonSampler".equals(samplerType)) {
factory = () -> new TinyMeanPoissonSampler(generator, mean);
}
sampler = factory.create();
sampler = factory.get();
}

/**
@@ -202,7 +191,7 @@ public void setup() {
* @return The sampler.
*/
public DiscreteSampler createSampler() {
return factory.create();
return factory.get();
}
}

@@ -20,6 +20,7 @@
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
import org.apache.commons.rng.sampling.distribution.LongSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
import org.apache.commons.rng.simple.RandomSource;
@@ -148,16 +149,6 @@ public void setup() {
}
}

/**
* Sampler that generates values of type {@code long}.
*/
interface LongSampler {
/**
* @return a sample.
*/
long sample();
}

/**
* Defines method to use for creating unsigned {@code long} values.
*/

0 comments on commit da20fc8

Please sign in to comment.