From 6eccf23009c1e0a03638255555e8c93a77667126 Mon Sep 17 00:00:00 2001 From: artemmalykh Date: Tue, 16 Jan 2018 13:29:04 +0300 Subject: [PATCH] IGNITE-7350: Distributed MLP cleanup/refactoring this closes #3368 --- .../ml/nn/MLPGroupTrainerExample.java | 2 +- .../ml/nn/MLPLocalTrainerExample.java | 6 +- .../nn/MLPGroupUpdateTrainerCacheInput.java | 28 ++++++++- .../ignite/ml/nn/MultilayerPerceptron.java | 2 +- .../AbstractMLPGroupUpdateTrainerInput.java | 2 +- .../distributed/MLPGroupUpdateTrainer.java | 11 ++-- .../MLPGroupUpdateTrainerDataCache.java | 2 +- .../MLPGroupUpdateTrainingData.java | 29 +++++++--- .../MLPGroupUpdateTrainingLoopData.java | 30 ++++++---- .../distributed/MLPMetaoptimizer.java | 2 +- .../trainers/local/MLPLocalBatchTrainer.java | 9 +-- .../optimization/BarzilaiBorweinUpdater.java | 8 ++- .../ml/optimization/BaseParametrized.java | 58 +++++++++++++++++++ .../ml/optimization/GradientDescent.java | 3 +- .../ml/optimization/GradientFunction.java | 2 +- .../{nn => optimization}/LossFunctions.java | 2 +- .../Parametrized.java} | 11 ++-- .../ignite/ml/optimization/SimpleUpdater.java | 3 +- .../SmoothParametrized.java} | 25 +------- .../NesterovParameterUpdate.java | 2 +- .../NesterovUpdateCalculator.java | 19 +++--- .../ParameterUpdateCalculator.java | 9 +-- .../RPropParameterUpdate.java | 2 +- .../RPropUpdateCalculator.java | 3 +- .../updatecalculators}/SimpleGDParameter.java | 2 +- .../SimpleGDUpdateCalculator.java | 3 +- .../updatecalculators/package-info.java | 22 +++++++ .../SparseDistributedMatrixMapReducer.java | 2 +- .../apache/ignite/ml/trainers/Trainer.java | 3 +- .../group/MetaoptimizerDistributedStep.java | 13 ++++- .../group/chain/ComputationsChain.java | 4 +- .../trainers/local/LocalBatchTrainer.java | 7 +-- .../local}/LocalBatchTrainerInput.java | 2 +- .../local}/package-info.java | 4 +- .../java/org/apache/ignite/ml/util/Utils.java | 22 +++++-- .../apache/ignite/ml/IgniteMLTestSuite.java | 4 +- .../ignite/ml/nn/MLPGroupTrainerTest.java | 6 +- .../ignite/ml/nn/MLPLocalTrainerTest.java | 9 +-- .../java/org/apache/ignite/ml/nn/MLPTest.java | 1 + .../nn/SimpleMLPLocalBatchTrainerInput.java | 1 + .../ml/nn/performance/MnistDistributed.java | 2 +- .../ignite/ml/nn/performance/MnistLocal.java | 4 +- .../optimization/OptimizationTestSuite.java | 33 +++++++++++ .../ml/trainers/group/TestGroupTrainer.java | 2 +- 44 files changed, 297 insertions(+), 119 deletions(-) create mode 100644 modules/ml/src/main/java/org/apache/ignite/ml/optimization/BaseParametrized.java rename modules/ml/src/main/java/org/apache/ignite/ml/{nn => optimization}/LossFunctions.java (97%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters/SmoothParametrized.java => optimization/Parametrized.java} (71%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters/BaseSmoothParametrized.java => optimization/SmoothParametrized.java} (79%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/NesterovParameterUpdate.java (98%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/NesterovUpdateCalculator.java (83%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/ParameterUpdateCalculator.java (89%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/RPropParameterUpdate.java (99%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/RPropUpdateCalculator.java (97%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/SimpleGDParameter.java (97%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => optimization/updatecalculators}/SimpleGDUpdateCalculator.java (95%) create mode 100644 modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java rename modules/ml/src/main/java/org/apache/ignite/ml/{nn => }/trainers/local/LocalBatchTrainer.java (95%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn => trainers/local}/LocalBatchTrainerInput.java (96%) rename modules/ml/src/main/java/org/apache/ignite/ml/{nn/updaters => trainers/local}/package-info.java (91%) create mode 100644 modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java index 8d4a1512f6927..d106fadfee423 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java @@ -35,7 +35,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.thread.IgniteThread; diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java index 3f4adc47aa5ad..b5574587e3d72 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java @@ -25,13 +25,13 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.Activators; -import org.apache.ignite.ml.nn.LocalBatchTrainerInput; -import org.apache.ignite.ml.nn.LossFunctions; +import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput; +import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.util.Utils; /** diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java index 14db261c305e1..05f52e1ce3976 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.ignite.Ignite; @@ -56,6 +57,11 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain */ private final MultilayerPerceptron mlp; + /** + * Random number generator. + */ + private final Random rand; + /** * Construct instance of this class with given parameters. * @@ -64,15 +70,32 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}. * @param cache Cache with labeled vectors. * @param batchSize Size of batch to return on each training iteration. + * @param rand RNG. */ public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, MLPInitializer init, int networksCnt, IgniteCache> cache, - int batchSize) { + int batchSize, Random rand) { super(networksCnt); this.batchSize = batchSize; this.cache = cache; this.mlp = new MultilayerPerceptron(arch, init); + this.rand = rand; + } + + /** + * Construct instance of this class with given parameters. + * + * @param arch Architecture of multilayer perceptron. + * @param init Initializer of multilayer perceptron. + * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}. + * @param cache Cache with labeled vectors. + * @param batchSize Size of batch to return on each training iteration. + */ + public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, MLPInitializer init, + int networksCnt, IgniteCache> cache, + int batchSize) { + this(arch, init, networksCnt, cache, batchSize, new Random()); } /** @@ -93,6 +116,7 @@ public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, int networksCnt, @Override public IgniteSupplier> batchSupplier() { String cName = cache.getName(); int bs = batchSize; + Random r = rand; // IMPL NOTE this is intended to make below lambda more lightweight. return () -> { Ignite ignite = Ignition.localIgnite(); @@ -105,7 +129,7 @@ public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, int networksCnt, int locKeysCnt = keys.size(); - int[] selected = Utils.selectKDistinct(locKeysCnt, Math.min(bs, locKeysCnt)); + int[] selected = Utils.selectKDistinct(locKeysCnt, Math.min(bs, locKeysCnt), r); // Get dimensions of vectors in cache. We suppose that every feature vector has // same dimension d 1 and every label has the same dimension d2. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MultilayerPerceptron.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MultilayerPerceptron.java index d55e0e9aa9315..7bf238d04518d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MultilayerPerceptron.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MultilayerPerceptron.java @@ -35,7 +35,7 @@ import org.apache.ignite.ml.nn.architecture.TransformationLayerArchitecture; import org.apache.ignite.ml.nn.initializers.MLPInitializer; import org.apache.ignite.ml.nn.initializers.RandomInitializer; -import org.apache.ignite.ml.nn.updaters.SmoothParametrized; +import org.apache.ignite.ml.optimization.SmoothParametrized; import static org.apache.ignite.ml.math.util.MatrixUtil.elementWiseTimes; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java index ed65af7d7f2d3..f2d95d5027f1e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java @@ -20,7 +20,7 @@ import java.util.UUID; import java.util.stream.Stream; import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.nn.LocalBatchTrainerInput; +import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput; import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; import org.apache.ignite.ml.trainers.group.GroupTrainerInput; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java index 1f98b533f0a8f..f4647d5b984ad 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java @@ -33,11 +33,11 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.util.MatrixUtil; -import org.apache.ignite.ml.nn.LossFunctions; +import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; import org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer; import org.apache.ignite.ml.trainers.group.ResultAndUpdates; @@ -227,8 +227,7 @@ MLPGroupUpdateTrainingContext>, MLPGroupUpdateTrainingLoopData> trainingLo UUID uuid = ctx.trainingUUID(); return () -> { - MLPGroupUpdateTrainingData data = MLPGroupUpdateTrainerDataCache - .getOrCreate(Ignition.localIgnite()).get(uuid); + MLPGroupUpdateTrainingData data = MLPGroupUpdateTrainerDataCache.getOrCreate(Ignition.localIgnite()).get(uuid); return new MLPGroupUpdateTrainingContext<>(data, prevUpdate); }; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java index c237f86ba8789..42003219faf9b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java @@ -55,7 +55,7 @@ public static IgniteCache getOrCreate(Ignite i CacheConfiguration cfg = new CacheConfiguration<>(); // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC); // Atomic transactions only. cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java index 86074dd83b091..740fac6712127 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java @@ -25,21 +25,36 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; /** Multilayer perceptron group update training data. */ public class MLPGroupUpdateTrainingData { - /** */ + /** {@link ParameterUpdateCalculator}. */ private final ParameterUpdateCalculator updateCalculator; - /** */ + + /** + * Count of steps which should be done by each of parallel trainings before sending it's update for combining with + * other parallel trainings updates. + */ private final int stepsCnt; - /** */ + + /** + * Function used to reduce updates in one training (for example, sum all sequential gradient updates to get one + * gradient update). + */ private final IgniteFunction, U> updateReducer; - /** */ + + /** + * Supplier of batches in the form of (inputs, groundTruths). + */ private final IgniteSupplier> batchSupplier; - /** */ + + /** + * Loss function. + */ private final IgniteFunction loss; - /** */ + + /** Error tolerance. */ private final double tolerance; /** Construct multilayer perceptron group update training data with all parameters provided. */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java index 0f3d97448a7f8..2050ee59d9e1c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java @@ -26,29 +26,39 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; /** Multilayer perceptron group update training loop data. */ public class MLPGroupUpdateTrainingLoopData

