Skip to content
Permalink
Browse files
RNG-177: Add stream methods to the sampler API
  • Loading branch information
aherbert committed May 4, 2022
1 parent 0ead6fd commit dc779b4a919ae88b6094ec2c8f7811e38070c3bf
Showing 8 changed files with 410 additions and 5 deletions.
@@ -17,6 +17,8 @@

package org.apache.commons.rng.sampling;

import java.util.stream.Stream;

/**
* Sampler that generates values of a specified type.
*
@@ -25,9 +27,37 @@
*/
public interface ObjectSampler<T> {
/**
* Create a sample.
* Create an object sample.
*
* @return a sample
* @return a sample.
*/
T sample();

/**
* Returns an effectively unlimited stream of object sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}().
*
* @return a stream of object values.
* @since 1.5
*/
default Stream<T> samples() {
return Stream.generate(this::sample).sequential();
}

/**
* Returns a stream producing the given {@code streamSize} number of object
* sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}.
*
* @param streamSize Number of values to generate.
* @return a stream of object values.
* @since 1.5
*/
default Stream<T> samples(long streamSize) {
return samples().limit(streamSize);
}
}
@@ -16,16 +16,46 @@
*/
package org.apache.commons.rng.sampling.distribution;

import java.util.stream.DoubleStream;

/**
* Sampler that generates values of type {@code double}.
*
* @since 1.0
*/
public interface ContinuousSampler {
/**
* Creates a sample.
* Creates a {@code double} sample.
*
* @return a sample.
*/
double sample();

/**
* Returns an effectively unlimited stream of {@code double} sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}().
*
* @return a stream of {@code double} values.
* @since 1.5
*/
default DoubleStream samples() {
return DoubleStream.generate(this::sample).sequential();
}

/**
* Returns a stream producing the given {@code streamSize} number of {@code double}
* sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}.
*
* @param streamSize Number of values to generate.
* @return a stream of {@code double} values.
* @since 1.5
*/
default DoubleStream samples(long streamSize) {
return samples().limit(streamSize);
}
}
@@ -16,16 +16,46 @@
*/
package org.apache.commons.rng.sampling.distribution;

import java.util.stream.IntStream;

/**
* Sampler that generates values of type {@code int}.
*
* @since 1.0
*/
public interface DiscreteSampler {
/**
* Creates a sample.
* Creates an {@code int} sample.
*
* @return a sample.
*/
int sample();

/**
* Returns an effectively unlimited stream of {@code int} sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}().
*
* @return a stream of {@code int} values.
* @since 1.5
*/
default IntStream samples() {
return IntStream.generate(this::sample).sequential();
}

/**
* Returns a stream producing the given {@code streamSize} number of {@code int}
* sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}.
*
* @param streamSize Number of values to generate.
* @return a stream of {@code int} values.
* @since 1.5
*/
default IntStream samples(long streamSize) {
return samples().limit(streamSize);
}
}
@@ -16,16 +16,46 @@
*/
package org.apache.commons.rng.sampling.distribution;

import java.util.stream.LongStream;

/**
* Sampler that generates values of type {@code long}.
*
* @since 1.4
*/
public interface LongSampler {
/**
* Creates a sample.
* Creates a {@code long} sample.
*
* @return a sample.
*/
long sample();

/**
* Returns an effectively unlimited stream of {@code long} sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}().
*
* @return a stream of {@code long} values.
* @since 1.5
*/
default LongStream samples() {
return LongStream.generate(this::sample).sequential();
}

/**
* Returns a stream producing the given {@code streamSize} number of {@code long}
* sample values.
*
* <p>The default implementation produces a sequential stream that repeatedly
* calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}.
*
* @param streamSize Number of values to generate.
* @return a stream of {@code long} values.
* @since 1.5
*/
default LongStream samples(long streamSize) {
return samples().limit(streamSize);
}
}
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling;

import java.util.concurrent.ThreadLocalRandom;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Tests the default methods in the {@link ObjectSampler} interface.
*/
class ObjectSamplerTest {
@Test
void testSamplesUnlimitedSize() {
final ObjectSampler<Double> s = RandomSource.SPLIT_MIX_64.create()::nextDouble;
Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize());
}

@RepeatedTest(value = 3)
void testSamples() {
final long seed = RandomSource.createLong();
final ObjectSampler<Double> s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final ObjectSampler<Double> s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final int count = ThreadLocalRandom.current().nextInt(3, 13);
Assertions.assertArrayEquals(createSamples(s1, count),
s2.samples().limit(count).toArray());
}

@ParameterizedTest
@ValueSource(ints = {0, 1, 2, 5, 13})
void testSamples(int streamSize) {
final long seed = RandomSource.createLong();
final ObjectSampler<Double> s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final ObjectSampler<Double> s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
Assertions.assertArrayEquals(createSamples(s1, streamSize),
s2.samples(streamSize).toArray());
}

/**
* Creates an array of samples.
*
* @param sampler Source of samples.
* @param count Number of samples.
* @return the samples
*/
private static Double[] createSamples(ObjectSampler<Double> sampler, int count) {
final Double[] data = new Double[count];
for (int i = 0; i < count; i++) {
// Explicit boxing
data[i] = Double.valueOf(sampler.sample());
}
return data;
}
}
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.distribution;

import java.util.concurrent.ThreadLocalRandom;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Tests the default methods in the {@link ContinuousSampler} interface.
*/
class ContinuousSamplerTest {
@Test
void testSamplesUnlimitedSize() {
final ContinuousSampler s = RandomSource.SPLIT_MIX_64.create()::nextDouble;
Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize());
}

@RepeatedTest(value = 3)
void testSamples() {
final long seed = RandomSource.createLong();
final ContinuousSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final ContinuousSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final int count = ThreadLocalRandom.current().nextInt(3, 13);
Assertions.assertArrayEquals(createSamples(s1, count),
s2.samples().limit(count).toArray());
}

@ParameterizedTest
@ValueSource(ints = {0, 1, 2, 5, 13})
void testSamples(int streamSize) {
final long seed = RandomSource.createLong();
final ContinuousSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
final ContinuousSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble;
Assertions.assertArrayEquals(createSamples(s1, streamSize),
s2.samples(streamSize).toArray());
}

/**
* Creates an array of samples.
*
* @param sampler Source of samples.
* @param count Number of samples.
* @return the samples
*/
private static double[] createSamples(ContinuousSampler sampler, int count) {
final double[] data = new double[count];
for (int i = 0; i < count; i++) {
data[i] = sampler.sample();
}
return data;
}
}

0 comments on commit dc779b4

Please sign in to comment.