Skip to content

Commit

Permalink
MATH-1597: LowDiscrepancySequence supplier/jump for Halton and Sobol
Browse files Browse the repository at this point in the history
  • Loading branch information
samyBadjoudj committed Jul 6, 2021
1 parent 7f42535 commit 0412ad3
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.random;

import java.util.function.Supplier;
package org.apache.commons.math4.legacy.quasirandom;

import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
import org.apache.commons.math4.legacy.exception.NotPositiveException;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;

import java.util.Arrays;

/**
* Implementation of a Halton sequence.
* <p>
Expand All @@ -43,15 +44,15 @@
* The generator supports two modes:
* <ul>
* <li>sequential generation of points: {@link #get()}</li>
* <li>random access to the i-th point in the sequence: {@link #skipTo(int)}</li>
* <li>random access to the i-th point in the sequence: {@link LowDiscrepancySequence#jump(long)}</li>
* </ul>
*
* @see <a href="http://en.wikipedia.org/wiki/Halton_sequence">Halton sequence (Wikipedia)</a>
* @see <a href="https://lirias.kuleuven.be/bitstream/123456789/131168/1/mcm2005_bartv.pdf">
* On the Halton sequence and its scramblings</a>
* @since 3.3
*/
public class HaltonSequenceGenerator implements Supplier<double[]> {
public class HaltonSequenceGenerator implements LowDiscrepancySequence {

/** The first 40 primes. */
private static final int[] PRIMES = new int[] {
Expand All @@ -71,7 +72,7 @@ public class HaltonSequenceGenerator implements Supplier<double[]> {
private final int dimension;

/** The current index in the sequence. */
private int count;
private long count;

/** The base numbers for each component. */
private final int[] base;
Expand All @@ -89,6 +90,13 @@ public HaltonSequenceGenerator(final int dimension) {
this(dimension, PRIMES, WEIGHTS);
}

private HaltonSequenceGenerator(HaltonSequenceGenerator original){
this.base = original.base;
this.count = original.count;
this.dimension = original.dimension;
this.weight = Arrays.copyOf(original.weight,original.weight.length);
}

/**
* Construct a new Halton sequence generator with the given base numbers and weights for each dimension.
* The length of the bases array defines the space dimension and is required to be &gt; 0.
Expand Down Expand Up @@ -123,12 +131,12 @@ public HaltonSequenceGenerator(final int dimension, final int[] bases, final int
public double[] get() {
final double[] v = new double[dimension];
for (int i = 0; i < dimension; i++) {
int index = count;
long index = count;
double f = 1.0 / base[i];

int j = 0;
while (index > 0) {
final int digit = scramble(i, j, base[i], index % base[i]);
final long digit = scramble(i, j, base[i], index % base[i]);
v[i] += f * digit;
index /= base[i]; // floor( index / base )
f /= base[i];
Expand All @@ -151,31 +159,44 @@ public double[] get() {
* @param digit the j-th digit
* @return the scrambled digit
*/
protected int scramble(final int i, final int j, final int b, final int digit) {
protected long scramble(final int i, final int j, final int b, final long digit) {
return weight != null ? (weight[i] * digit) % b : digit;
}

/**
* Skip to the i-th point in the Halton sequence.
* jump to the i-th point in the Halton sequence.
* <p>
* This operation can be performed in O(1).
*
* @param index the index in the sequence to skip to
* @return the i-th point in the Halton sequence
* @throws org.apache.commons.math4.legacy.exception.NotPositiveException NotPositiveException if index &lt; 0
* @return the copy of this sequence
* @throws NotPositiveException if index &lt; 0
*/
public double[] skipTo(final int index) {
@Override
public LowDiscrepancySequence jump(final long index) throws NotPositiveException {
if(index < 0) {
throw new NotPositiveException(index);
}
HaltonSequenceGenerator copy = this.copy();
copy.performJump(index);
return copy;
}

/**
* do jump at the index
* @param index
*/
private void performJump(long index) {
count = index;
return get();
}

/**
* Returns the index i of the next point in the Halton sequence that will be returned
* by calling {@link #get()}.
*
* @return the index of the next point
* private constructor avoid side effects
* @return copy of HaltonSequenceGenerator
*/
public int getNextIndex() {
return count;
private HaltonSequenceGenerator copy() {
return new HaltonSequenceGenerator(this);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.quasirandom;

import java.util.function.Supplier;

/** Interface to Low Discrepancy Sequence generator and supplier
* Supplier of a low discrepancy vectors
* Offers navigation through underlying sequence
*/
public interface LowDiscrepancySequence extends Supplier<double[]> {
/**
* Skip to the index in the sequence
* @param index of the element to skip to
* @return T element at the index
*/
LowDiscrepancySequence jump(long index);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.random;
package org.apache.commons.math4.legacy.quasirandom;

import java.io.BufferedReader;
import java.io.IOException;
Expand All @@ -24,10 +24,10 @@
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.StringTokenizer;
import java.util.function.Supplier;

import org.apache.commons.math4.legacy.exception.MathInternalError;
import org.apache.commons.math4.legacy.exception.MathParseException;
import org.apache.commons.math4.legacy.exception.NotPositiveException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.core.jdkmath.AccurateMath;
Expand All @@ -45,15 +45,16 @@
* The generator supports two modes:
* <ul>
* <li>sequential generation of points: {@link #get()}</li>
* <li>random access to the i-th point in the sequence: {@link #skipTo(int)}</li>
* <li>random access to the i-th point in the sequence: {@link LowDiscrepancySequence#jump(long)}</li>
* </ul>
*
* @see <a href="http://en.wikipedia.org/wiki/Sobol_sequence">Sobol sequence (Wikipedia)</a>
* @see <a href="http://web.maths.unsw.edu.au/~fkuo/sobol/">Sobol sequence direction numbers</a>
*
* @since 3.3
*/
public class SobolSequenceGenerator implements Supplier<double[]> {
public class SobolSequenceGenerator implements LowDiscrepancySequence {

/** The number of bits to use. */
private static final int BITS = 52;

Expand All @@ -73,7 +74,7 @@ public class SobolSequenceGenerator implements Supplier<double[]> {
private final int dimension;

/** The current index in the sequence. */
private int count;
private long count;

/** The direction vector for each component. */
private final long[][] direction;
Expand Down Expand Up @@ -171,6 +172,13 @@ public SobolSequenceGenerator(final int dimension, final InputStream is)
}
}

private SobolSequenceGenerator(SobolSequenceGenerator sobolSequenceGenerator) {
this.dimension = sobolSequenceGenerator.dimension;
this.count = sobolSequenceGenerator.count;
this.direction = Arrays.copyOf(sobolSequenceGenerator.direction, sobolSequenceGenerator.direction.length);
this.x = Arrays.copyOf(sobolSequenceGenerator.x, sobolSequenceGenerator.x.length);
}

/**
* Load the direction vector for each dimension from the given stream.
* <p>
Expand All @@ -185,7 +193,7 @@ public SobolSequenceGenerator(final int dimension, final InputStream is)
private int initFromStream(final InputStream is) throws IOException {
// special case: dimension 1 -> use unit initialization
for (int i = 1; i <= BITS; i++) {
direction[0][i] = 1L << (BITS - i);
direction[0][i] = 1l << (BITS - i);
}

final Charset charset = Charset.forName(FILE_CHARSET);
Expand Down Expand Up @@ -261,7 +269,7 @@ public double[] get() {

// find the index c of the rightmost 0
int c = 1;
int value = count - 1;
long value = count - 1;
while ((value & 1) == 1) {
value >>= 1;
c++;
Expand All @@ -281,15 +289,30 @@ public double[] get() {
* This operation can be performed in O(1).
*
* @param index the index in the sequence to skip to
* @return the i-th point in the Sobol sequence
* @throws org.apache.commons.math4.legacy.exception.NotPositiveException NotPositiveException if index &lt; 0
* @return the sequence it self
* @throws NotPositiveException if index &lt; 0
*/
public double[] skipTo(final int index) {
@Override
public LowDiscrepancySequence jump(final long index) throws NotPositiveException {
if(index < 0) {
throw new NotPositiveException(index);
}

SobolSequenceGenerator copy = this.copy();
copy.performJump(index);
return copy;
}

/**
* do jump at the index
* @param index
*/
private void performJump(long index) {
if (index == 0) {
// reset x vector
Arrays.fill(x, 0);
} else {
final int i = index - 1;
final long i = index - 1;
final long grayCode = i ^ (i >> 1); // compute the gray code of i = i XOR floor(i / 2)
for (int j = 0; j < dimension; j++) {
long result = 0;
Expand All @@ -307,17 +330,14 @@ public double[] skipTo(final int index) {
}
}
count = index;
return get();
}

/**
* Returns the index i of the next point in the Sobol sequence that will be returned
* by calling {@link #get()}.
*
* @return the index of the next point
* private constructor avoid side effects
* @return copy of LowDiscrepancySequence
*/
public int getNextIndex() {
return count;
private SobolSequenceGenerator copy() {
return new SobolSequenceGenerator(this);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.quasirandom;
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.random;
package org.apache.commons.math4.legacy.quasirandom;

import org.junit.Assert;
import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
import org.apache.commons.math4.legacy.exception.NotPositiveException;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

Expand Down Expand Up @@ -63,7 +64,6 @@ public void test3DReference() {
for (int i = 0; i < referenceValues.length; i++) {
double[] result = generator.get();
Assert.assertArrayEquals(referenceValues[i], result, 1e-3);
Assert.assertEquals(i + 1, generator.getNextIndex());
}
}

Expand All @@ -73,7 +73,6 @@ public void test2DUnscrambledReference() {
for (int i = 0; i < referenceValuesUnscrambled.length; i++) {
double[] result = generator.get();
Assert.assertArrayEquals(referenceValuesUnscrambled[i], result, 1e-3);
Assert.assertEquals(i + 1, generator.getNextIndex());
}
}

Expand Down Expand Up @@ -119,16 +118,30 @@ public void testConstructor2() throws Exception{
}

@Test
public void testSkip() {
double[] result = generator.skipTo(5);
public void testJump() {
LowDiscrepancySequence copyOfSeq = generator.jump(5);
double[] result = copyOfSeq.get();
Assert.assertArrayEquals(referenceValues[5], result, 1e-3);
Assert.assertEquals(6, generator.getNextIndex());


for (int i = 6; i < referenceValues.length; i++) {
result = generator.get();
result = copyOfSeq.get();
Assert.assertArrayEquals(referenceValues[i], result, 1e-3);
Assert.assertEquals(i + 1, generator.getNextIndex());
}
}


@Test(expected = NotPositiveException.class)
public void testJumpNegativeIndex() {
LowDiscrepancySequence copyOfSeq = generator.jump(-5);
}



@Test
public void testFirstSupplying() {
LowDiscrepancySequence sequence = new HaltonSequenceGenerator(3);
Assert.assertArrayEquals(new double[]{0.0, 0.0, 0.0}, sequence.get(),1e-6);
}

}

0 comments on commit 0412ad3

Please sign in to comment.