implements Serializable { - /** */ + /** {@link ParameterUpdateCalculator}. */ private final ParameterUpdateCalculator updateCalculator; - /** */ + + /** + * Count of steps which should be done by each of parallel trainings before sending it's update for combining with + * other parallel trainings updates. + */ private final int stepsCnt; - /** */ + + /** Function used to reduce updates of all steps of given parallel training. */ private final IgniteFunction, P> updateReducer; - /** */ + + /** Previous update. */ private final P previousUpdate; - /** */ + + /** Supplier of batches. */ private final IgniteSupplier> batchSupplier; - /** */ + + /** Loss function. */ private final IgniteFunction loss; - /** */ + + /** Error tolerance. */ private final double tolerance; - /** */ + /** Key. */ private final GroupTrainerCacheKey key; - /** */ + + /** MLP. */ private final MultilayerPerceptron mlp; /** Create multilayer perceptron group update training loop data. */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java index 249136b4286b7..6e314f1d4e705 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java @@ -26,7 +26,7 @@ /** Meta-optimizer for multilayer perceptron. */ public class MLPMetaoptimizer

implements Metaoptimizer, P, P, P, ArrayList

> { - /** */ + /** Function used for reducing updates produced by parallel trainings. */ private final IgniteFunction, P> allUpdatesReducer; /** Construct metaoptimizer. */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java index 0c923952f3c33..059d15afe1ab9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java @@ -21,11 +21,12 @@ import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.nn.LossFunctions; +import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.trainers.local.LocalBatchTrainer; /** * Local batch trainer for MLP. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BarzilaiBorweinUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BarzilaiBorweinUpdater.java index 2190d86d879e9..9b98ad8e92a0e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BarzilaiBorweinUpdater.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BarzilaiBorweinUpdater.java @@ -34,8 +34,11 @@ public class BarzilaiBorweinUpdater implements Updater { /** * {@inheritDoc} */ - @Override public Vector compute(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient, int iteration) { - double learningRate = computeLearningRate(oldWeights != null ? oldWeights.copy() : null, oldGradient != null ? oldGradient.copy() : null, weights.copy(), gradient.copy()); + @Override public Vector compute(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient, + int iteration) { + double learningRate = computeLearningRate(oldWeights != null ? oldWeights.copy() : null, + oldGradient != null ? oldGradient.copy() : null, weights.copy(), gradient.copy()); + return weights.copy().minus(gradient.copy().times(learningRate)); } @@ -45,6 +48,7 @@ private double computeLearningRate(Vector oldWeights, Vector oldGradient, Vector return INITIAL_LEARNING_RATE; else { Vector gradientDiff = gradient.minus(oldGradient); + return weights.minus(oldWeights).dot(gradientDiff) / Math.pow(gradientDiff.kNorm(2.0), 2.0); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BaseParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BaseParametrized.java new file mode 100644 index 0000000000000..c5b2423ff136b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/BaseParametrized.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.optimization; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.util.Utils; + +/** + * Base interface for parametrized models. + * + * @param Model class. + */ +interface BaseParametrized> { + /** + * Get parameters vector. + * + * @return Parameters vector. + */ + Vector parameters(); + + /** + * Set parameters. + * + * @param vector Parameters vector. + */ + M setParameters(Vector vector); + + /** + * Return new model with given parameters vector. + * + * @param vector Parameters vector. + */ + default M withParameters(Vector vector) { + return Utils.copy(this).setParameters(vector); + } + + /** + * Get count of parameters of this model. + * + * @return Count of parameters of this model. + */ + int parametersCount(); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientDescent.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientDescent.java index f02bcb34251f4..15ed914e4ee16 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientDescent.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientDescent.java @@ -141,6 +141,7 @@ private Vector calculateDistributedGradient(SparseDistributedMatrix data, Vector cnt++; } } + return resGradient.divide(cnt); }, weights); @@ -187,7 +188,7 @@ private Matrix extractInputs(Matrix data) { /** Makes carrying of the gradient function and fixes data matrix. */ private IgniteFunction getLossGradientFunction(Matrix data) { if (data instanceof SparseDistributedMatrix) { - SparseDistributedMatrix distributedMatrix = (SparseDistributedMatrix) data; + SparseDistributedMatrix distributedMatrix = (SparseDistributedMatrix)data; if (distributedMatrix.getStorage().storageMode() == StorageConstants.ROW_STORAGE_MODE) return weights -> calculateDistributedGradient(distributedMatrix, weights); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientFunction.java index 7dc667401c8b0..a6a1e71a38701 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientFunction.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/GradientFunction.java @@ -27,5 +27,5 @@ @FunctionalInterface public interface GradientFunction extends Serializable { /** */ - Vector compute(Matrix inputs, Vector groundTruth, Vector point); + Vector compute(Matrix inputs, Vector groundTruth, Vector pnt); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/LossFunctions.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java similarity index 97% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/LossFunctions.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java index dff239ce64171..13fcb601d78de 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/LossFunctions.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/LossFunctions.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn; +package org.apache.ignite.ml.optimization; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/Parametrized.java similarity index 71% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/Parametrized.java index 1534a6dc8837e..d64d9c35a2550 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/Parametrized.java @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; - -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.Matrix; +package org.apache.ignite.ml.optimization; /** - * Interface for models which are smooth functions of their parameters. + * Interface for parametrized models. + * + * @param Type of model. */ -public interface SmoothParametrized> extends BaseSmoothParametrized, Model { +public interface Parametrized> extends BaseParametrized { } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SimpleUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SimpleUpdater.java index 0f6d5206bbc75..decbb86122938 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SimpleUpdater.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SimpleUpdater.java @@ -39,7 +39,8 @@ public SimpleUpdater(double learningRate) { /** * {@inheritDoc} */ - @Override public Vector compute(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient, int iteration) { + @Override public Vector compute(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient, + int iteration) { return weights.minus(gradient.times(learningRate)); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SmoothParametrized.java similarity index 79% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/SmoothParametrized.java index 8e2f0dff5fe13..862a78fd965b1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/SmoothParametrized.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.math.Matrix; @@ -26,7 +26,7 @@ /** * Interface for models which are smooth functions of their parameters. */ -interface BaseSmoothParametrized & Model> { +public interface SmoothParametrized> extends Parametrized, Model { /** * Compose function in the following way: feed output of this model as input to second argument to loss function. * After that we have a function g of three arguments: input, ground truth, parameters. @@ -42,25 +42,4 @@ interface BaseSmoothParametrized & Model loss, Matrix inputsBatch, Matrix truthBatch); - - /** - * Get parameters vector. - * - * @return Parameters vector. - */ - Vector parameters(); - - /** - * Set parameters. - * - * @param vector Parameters vector. - */ - M setParameters(Vector vector); - - /** - * Get count of parameters of this model. - * - * @return Count of parameters of this model. - */ - int parametersCount(); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java similarity index 98% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java index 8671285f70eab..b494b14fbbf95 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import java.io.Serializable; import java.util.List; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java similarity index 83% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java index 5caddd4cb235f..2bee506c11f89 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java @@ -15,17 +15,18 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.optimization.SmoothParametrized; /** * Class encapsulating Nesterov algorithm for MLP parameters updateCache. */ -public class NesterovUpdateCalculator +public class NesterovUpdateCalculator> implements ParameterUpdateCalculator { /** * Learning rate. @@ -53,20 +54,20 @@ public NesterovUpdateCalculator(double learningRate, double momentum) { } /** {@inheritDoc} */ - @Override public NesterovParameterUpdate calculateNewUpdate(SmoothParametrized mdl, + @Override public NesterovParameterUpdate calculateNewUpdate(M mdl, NesterovParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) { - // TODO:IGNITE-7350 create new updateCache object here instead of in-place change. + Vector prevUpdates = updaterParameters.prevIterationUpdates(); + + M newMdl = mdl; if (iteration > 0) { Vector curParams = mdl.parameters(); - mdl.setParameters(curParams.minus(updaterParameters.prevIterationUpdates().times(momentum))); + newMdl = mdl.withParameters(curParams.minus(prevUpdates.times(momentum))); } - Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); - updaterParameters.setPreviousUpdates(updaterParameters.prevIterationUpdates() - .plus(gradient.times(learningRate))); + Vector gradient = newMdl.differentiateByParameters(loss, inputs, groundTruth); - return updaterParameters; + return new NesterovParameterUpdate(prevUpdates.plus(gradient.times(learningRate))); } /** {@inheritDoc} */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java similarity index 89% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java index 77e37631ef75f..92f758365e04d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; @@ -26,19 +26,20 @@ * Interface for classes encapsulating parameters updateCache logic. * * @param Type of model to be updated. - * @param

Type of parameters needed for this updater. + * @param

Type of parameters needed for this update calculator. */ public interface ParameterUpdateCalculator { /** - * Initializes the updater. + * Initializes the update calculator. * * @param mdl Model to be trained. * @param loss Loss function. + * @return Initialized parameters. */ P init(M mdl, IgniteFunction loss); /** - * Calculate new updateCache. + * Calculate new update. * * @param mdl Model to be updated. * @param updaterParameters Updater parameters to updateCache. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java similarity index 99% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java index e2fa4d5a6bc10..fd0a045390c3a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import java.io.Serializable; import java.util.List; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java similarity index 97% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java index 99f39c9d138b9..80345d9b49d37 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; @@ -23,6 +23,7 @@ import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.util.MatrixUtil; +import org.apache.ignite.ml.optimization.SmoothParametrized; /** * Class encapsulating RProp algorithm. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java similarity index 97% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java index 7159621638b6f..22fc18afb1983 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import java.io.Serializable; import org.apache.ignite.ml.math.Vector; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java similarity index 95% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java rename to modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java index d2197d9c88eb4..291e63dbbb47e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.updaters; +package org.apache.ignite.ml.optimization.updatecalculators; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.optimization.SmoothParametrized; /** * Simple gradient descent parameters updater. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java new file mode 100644 index 0000000000000..071dc13389ad9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Contains update calculators. + */ +package org.apache.ignite.ml.optimization.updatecalculators; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java index 7a5f90bb6a491..20f861e84305e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java @@ -70,7 +70,7 @@ public R mapReduce(IgniteBiFunction mapper, IgniteFunction< for (RowColMatrixKey key : locKeys) { Map row = storage.cache().get(key); - for (Map.Entry cell : row.entrySet()) + for (Map.Entry cell : row.entrySet()) locMatrix.set(idx, cell.getKey(), cell.getValue()); idx++; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java index 7540d6f8da807..5efdf57dcbad9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java @@ -21,7 +21,8 @@ /** Trainer interface. */ public interface Trainer { - /** Train the model based on provided data. + /** + * Train the model based on provided data. * * @param data Data for training. * @return Trained model. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java index 67dcf7f5823d6..08e1f47d5ccb0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java @@ -27,8 +27,17 @@ import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID; /** - * Distributed step - * TODO: IGNITE-7350: add full description. + * Distributed step based on {@link Metaoptimizer}. + * + * @param Type of local context. + * @param Type of data in {@link GroupTrainerCacheKey}. + * @param Type of values of cache on which training is done. + * @param Type of distributed context. + * @param Type of data to which data returned by distributed initialization is mapped (see {@link Metaoptimizer}). + * @param Type of data to which data returned by data processor is mapped (see {@link Metaoptimizer}). + * @param Type of data which is processed in training loop step (see {@link Metaoptimizer}). + * @param Type of data returned by training loop step data processor (see {@link Metaoptimizer}). + * @param Type of data returned by initialization (see {@link Metaoptimizer}). */ class MetaoptimizerDistributedStep implements DistributedEntryProcessingStep { diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java index 534b5f9909526..3c3bdab833863 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java @@ -53,7 +53,7 @@ * @param Type of cache values. * @param Type of input of this chain. * @param Type of output of this chain. - * // TODO: IGNITE-7350 check if it is possible to integrate with {@link EntryProcessor}. + * // TODO: IGNITE-7405 check if it is possible to integrate with {@link EntryProcessor}. */ @FunctionalInterface public interface ComputationsChain { @@ -229,7 +229,7 @@ default ComputationsChain thenWhile(IgniteBiPredicate cond, } /** - * Combine two this chain to other: feed this chain as input to other, pass same context as second argument to both chains + * Combine this chain with other: feed this chain as input to other, pass same context as second argument to both chains * process method. * * @param next Next chain. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java similarity index 95% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java rename to modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java index 8579b825e32b3..ab31f9ff8d60f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn.trainers.local; +package org.apache.ignite.ml.trainers.local; import org.apache.ignite.IgniteLogger; import org.apache.ignite.lang.IgniteBiTuple; @@ -27,8 +27,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.util.MatrixUtil; -import org.apache.ignite.ml.nn.LocalBatchTrainerInput; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; /** * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from @@ -83,7 +82,7 @@ public LocalBatchTrainer(IgniteFunction updater = updaterSupplier.get(); + ParameterUpdateCalculator updater = updaterSupplier.get(); P updaterParams = updater.init(mdl, loss); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/LocalBatchTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java similarity index 96% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/LocalBatchTrainerInput.java rename to modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java index 3a87d026ff355..38b7592999225 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/LocalBatchTrainerInput.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.nn; +package org.apache.ignite.ml.trainers.local; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.Model; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java similarity index 91% rename from modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java rename to modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java index 13bc3c8899483..8a15b73fae889 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java @@ -17,6 +17,6 @@ /** * - * Contains parameters updaters. + * Contains local trainers. */ -package org.apache.ignite.ml.nn.updaters; \ No newline at end of file +package org.apache.ignite.ml.trainers.local; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java index 4472300e84503..206e1e9fc71cf 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java @@ -60,21 +60,22 @@ public static T copy(T orig) { } /** - * Select k distinct integers from range [0, n) with reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling. + * Select k distinct integers from range [0, n) with reservoir sampling: + * https://en.wikipedia.org/wiki/Reservoir_sampling. * * @param n Number specifying left end of range of integers to pick values from. * @param k Count specifying how many integers should be picked. + * @param rand RNG. * @return Array containing k distinct integers from range [0, n); */ - public static int[] selectKDistinct(int n, int k) { + public static int[] selectKDistinct(int n, int k, Random rand) { int i; + Random r = rand != null ? rand : new Random(); int res[] = new int[k]; for (i = 0; i < k; i++) res[i] = i; - Random r = new Random(); - for (; i < n; i++) { int j = r.nextInt(i + 1); @@ -84,4 +85,17 @@ public static int[] selectKDistinct(int n, int k) { return res; } + + /** + * Select k distinct integers from range [0, n) with reservoir sampling: + * https://en.wikipedia.org/wiki/Reservoir_sampling. + * Equivalent to {@code selectKDistinct(n, k, new Random())}. + * + * @param n Number specifying left end of range of integers to pick values from. + * @param k Count specifying how many integers should be picked. + * @return Array containing k distinct integers from range [0, n); + */ + public static int[] selectKDistinct(int n, int k) { + return selectKDistinct(n, k, new Random()); + } } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 35ffdbcba00af..d5d6d9484f3a7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; +import org.apache.ignite.ml.optimization.OptimizationTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite; import org.apache.ignite.ml.trees.DecisionTreesTestSuite; @@ -39,7 +40,8 @@ KNNTestSuite.class, LocalModelsTest.class, MLPTestSuite.class, - TrainersGroupTestSuite.class + TrainersGroupTestSuite.class, + OptimizationTestSuite.class }) public class IgniteMLTestSuite { // No-op. diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java index 7f990c9e41a6d..151fead16381a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java @@ -31,7 +31,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; @@ -93,7 +93,7 @@ public void testXOR() { } } - int totalCnt = 100; + int totalCnt = 20; int failCnt = 0; double maxFailRatio = 0.3; MLPGroupUpdateTrainer trainer = MLPGroupUpdateTrainer.getDefault(ignite). @@ -104,7 +104,7 @@ public void testXOR() { for (int i = 0; i < totalCnt; i++) { MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf, - new RandomInitializer(rnd), 6, cache, 4); + new RandomInitializer(new Random(123L)), 6, cache, 4, new Random(123L)); MultilayerPerceptron mlp = trainer.train(trainerInput); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java index e659e161b7d9b..b4c14e1d79560 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java @@ -27,10 +27,11 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer; -import org.apache.ignite.ml.nn.updaters.NesterovUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.junit.Test; /** diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java index d757fcb7c11f8..555abce1a5e67 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java @@ -25,6 +25,7 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; import org.junit.Assert; import org.junit.Test; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java index 07a9e7442ef17..8bc0a6d348b4e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java @@ -24,6 +24,7 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; +import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput; import org.apache.ignite.ml.util.Utils; /** diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java index d9e4060e315b4..112aaded644e9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java @@ -37,7 +37,7 @@ import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.util.MnistUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java index eab52883430e6..cda0413afeed8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java @@ -28,12 +28,12 @@ import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.Activators; -import org.apache.ignite.ml.nn.LossFunctions; +import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.junit.Test; import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.createDataset; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java new file mode 100644 index 0000000000000..0ae6e4c70bbe4 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.optimization; + +import org.apache.ignite.ml.optimization.util.SparseDistributedMatrixMapReducerTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for group trainer tests. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + GradientDescentTest.class, + SparseDistributedMatrixMapReducerTest.class +}) +public class OptimizationTestSuite { +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java index d5b4edee05945..0a49fe0fd7b74 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java @@ -83,7 +83,7 @@ public TestGroupTrainer(Ignite ignite) { /** {@inheritDoc} */ @Override protected ComputationsChain trainingLoopStep() { - // TODO:IGNITE-7350 here we should explicitly create variable because we cannot infer context type, think about it. + // TODO:IGNITE-7405 here we should explicitly create variable because we cannot infer context type, think about it. ComputationsChain chain = Chains. create(new TestTrainingLoopStep()); return chain.