diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
index 08fb07fdbfd4f..dec0fb7f9aef2 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
@@ -17,10 +17,6 @@
package org.apache.ignite.ml.nn;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -28,6 +24,7 @@
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
@@ -40,6 +37,11 @@
import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.util.Utils;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
/**
* Multilayer perceptron trainer based on partition based {@link Dataset}.
*
@@ -378,4 +380,10 @@ static double[] batch(double[] data, int[] rows, int totalRows) {
return res;
}
+
+ /** {@inheritDoc} */
+ @Override public MLPTrainer
withEnvironmentBuilder(
+ LearningEnvironmentBuilder envBuilder) {
+ return (MLPTrainer
)super.withEnvironmentBuilder(envBuilder);
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 4fcef08a52e65..345a885005c05 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml.regressions.logistic;
-import java.util.Arrays;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -39,6 +38,8 @@
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
+import java.util.Arrays;
+
/**
* Trainer of the logistic regression model based on stochastic gradient descent algorithm.
*/
@@ -103,7 +104,7 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)};
MultilayerPerceptron mlp;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
index 3e48a4ae570c4..6e8fd1e01ee40 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
@@ -96,9 +96,10 @@ private AdaptableDatasetTrainer(IgniteFunction before, DatasetTrainer AdaptableDatasetModel fit(DatasetBuilder datasetBuilder,
FeatureLabelExtractor extractor) {
- M fit = wrapped.fit(
- datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
- extractor.andThen(afterExtractor));
+ M fit = wrapped.
+ withEnvironmentBuilder(envBuilder)
+ .fit(datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+ extractor.andThen(afterExtractor));
return new AdaptableDatasetModel<>(before, fit, after);
}
@@ -112,10 +113,11 @@ private AdaptableDatasetTrainer(IgniteFunction before, DatasetTrainer AdaptableDatasetModel updateModel(
AdaptableDatasetModel mdl, DatasetBuilder datasetBuilder,
FeatureLabelExtractor extractor) {
- M updated = wrapped.updateModel(
- mdl.innerModel(),
- datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
- extractor.andThen(afterExtractor));
+ M updated = wrapped.withEnvironmentBuilder(envBuilder)
+ .updateModel(
+ mdl.innerModel(),
+ datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+ extractor.andThen(afterExtractor));
return mdl.withInnerModel(updated);
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
index 06e9ef2ab9a06..052fc96111be8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml;
-import java.util.stream.IntStream;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
@@ -26,6 +25,8 @@
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.junit.Assert;
+import java.util.stream.IntStream;
+
import static org.junit.Assert.assertTrue;
/** */
@@ -170,6 +171,31 @@ public static void assertEquals(Matrix exp, Matrix actual) {
}
}
+ /**
+ * Verifies that two vectors are equal.
+ *
+ * @param exp Expected vector.
+ * @param observed Actual vector.
+ */
+ public static void assertEquals(Vector exp, Vector observed, double eps) {
+ Assert.assertNotNull("Observed should not be null", observed);
+
+ if (exp.size() != observed.size()) {
+ String msgBuff = "Observed has incorrect dimensions." +
+ "\nobserved is " + observed.size() +
+ " x " + observed.size();
+
+ Assert.fail(msgBuff);
+ }
+
+ for (int i = 0; i < exp.size(); ++i) {
+ double eij = exp.getX(i);
+ double aij = observed.getX(i);
+
+ Assert.assertEquals(eij, aij, eps);
+ }
+ }
+
/**
* Verifies that two double arrays are close (sup norm).
*
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
index 87d56cd62b1f3..bf0943c0d7433 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
@@ -17,6 +17,7 @@
package org.apache.ignite.ml.composition;
+import org.apache.ignite.ml.composition.bagging.BaggingTest;
import org.apache.ignite.ml.composition.boosting.GDBTrainerTest;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregatorTest;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregatorTest;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java
similarity index 85%
rename from modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
rename to modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java
index 7a84b64f0ce56..1fc218cdf431b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/bagging/BaggingTest.java
@@ -15,15 +15,12 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.composition;
+package org.apache.ignite.ml.composition.bagging;
-import java.util.Arrays;
-import java.util.Map;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.bagging.BaggedModel;
-import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
+import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
@@ -38,15 +35,37 @@
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.trainers.AdaptableDatasetModel;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.TrainerTransformers;
import org.junit.Test;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
/**
* Tests for bagging algorithm.
*/
public class BaggingTest extends TrainerTest {
+ /**
+ * Dependency of weights of first model in ensemble after training in
+ * {@link BaggingTest#testNaiveBaggingLogRegression()}. This dependency is tested to ensure that it is
+ * fully determined by provided seeds.
+ */
+ private static Map firstModelWeights;
+
+ static {
+ firstModelWeights = new HashMap<>();
+
+ firstModelWeights.put(1, VectorUtils.of(-0.14721735583126058, 4.366377931980097));
+ firstModelWeights.put(2, VectorUtils.of(-1.0092940937477968, 1.2950461550870134));
+ firstModelWeights.put(3, VectorUtils.of(-5.5345231104301655, -0.7554216668724918));
+ firstModelWeights.put(4, VectorUtils.of(0.136489632011201, 1.0937407007786915));
+ firstModelWeights.put(13, VectorUtils.of(-0.27321382073998685, 1.1199411864901687));
+ }
+
/**
* Test that count of entries in context is equal to initial dataset size * subsampleRatio.
*/
@@ -81,7 +100,7 @@ public void testNaiveBaggingLogRegression() {
BaggedTrainer baggedTrainer = TrainerTransformers.makeBagged(
trainer,
- 10,
+ 7,
0.7,
2,
2,
@@ -95,6 +114,10 @@ public void testNaiveBaggingLogRegression() {
(k, v) -> v[0]
);
+ Vector weights = ((LogisticRegressionModel)((AdaptableDatasetModel)((ModelsParallelComposition)((AdaptableDatasetModel)mdl
+ .model()).innerModel()).submodels().get(0)).innerModel()).weights();
+
+ TestUtils.assertEquals(firstModelWeights.get(parts), weights, 0.0);
TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 10)), PRECISION);
TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION);
}