maxMargins = new TreeMap<>();
-
- models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k));
-
- return maxMargins.lastEntry().getValue();
- }
-
- /** {@inheritDoc} */
- @Override public void saveModel(Exporter exporter, P path) {
- exporter.save(this, path);
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
- SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o;
- return Objects.equals(models, mdl.models);
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- return Objects.hash(models);
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- StringBuilder wholeStr = new StringBuilder();
-
- models.forEach((clsLb, mdl) -> {
- wholeStr.append("The class with label " + clsLb + " has classifier: " + mdl.toString() + System.lineSeparator());
- });
-
- return wholeStr.toString();
- }
-
- /**
- * Adds a specific SVM binary classifier to the bunch of same classifiers.
- *
- * @param clsLb The class label for the added model.
- * @param mdl The model.
- */
- public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
- models.put(clsLb, mdl);
- }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
deleted file mode 100644
index 669e2e39b2f94..0000000000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.structures.LabeledDataset;
-
-/**
- * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient
- * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function.
- *
- * All common parameters are shared with bunch of binary classification trainers.
- */
-public class SVMLinearMultiClassClassificationTrainer implements Trainer {
- /** Amount of outer SDCA algorithm iterations. */
- private int amountOfIterations = 20;
-
- /** Amount of local SDCA algorithm iterations. */
- private int amountOfLocIterations = 50;
-
- /** Regularization parameter. */
- private double lambda = 0.2;
-
- /**
- * Returns model based on data.
- *
- * @param data data to build model.
- * @return model.
- */
- @Override public SVMLinearMultiClassClassificationModel train(LabeledDataset data) {
- List classes = getClassLabels(data);
-
- SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
-
- classes.forEach(clsLb -> {
- LabeledDataset binarizedDataset = binarizeLabels(data, clsLb);
-
- SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
- .withAmountOfIterations(this.amountOfIterations())
- .withAmountOfLocIterations(this.amountOfLocIterations())
- .withLambda(this.lambda());
-
- multiClsMdl.add(clsLb, trainer.train(binarizedDataset));
- });
-
- return multiClsMdl;
- }
-
- /**
- * Copies the given data and changes class labels in +1 for chosen class and in -1 for the rest classes.
- *
- * @param data Data to transform.
- * @param clsLb Chosen class in schema One-vs-Rest.
- * @return Copy of dataset with new labels.
- */
- private LabeledDataset binarizeLabels(LabeledDataset data, double clsLb) {
- final LabeledDataset ds = data.copy();
-
- for (int i = 0; i < ds.rowSize(); i++)
- ds.setLabel(i, ds.label(i) == clsLb ? 1.0 : -1.0);
-
- return ds;
- }
-
- /** Iterates among dataset and collects class labels. */
- private List getClassLabels(LabeledDataset data) {
- final Set clsLabels = new HashSet<>();
-
- for (int i = 0; i < data.rowSize(); i++)
- clsLabels.add(data.label(i));
-
- List res = new ArrayList<>();
- res.addAll(clsLabels);
-
- return res;
- }
-
- /**
- * Set up the regularization parameter.
- *
- * @param lambda The regularization parameter. Should be more than 0.0.
- * @return Trainer with new lambda parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
- assert lambda > 0.0;
- this.lambda = lambda;
- return this;
- }
-
- /**
- * Gets the regularization lambda.
- *
- * @return The parameter value.
- */
- public double lambda() {
- return lambda;
- }
-
- /**
- * Gets the amount of outer iterations of SCDA algorithm.
- *
- * @return The parameter value.
- */
- public int amountOfIterations() {
- return amountOfIterations;
- }
-
- /**
- * Set up the amount of outer iterations of SCDA algorithm.
- *
- * @param amountOfIterations The parameter value.
- * @return Trainer with new amountOfIterations parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
- this.amountOfIterations = amountOfIterations;
- return this;
- }
-
- /**
- * Gets the amount of local iterations of SCDA algorithm.
- *
- * @return The parameter value.
- */
- public int amountOfLocIterations() {
- return amountOfLocIterations;
- }
-
- /**
- * Set up the amount of local iterations of SCDA algorithm.
- *
- * @param amountOfLocIterations The parameter value.
- * @return Trainer with new amountOfLocIterations parameter value.
- */
- public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
- this.amountOfLocIterations = amountOfLocIterations;
- return this;
- }
-}
-
-
-
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/DistributedLinearSVMBinaryClassificationTrainerTest.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
similarity index 56%
rename from modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/DistributedLinearSVMBinaryClassificationTrainerTest.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
index 1be1d1cb447c6..0aee0fbf140cb 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/DistributedLinearSVMBinaryClassificationTrainerTest.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
@@ -15,21 +15,14 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.svm.binary;
+package org.apache.ignite.ml.svm;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
+import java.io.Serializable;
/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
+ * Partition context of the SVM classification algorithm.
*/
-public class DistributedLinearSVMBinaryClassificationTrainerTest extends GenericLinearSVMBinaryClassificationTrainerTest {
+public class SVMPartitionContext implements Serializable {
/** */
- public DistributedLinearSVMBinaryClassificationTrainerTest() {
- super(
- new SVMLinearBinaryClassificationTrainer(),
- true,
- 1e-2);
- }
+ private static final long serialVersionUID = -7212307112344430126L;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
new file mode 100644
index 0000000000000..ad85758a52d0d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVector;
+
+/**
+ * SVM partition data builder that builds {@link LabeledDataset}.
+ *
+ * @param Type of a key in upstream data.
+ * @param Type of a value in upstream data.
+ * @param Type of a partition context.
+ */
+public class SVMPartitionDataBuilderOnHeap
+ implements PartitionDataBuilder> {
+ /** */
+ private static final long serialVersionUID = -7820760153954269227L;
+
+ /** Extractor of X matrix row. */
+ private final IgniteBiFunction xExtractor;
+
+ /** Extractor of Y vector value. */
+ private final IgniteBiFunction yExtractor;
+
+ /** Number of columns. */
+ private final int cols;
+
+ /**
+ * Constructs a new instance of SVM partition data builder.
+ *
+ * @param xExtractor Extractor of X matrix row.
+ * @param yExtractor Extractor of Y vector value.
+ * @param cols Number of columns.
+ */
+ public SVMPartitionDataBuilderOnHeap(IgniteBiFunction xExtractor,
+ IgniteBiFunction yExtractor, int cols) {
+ this.xExtractor = xExtractor;
+ this.yExtractor = yExtractor;
+ this.cols = cols;
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabeledDataset build(Iterator> upstreamData, long upstreamDataSize,
+ C ctx) {
+
+ double[][] x = new double[Math.toIntExact(upstreamDataSize)][cols];
+ double[] y = new double[Math.toIntExact(upstreamDataSize)];
+
+ int ptr = 0;
+ while (upstreamData.hasNext()) {
+ UpstreamEntry entry = upstreamData.next();
+ double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
+
+ assert row.length == cols : "X extractor must return exactly " + cols + " columns";
+
+ x[ptr] = row;
+
+ y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
+
+ ptr++;
+ }
+
+ return new LabeledDataset<>(x, y);
+ }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index 57d93d646fb44..421a19fb7070c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -32,7 +32,6 @@
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.junit.Assert;
import org.junit.Test;
@@ -97,33 +96,6 @@ public void importExportSVMBinaryClassificationModelTest() throws IOException {
}
- /** */
- @Test
- public void importExportSVMMulticlassClassificationModelTest() throws IOException {
- executeModelTest(mdlFilePath -> {
-
-
- SVMLinearBinaryClassificationModel binaryMdl1 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3);
- SVMLinearBinaryClassificationModel binaryMdl2 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{2, 3}), 4);
- SVMLinearBinaryClassificationModel binaryMdl3 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{3, 4}), 5);
-
- SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
- mdl.add(1, binaryMdl1);
- mdl.add(2, binaryMdl2);
- mdl.add(3, binaryMdl3);
-
- Exporter exporter = new FileExporter<>();
- mdl.saveModel(exporter, mdlFilePath);
-
- SVMLinearMultiClassClassificationModel load = exporter.load(mdlFilePath);
-
- Assert.assertNotNull(load);
- Assert.assertEquals("", mdl, load);
-
- return null;
- });
- }
-
/** */
private void executeModelTest(Function code) throws IOException {
Path mdlPath = Files.createTempFile(null, null);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/BaseSVMTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/BaseSVMTest.java
deleted file mode 100644
index 424118de7521f..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/BaseSVMTest.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import org.apache.ignite.Ignite;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-
-/**
- * Base class for SVM tests.
- */
-public class BaseSVMTest extends GridCommonAbstractTest {
- /** Count of nodes. */
- private static final int NODE_COUNT = 4;
-
- /** Grid instance. */
- protected Ignite ignite;
-
- /**
- * Default constructor.
- */
- public BaseSVMTest() {
- super(false);
- }
-
- /**
- * {@inheritDoc}
- */
- @Override protected void beforeTest() throws Exception {
- ignite = grid(NODE_COUNT);
- }
-
- /** {@inheritDoc} */
- @Override protected void beforeTestsStarted() throws Exception {
- for (int i = 1; i <= NODE_COUNT; i++)
- startGrid(i);
- }
-
- /** {@inheritDoc} */
- @Override protected void afterTestsStopped() throws Exception {
- stopAllGrids();
- }
-
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
new file mode 100644
index 0000000000000..353915ca8ce6c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
+/**
+ * Tests for {@link SVMLinearBinaryClassificationTrainer}.
+ */
+public class SVMBinaryTrainerTest {
+ /** Fixed size of Dataset. */
+ private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+ /** Fixed size of columns in Dataset. */
+ private static final int AMOUNT_OF_FEATURES = 2;
+
+ /**
+ * Test trainer on classification model y = x.
+ */
+ @Test
+ public void testTrainWithTheLinearlySeparableCase() {
+ Map data = new HashMap<>();
+
+
+ ThreadLocalRandom rndX = ThreadLocalRandom.current();
+ ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+ for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+ double x = rndX.nextDouble(-1000, 1000);
+ double y = rndY.nextDouble(-1000, 1000);
+ double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+ vec[0] = y - x > 0 ? 1 : -1; // assign label.
+ vec[1] = x;
+ vec[2] = y;
+ data.put(i, vec);
+ }
+
+
+ SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer<>();
+
+ SVMLinearBinaryClassificationModel mdl = trainer.fit(
+ new LocalDatasetBuilder<>(data, 10),
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0],
+ AMOUNT_OF_FEATURES);
+
+ double precision = 1e-2;
+
+ TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), precision);
+ TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), precision);
+ }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
index 35b6644ca5e55..25334662bbd8f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
@@ -57,21 +57,6 @@ public void testPredictWithRawLabels() {
}
- /** */
- @Test
- public void testPredictWithMultiClasses() {
- Vector weights1 = new DenseLocalOnHeapVector(new double[]{10.0, 0.0});
- Vector weights2 = new DenseLocalOnHeapVector(new double[]{0.0, 10.0});
- Vector weights3 = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0});
- SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
- mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true));
- mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true));
- mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true));
-
- Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
- TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION);
- }
-
/** */
@Test
public void testPredictWithErasedLabels() {
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
index 853a43f8153e2..dd87fecf5b6e5 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
@@ -17,23 +17,16 @@
package org.apache.ignite.ml.svm;
-import org.apache.ignite.ml.svm.binary.DistributedLinearSVMBinaryClassificationTrainerTest;
-import org.apache.ignite.ml.svm.binary.LocalLinearSVMBinaryClassificationTrainerTest;
-import org.apache.ignite.ml.svm.multi.DistributedLinearSVMMultiClassClassificationTrainerTest;
-import org.apache.ignite.ml.svm.multi.LocalLinearSVMMultiClassClassificationTrainerTest;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
/**
- * Test suite for all tests located in org.apache.ignite.ml.regressions.* package.
+ * Test suite for all tests located in org.apache.ignite.ml.svm.* package.
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
- LocalLinearSVMBinaryClassificationTrainerTest.class,
- DistributedLinearSVMBinaryClassificationTrainerTest.class,
- LocalLinearSVMMultiClassClassificationTrainerTest.class,
- DistributedLinearSVMMultiClassClassificationTrainerTest.class,
- SVMModelTest.class
+ SVMModelTest.class,
+ SVMBinaryTrainerTest.class
})
public class SVMTestSuite {
// No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/GenericLinearSVMBinaryClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/GenericLinearSVMBinaryClassificationTrainerTest.java
deleted file mode 100644
index f3905579517f8..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/GenericLinearSVMBinaryClassificationTrainerTest.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.binary;
-
-import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.svm.BaseSVMTest;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
-import org.junit.Test;
-
-/**
- * Base class for all linear regression trainers.
- */
-public class GenericLinearSVMBinaryClassificationTrainerTest extends BaseSVMTest {
- /** Fixed size of Dataset. */
- private static final int AMOUNT_OF_OBSERVATIONS = 100;
-
- /** Fixed size of columns in Dataset. */
- private static final int AMOUNT_OF_FEATURES = 2;
-
- /** */
- private final Trainer trainer;
-
- /** */
- private boolean isDistributed;
-
- /** */
- private final double precision;
-
- /** */
- GenericLinearSVMBinaryClassificationTrainerTest(
- Trainer trainer,
- boolean isDistributed,
- double precision) {
- super();
- this.trainer = trainer;
- this.precision = precision;
- this.isDistributed = isDistributed;
- }
-
- /**
- * Test trainer on classification model y = x.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase() {
- if (isDistributed)
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset dataset = new LabeledDataset(AMOUNT_OF_OBSERVATIONS, AMOUNT_OF_FEATURES, isDistributed);
-
- ThreadLocalRandom rndX = ThreadLocalRandom.current();
- ThreadLocalRandom rndY = ThreadLocalRandom.current();
- for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
- double x = rndX.nextDouble(-1000, 1000);
- double y = rndY.nextDouble(-1000, 1000);
- dataset.features(i).set(0, x);
- dataset.features(i).set(1, y);
- double lb = y - x > 0 ? 1 : -1;
- dataset.setLabel(i, lb);
- }
-
- SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
-
- TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
- TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
- }
-
- /**
- * Test trainer on classification model y = x. Amount of generated points is increased 10 times.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase10() {
- if (isDistributed)
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset dataset = new LabeledDataset(AMOUNT_OF_OBSERVATIONS * 10, AMOUNT_OF_FEATURES, isDistributed);
-
- ThreadLocalRandom rndX = ThreadLocalRandom.current();
- ThreadLocalRandom rndY = ThreadLocalRandom.current();
- for (int i = 0; i < AMOUNT_OF_OBSERVATIONS * 10; i++) {
- double x = rndX.nextDouble(-1000, 1000);
- double y = rndY.nextDouble(-1000, 1000);
- dataset.features(i).set(0, x);
- dataset.features(i).set(1, y);
- double lb = y - x > 0 ? 1 : -1;
- dataset.setLabel(i, lb);
- }
-
- SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
-
- TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
- TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
- }
-
- /**
- * Test trainer on classification model y = x. Amount of generated points is increased 100 times.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase100() {
- if (isDistributed)
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset dataset = new LabeledDataset(AMOUNT_OF_OBSERVATIONS * 100, AMOUNT_OF_FEATURES, isDistributed);
-
- ThreadLocalRandom rndX = ThreadLocalRandom.current();
- ThreadLocalRandom rndY = ThreadLocalRandom.current();
- for (int i = 0; i < AMOUNT_OF_OBSERVATIONS * 100; i++) {
- double x = rndX.nextDouble(-1000, 1000);
- double y = rndY.nextDouble(-1000, 1000);
- dataset.features(i).set(0, x);
- dataset.features(i).set(1, y);
- double lb = y - x > 0 ? 1 : -1;
- dataset.setLabel(i, lb);
- }
-
- SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
-
- TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
- TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/LocalLinearSVMBinaryClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/LocalLinearSVMBinaryClassificationTrainerTest.java
deleted file mode 100644
index a9cb54ce80bc1..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/LocalLinearSVMBinaryClassificationTrainerTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.binary;
-
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
- */
-public class LocalLinearSVMBinaryClassificationTrainerTest extends GenericLinearSVMBinaryClassificationTrainerTest {
- /** */
- public LocalLinearSVMBinaryClassificationTrainerTest() {
- super(
- new SVMLinearBinaryClassificationTrainer()
- .withLambda(0.2)
- .withAmountOfIterations(10)
- .withAmountOfLocIterations(20),
- false,
- 1e-2);
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java
deleted file mode 100644
index 6806e0b75818c..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
- */
-public class DistributedLinearSVMMultiClassClassificationTrainerTest extends GenericLinearSVMMultiClassClassificationTrainerTest {
- /** */
- public DistributedLinearSVMMultiClassClassificationTrainerTest() {
- super(
- new SVMLinearMultiClassClassificationTrainer(),
- true,
- 1e-2);
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java
deleted file mode 100644
index 8c6083dd5b5e2..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.svm.BaseSVMTest;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
-import org.junit.Test;
-
-/**
- * Base class for all linear regression trainers.
- */
-public class GenericLinearSVMMultiClassClassificationTrainerTest extends BaseSVMTest {
- /** */
- private final Trainer trainer;
-
- /** */
- private boolean isDistributed;
-
- /** */
- private final double precision;
-
- /** */
- GenericLinearSVMMultiClassClassificationTrainerTest(
- Trainer trainer,
- boolean isDistributed,
- double precision) {
- super();
- this.trainer = trainer;
- this.precision = precision;
- this.isDistributed = isDistributed;
- }
-
- /**
- * Test trainer on classification model y = x.
- */
- @Test
- public void testTrainWithTheLinearlySeparableCase() {
- if (isDistributed)
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- double[][] mtx =
- new double[][] {
- {-10.0, 12.0},
- {-5.0, 14.0},
- {-3.0, 18.0},
- {13.0, -1.0},
- {10.0, -2.0},
- {15.0, -3.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
- LabeledDataset dataset = new LabeledDataset(mtx, lbs, null, isDistributed);
-
-
- SVMLinearMultiClassClassificationModel mdl = trainer.train(dataset);
- TestUtils.assertEquals(1.0, mdl.apply(new DenseLocalOnHeapVector(new double[] {-2.0, 15})), precision);
- TestUtils.assertEquals(2.0, mdl.apply(new DenseLocalOnHeapVector(new double[] {12, -5})), precision);
- }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java
deleted file mode 100644
index a239c95556b01..0000000000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
- */
-public class LocalLinearSVMMultiClassClassificationTrainerTest extends GenericLinearSVMMultiClassClassificationTrainerTest {
- /** */
- public LocalLinearSVMMultiClassClassificationTrainerTest() {
- super(
- new SVMLinearMultiClassClassificationTrainer()
- .withLambda(0.2)
- .withAmountOfIterations(10)
- .withAmountOfLocIterations(20),
- false,
- 1e-2);
- }
-}