Skip to content

Commit

Permalink
MATH-1623: Add parameterized unit tests for simplex-based optimizers.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilles Sadowski committed Aug 9, 2021
1 parent c40a306 commit fdbb8b9
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

import java.util.Arrays;
import org.opentest4j.AssertionFailedError;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.aggregator.ArgumentsAggregator;
import org.junit.jupiter.params.aggregator.ArgumentsAccessor;
import org.junit.jupiter.params.aggregator.ArgumentsAggregationException;
import org.junit.jupiter.params.aggregator.AggregateWith;
import org.junit.jupiter.params.provider.CsvFileSource;
import org.apache.commons.math4.legacy.core.MathArrays;
import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
import org.apache.commons.math4.legacy.optim.InitialGuess;
import org.apache.commons.math4.legacy.optim.MaxEval;
import org.apache.commons.math4.legacy.optim.PointValuePair;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.TestFunction;

/**
* Tests for {@link SimplexOptimizer simplex-based algorithms}.
*/
public class SimplexOptimizerTest {
private static final String NELDER_MEAD_INPUT_FILE = "std_test_func.simplex.nelder_mead.csv";
private static final String MULTIDIRECTIONAL_INPUT_FILE = "std_test_func.simplex.multidirectional.csv";

@ParameterizedTest
@CsvFileSource(resources = NELDER_MEAD_INPUT_FILE)
void testFunctionWithNelderMead(@AggregateWith(TaskAggregator.class) Task task) {
task.run(new NelderMeadTransform());
}

@ParameterizedTest
@CsvFileSource(resources = MULTIDIRECTIONAL_INPUT_FILE)
void testFunctionWithMultiDirectional(@AggregateWith(TaskAggregator.class) Task task) {
task.run(new MultiDirectionalTransform());
}

/**
* Optimization task.
*/
public static class Task {
/** Function evaluations hard count (debugging). */
private static final int FUNC_EVAL_DEBUG = 20000;
/** Default convergence criterion. */
private static final double CONVERGENCE_CHECK = 1e-9;
/** Default simplex size. */
private static final double SIDE_LENGTH = 1;
/** Function. */
private final MultivariateFunction f;
/** Initial value. */
private final double[] start;
/** Optimum. */
private final double[] optimum;
/** Tolerance. */
private final double pointTolerance;
/** Allowed function evaluations. */
private final int functionEvaluations;
/** Repeats on failure. */
private final int repeatsOnFailure;
/** Range of random noise. */
private double jitter;

/**
* @param f Test function.
* @param start Start point.
* @param optimum Optimum.
* @param pointTolerance Allowed distance between result and
* {@code optimum}.
* @param functionEvaluations Allowed number of function evaluations.
* @param repeatsOnFailure Maximum number of times to rerun when an
* {@link AssertionFailedError} is thrown.
* @param jitter Size of random jitter.
*/
Task(MultivariateFunction f,
double[] start,
double[] optimum,
double pointTolerance,
int functionEvaluations,
int repeatsOnFailure,
double jitter) {
this.f = f;
this.start = start;
this.optimum = optimum;
this.pointTolerance = pointTolerance;
this.functionEvaluations = functionEvaluations;
this.repeatsOnFailure = repeatsOnFailure;
this.jitter = jitter;
}

@Override
public String toString() {
return f.toString();
}

/**
* @param factory Simplex transform factory.
*/
public void run(Simplex.TransformFactory factory) {
// Let run with a maximum number of evaluations larger than expected
// (as specified by "functionEvaluations") in order to have the unit
// test failure message (see assertion below) report the actual number
// required by the current code.
final int maxEval = Math.max(functionEvaluations, FUNC_EVAL_DEBUG);

int currentRetry = -1;
AssertionFailedError lastFailure = null;
while (currentRetry++ <= repeatsOnFailure) {
try {
final String name = f.toString();

final SimplexOptimizer optim = new SimplexOptimizer(-1, CONVERGENCE_CHECK);
final Simplex initialSimplex =
Simplex.alongAxes(OptimTestUtils.point(start.length,
SIDE_LENGTH,
jitter));
final double[] startPoint = OptimTestUtils.point(start, jitter);
final PointValuePair result =
optim.optimize(new MaxEval(maxEval),
new ObjectiveFunction(f),
GoalType.MINIMIZE,
new InitialGuess(startPoint),
initialSimplex,
factory);

final double[] endPoint = result.getPoint();
final double funcValue = result.getValue();
final double dist = MathArrays.distance(optimum, endPoint);
Assertions.assertEquals(0d, dist, pointTolerance,
name + ": distance to optimum" +
" f(" + Arrays.toString(endPoint) + ")=" +
funcValue);

final int nEval = optim.getEvaluations();
Assertions.assertTrue(nEval < functionEvaluations,
name + ": nEval=" + nEval);

break; // Assertions passed: Retry not neccessary.
} catch (AssertionFailedError e) {
if (currentRetry >= repeatsOnFailure) {
// Allowed repeats have been exhausted: Bail out.
throw e;
}
}
}
}
}

/**
* Helper for preparing a {@link Task}.
*/
public static class TaskAggregator implements ArgumentsAggregator {
@Override
public Object aggregateArguments(ArgumentsAccessor a,
ParameterContext context)
throws ArgumentsAggregationException {

int index = 0; // Argument index.

final TestFunction funcGen = a.get(index++, TestFunction.class);
final int dim = a.getInteger(index++);
final double[] start = toArrayOfDoubles(a.getString(index++), dim);
final double[] optimum = toArrayOfDoubles(a.getString(index++), dim);
final double pointTol = a.getDouble(index++);
final int funcEval = a.getInteger(index++);
final int repeat = a.getInteger(index++);
final double jitter = a.getDouble(index++);

return new Task(funcGen.withDimension(dim),
start,
optimum,
pointTol,
funcEval,
repeat,
jitter);
}

/**
* @param params Comma-separated list of values.
* @param dim Expected number of values.
* @return an array of {@code double} values.
* @throws ArgumentsAggregationException if the number of values
* is not equal to {@code dim}.
*/
private static double[] toArrayOfDoubles(String params,
int dim) {
final String[] s = params.trim().split("\\s+");

if (s.length != dim) {
final String msg = "Expected " + dim + " values: " + Arrays.toString(s);
throw new ArgumentsAggregationException(msg);
}

final double[] p = new double[dim];
for (int i = 0; i < dim; i++) {
p[i] = Double.valueOf(s[i]);
}

return p;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# CSV Format defined in "SimplexOptimizerTest.Task" class.
# Columns are:
# 0: function name (value from "TestFunction" enum)
# 1: space dimension (n)
# 2: nominal start point ("n" space-separated values)
# 3: optimum ("n" space-separated values)
# 4: maximum expected distance from the result to the optimum
# 5: expected number of function evaluations
# 6: number of retries in case of assertion failure
# 7: size of the random noise (to generate slightly different initial conditions)
#
# Caveat: Some tests are commented out (cf. JIRA: MATH-1552).
#
PARABOLA, 4, 2.5 3.1 4.6 5.8, 0 0 0 0, 1e-4, 340, 3, 2e-1
ROSENBROCK, 2, -1.2 1, 1 1, 2e-3, 9200, 3, 1e-1
ROSENBROCK, 10, -0.1 0.1 0.2 -0.1 -0.2 0.3 0.2 -0.1 0.2 -0.3, 1 1 1 1 1 1 1 1 1 1, 3e-3, 100000, 3, 5e-2
#POWELL, 4, 3 -1 -2 1, 0 0 0 0, 5e-3, 420, 3, 1e-1
CIGAR, 13, -1.2 2.3 -3.2 2.1 1.2 -2.3 3.2 -2.1 -1.2 2.3 -3.2 2.1 -1.2, 0 0 0 0 0 0 0 0 0 0 0 0 0, 1e-6, 7000, 3, 1e-1
SPHERE, 13, -1.2 2.3 -3.2 2.1 1.2 -2.3 3.2 -2.1 -1.2 2.3 -3.2 2.1 -1.2, 0 0 0 0 0 0 0 0 0 0 0 0 0, 5e-5, 3600, 3, 1e-1
ELLI, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 1e-4, 50000, 3, 1e-1
TWO_AXES, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 1e-6, 3200, 3, 1e-1
CIG_TAB, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 5e-6, 2700, 3, 1e-1
TABLET, 11, 2 3 4 -3 -2 -1 2 3 4 3 -1, 0 0 0 0 0 0 0 0 0 0 0, 5e-6, 3000, 3, 1e-1
DIFF_POW, 7, 1 -1 1 -1 1 -1 1, 0 0 0 0 0 0 0, 5e-4, 2500, 3, 1e-1
SS_DIFF_POW, 6, -3.2 2.1 1.2 -2.3 3.2 -2.1, 0 0 0 0 0 0, 1e-3, 4000, 3, 1e-1
ACKLEY, 4, 3 4 -3 -2, 0 0 0 0, 1e-6, 700, 3, 5e-1
#RASTRIGIN, 4, 3 4 -3 -2, 0 0 0 0, 1e-6, 10000, 3, 5e-1
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# CSV Format defined in "SimplexOptimizerTest.Task" class.
# Columns are:
# 0: function name (value from "TestFunction" enum)
# 1: space dimension (n)
# 2: nominal start point ("n" space-separated values)
# 3: optimum ("n" space-separated values)
# 4: maximum expected distance from the result to the optimum
# 5: expected number of function evaluations
# 6: number of retries in case of assertion failure
# 7: size of the random noise (to generate slightly different initial conditions)
#
# Caveat: Some tests are commented out (cf. JIRA: MATH-1552).
#
PARABOLA, 4, 2.5 3.1 4.6 5.8, 0 0 0 0, 1e-4, 200, 3, 2e-1
ROSENBROCK, 2, -1.2 1, 1 1, 1e-4, 180, 3, 1e-1
ROSENBROCK, 10, -0.1 0.1 0.2 -0.1 -0.2 0.3 0.2 -0.1 0.2 -0.3, 1 1 1 1 1 1 1 1 1 1, 5e-5, 9000, 3, 5e-2
POWELL, 4, 3 -1 -2 1, 0 0 0 0, 5e-3, 420, 3, 1e-1
CIGAR, 13, -1.2 2.3 -3.2 2.1 1.2 -2.3 3.2 -2.1 -1.2 2.3 -3.2 2.1 -1.2, 0 0 0 0 0 0 0 0 0 0 0 0 0, 5e-5, 7000, 3, 1e-1
SPHERE, 13, -1.2 2.3 -3.2 2.1 1.2 -2.3 3.2 -2.1 -1.2 2.3 -3.2 2.1 -1.2, 0 0 0 0 0 0 0 0 0 0 0 0 0, 5e-4, 3000, 3, 1e-1
ELLI, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 1e-4, 50000, 3, 1e-1
#TWO_AXES, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 1e-4, 5000, 3, 1e-1
#CIG_TAB, 10, 2 3 4 -3 -2 -1 2 3 4 3, 0 0 0 0 0 0 0 0 0 0, 1e-4, 7000, 3, 1e-1
TABLET, 11, 2 3 4 -3 -2 -1 2 3 4 3 -1, 0 0 0 0 0 0 0 0 0 0 0, 2e-4, 3100, 3, 1e-1
#DIFF_POW, 7, 1 -1 1 -1 1 -1 1, 0 0 0 0 0 0 0, 5e-4, 2500, 3, 1e-1
SS_DIFF_POW, 6, -3.2 2.1 1.2 -2.3 3.2 -2.1, 0 0 0 0 0 0, 1e-3, 4000, 3, 1e-1
ACKLEY, 4, 3 4 -3 -2, 0 0 0 0, 1e-6, 350, 3, 5e-1
#RASTRIGIN, 4, 3 4 -3 -2, 0 0 0 0, 1e-6, 10000, 3, 5e-1

0 comments on commit fdbb8b9

Please sign in to comment.