diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index 08fb07fdbfd4f..dec0fb7f9aef2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -17,10 +17,6 @@ package org.apache.ignite.ml.nn; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; import org.apache.ignite.ml.composition.CompositionUtils; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -28,6 +24,7 @@ import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.matrix.Matrix; @@ -40,6 +37,11 @@ import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer; import org.apache.ignite.ml.util.Utils; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + /** * Multilayer perceptron trainer based on partition based {@link Dataset}. * @@ -378,4 +380,10 @@ static double[] batch(double[] data, int[] rows, int totalRows) { return res; } + + /** {@inheritDoc} */ + @Override public MLPTrainer

withEnvironmentBuilder( + LearningEnvironmentBuilder envBuilder) { + return (MLPTrainer

)super.withEnvironmentBuilder(envBuilder); + } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java index 4fcef08a52e65..345a885005c05 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml.regressions.logistic; -import java.util.Arrays; import org.apache.ignite.ml.composition.CompositionUtils; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -39,6 +38,8 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.jetbrains.annotations.NotNull; +import java.util.Arrays; + /** * Trainer of the logistic regression model based on stochastic gradient descent algorithm. */ @@ -103,7 +104,7 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)}; MultilayerPerceptron mlp; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java index 3e48a4ae570c4..6e8fd1e01ee40 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java @@ -96,9 +96,10 @@ private AdaptableDatasetTrainer(IgniteFunction before, DatasetTrainer AdaptableDatasetModel fit(DatasetBuilder datasetBuilder, FeatureLabelExtractor extractor) { - M fit = wrapped.fit( - datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), - extractor.andThen(afterExtractor)); + M fit = wrapped. + withEnvironmentBuilder(envBuilder) + .fit(datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), + extractor.andThen(afterExtractor)); return new AdaptableDatasetModel<>(before, fit, after); } @@ -112,10 +113,11 @@ private AdaptableDatasetTrainer(IgniteFunction before, DatasetTrainer AdaptableDatasetModel updateModel( AdaptableDatasetModel mdl, DatasetBuilder datasetBuilder, FeatureLabelExtractor extractor) { - M updated = wrapped.updateModel( - mdl.innerModel(), - datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), - extractor.andThen(afterExtractor)); + M updated = wrapped.withEnvironmentBuilder(envBuilder) + .updateModel( + mdl.innerModel(), + datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), + extractor.andThen(afterExtractor)); return mdl.withInnerModel(updated); } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index 06e9ef2ab9a06..052fc96111be8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml; -import java.util.stream.IntStream; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.primitives.matrix.Matrix; @@ -26,6 +25,8 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor; import org.junit.Assert; +import java.util.stream.IntStream; + import static org.junit.Assert.assertTrue; /** */ @@ -170,6 +171,31 @@ public static void assertEquals(Matrix exp, Matrix actual) { } } + /** + * Verifies that two vectors are equal. + * + * @param exp Expected vector. + * @param observed Actual vector. + */ + public static void assertEquals(Vector exp, Vector observed, double eps) { + Assert.assertNotNull("Observed should not be null", observed); + + if (exp.size() != observed.size()) { + String msgBuff = "Observed has incorrect dimensions." + + "\nobserved is " + observed.size() + + " x " + observed.size(); + + Assert.fail(msgBuff); + } + + for (int i = 0; i < exp.size(); ++i) { + double eij = exp.getX(i); + double aij = observed.getX(i); + + Assert.assertEquals(eij, aij, eps); + } + } + /** * Verifies that two double arrays are close (sup norm). * diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java index 87d56cd62b1f3..bf0943c0d7433 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.composition; +import org.apache.ignite.ml.composition.bagging.BaggingTest; import org.apache.ignite.ml.composition.boosting.GDBTrainerTest; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregatorTest; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregatorTest; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java similarity index 85% rename from modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java rename to modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java index 7a84b64f0ce56..1fc218cdf431b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java @@ -15,15 +15,12 @@ * limitations under the License. */ -package org.apache.ignite.ml.composition; +package org.apache.ignite.ml.composition.bagging; -import java.util.Arrays; -import java.util.Map; import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.composition.bagging.BaggedModel; -import org.apache.ignite.ml.composition.bagging.BaggedTrainer; +import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.Dataset; @@ -38,15 +35,37 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.trainers.AdaptableDatasetModel; import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.trainers.FeatureLabelExtractor; import org.apache.ignite.ml.trainers.TrainerTransformers; import org.junit.Test; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + /** * Tests for bagging algorithm. */ public class BaggingTest extends TrainerTest { + /** + * Dependency of weights of first model in ensemble after training in + * {@link BaggingTest#testNaiveBaggingLogRegression()}. This dependency is tested to ensure that it is + * fully determined by provided seeds. + */ + private static Map firstModelWeights; + + static { + firstModelWeights = new HashMap<>(); + + firstModelWeights.put(1, VectorUtils.of(-0.14721735583126058, 4.366377931980097)); + firstModelWeights.put(2, VectorUtils.of(-1.0092940937477968, 1.2950461550870134)); + firstModelWeights.put(3, VectorUtils.of(-5.5345231104301655, -0.7554216668724918)); + firstModelWeights.put(4, VectorUtils.of(0.136489632011201, 1.0937407007786915)); + firstModelWeights.put(13, VectorUtils.of(-0.27321382073998685, 1.1199411864901687)); + } + /** * Test that count of entries in context is equal to initial dataset size * subsampleRatio. */ @@ -81,7 +100,7 @@ public void testNaiveBaggingLogRegression() { BaggedTrainer baggedTrainer = TrainerTransformers.makeBagged( trainer, - 10, + 7, 0.7, 2, 2, @@ -95,6 +114,10 @@ public void testNaiveBaggingLogRegression() { (k, v) -> v[0] ); + Vector weights = ((LogisticRegressionModel)((AdaptableDatasetModel)((ModelsParallelComposition)((AdaptableDatasetModel)mdl + .model()).innerModel()).submodels().get(0)).innerModel()).weights(); + + TestUtils.assertEquals(firstModelWeights.get(parts), weights, 0.0); TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION); TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); }