diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java index 44fb77e929e37..98745a4dd4c16 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java @@ -30,7 +30,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.trainers.DatasetTrainer; diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java index 52ee330870ae0..8530045805936 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java @@ -31,8 +31,8 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; /** * Run logistic regression model based on diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java deleted file mode 100644 index 962fdac2c6c19..0000000000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.logistic.multiclass; - -import java.io.FileNotFoundException; -import java.util.Arrays; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.examples.ml.util.MLSandboxDatasets; -import org.apache.ignite.examples.ml.util.SandboxMLCache; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; - -/** - * Run Logistic Regression multi-class classification trainer ({@link LogRegressionMultiClassModel}) over distributed - * dataset to build two models: one with minmaxscaling and one without minmaxscaling. - *

- * Code in this example launches Ignite grid and fills the cache with test data points (preprocessed - * Glass dataset).

- *

- * After that it trains two logistic regression models based on the specified data - one model is with minmaxscaling - * and one without minmaxscaling.

- *

- * Finally, this example loops over the test set of data points, applies the trained models to predict the target value, - * compares prediction to expected outcome (ground truth), and builds - * confusion matrices.

- *

- * You can change the test data used in this example and re-run it to explore this algorithm further.

- */ -public class LogRegressionMultiClassClassificationExample { - /** Run example. */ - public static void main(String[] args) throws FileNotFoundException { - System.out.println(); - System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteCache dataCache = new SandboxMLCache(ignite) - .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION); - - LogRegressionMultiClassTrainer trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - )) - .withAmountOfIterations(100000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - LogRegressionMultiClassModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()), - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model"); - System.out.println(mdl.toString()); - - MinMaxScalerTrainer normalizationTrainer = new MinMaxScalerTrainer<>(); - - IgniteBiFunction preprocessor = normalizationTrainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()) - ); - - LogRegressionMultiClassModel mdlWithNormalization = trainer.fit( - ignite, - dataCache, - preprocessor, - (k, v) -> v.get(0) - ); - - System.out.println(">>> Logistic Regression Multi-class model with normalization"); - System.out.println(mdlWithNormalization.toString()); - - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|"); - System.out.println(">>> ----------------------------------------------------------------"); - - int amountOfErrors = 0; - int amountOfErrorsWithNormalization = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - - try (QueryCursor> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry observation : observations) { - Vector val = observation.getValue(); - Vector inputs = val.copyOfRange(1, val.size()); - double groundTruth = val.get(0); - - double prediction = mdl.apply(inputs); - double predictionWithNormalization = mdlWithNormalization.apply(inputs); - - totalAmount++; - - // Collect data for model - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); - int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtx[idx1][idx2]++; - - // Collect data for model with normalization - if(groundTruth != predictionWithNormalization) - amountOfErrorsWithNormalization++; - - idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2); - idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtxWithNormalization[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth); - } - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println("\n>>> -----------------Logistic Regression model-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - - System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization)); - - System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed."); - } - } - } -} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java deleted file mode 100644 index c7b7fe81690b4..0000000000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * ML multi-class logistic regression examples. - */ -package org.apache.ignite.examples.ml.regression.logistic.multiclass; diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java similarity index 90% rename from examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java rename to examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java index 679bd776a3772..d9d1805df9fdb 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.examples.ml.svm.binary; +package org.apache.ignite.examples.ml.svm; import java.io.FileNotFoundException; import java.util.Arrays; @@ -28,11 +28,11 @@ import org.apache.ignite.examples.ml.util.MLSandboxDatasets; import org.apache.ignite.examples.ml.util.SandboxMLCache; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer; /** - * Run SVM binary-class classification model ({@link SVMLinearBinaryClassificationModel}) over distributed dataset. + * Run SVM binary-class classification model ({@link SVMLinearClassificationModel}) over distributed dataset. *

* Code in this example launches Ignite grid and fills the cache with test data points (based on the * Iris dataset).

@@ -57,9 +57,9 @@ public static void main(String[] args) throws FileNotFoundException { IgniteCache dataCache = new SandboxMLCache(ignite) .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer(); - SVMLinearBinaryClassificationModel mdl = trainer.fit( + SVMLinearClassificationModel mdl = trainer.fit( ignite, dataCache, (k, v) -> v.copyOfRange(1, v.size()), diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java deleted file mode 100644 index 22c9ad7f2c77f..0000000000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * SVM Binary Classification Examples. - */ -package org.apache.ignite.examples.ml.svm.binary; \ No newline at end of file diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java deleted file mode 100644 index 987ac41cbfff1..0000000000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.svm.multiclass; - -import java.io.FileNotFoundException; -import java.util.Arrays; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.examples.ml.util.MLSandboxDatasets; -import org.apache.ignite.examples.ml.util.SandboxMLCache; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer; - -/** - * Run SVM multi-class classification trainer ({@link SVMLinearMultiClassClassificationModel}) over distributed dataset - * to build two models: one with minmaxscaling and one without minmaxscaling. - *

- * Code in this example launches Ignite grid and fills the cache with test data points (preprocessed - * Glass dataset).

- *

- * After that it trains two SVM multi-class models based on the specified data - one model is with minmaxscaling - * and one without minmaxscaling.

- *

- * Finally, this example loops over the test set of data points, applies the trained models to predict what cluster - * does this point belong to, compares prediction to expected outcome (ground truth), and builds - * confusion matrix.

- *

- * You can change the test data used in this example and re-run it to explore this algorithm further.

- * NOTE: the smallest 3rd class could be classified via linear SVM here. - */ -public class SVMMultiClassClassificationExample { - /** Run example. */ - public static void main(String[] args) throws FileNotFoundException { - System.out.println(); - System.out.println(">>> SVM Multi-class classification model over cached dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteCache dataCache = new SandboxMLCache(ignite) - .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer(); - - SVMLinearMultiClassClassificationModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()), - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model"); - System.out.println(mdl.toString()); - - MinMaxScalerTrainer minMaxScalerTrainer = new MinMaxScalerTrainer<>(); - - IgniteBiFunction preprocessor = minMaxScalerTrainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()) - ); - - SVMLinearMultiClassClassificationModel mdlWithScaling = trainer.fit( - ignite, - dataCache, - preprocessor, - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model with MinMaxScaling"); - System.out.println(mdlWithScaling.toString()); - - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println(">>> | Prediction\t| Prediction with MinMaxScaling\t| Ground Truth\t|"); - System.out.println(">>> ----------------------------------------------------------------"); - - int amountOfErrors = 0; - int amountOfErrorsWithMinMaxScaling = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - int[][] confusionMtxWithMinMaxScaling = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - - try (QueryCursor> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry observation : observations) { - Vector val = observation.getValue(); - Vector inputs = val.copyOfRange(1, val.size()); - double groundTruth = val.get(0); - - double prediction = mdl.apply(inputs); - double predictionWithMinMaxScaling = mdlWithScaling.apply(inputs); - - totalAmount++; - - // Collect data for model - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); - int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtx[idx1][idx2]++; - - // Collect data for model with minmaxscaling - if (groundTruth != predictionWithMinMaxScaling) - amountOfErrorsWithMinMaxScaling++; - - idx1 = (int)predictionWithMinMaxScaling == 1 ? 0 : ((int)predictionWithMinMaxScaling == 3 ? 1 : 2); - idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtxWithMinMaxScaling[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithMinMaxScaling, groundTruth); - } - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println("\n>>> -----------------SVM model-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - - System.out.println("\n>>> -----------------SVM model with MinMaxScaling-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithMinMaxScaling); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithMinMaxScaling / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithMinMaxScaling)); - - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - } - } - } -} diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java deleted file mode 100644 index 8b685a4391122..0000000000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * SVM Multi-class Classification Examples. - */ -package org.apache.ignite.examples.ml.svm.multiclass; diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java index b98b0ebc210ed..2c6a820fa4b97 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java @@ -32,8 +32,8 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java similarity index 99% rename from modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java rename to modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java index f2065325a2354..5cc44f8909fd4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.regressions.logistic.binomial; +package org.apache.ignite.ml.regressions.logistic; import java.io.Serializable; import java.util.Objects; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java similarity index 99% rename from modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java rename to modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java index 47fa59da6adbc..cdbfe4cb14d8e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.ignite.ml.regressions.logistic.binomial; +package org.apache.ignite.ml.regressions.logistic; import java.io.Serializable; import java.util.Arrays; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java deleted file mode 100644 index d32b1ee876b30..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * Contains binomial logistic regression. - */ -package org.apache.ignite.ml.regressions.logistic.binomial; \ No newline at end of file diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java deleted file mode 100644 index a7c9118bb0b1f..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.multiclass; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.TreeMap; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; - -/** Base class for multi-classification model for set of Logistic Regression classifiers. */ -public class LogRegressionMultiClassModel implements Model, Exportable, Serializable { - /** */ - private static final long serialVersionUID = -114986533350117L; - - /** List of models associated with each class. */ - private Map models; - - /** */ - public LogRegressionMultiClassModel() { - this.models = new HashMap<>(); - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - TreeMap maxMargins = new TreeMap<>(); - - models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k)); - - return maxMargins.lastEntry().getValue(); - } - - /** {@inheritDoc} */ - @Override public

void saveModel(Exporter exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - - if (o == null || getClass() != o.getClass()) - return false; - - LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o; - - return Objects.equals(models, mdl.models); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(models); - } - - /** {@inheritDoc} */ - @Override public String toString() { - StringBuilder wholeStr = new StringBuilder(); - - models.forEach((clsLb, mdl) -> - wholeStr - .append("The class with label ") - .append(clsLb) - .append(" has classifier: ") - .append(mdl.toString()) - .append(System.lineSeparator()) - ); - - return wholeStr.toString(); - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } - - /** - * Adds a specific Log Regression binary classifier to the bunch of same classifiers. - * - * @param clsLb The class label for the added model. - * @param mdl The model. - */ - public void add(double clsLb, LogisticRegressionModel mdl) { - models.put(clsLb, mdl); - } - - /** - * @param clsLb Class label. - * @return model for class label if it exists. - */ - public Optional getModel(Double clsLb) { - return Optional.ofNullable(models.get(clsLb)); - } -} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java deleted file mode 100644 index fd5a624aaa994..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.multiclass; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; - -/** - * All common parameters are shared with bunch of binary classification trainers. - */ -public class LogRegressionMultiClassTrainer

- extends SingleLabelDatasetTrainer { - /** Update strategy. */ - private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ); - - /** Max number of iteration. */ - private int amountOfIterations = 100; - - /** Batch size. */ - private int batchSize = 100; - - /** Number of local iterations. */ - private int amountOfLocIterations = 100; - - /** Seed for random generator. */ - private long seed = 1234L; - - /** - * Trains model based on the specified data. - * - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @return Model. - */ - @Override public LogRegressionMultiClassModel fit(DatasetBuilder datasetBuilder, - IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { - List classes = extractClassLabels(datasetBuilder, lbExtractor); - - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl, - DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { - - List classes = extractClassLabels(datasetBuilder, lbExtractor); - - if(classes.isEmpty()) - return getLastTrainedModelOrThrowEmptyDatasetException(newMdl); - - LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); - - classes.forEach(clsLb -> { - LogisticRegressionSGDTrainer trainer = - new LogisticRegressionSGDTrainer<>() - .withBatchSize(batchSize) - .withLocIterations(amountOfLocIterations) - .withMaxIterations(amountOfIterations) - .withSeed(seed); - - IgniteBiFunction lbTransformer = (k, v) -> { - Double lb = lbExtractor.apply(k, v); - - if (lb.equals(clsLb)) - return 1.0; - else - return 0.0; - }; - - LogisticRegressionModel mdl = Optional.ofNullable(newMdl) - .flatMap(multiClassModel -> multiClassModel.getModel(clsLb)) - .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)) - .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); - - multiClsMdl.add(clsLb, mdl); - }); - - return multiClsMdl; - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(LogRegressionMultiClassModel mdl) { - return true; - } - - /** Iterates among dataset and collects class labels. */ - private List extractClassLabels(DatasetBuilder datasetBuilder, - IgniteBiFunction lbExtractor) { - assert datasetBuilder != null; - - PartitionDataBuilder partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); - - List res = new ArrayList<>(); - - try (Dataset dataset = datasetBuilder.build( - envBuilder, - (env, upstream, upstreamSize) -> new EmptyContext(), - partDataBuilder - )) { - final Set clsLabels = dataset.compute(data -> { - final Set locClsLabels = new HashSet<>(); - - final double[] lbs = data.getY(); - - for (double lb : lbs) - locClsLabels.add(lb); - - return locClsLabels; - }, (a, b) -> { - if (a == null) - return b == null ? new HashSet<>() : b; - if (b == null) - return a; - return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); - }); - - if (clsLabels != null) - res.addAll(clsLabels); - - } - catch (Exception e) { - throw new RuntimeException(e); - } - return res; - } - - /** - * Set up the regularization parameter. - * - * @param batchSize The size of learning batch. - * @return Trainer with new batch size parameter value. - */ - public LogRegressionMultiClassTrainer withBatchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - - /** - * Get the batch size. - * - * @return The parameter value. - */ - public double getBatchSize() { - return batchSize; - } - - /** - * Get the amount of outer iterations of SGD algorithm. - * - * @return The parameter value. - */ - public int getAmountOfIterations() { - return amountOfIterations; - } - - /** - * Set up the amount of outer iterations. - * - * @param amountOfIterations The parameter value. - * @return Trainer with new amountOfIterations parameter value. - */ - public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) { - this.amountOfIterations = amountOfIterations; - return this; - } - - /** - * Get the amount of local iterations. - * - * @return The parameter value. - */ - public int getAmountOfLocIterations() { - return amountOfLocIterations; - } - - /** - * Set up the amount of local iterations of SGD algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new amountOfLocIterations parameter value. - */ - public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) { - this.amountOfLocIterations = amountOfLocIterations; - return this; - } - - /** - * Set up the random seed parameter. - * - * @param seed Seed for random generator. - * @return Trainer with new seed parameter value. - */ - public LogRegressionMultiClassTrainer withSeed(long seed) { - this.seed = seed; - return this; - } - - /** - * Get the seed for random generator. - * - * @return The parameter value. - */ - public long seed() { - return seed; - } - - /** - * Set up the updates strategy. - * - * @param updatesStgy Update strategy. - * @return Trainer with new update strategy parameter value. - */ - public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { - this.updatesStgy = updatesStgy; - return this; - } - - /** - * Get the update strategy. - * - * @return The parameter value. - */ - public UpdatesStrategy getUpdatesStgy() { - return updatesStgy; - } -} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java deleted file mode 100644 index 2e7b9478f76f2..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * Contains multi-class logistic regression. - */ -package org.apache.ignite.ml.regressions.logistic.multiclass; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java similarity index 87% rename from modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java rename to modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java index f5d2b28732767..579fdb210dc33 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java @@ -27,7 +27,7 @@ /** * Base class for SVM linear classification model. */ -public class SVMLinearBinaryClassificationModel implements Model, Exportable, Serializable { +public class SVMLinearClassificationModel implements Model, Exportable, Serializable { /** */ private static final long serialVersionUID = -996984622291440226L; @@ -44,7 +44,7 @@ public class SVMLinearBinaryClassificationModel implements Model private double intercept; /** */ - public SVMLinearBinaryClassificationModel(Vector weights, double intercept) { + public SVMLinearClassificationModel(Vector weights, double intercept) { this.weights = weights; this.intercept = intercept; } @@ -55,7 +55,7 @@ public SVMLinearBinaryClassificationModel(Vector weights, double intercept) { * @param isKeepingRawLabels The parameter value. * @return Model with new isKeepingRawLabels parameter value. */ - public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabels) { + public SVMLinearClassificationModel withRawLabels(boolean isKeepingRawLabels) { this.isKeepingRawLabels = isKeepingRawLabels; return this; } @@ -66,7 +66,7 @@ public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabe * @param threshold The parameter value. * @return Model with new threshold parameter value. */ - public SVMLinearBinaryClassificationModel withThreshold(double threshold) { + public SVMLinearClassificationModel withThreshold(double threshold) { this.threshold = threshold; return this; } @@ -77,7 +77,7 @@ public SVMLinearBinaryClassificationModel withThreshold(double threshold) { * @param weights The parameter value. * @return Model with new weights parameter value. */ - public SVMLinearBinaryClassificationModel withWeights(Vector weights) { + public SVMLinearClassificationModel withWeights(Vector weights) { this.weights = weights; return this; } @@ -88,7 +88,7 @@ public SVMLinearBinaryClassificationModel withWeights(Vector weights) { * @param intercept The parameter value. * @return Model with new intercept parameter value. */ - public SVMLinearBinaryClassificationModel withIntercept(double intercept) { + public SVMLinearClassificationModel withIntercept(double intercept) { this.intercept = intercept; return this; } @@ -139,7 +139,7 @@ public double intercept() { } /** {@inheritDoc} */ - @Override public

void saveModel(Exporter exporter, P path) { + @Override public

void saveModel(Exporter exporter, P path) { exporter.save(this, path); } @@ -150,7 +150,7 @@ public double intercept() { if (o == null || getClass() != o.getClass()) return false; - SVMLinearBinaryClassificationModel mdl = (SVMLinearBinaryClassificationModel)o; + SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)o; return Double.compare(mdl.intercept, intercept) == 0 && Double.compare(mdl.threshold, threshold) == 0 diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java similarity index 92% rename from modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java rename to modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java index 7ceb53b409c8f..67484ea59870b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java @@ -39,7 +39,7 @@ * and 1 labels for two classes and makes binary classification.

The paper about this algorithm could be found * here https://arxiv.org/abs/1409.1458. */ -public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrainer { +public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer { /** Amount of outer SDCA algorithm iterations. */ private int amountOfIterations = 200; @@ -60,14 +60,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai * @param lbExtractor Label extractor. * @return Model. */ - @Override public SVMLinearBinaryClassificationModel fit(DatasetBuilder datasetBuilder, + @Override public SVMLinearClassificationModel fit(DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); } /** {@inheritDoc} */ - @Override protected SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl, + @Override protected SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { @@ -117,11 +117,11 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai } catch (Exception e) { throw new RuntimeException(e); } - return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); + return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); } /** {@inheritDoc} */ - @Override protected boolean checkState(SVMLinearBinaryClassificationModel mdl) { + @Override protected boolean checkState(SVMLinearClassificationModel mdl) { return true; } @@ -129,7 +129,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai * @param mdl Model. * @return vector of model weights with intercept. */ - private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) { + private Vector getStateVector(SVMLinearClassificationModel mdl) { double intercept = mdl.intercept(); Vector weights = mdl.weights(); @@ -262,7 +262,7 @@ else if (alpha >= 1.0) * @param lambda The regularization parameter. Should be more than 0.0. * @return Trainer with new lambda parameter value. */ - public SVMLinearBinaryClassificationTrainer withLambda(double lambda) { + public SVMLinearClassificationTrainer withLambda(double lambda) { assert lambda > 0.0; this.lambda = lambda; return this; @@ -292,7 +292,7 @@ public int getAmountOfIterations() { * @param amountOfIterations The parameter value. * @return Trainer with new amountOfIterations parameter value. */ - public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) { + public SVMLinearClassificationTrainer withAmountOfIterations(int amountOfIterations) { this.amountOfIterations = amountOfIterations; return this; } @@ -312,7 +312,7 @@ public int getAmountOfLocIterations() { * @param amountOfLocIterations The parameter value. * @return Trainer with new amountOfLocIterations parameter value. */ - public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { + public SVMLinearClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { this.amountOfLocIterations = amountOfLocIterations; return this; } @@ -332,7 +332,7 @@ public long getSeed() { * @param seed The parameter value. * @return Model with new seed parameter value. */ - public SVMLinearBinaryClassificationTrainer withSeed(long seed) { + public SVMLinearClassificationTrainer withSeed(long seed) { this.seed = seed; return this; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java deleted file mode 100644 index 46bf4b2976c70..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.TreeMap; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; - -/** Base class for multi-classification model for set of SVM classifiers. */ -public class SVMLinearMultiClassClassificationModel implements Model, Exportable, Serializable { - /** */ - private static final long serialVersionUID = -667986511191350227L; - - /** List of models associated with each class. */ - private Map models; - - /** */ - public SVMLinearMultiClassClassificationModel() { - this.models = new HashMap<>(); - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - TreeMap maxMargins = new TreeMap<>(); - - models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k)); - - return maxMargins.lastEntry().getValue(); - } - - /** {@inheritDoc} */ - @Override public

void saveModel(Exporter exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - - if (o == null || getClass() != o.getClass()) - return false; - - SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o; - - return Objects.equals(models, mdl.models); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(models); - } - - /** {@inheritDoc} */ - @Override public String toString() { - StringBuilder wholeStr = new StringBuilder(); - - models.forEach((clsLb, mdl) -> - wholeStr - .append("The class with label ") - .append(clsLb) - .append(" has classifier: ") - .append(mdl.toString()) - .append(System.lineSeparator()) - ); - - return wholeStr.toString(); - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } - - /** - * Adds a specific SVM binary classifier to the bunch of same classifiers. - * - * @param clsLb The class label for the added model. - * @param mdl The model. - */ - public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) { - models.put(clsLb, mdl); - } - - /** - * @param clsLb Class label. - * @return model trained for target class if it exists. - */ - public Optional getModelForClass(double clsLb) { - return Optional.of(models.get(clsLb)); - } -} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java deleted file mode 100644 index 94f2a990133b4..0000000000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; - -/** - * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient - * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function. - * - * All common parameters are shared with bunch of binary classification trainers. - */ -public class SVMLinearMultiClassClassificationTrainer - extends SingleLabelDatasetTrainer { - /** Amount of outer SDCA algorithm iterations. */ - private int amountOfIterations = 20; - - /** Amount of local SDCA algorithm iterations. */ - private int amountOfLocIterations = 50; - - /** Regularization parameter. */ - private double lambda = 0.2; - - /** The seed number. */ - private long seed = 1234L; - - /** - * Trains model based on the specified data. - * - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @return Model. - */ - @Override public SVMLinearMultiClassClassificationModel fit(DatasetBuilder datasetBuilder, - IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public SVMLinearMultiClassClassificationModel updateModel( - SVMLinearMultiClassClassificationModel mdl, - DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { - - List classes = extractClassLabels(datasetBuilder, lbExtractor); - if (classes.isEmpty()) - return getLastTrainedModelOrThrowEmptyDatasetException(mdl); - - SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel(); - - classes.forEach(clsLb -> { - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() - .withAmountOfIterations(this.getAmountOfIterations()) - .withAmountOfLocIterations(this.getAmountOfLocIterations()) - .withLambda(this.getLambda()) - .withSeed(this.seed); - - IgniteBiFunction lbTransformer = (k, v) -> { - Double lb = lbExtractor.apply(k, v); - - if (lb.equals(clsLb)) - return 1.0; - else - return 0.0; - }; - - SVMLinearBinaryClassificationModel updatedMdl; - - if (mdl == null) - updatedMdl = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer); - else - updatedMdl = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer); - multiClsMdl.add(clsLb, updatedMdl); - }); - - return multiClsMdl; - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) { - return true; - } - - /** - * Trains model based on the specified data. - * - * @param svmTrainer Prepared SVM trainer. - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - */ - private SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer, - DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { - - return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor); - } - - /** - * Updates already learned model or fit new model if there is no model for current class label. - * - * @param multiClsMdl Learning multi-class model. - * @param clsLb Current class label. - * @param svmTrainer Prepared SVM trainer. - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - */ - private SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl, - Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder datasetBuilder, - IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { - - return multiClsMdl.getModelForClass(clsLb) - .map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor)) - .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor)); - } - - /** Iterates among dataset and collects class labels. */ - private List extractClassLabels(DatasetBuilder datasetBuilder, - IgniteBiFunction lbExtractor) { - assert datasetBuilder != null; - - PartitionDataBuilder partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); - - List res = new ArrayList<>(); - - try (Dataset dataset = datasetBuilder.build( - envBuilder, - (env, upstream, upstreamSize) -> new EmptyContext(), - partDataBuilder - )) { - final Set clsLabels = dataset.compute(data -> { - final Set locClsLabels = new HashSet<>(); - - final double[] lbs = data.getY(); - - for (double lb : lbs) - locClsLabels.add(lb); - - return locClsLabels; - }, (a, b) -> { - if (a == null) - return b == null ? new HashSet<>() : b; - if (b == null) - return a; - return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); - }); - - if (clsLabels != null) - res.addAll(clsLabels); - } catch (Exception e) { - throw new RuntimeException(e); - } - return res; - } - - /** - * Set up the regularization parameter. - * - * @param lambda The regularization parameter. Should be more than 0.0. - * @return Trainer with new lambda parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) { - assert lambda > 0.0; - this.lambda = lambda; - return this; - } - - /** - * Get the regularization lambda. - * - * @return The property value. - */ - public double getLambda() { - return lambda; - } - - /** - * Gets the amount of outer iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfIterations() { - return amountOfIterations; - } - - /** - * Set up the amount of outer iterations of SCDA algorithm. - * - * @param amountOfIterations The parameter value. - * @return Trainer with new amountOfIterations parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) { - this.amountOfIterations = amountOfIterations; - return this; - } - - /** - * Gets the amount of local iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfLocIterations() { - return amountOfLocIterations; - } - - /** - * Set up the amount of local iterations of SCDA algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new amountOfLocIterations parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { - this.amountOfLocIterations = amountOfLocIterations; - return this; - } - - /** - * Gets the seed number. - * - * @return The property value. - */ - public long getSeed() { - return seed; - } - - /** - * Set up the seed. - * - * @param seed The parameter value. - * @return Model with new seed parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withSeed(long seed) { - this.seed = seed; - return this; - } -} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java index 745eac946f4f7..e951145e948a6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java @@ -34,16 +34,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.structures.Dataset; import org.apache.ignite.ml.structures.DatasetRow; import org.apache.ignite.ml.structures.FeatureMetadata; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -92,17 +89,7 @@ public void test() { test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2)); - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - mdl.add(1, new LogisticRegressionModel(new DenseVector(), 1.0)); - test(mdl, new LogRegressionMultiClassModel()); - - test(new LinearRegressionModel(null, 1.0), new LinearRegressionModel(null, 0.5)); - - SVMLinearMultiClassClassificationModel mdl1 = new SVMLinearMultiClassClassificationModel(); - mdl1.add(1, new SVMLinearBinaryClassificationModel(new DenseVector(), 1.0)); - test(mdl1, new SVMLinearMultiClassClassificationModel()); - - test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5)); + test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5)); test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new ANNClassificationTrainer.CentroidStat())); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java index ca3f0b5549390..c5b2ffe2d59aa 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java @@ -43,12 +43,10 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; import org.junit.Assert; import org.junit.Test; @@ -99,36 +97,11 @@ public void importExportLinearRegressionModelTest() throws IOException { @Test public void importExportSVMBinaryClassificationModelTest() throws IOException { executeModelTest(mdlFilePath -> { - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3); - Exporter exporter = new FileExporter<>(); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(new DenseVector(new double[] {1, 2}), 3); + Exporter exporter = new FileExporter<>(); mdl.saveModel(exporter, mdlFilePath); - SVMLinearBinaryClassificationModel load = exporter.load(mdlFilePath); - - Assert.assertNotNull(load); - Assert.assertEquals("", mdl, load); - - return null; - }); - } - - /** */ - @Test - public void importExportSVMMultiClassClassificationModelTest() throws IOException { - executeModelTest(mdlFilePath -> { - SVMLinearBinaryClassificationModel binaryMdl1 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3); - SVMLinearBinaryClassificationModel binaryMdl2 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{2, 3}), 4); - SVMLinearBinaryClassificationModel binaryMdl3 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{3, 4}), 5); - - SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel(); - mdl.add(1, binaryMdl1); - mdl.add(2, binaryMdl2); - mdl.add(3, binaryMdl3); - - Exporter exporter = new FileExporter<>(); - mdl.saveModel(exporter, mdlFilePath); - - SVMLinearMultiClassClassificationModel load = exporter.load(mdlFilePath); + SVMLinearClassificationModel load = exporter.load(mdlFilePath); Assert.assertNotNull(load); Assert.assertEquals("", mdl, load); @@ -154,23 +127,6 @@ public void importExportLogisticRegressionModelTest() throws IOException { }); } - /** */ - @Test - public void importExportLogRegressionMultiClassModelTest() throws IOException { - executeModelTest(mdlFilePath -> { - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - Exporter exporter = new FileExporter<>(); - mdl.saveModel(exporter, mdlFilePath); - - LogRegressionMultiClassModel load = exporter.load(mdlFilePath); - - Assert.assertNotNull(load); - Assert.assertEquals("", mdl, load); - - return null; - }); - } - /** */ private void executeModelTest(Function code) throws IOException { Path mdlPath = Files.createTempFile(null, null); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java index 9842d92d9673c..61f9fc40a83b6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java @@ -28,8 +28,8 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Assert; import org.junit.Test; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java index e59d51521e6a6..84459009a2fe8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java @@ -20,7 +20,7 @@ import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.junit.Test; /** diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java index d517ce6fe3377..fec62209c0f66 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java @@ -29,7 +29,7 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Test; /** diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 021b567b4201f..2fa69ef20c007 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -20,7 +20,6 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest; -import org.apache.ignite.ml.regressions.logistic.LogRegMultiClassTrainerTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest; import org.junit.runner.RunWith; @@ -35,8 +34,7 @@ LinearRegressionLSQRTrainerTest.class, LinearRegressionSGDTrainerTest.class, LogisticRegressionModelTest.class, - LogisticRegressionSGDTrainerTest.class, - LogRegMultiClassTrainerTest.class + LogisticRegressionSGDTrainerTest.class }) public class RegressionsTestSuite { // No-op. diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java index 66871b04445a8..36d0fc7301fb3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java @@ -21,8 +21,6 @@ import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; import org.junit.Test; import static org.junit.Assert.assertTrue; @@ -40,9 +38,9 @@ public void testPredict() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); LinearRegressionModel mdl = new LinearRegressionModel(weights, 1.0); - assertTrue(mdl.toString().length() > 0); - assertTrue(mdl.toString(true).length() > 0); - assertTrue(mdl.toString(false).length() > 0); + assertTrue(!mdl.toString().isEmpty()); + assertTrue(!mdl.toString(true).isEmpty()); + assertTrue(!mdl.toString(false).isEmpty()); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); @@ -60,21 +58,6 @@ public void testPredict() { TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.apply(observation), PRECISION); } - /** */ - @Test - public void testPredictWithMultiClasses() { - Vector weights1 = new DenseVector(new double[]{10.0, 0.0}); - Vector weights2 = new DenseVector(new double[]{0.0, 10.0}); - Vector weights3 = new DenseVector(new double[]{-1.0, -1.0}); - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - mdl.add(1, new LogisticRegressionModel(weights1, 0.0).withRawLabels(true)); - mdl.add(2, new LogisticRegressionModel(weights2, 0.0).withRawLabels(true)); - mdl.add(2, new LogisticRegressionModel(weights3, 0.0).withRawLabels(true)); - - Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); - } - /** */ @Test(expected = CardinalityException.class) public void testPredictOnAnObservationWithWrongCardinality() { diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java deleted file mode 100644 index c99bf02fa2fd4..0000000000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.SmoothParametrized; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; -import org.junit.Assert; -import org.junit.Test; - -/** - * Tests for {@link LogRegressionMultiClassTrainer}. - */ -public class LogRegMultiClassTrainerTest extends TrainerTest { - /** - * Test trainer on 4 sets grouped around of square vertices. - */ - @Test - public void testTrainWithTheLinearlySeparableCase() { - Map cacheMock = new HashMap<>(); - - for (int i = 0; i < fourSetsInSquareVertices.length; i++) - cacheMock.put(i, fourSetsInSquareVertices[i]); - - final UpdatesStrategy stgy = new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ); - - LogRegressionMultiClassTrainer trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(stgy) - .withAmountOfIterations(1000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - Assert.assertEquals(trainer.getAmountOfIterations(), 1000); - Assert.assertEquals(trainer.getAmountOfLocIterations(), 10); - Assert.assertEquals(trainer.getBatchSize(), 100, PRECISION); - Assert.assertEquals(trainer.seed(), 123L); - Assert.assertEquals(trainer.getUpdatesStgy(), stgy); - - LogRegressionMultiClassModel mdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-10, 10)), PRECISION); - TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), PRECISION); - TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), PRECISION); - } - - /** */ - @Test - public void testUpdate() { - Map cacheMock = new HashMap<>(); - - for (int i = 0; i < fourSetsInSquareVertices.length; i++) - cacheMock.put(i, fourSetsInSquareVertices[i]); - - LogRegressionMultiClassTrainer trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - )) - .withAmountOfIterations(1000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - LogRegressionMultiClassModel originalMdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - LogRegressionMultiClassModel updatedOnSameDS = trainer.update( - originalMdl, - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update( - originalMdl, - new HashMap(), - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - List vectors = Arrays.asList( - VectorUtils.of(10, 10), - VectorUtils.of(-10, 10), - VectorUtils.of(-10, -10), - VectorUtils.of(10, -10) - ); - - for (Vector vec : vectors) { - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); - } - } -} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java index e8aaacd9c23fe..4fae638599cf8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java @@ -21,7 +21,6 @@ import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index d9b6f7a600d10..7236820561e99 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -27,8 +27,6 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; import org.junit.Test; /** diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index d6f77c00d4707..ccde0d7eb3b8c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -27,7 +27,7 @@ import org.junit.Test; /** - * Tests for {@link SVMLinearBinaryClassificationTrainer}. + * Tests for {@link SVMLinearClassificationTrainer}. */ public class SVMBinaryTrainerTest extends TrainerTest { /** @@ -40,10 +40,10 @@ public void testTrainWithTheLinearlySeparableCase() { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer() .withSeed(1234L); - SVMLinearBinaryClassificationModel mdl = trainer.fit( + SVMLinearClassificationModel mdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -62,18 +62,18 @@ public void testUpdate() { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer() .withAmountOfIterations(1000) .withSeed(1234L); - SVMLinearBinaryClassificationModel originalMdl = trainer.fit( + SVMLinearClassificationModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), (k, v) -> v[0] ); - SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update( + SVMLinearClassificationModel updatedOnSameDS = trainer.update( originalMdl, cacheMock, parts, @@ -81,7 +81,7 @@ public void testUpdate() { (k, v) -> v[0] ); - SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update( + SVMLinearClassificationModel updatedOnEmptyDS = trainer.update( originalMdl, new HashMap(), parts, diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java index 9c452f944796f..3bac7906568c2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java @@ -36,7 +36,7 @@ public class SVMModelTest { @Test public void testPredictWithRawLabels() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withRawLabels(true); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); @@ -55,36 +55,16 @@ public void testPredictWithRawLabels() { Assert.assertTrue(mdl.isKeepingRawLabels()); - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - } - - - /** */ - @Test - public void testPredictWithMultiClasses() { - Vector weights1 = new DenseVector(new double[]{10.0, 0.0}); - Vector weights2 = new DenseVector(new double[]{0.0, 10.0}); - Vector weights3 = new DenseVector(new double[]{-1.0, -1.0}); - SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel(); - mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true)); - mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true)); - mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true)); - - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - - Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); + Assert.assertTrue(!mdl.toString().isEmpty()); + Assert.assertTrue(!mdl.toString(true).isEmpty()); + Assert.assertTrue(!mdl.toString(false).isEmpty()); } /** */ @Test public void testPredictWithErasedLabels() { Vector weights = new DenseVector(new double[]{1.0, 1.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION); @@ -101,7 +81,7 @@ public void testPredictWithErasedLabels() { observation = new DenseVector(new double[]{-1.0, -2.0}); TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); - final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); + final SVMLinearClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0})); System.out.println("The SVM model is " + mdlWithNewData); observation = new DenseVector(new double[]{-1.0, -2.0}); @@ -113,7 +93,7 @@ public void testPredictWithErasedLabels() { @Test public void testPredictWithErasedLabelsAndChangedThreshold() { Vector weights = new DenseVector(new double[]{1.0, 1.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withThreshold(5); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION); @@ -129,7 +109,7 @@ public void testPredictWithErasedLabelsAndChangedThreshold() { public void testPredictOnAnObservationWithWrongCardinality() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0); Vector observation = new DenseVector(new double[]{1.0}); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java deleted file mode 100644 index 7c4809fd75c3c..0000000000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.junit.Test; - -/** - * Tests for {@link SVMLinearBinaryClassificationTrainer}. - */ -public class SVMMultiClassTrainerTest extends TrainerTest { - /** - * Test trainer on 4 sets grouped around of square vertices. - */ - @Test - public void testTrainWithTheLinearlySeparableCase() { - Map cacheMock = new HashMap<>(); - - for (int i = 0; i < twoLinearlySeparableClasses.length; i++) - cacheMock.put(i, twoLinearlySeparableClasses[i]); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() - .withLambda(0.3) - .withAmountOfLocIterations(10) - .withAmountOfIterations(20) - .withSeed(1234L); - - SVMLinearMultiClassClassificationModel mdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION); - } - - /** */ - @Test - public void testUpdate() { - Map cacheMock = new HashMap<>(); - - for (int i = 0; i < twoLinearlySeparableClasses.length; i++) - cacheMock.put(i, twoLinearlySeparableClasses[i]); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() - .withLambda(0.3) - .withAmountOfLocIterations(10) - .withAmountOfIterations(100) - .withSeed(1234L); - - SVMLinearMultiClassClassificationModel originalMdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update( - originalMdl, - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update( - originalMdl, - new HashMap(), - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - Vector v = VectorUtils.of(100, 10); - TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), PRECISION); - TestUtils.assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), PRECISION); - } -} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java index df7263f9d47f7..a2aea6ef9798b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java @@ -27,7 +27,6 @@ @Suite.SuiteClasses({ SVMModelTest.class, SVMBinaryTrainerTest.class, - SVMMultiClassTrainerTest.class, }) public class SVMTestSuite { // No-op. diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java index 1b96ce29b7905..31fe8b31349c3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java @@ -37,8 +37,8 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Test; /**