maxMargins = new TreeMap<>();
-
- models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k));
-
- return maxMargins.lastEntry().getValue();
- }
-
- /** {@inheritDoc} */
- @Override public void saveModel(Exporter exporter, P path) {
- exporter.save(this, path);
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
-
- if (o == null || getClass() != o.getClass())
- return false;
-
- LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o;
-
- return Objects.equals(models, mdl.models);
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- return Objects.hash(models);
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- StringBuilder wholeStr = new StringBuilder();
-
- models.forEach((clsLb, mdl) ->
- wholeStr
- .append("The class with label ")
- .append(clsLb)
- .append(" has classifier: ")
- .append(mdl.toString())
- .append(System.lineSeparator())
- );
-
- return wholeStr.toString();
- }
-
- /** {@inheritDoc} */
- @Override public String toString(boolean pretty) {
- return toString();
- }
-
- /**
- * Adds a specific Log Regression binary classifier to the bunch of same classifiers.
- *
- * @param clsLb The class label for the added model.
- * @param mdl The model.
- */
- public void add(double clsLb, LogisticRegressionModel mdl) {
- models.put(clsLb, mdl);
- }
-
- /**
- * @param clsLb Class label.
- * @return model for class label if it exists.
- */
- public Optional getModel(Double clsLb) {
- return Optional.ofNullable(models.get(clsLb));
- }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
deleted file mode 100644
index fd5a624aaa994..0000000000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.logistic.multiclass;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import org.apache.ignite.ml.dataset.Dataset;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
-import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
-import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
-import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
-
-/**
- * All common parameters are shared with bunch of binary classification trainers.
- */
-public class LogRegressionMultiClassTrainer
- extends SingleLabelDatasetTrainer {
- /** Update strategy. */
- private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- );
-
- /** Max number of iteration. */
- private int amountOfIterations = 100;
-
- /** Batch size. */
- private int batchSize = 100;
-
- /** Number of local iterations. */
- private int amountOfLocIterations = 100;
-
- /** Seed for random generator. */
- private long seed = 1234L;
-
- /**
- * Trains model based on the specified data.
- *
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- * @return Model.
- */
- @Override public LogRegressionMultiClassModel fit(DatasetBuilder datasetBuilder,
- IgniteBiFunction featureExtractor,
- IgniteBiFunction lbExtractor) {
- List classes = extractClassLabels(datasetBuilder, lbExtractor);
-
- return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
- }
-
- /** {@inheritDoc} */
- @Override public LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl,
- DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor,
- IgniteBiFunction lbExtractor) {
-
- List classes = extractClassLabels(datasetBuilder, lbExtractor);
-
- if(classes.isEmpty())
- return getLastTrainedModelOrThrowEmptyDatasetException(newMdl);
-
- LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
-
- classes.forEach(clsLb -> {
- LogisticRegressionSGDTrainer> trainer =
- new LogisticRegressionSGDTrainer<>()
- .withBatchSize(batchSize)
- .withLocIterations(amountOfLocIterations)
- .withMaxIterations(amountOfIterations)
- .withSeed(seed);
-
- IgniteBiFunction lbTransformer = (k, v) -> {
- Double lb = lbExtractor.apply(k, v);
-
- if (lb.equals(clsLb))
- return 1.0;
- else
- return 0.0;
- };
-
- LogisticRegressionModel mdl = Optional.ofNullable(newMdl)
- .flatMap(multiClassModel -> multiClassModel.getModel(clsLb))
- .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer))
- .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
-
- multiClsMdl.add(clsLb, mdl);
- });
-
- return multiClsMdl;
- }
-
- /** {@inheritDoc} */
- @Override protected boolean checkState(LogRegressionMultiClassModel mdl) {
- return true;
- }
-
- /** Iterates among dataset and collects class labels. */
- private List extractClassLabels(DatasetBuilder datasetBuilder,
- IgniteBiFunction lbExtractor) {
- assert datasetBuilder != null;
-
- PartitionDataBuilder partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
-
- List res = new ArrayList<>();
-
- try (Dataset dataset = datasetBuilder.build(
- envBuilder,
- (env, upstream, upstreamSize) -> new EmptyContext(),
- partDataBuilder
- )) {
- final Set clsLabels = dataset.compute(data -> {
- final Set locClsLabels = new HashSet<>();
-
- final double[] lbs = data.getY();
-
- for (double lb : lbs)
- locClsLabels.add(lb);
-
- return locClsLabels;
- }, (a, b) -> {
- if (a == null)
- return b == null ? new HashSet<>() : b;
- if (b == null)
- return a;
- return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
- });
-
- if (clsLabels != null)
- res.addAll(clsLabels);
-
- }
- catch (Exception e) {
- throw new RuntimeException(e);
- }
- return res;
- }
-
- /**
- * Set up the regularization parameter.
- *
- * @param batchSize The size of learning batch.
- * @return Trainer with new batch size parameter value.
- */
- public LogRegressionMultiClassTrainer withBatchSize(int batchSize) {
- this.batchSize = batchSize;
- return this;
- }
-
- /**
- * Get the batch size.
- *
- * @return The parameter value.
- */
- public double getBatchSize() {
- return batchSize;
- }
-
- /**
- * Get the amount of outer iterations of SGD algorithm.
- *
- * @return The parameter value.
- */
- public int getAmountOfIterations() {
- return amountOfIterations;
- }
-
- /**
- * Set up the amount of outer iterations.
- *
- * @param amountOfIterations The parameter value.
- * @return Trainer with new amountOfIterations parameter value.
- */
- public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) {
- this.amountOfIterations = amountOfIterations;
- return this;
- }
-
- /**
- * Get the amount of local iterations.
- *
- * @return The parameter value.
- */
- public int getAmountOfLocIterations() {
- return amountOfLocIterations;
- }
-
- /**
- * Set up the amount of local iterations of SGD algorithm.
- *
- * @param amountOfLocIterations The parameter value.
- * @return Trainer with new amountOfLocIterations parameter value.
- */
- public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) {
- this.amountOfLocIterations = amountOfLocIterations;
- return this;
- }
-
- /**
- * Set up the random seed parameter.
- *
- * @param seed Seed for random generator.
- * @return Trainer with new seed parameter value.
- */
- public LogRegressionMultiClassTrainer withSeed(long seed) {
- this.seed = seed;
- return this;
- }
-
- /**
- * Get the seed for random generator.
- *
- * @return The parameter value.
- */
- public long seed() {
- return seed;
- }
-
- /**
- * Set up the updates strategy.
- *
- * @param updatesStgy Update strategy.
- * @return Trainer with new update strategy parameter value.
- */
- public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
- this.updatesStgy = updatesStgy;
- return this;
- }
-
- /**
- * Get the update strategy.
- *
- * @return The parameter value.
- */
- public UpdatesStrategy getUpdatesStgy() {
- return updatesStgy;
- }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java
deleted file mode 100644
index 2e7b9478f76f2..0000000000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- *
- * Contains multi-class logistic regression.
- */
-package org.apache.ignite.ml.regressions.logistic.multiclass;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
similarity index 87%
rename from modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
index f5d2b28732767..579fdb210dc33 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
@@ -27,7 +27,7 @@
/**
* Base class for SVM linear classification model.
*/
-public class SVMLinearBinaryClassificationModel implements Model, Exportable, Serializable {
+public class SVMLinearClassificationModel implements Model, Exportable, Serializable {
/** */
private static final long serialVersionUID = -996984622291440226L;
@@ -44,7 +44,7 @@ public class SVMLinearBinaryClassificationModel implements Model
private double intercept;
/** */
- public SVMLinearBinaryClassificationModel(Vector weights, double intercept) {
+ public SVMLinearClassificationModel(Vector weights, double intercept) {
this.weights = weights;
this.intercept = intercept;
}
@@ -55,7 +55,7 @@ public SVMLinearBinaryClassificationModel(Vector weights, double intercept) {
* @param isKeepingRawLabels The parameter value.
* @return Model with new isKeepingRawLabels parameter value.
*/
- public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabels) {
+ public SVMLinearClassificationModel withRawLabels(boolean isKeepingRawLabels) {
this.isKeepingRawLabels = isKeepingRawLabels;
return this;
}
@@ -66,7 +66,7 @@ public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabe
* @param threshold The parameter value.
* @return Model with new threshold parameter value.
*/
- public SVMLinearBinaryClassificationModel withThreshold(double threshold) {
+ public SVMLinearClassificationModel withThreshold(double threshold) {
this.threshold = threshold;
return this;
}
@@ -77,7 +77,7 @@ public SVMLinearBinaryClassificationModel withThreshold(double threshold) {
* @param weights The parameter value.
* @return Model with new weights parameter value.
*/
- public SVMLinearBinaryClassificationModel withWeights(Vector weights) {
+ public SVMLinearClassificationModel withWeights(Vector weights) {
this.weights = weights;
return this;
}
@@ -88,7 +88,7 @@ public SVMLinearBinaryClassificationModel withWeights(Vector weights) {
* @param intercept The parameter value.
* @return Model with new intercept parameter value.
*/
- public SVMLinearBinaryClassificationModel withIntercept(double intercept) {
+ public SVMLinearClassificationModel withIntercept(double intercept) {
this.intercept = intercept;
return this;
}
@@ -139,7 +139,7 @@ public double intercept() {
}
/** {@inheritDoc} */
- @Override public void saveModel(Exporter exporter, P path) {
+ @Override public void saveModel(Exporter exporter, P path) {
exporter.save(this, path);
}
@@ -150,7 +150,7 @@ public double intercept() {
if (o == null || getClass() != o.getClass())
return false;
- SVMLinearBinaryClassificationModel mdl = (SVMLinearBinaryClassificationModel)o;
+ SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)o;
return Double.compare(mdl.intercept, intercept) == 0
&& Double.compare(mdl.threshold, threshold) == 0
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
similarity index 92%
rename from modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
index 7ceb53b409c8f..67484ea59870b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
@@ -39,7 +39,7 @@
* and 1 labels for two classes and makes binary classification.
The paper about this algorithm could be found
* here https://arxiv.org/abs/1409.1458.
*/
-public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrainer {
+public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer {
/** Amount of outer SDCA algorithm iterations. */
private int amountOfIterations = 200;
@@ -60,14 +60,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
* @param lbExtractor Label extractor.
* @return Model.
*/
- @Override public SVMLinearBinaryClassificationModel fit(DatasetBuilder datasetBuilder,
+ @Override public SVMLinearClassificationModel fit(DatasetBuilder datasetBuilder,
IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {
return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
}
/** {@inheritDoc} */
- @Override protected SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl,
+ @Override protected SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl,
DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor,
IgniteBiFunction lbExtractor) {
@@ -117,11 +117,11 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
} catch (Exception e) {
throw new RuntimeException(e);
}
- return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
+ return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
}
/** {@inheritDoc} */
- @Override protected boolean checkState(SVMLinearBinaryClassificationModel mdl) {
+ @Override protected boolean checkState(SVMLinearClassificationModel mdl) {
return true;
}
@@ -129,7 +129,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
* @param mdl Model.
* @return vector of model weights with intercept.
*/
- private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) {
+ private Vector getStateVector(SVMLinearClassificationModel mdl) {
double intercept = mdl.intercept();
Vector weights = mdl.weights();
@@ -262,7 +262,7 @@ else if (alpha >= 1.0)
* @param lambda The regularization parameter. Should be more than 0.0.
* @return Trainer with new lambda parameter value.
*/
- public SVMLinearBinaryClassificationTrainer withLambda(double lambda) {
+ public SVMLinearClassificationTrainer withLambda(double lambda) {
assert lambda > 0.0;
this.lambda = lambda;
return this;
@@ -292,7 +292,7 @@ public int getAmountOfIterations() {
* @param amountOfIterations The parameter value.
* @return Trainer with new amountOfIterations parameter value.
*/
- public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) {
+ public SVMLinearClassificationTrainer withAmountOfIterations(int amountOfIterations) {
this.amountOfIterations = amountOfIterations;
return this;
}
@@ -312,7 +312,7 @@ public int getAmountOfLocIterations() {
* @param amountOfLocIterations The parameter value.
* @return Trainer with new amountOfLocIterations parameter value.
*/
- public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
+ public SVMLinearClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
this.amountOfLocIterations = amountOfLocIterations;
return this;
}
@@ -332,7 +332,7 @@ public long getSeed() {
* @param seed The parameter value.
* @return Model with new seed parameter value.
*/
- public SVMLinearBinaryClassificationTrainer withSeed(long seed) {
+ public SVMLinearClassificationTrainer withSeed(long seed) {
this.seed = seed;
return this;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
deleted file mode 100644
index 46bf4b2976c70..0000000000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.TreeMap;
-import org.apache.ignite.ml.Exportable;
-import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-
-/** Base class for multi-classification model for set of SVM classifiers. */
-public class SVMLinearMultiClassClassificationModel implements Model, Exportable, Serializable {
- /** */
- private static final long serialVersionUID = -667986511191350227L;
-
- /** List of models associated with each class. */
- private Map models;
-
- /** */
- public SVMLinearMultiClassClassificationModel() {
- this.models = new HashMap<>();
- }
-
- /** {@inheritDoc} */
- @Override public Double apply(Vector input) {
- TreeMap maxMargins = new TreeMap<>();
-
- models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k));
-
- return maxMargins.lastEntry().getValue();
- }
-
- /** {@inheritDoc} */
- @Override public void saveModel(Exporter exporter, P path) {
- exporter.save(this, path);
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
-
- if (o == null || getClass() != o.getClass())
- return false;
-
- SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o;
-
- return Objects.equals(models, mdl.models);
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- return Objects.hash(models);
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- StringBuilder wholeStr = new StringBuilder();
-
- models.forEach((clsLb, mdl) ->
- wholeStr
- .append("The class with label ")
- .append(clsLb)
- .append(" has classifier: ")
- .append(mdl.toString())
- .append(System.lineSeparator())
- );
-
- return wholeStr.toString();
- }
-
- /** {@inheritDoc} */
- @Override public String toString(boolean pretty) {
- return toString();
- }
-
- /**
- * Adds a specific SVM binary classifier to the bunch of same classifiers.
- *
- * @param clsLb The class label for the added model.
- * @param mdl The model.
- */
- public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
- models.put(clsLb, mdl);
- }
-
- /**
- * @param clsLb Class label.
- * @return model trained for target class if it exists.
- */
- public Optional getModelForClass(double clsLb) {
- return Optional.of(models.get(clsLb));
- }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
deleted file mode 100644
index 94f2a990133b4..0000000000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import org.apache.ignite.ml.dataset.Dataset;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
-import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
-import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
-
-/**
- * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient
- * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function.
- *
- * All common parameters are shared with bunch of binary classification trainers.
- */
-public class SVMLinearMultiClassClassificationTrainer
- extends SingleLabelDatasetTrainer {
- /** Amount of outer SDCA algorithm iterations. */
- private int amountOfIterations = 20;
-
- /** Amount of local SDCA algorithm iterations. */
- private int amountOfLocIterations = 50;
-
- /** Regularization parameter. */
- private double lambda = 0.2;
-
- /** The seed number. */
- private long seed = 1234L;
-
- /**
- * Trains model based on the specified data.
- *
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- * @return Model.
- */
- @Override public SVMLinearMultiClassClassificationModel fit(DatasetBuilder datasetBuilder,
- IgniteBiFunction featureExtractor,
- IgniteBiFunction lbExtractor) {
- return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
- }
-
- /** {@inheritDoc} */
- @Override public SVMLinearMultiClassClassificationModel updateModel(
- SVMLinearMultiClassClassificationModel mdl,
- DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor,
- IgniteBiFunction lbExtractor) {
-
- List classes = extractClassLabels(datasetBuilder, lbExtractor);
- if (classes.isEmpty())
- return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
-
- SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
-
- classes.forEach(clsLb -> {
- SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
- .withAmountOfIterations(this.getAmountOfIterations())
- .withAmountOfLocIterations(this.getAmountOfLocIterations())
- .withLambda(this.getLambda())
- .withSeed(this.seed);
-
- IgniteBiFunction lbTransformer = (k, v) -> {
- Double lb = lbExtractor.apply(k, v);
-
- if (lb.equals(clsLb))
- return 1.0;
- else
- return 0.0;
- };
-
- SVMLinearBinaryClassificationModel updatedMdl;
-
- if (mdl == null)
- updatedMdl = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer);
- else
- updatedMdl = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer);
- multiClsMdl.add(clsLb, updatedMdl);
- });
-
- return multiClsMdl;
- }
-
- /** {@inheritDoc} */
- @Override protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) {
- return true;
- }
-
- /**
- * Trains model based on the specified data.
- *
- * @param svmTrainer Prepared SVM trainer.
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- */
- private SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer,
- DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor,
- IgniteBiFunction lbExtractor) {
-
- return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor);
- }
-
- /**
- * Updates already learned model or fit new model if there is no model for current class label.
- *
- * @param multiClsMdl Learning multi-class model.
- * @param clsLb Current class label.
- * @param svmTrainer Prepared SVM trainer.
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- */
- private SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl,
- Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder datasetBuilder,
- IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {
-
- return multiClsMdl.getModelForClass(clsLb)
- .map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor))
- .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor));
- }
-
- /** Iterates among dataset and collects class labels. */
- private List extractClassLabels(DatasetBuilder datasetBuilder,
- IgniteBiFunction lbExtractor) {
- assert datasetBuilder != null;
-
- PartitionDataBuilder partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
-
- List res = new ArrayList<>();
-
- try (Dataset dataset = datasetBuilder.build(
- envBuilder,
- (env, upstream, upstreamSize) -> new EmptyContext(),
- partDataBuilder
- )) {
- final Set clsLabels = dataset.compute(data -> {
- final Set locClsLabels = new HashSet<>();
-
- final double[] lbs = data.getY();
-
- for (double lb : lbs)
- locClsLabels.add(lb);
-
- return locClsLabels;
- }, (a, b) -> {
- if (a == null)
- return b == null ? new HashSet<>() : b;
- if (b == null)
- return a;
- return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
- });
-
- if (clsLabels != null)
- res.addAll(clsLabels);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- return res;
- }
-
- /**
- * Set up the regularization parameter.
- *
- * @param lambda The regularization parameter. Should be more than 0.0.
- * @return Trainer with new lambda parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
- assert lambda > 0.0;
- this.lambda = lambda;
- return this;
- }
-
- /**
- * Get the regularization lambda.
- *
- * @return The property value.
- */
- public double getLambda() {
- return lambda;
- }
-
- /**
- * Gets the amount of outer iterations of SCDA algorithm.
- *
- * @return The property value.
- */
- public int getAmountOfIterations() {
- return amountOfIterations;
- }
-
- /**
- * Set up the amount of outer iterations of SCDA algorithm.
- *
- * @param amountOfIterations The parameter value.
- * @return Trainer with new amountOfIterations parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
- this.amountOfIterations = amountOfIterations;
- return this;
- }
-
- /**
- * Gets the amount of local iterations of SCDA algorithm.
- *
- * @return The property value.
- */
- public int getAmountOfLocIterations() {
- return amountOfLocIterations;
- }
-
- /**
- * Set up the amount of local iterations of SCDA algorithm.
- *
- * @param amountOfLocIterations The parameter value.
- * @return Trainer with new amountOfLocIterations parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
- this.amountOfLocIterations = amountOfLocIterations;
- return this;
- }
-
- /**
- * Gets the seed number.
- *
- * @return The property value.
- */
- public long getSeed() {
- return seed;
- }
-
- /**
- * Set up the seed.
- *
- * @param seed The parameter value.
- * @return Model with new seed parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withSeed(long seed) {
- this.seed = seed;
- return this;
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
index 745eac946f4f7..e951145e948a6 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
@@ -34,16 +34,13 @@
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.structures.Dataset;
import org.apache.ignite.ml.structures.DatasetRow;
import org.apache.ignite.ml.structures.FeatureMetadata;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
+import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -92,17 +89,7 @@ public void test() {
test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2));
- LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel();
- mdl.add(1, new LogisticRegressionModel(new DenseVector(), 1.0));
- test(mdl, new LogRegressionMultiClassModel());
-
- test(new LinearRegressionModel(null, 1.0), new LinearRegressionModel(null, 0.5));
-
- SVMLinearMultiClassClassificationModel mdl1 = new SVMLinearMultiClassClassificationModel();
- mdl1.add(1, new SVMLinearBinaryClassificationModel(new DenseVector(), 1.0));
- test(mdl1, new SVMLinearMultiClassClassificationModel());
-
- test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5));
+ test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5));
test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()),
new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new ANNClassificationTrainer.CentroidStat()));
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
index ca3f0b5549390..c5b2ffe2d59aa 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
@@ -43,12 +43,10 @@
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
+import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.junit.Assert;
import org.junit.Test;
@@ -99,36 +97,11 @@ public void importExportLinearRegressionModelTest() throws IOException {
@Test
public void importExportSVMBinaryClassificationModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
- SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3);
- Exporter exporter = new FileExporter<>();
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(new DenseVector(new double[] {1, 2}), 3);
+ Exporter exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
- SVMLinearBinaryClassificationModel load = exporter.load(mdlFilePath);
-
- Assert.assertNotNull(load);
- Assert.assertEquals("", mdl, load);
-
- return null;
- });
- }
-
- /** */
- @Test
- public void importExportSVMMultiClassClassificationModelTest() throws IOException {
- executeModelTest(mdlFilePath -> {
- SVMLinearBinaryClassificationModel binaryMdl1 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3);
- SVMLinearBinaryClassificationModel binaryMdl2 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{2, 3}), 4);
- SVMLinearBinaryClassificationModel binaryMdl3 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{3, 4}), 5);
-
- SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
- mdl.add(1, binaryMdl1);
- mdl.add(2, binaryMdl2);
- mdl.add(3, binaryMdl3);
-
- Exporter exporter = new FileExporter<>();
- mdl.saveModel(exporter, mdlFilePath);
-
- SVMLinearMultiClassClassificationModel load = exporter.load(mdlFilePath);
+ SVMLinearClassificationModel load = exporter.load(mdlFilePath);
Assert.assertNotNull(load);
Assert.assertEquals("", mdl, load);
@@ -154,23 +127,6 @@ public void importExportLogisticRegressionModelTest() throws IOException {
});
}
- /** */
- @Test
- public void importExportLogRegressionMultiClassModelTest() throws IOException {
- executeModelTest(mdlFilePath -> {
- LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel();
- Exporter exporter = new FileExporter<>();
- mdl.saveModel(exporter, mdlFilePath);
-
- LogRegressionMultiClassModel load = exporter.load(mdlFilePath);
-
- Assert.assertNotNull(load);
- Assert.assertEquals("", mdl, load);
-
- return null;
- });
- }
-
/** */
private void executeModelTest(Function code) throws IOException {
Path mdlPath = Files.createTempFile(null, null);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
index 9842d92d9673c..61f9fc40a83b6 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
@@ -28,8 +28,8 @@
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.junit.Assert;
import org.junit.Test;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java
index e59d51521e6a6..84459009a2fe8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java
@@ -20,7 +20,7 @@
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.junit.Test;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
index d517ce6fe3377..fec62209c0f66 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
@@ -29,7 +29,7 @@
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.junit.Test;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 021b567b4201f..2fa69ef20c007 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -20,7 +20,6 @@
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest;
-import org.apache.ignite.ml.regressions.logistic.LogRegMultiClassTrainerTest;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest;
import org.junit.runner.RunWith;
@@ -35,8 +34,7 @@
LinearRegressionLSQRTrainerTest.class,
LinearRegressionSGDTrainerTest.class,
LogisticRegressionModelTest.class,
- LogisticRegressionSGDTrainerTest.class,
- LogRegMultiClassTrainerTest.class
+ LogisticRegressionSGDTrainerTest.class
})
public class RegressionsTestSuite {
// No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
index 66871b04445a8..36d0fc7301fb3 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java
@@ -21,8 +21,6 @@
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
@@ -40,9 +38,9 @@ public void testPredict() {
Vector weights = new DenseVector(new double[]{2.0, 3.0});
LinearRegressionModel mdl = new LinearRegressionModel(weights, 1.0);
- assertTrue(mdl.toString().length() > 0);
- assertTrue(mdl.toString(true).length() > 0);
- assertTrue(mdl.toString(false).length() > 0);
+ assertTrue(!mdl.toString().isEmpty());
+ assertTrue(!mdl.toString(true).isEmpty());
+ assertTrue(!mdl.toString(false).isEmpty());
Vector observation = new DenseVector(new double[]{1.0, 1.0});
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION);
@@ -60,21 +58,6 @@ public void testPredict() {
TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.apply(observation), PRECISION);
}
- /** */
- @Test
- public void testPredictWithMultiClasses() {
- Vector weights1 = new DenseVector(new double[]{10.0, 0.0});
- Vector weights2 = new DenseVector(new double[]{0.0, 10.0});
- Vector weights3 = new DenseVector(new double[]{-1.0, -1.0});
- LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel();
- mdl.add(1, new LogisticRegressionModel(weights1, 0.0).withRawLabels(true));
- mdl.add(2, new LogisticRegressionModel(weights2, 0.0).withRawLabels(true));
- mdl.add(2, new LogisticRegressionModel(weights3, 0.0).withRawLabels(true));
-
- Vector observation = new DenseVector(new double[]{1.0, 1.0});
- TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION);
- }
-
/** */
@Test(expected = CardinalityException.class)
public void testPredictOnAnObservationWithWrongCardinality() {
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
deleted file mode 100644
index c99bf02fa2fd4..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.logistic;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import org.apache.ignite.ml.optimization.SmoothParametrized;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
-import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Tests for {@link LogRegressionMultiClassTrainer}.
- */
-public class LogRegMultiClassTrainerTest extends TrainerTest {
- /**
- * Test trainer on 4 sets grouped around of square vertices.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase() {
- Map cacheMock = new HashMap<>();
-
- for (int i = 0; i < fourSetsInSquareVertices.length; i++)
- cacheMock.put(i, fourSetsInSquareVertices[i]);
-
- final UpdatesStrategy stgy = new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- );
-
- LogRegressionMultiClassTrainer> trainer = new LogRegressionMultiClassTrainer<>()
- .withUpdatesStgy(stgy)
- .withAmountOfIterations(1000)
- .withAmountOfLocIterations(10)
- .withBatchSize(100)
- .withSeed(123L);
-
- Assert.assertEquals(trainer.getAmountOfIterations(), 1000);
- Assert.assertEquals(trainer.getAmountOfLocIterations(), 10);
- Assert.assertEquals(trainer.getBatchSize(), 100, PRECISION);
- Assert.assertEquals(trainer.seed(), 123L);
- Assert.assertEquals(trainer.getUpdatesStgy(), stgy);
-
- LogRegressionMultiClassModel mdl = trainer.fit(
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- Assert.assertTrue(mdl.toString().length() > 0);
- Assert.assertTrue(mdl.toString(true).length() > 0);
- Assert.assertTrue(mdl.toString(false).length() > 0);
-
- TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 10)), PRECISION);
- TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-10, 10)), PRECISION);
- TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), PRECISION);
- TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), PRECISION);
- }
-
- /** */
- @Test
- public void testUpdate() {
- Map cacheMock = new HashMap<>();
-
- for (int i = 0; i < fourSetsInSquareVertices.length; i++)
- cacheMock.put(i, fourSetsInSquareVertices[i]);
-
- LogRegressionMultiClassTrainer> trainer = new LogRegressionMultiClassTrainer<>()
- .withUpdatesStgy(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ))
- .withAmountOfIterations(1000)
- .withAmountOfLocIterations(10)
- .withBatchSize(100)
- .withSeed(123L);
-
- LogRegressionMultiClassModel originalMdl = trainer.fit(
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- LogRegressionMultiClassModel updatedOnSameDS = trainer.update(
- originalMdl,
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update(
- originalMdl,
- new HashMap(),
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- List vectors = Arrays.asList(
- VectorUtils.of(10, 10),
- VectorUtils.of(-10, 10),
- VectorUtils.of(-10, -10),
- VectorUtils.of(10, -10)
- );
-
- for (Vector vec : vectors) {
- TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
- TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
- }
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
index e8aaacd9c23fe..4fae638599cf8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
@@ -21,7 +21,6 @@
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
index d9b6f7a600d10..7236820561e99 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -27,8 +27,6 @@
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
import org.junit.Test;
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index d6f77c00d4707..ccde0d7eb3b8c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -27,7 +27,7 @@
import org.junit.Test;
/**
- * Tests for {@link SVMLinearBinaryClassificationTrainer}.
+ * Tests for {@link SVMLinearClassificationTrainer}.
*/
public class SVMBinaryTrainerTest extends TrainerTest {
/**
@@ -40,10 +40,10 @@ public void testTrainWithTheLinearlySeparableCase() {
for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
cacheMock.put(i, twoLinearlySeparableClasses[i]);
- SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer()
.withSeed(1234L);
- SVMLinearBinaryClassificationModel mdl = trainer.fit(
+ SVMLinearClassificationModel mdl = trainer.fit(
cacheMock,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -62,18 +62,18 @@ public void testUpdate() {
for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
cacheMock.put(i, twoLinearlySeparableClasses[i]);
- SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer()
.withAmountOfIterations(1000)
.withSeed(1234L);
- SVMLinearBinaryClassificationModel originalMdl = trainer.fit(
+ SVMLinearClassificationModel originalMdl = trainer.fit(
cacheMock,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
(k, v) -> v[0]
);
- SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update(
+ SVMLinearClassificationModel updatedOnSameDS = trainer.update(
originalMdl,
cacheMock,
parts,
@@ -81,7 +81,7 @@ public void testUpdate() {
(k, v) -> v[0]
);
- SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update(
+ SVMLinearClassificationModel updatedOnEmptyDS = trainer.update(
originalMdl,
new HashMap(),
parts,
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
index 9c452f944796f..3bac7906568c2 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
@@ -36,7 +36,7 @@ public class SVMModelTest {
@Test
public void testPredictWithRawLabels() {
Vector weights = new DenseVector(new double[]{2.0, 3.0});
- SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withRawLabels(true);
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true);
Vector observation = new DenseVector(new double[]{1.0, 1.0});
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION);
@@ -55,36 +55,16 @@ public void testPredictWithRawLabels() {
Assert.assertTrue(mdl.isKeepingRawLabels());
- Assert.assertTrue(mdl.toString().length() > 0);
- Assert.assertTrue(mdl.toString(true).length() > 0);
- Assert.assertTrue(mdl.toString(false).length() > 0);
- }
-
-
- /** */
- @Test
- public void testPredictWithMultiClasses() {
- Vector weights1 = new DenseVector(new double[]{10.0, 0.0});
- Vector weights2 = new DenseVector(new double[]{0.0, 10.0});
- Vector weights3 = new DenseVector(new double[]{-1.0, -1.0});
- SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
- mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true));
- mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true));
- mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true));
-
- Assert.assertTrue(mdl.toString().length() > 0);
- Assert.assertTrue(mdl.toString(true).length() > 0);
- Assert.assertTrue(mdl.toString(false).length() > 0);
-
- Vector observation = new DenseVector(new double[]{1.0, 1.0});
- TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION);
+ Assert.assertTrue(!mdl.toString().isEmpty());
+ Assert.assertTrue(!mdl.toString(true).isEmpty());
+ Assert.assertTrue(!mdl.toString(false).isEmpty());
}
/** */
@Test
public void testPredictWithErasedLabels() {
Vector weights = new DenseVector(new double[]{1.0, 1.0});
- SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0);
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
Vector observation = new DenseVector(new double[]{1.0, 1.0});
TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION);
@@ -101,7 +81,7 @@ public void testPredictWithErasedLabels() {
observation = new DenseVector(new double[]{-1.0, -2.0});
TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
- final SVMLinearBinaryClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0}));
+ final SVMLinearClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] {-2.0, -2.0}));
System.out.println("The SVM model is " + mdlWithNewData);
observation = new DenseVector(new double[]{-1.0, -2.0});
@@ -113,7 +93,7 @@ public void testPredictWithErasedLabels() {
@Test
public void testPredictWithErasedLabelsAndChangedThreshold() {
Vector weights = new DenseVector(new double[]{1.0, 1.0});
- SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5);
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withThreshold(5);
Vector observation = new DenseVector(new double[]{1.0, 1.0});
TestUtils.assertEquals(0.0, mdl.apply(observation), PRECISION);
@@ -129,7 +109,7 @@ public void testPredictWithErasedLabelsAndChangedThreshold() {
public void testPredictOnAnObservationWithWrongCardinality() {
Vector weights = new DenseVector(new double[]{2.0, 3.0});
- SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0);
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
Vector observation = new DenseVector(new double[]{1.0});
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
deleted file mode 100644
index 7c4809fd75c3c..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.junit.Test;
-
-/**
- * Tests for {@link SVMLinearBinaryClassificationTrainer}.
- */
-public class SVMMultiClassTrainerTest extends TrainerTest {
- /**
- * Test trainer on 4 sets grouped around of square vertices.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase() {
- Map cacheMock = new HashMap<>();
-
- for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
- cacheMock.put(i, twoLinearlySeparableClasses[i]);
-
- SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer()
- .withLambda(0.3)
- .withAmountOfLocIterations(10)
- .withAmountOfIterations(20)
- .withSeed(1234L);
-
- SVMLinearMultiClassClassificationModel mdl = trainer.fit(
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
- TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
- TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
- }
-
- /** */
- @Test
- public void testUpdate() {
- Map cacheMock = new HashMap<>();
-
- for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
- cacheMock.put(i, twoLinearlySeparableClasses[i]);
-
- SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer()
- .withLambda(0.3)
- .withAmountOfLocIterations(10)
- .withAmountOfIterations(100)
- .withSeed(1234L);
-
- SVMLinearMultiClassClassificationModel originalMdl = trainer.fit(
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update(
- originalMdl,
- cacheMock,
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update(
- originalMdl,
- new HashMap(),
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
-
- Vector v = VectorUtils.of(100, 10);
- TestUtils.assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), PRECISION);
- TestUtils.assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), PRECISION);
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
index df7263f9d47f7..a2aea6ef9798b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
@@ -27,7 +27,6 @@
@Suite.SuiteClasses({
SVMModelTest.class,
SVMBinaryTrainerTest.class,
- SVMMultiClassTrainerTest.class,
})
public class SVMTestSuite {
// No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
index 1b96ce29b7905..31fe8b31349c3 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
@@ -37,8 +37,8 @@
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
-import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.junit.Test;
/**