From 96fffa56cd0b334a1dd3b678b68bc9224f5b4c02 Mon Sep 17 00:00:00 2001 From: Zinoviev Alexey Date: Mon, 27 Aug 2018 20:44:07 +0300 Subject: [PATCH 1/2] IGNITE-8924: Result of merge --- .../ml/clustering/kmeans/KMeansTrainer.java | 60 +++++++++++++------ .../ml/knn/ann/ANNClassificationTrainer.java | 12 ++-- .../KNNClassificationModel.java | 9 ++- .../ml/math/isolve/lsqr/LSQROnHeap.java | 8 ++- .../linear/LinearRegressionSGDTrainer.java | 10 +++- .../LogisticRegressionSGDTrainer.java | 12 +++- .../LogRegressionMultiClassTrainer.java | 25 +++++--- .../SVMLinearBinaryClassificationTrainer.java | 40 ++++++++++--- ...LinearMultiClassClassificationTrainer.java | 8 ++- 9 files changed, 136 insertions(+), 48 deletions(-) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java index c005312cb0069..5b880fcc95ced 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java @@ -65,13 +65,13 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer { /** * Trains model based on the specified data. * - * @param datasetBuilder Dataset builder. + * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. + * @param lbExtractor Label extractor. * @return Model. */ @Override public KMeansModel fit(DatasetBuilder datasetBuilder, - IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { assert datasetBuilder != null; PartitionDataBuilder> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( @@ -85,7 +85,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer { (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { - final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a); + final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); + centers = initClusterCentersRandomly(dataset, k); boolean converged = false; @@ -113,7 +120,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer { centers[i] = newCentroids[i]; } } - } catch (Exception e) { + } + catch (Exception e) { throw new RuntimeException(e); } return new KMeansModel(centers, distance); @@ -124,15 +132,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer { * * @param centers Current centers on the current iteration. * @param dataset Dataset. - * @param cols Amount of columns. + * @param cols Amount of columns. * @return Helper data to calculate the new centroids. */ private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers, - Dataset> dataset, int cols) { + Dataset> dataset, int cols) { final Vector[] finalCenters = centers; return dataset.compute(data -> { - TotalCostAndCounts res = new TotalCostAndCounts(); for (int i = 0; i < data.rowSize(); i++) { @@ -147,20 +154,29 @@ private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers, int finalI = i; res.sums.compute(centroidIdx, - (IgniteBiFunction) (ind, v) -> v.plus(data.getRow(finalI).features())); + (IgniteBiFunction)(ind, v) -> { + Vector features = data.getRow(finalI).features(); + return v == null ? features : v.plus(features); + }); res.counts.merge(centroidIdx, 1, - (IgniteBiFunction) (i1, i2) -> i1 + i2); + (IgniteBiFunction)(i1, i2) -> i1 + i2); } return res; - }, (a, b) -> a == null ? b : a.merge(b)); + }, (a, b) -> { + if (a == null) + return b == null ? new TotalCostAndCounts() : b; + if (b == null) + return a; + return a.merge(b); + }); } /** * Find the closest cluster center index and distance to it from a given point. * * @param centers Centers to look in. - * @param pnt Point. + * @param pnt Point. */ private IgniteBiTuple findClosestCentroid(Vector[] centers, LabeledVector pnt) { double bestDistance = Double.POSITIVE_INFINITY; @@ -180,12 +196,11 @@ private IgniteBiTuple findClosestCentroid(Vector[] centers, Lab * K cluster centers are initialized randomly. * * @param dataset The dataset to pick up random centers. - * @param k Amount of clusters. + * @param k Amount of clusters. * @return K cluster centers. */ private Vector[] initClusterCentersRandomly(Dataset> dataset, - int k) { - + int k) { Vector[] initCenters = new DenseVector[k]; // Gets k or less vectors from each partition. @@ -211,12 +226,19 @@ private Vector[] initClusterCentersRandomly(Dataset a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList())); + }, (a, b) -> { + if (a == null) + return b == null ? new ArrayList<>() : b; + if (b == null) + return a; + return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()); + }); // Shuffle them. Collections.shuffle(rndPnts); @@ -228,7 +250,8 @@ private Vector[] initClusterCentersRandomly(Dataset counts = new ConcurrentHashMap<>(); - /** Count of points closest to the center with a given index. */ ConcurrentHashMap> centroidStat = new ConcurrentHashMap<>(); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java index 282be3c15aabd..1c45812908797 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java @@ -149,9 +149,7 @@ private CentroidStat getCentroidStat(DatasetBuilder datasetBuilder, (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { - return dataset.compute(data -> { - CentroidStat res = new CentroidStat(); for (int i = 0; i < data.rowSize(); i++) { @@ -171,7 +169,7 @@ private CentroidStat getCentroidStat(DatasetBuilder datasetBuilder, centroidStat.put(lb, 1); res.centroidStat.put(centroidIdx, centroidStat); } else { - int cnt = centroidStat.containsKey(lb) ? centroidStat.get(lb) : 0; + int cnt = centroidStat.getOrDefault(lb, 0); centroidStat.put(lb, cnt + 1); } @@ -179,7 +177,13 @@ private CentroidStat getCentroidStat(DatasetBuilder datasetBuilder, (IgniteBiFunction) (i1, i2) -> i1 + i2); } return res; - }, (a, b) -> a == null ? b : a.merge(b)); + }, (a, b) -> { + if (a == null) + return b == null ? new CentroidStat() : b; + if (b == null) + return a; + return a.merge(b); + }); } catch (Exception e) { throw new RuntimeException(e); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java index 3404ae80fe939..0b88f8181cf4f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.knn.classification; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -79,7 +80,13 @@ protected List findKNearestNeighbors(Vector v) { List neighborsFromPartitions = dataset.compute(data -> { TreeMap> distanceIdxPairs = getDistances(v, data); return Arrays.asList(getKClosestVectors(data, distanceIdxPairs)); - }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList())); + }, (a, b) -> { + if (a == null) + return b == null ? new ArrayList<>() : b; + if (b == null) + return a; + return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()); + }); LabeledVectorSet neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java index e138cf3ff3db9..f75caefab76a4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java @@ -103,7 +103,13 @@ else if (b == null) @Override protected int getColumns() { return dataset.compute( data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(), - (a, b) -> a == null ? b : a + (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + } ); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java index 2237c95f1dc12..44f60d1db6a16 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -82,7 +82,13 @@ public LinearRegressionSGDTrainer(UpdatesStrategy a == null ? b : a); + }, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); MLPArchitecture architecture = new MLPArchitecture(cols); architecture = architecture.withAddedLayer(1, true, Activators.LINEAR); @@ -100,7 +106,7 @@ public LinearRegressionSGDTrainer(UpdatesStrategy lbE = (IgniteBiFunction)(k, v) -> new double[]{lbExtractor.apply(k, v)}; + IgniteBiFunction lbE = (IgniteBiFunction)(k, v) -> new double[] {lbExtractor.apply(k, v)}; MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java index 840a18dcd29bc..639627950ac70 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java @@ -64,7 +64,7 @@ public class LogisticRegressionSGDTrainer

extends Single * @param seed Seed for random generator. */ public LogisticRegressionSGDTrainer(UpdatesStrategy updatesStgy, int maxIterations, - int batchSize, int locIterations, long seed) { + int batchSize, int locIterations, long seed) { this.updatesStgy = updatesStgy; this.maxIterations = maxIterations; this.batchSize = batchSize; @@ -82,7 +82,13 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy a == null ? b : a); + }, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); MLPArchitecture architecture = new MLPArchitecture(cols); architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID); @@ -100,7 +106,7 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy new double[]{lbExtractor.apply(k, v)}); + MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)}); double[] params = mlp.parameters().getStorage().data(); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java index 1ed938a4440f0..4885373ae094f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -61,14 +61,14 @@ public class LogRegressionMultiClassTrainer

/** * Trains model based on the specified data. * - * @param datasetBuilder Dataset builder. + * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. + * @param lbExtractor Label extractor. * @return Model. */ @Override public LogRegressionMultiClassModel fit(DatasetBuilder datasetBuilder, - IgniteBiFunction featureExtractor, - IgniteBiFunction lbExtractor) { + IgniteBiFunction featureExtractor, + IgniteBiFunction lbExtractor) { List classes = extractClassLabels(datasetBuilder, lbExtractor); LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); @@ -92,7 +92,8 @@ public class LogRegressionMultiClassTrainer

} /** Iterates among dataset and collects class labels. */ - private List extractClassLabels(DatasetBuilder datasetBuilder, IgniteBiFunction lbExtractor) { + private List extractClassLabels(DatasetBuilder datasetBuilder, + IgniteBiFunction lbExtractor) { assert datasetBuilder != null; PartitionDataBuilder partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); @@ -108,14 +109,22 @@ private List extractClassLabels(DatasetBuilder datasetBuild final double[] lbs = data.getY(); - for (double lb : lbs) locClsLabels.add(lb); + for (double lb : lbs) + locClsLabels.add(lb); return locClsLabels; - }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet())); + }, (a, b) -> { + if (a == null) + return b == null ? new HashSet<>() : b; + if (b == null) + return a; + return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); + }); res.addAll(clsLabels); - } catch (Exception e) { + } + catch (Exception e) { throw new RuntimeException(e); } return res; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 4f11318bd3a27..1f369dff55cd1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -50,9 +50,9 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai /** * Trains model based on the specified data. * - * @param datasetBuilder Dataset builder. + * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. + * @param lbExtractor Label extractor. * @return Model. */ @Override public SVMLinearBinaryClassificationModel fit(DatasetBuilder datasetBuilder, @@ -67,19 +67,28 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai Vector weights; - try(Dataset> dataset = datasetBuilder.build( + try (Dataset> dataset = datasetBuilder.build( (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { - final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a); + final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); + final int weightVectorSizeWithIntercept = cols + 1; + weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); for (int i = 0; i < this.getAmountOfIterations(); i++) { Vector deltaWeights = calculateUpdates(weights, dataset); weights = weights.plus(deltaWeights); // creates new vector } - } catch (Exception e) { + } + catch (Exception e) { throw new RuntimeException(e); } return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); @@ -87,11 +96,12 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai /** */ @NotNull private Vector initializeWeightsWithZeros(int vectorSize) { - return new DenseVector(vectorSize); + return new DenseVector(vectorSize); } /** */ - private Vector calculateUpdates(Vector weights, Dataset> dataset) { + private Vector calculateUpdates(Vector weights, + Dataset> dataset) { return dataset.compute(data -> { Vector copiedWeights = weights.copy(); Vector deltaWeights = initializeWeightsWithZeros(weights.size()); @@ -112,12 +122,18 @@ private Vector calculateUpdates(Vector weights, Dataset a == null ? b : a.plus(b)); + }, (a, b) -> { + if (a == null) + return b == null ? new DenseVector() : b; + if (b == null) + return a; + return a.plus(b); + }); } /** */ private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, - int randomIdx) { + int randomIdx) { LabeledVector row = (LabeledVector)data.getRow(randomIdx); Double lb = (Double)row.label(); Vector v = makeVectorWithInterceptElement(row); @@ -191,6 +207,7 @@ else if (alpha >= 1.0) /** * Set up the regularization parameter. + * * @param lambda The regularization parameter. Should be more than 0.0. * @return Trainer with new lambda parameter value. */ @@ -202,6 +219,7 @@ public SVMLinearBinaryClassificationTrainer withLambda(double lambda) { /** * Gets the regularization lambda. + * * @return The parameter value. */ public double lambda() { @@ -210,6 +228,7 @@ public double lambda() { /** * Gets the amount of outer iterations of SCDA algorithm. + * * @return The parameter value. */ public int getAmountOfIterations() { @@ -218,6 +237,7 @@ public int getAmountOfIterations() { /** * Set up the amount of outer iterations of SCDA algorithm. + * * @param amountOfIterations The parameter value. * @return Trainer with new amountOfIterations parameter value. */ @@ -228,6 +248,7 @@ public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfI /** * Gets the amount of local iterations of SCDA algorithm. + * * @return The parameter value. */ public int getAmountOfLocIterations() { @@ -236,6 +257,7 @@ public int getAmountOfLocIterations() { /** * Set up the amount of local iterations of SCDA algorithm. + * * @param amountOfLocIterations The parameter value. * @return Trainer with new amountOfLocIterations parameter value. */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 7069c4d0d5765..6a6e1ed48259e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -106,7 +106,13 @@ private List extractClassLabels(DatasetBuilder datasetBuild for (double lb : lbs) locClsLabels.add(lb); return locClsLabels; - }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet())); + }, (a, b) -> { + if (a == null) + return b == null ? new HashSet<>() : b; + if (b == null) + return a; + return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); + }); res.addAll(clsLabels); From aa4ba0fa8aca1441cd0fdf0ecb9c5ab36fba1898 Mon Sep 17 00:00:00 2001 From: Zinoviev Alexey Date: Mon, 27 Aug 2018 23:06:11 +0300 Subject: [PATCH 2/2] Fixed random seed in tests and algorithms --- .../SVMLinearBinaryClassificationTrainer.java | 28 ++++- ...LinearMultiClassClassificationTrainer.java | 26 ++++- .../ignite/ml/knn/ANNClassificationTest.java | 3 - .../svm/SVMBinaryTrainerIntegrationTest.java | 102 ------------------ .../ignite/ml/svm/SVMBinaryTrainerTest.java | 3 +- .../ml/svm/SVMMultiClassTrainerTest.java | 3 +- .../apache/ignite/ml/svm/SVMTestSuite.java | 1 - 7 files changed, 55 insertions(+), 111 deletions(-) delete mode 100644 modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 1f369dff55cd1..933a7128731be 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -17,7 +17,7 @@ package org.apache.ignite.ml.svm; -import java.util.concurrent.ThreadLocalRandom; +import java.util.Random; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; @@ -47,6 +47,9 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai /** Regularization parameter. */ private double lambda = 0.4; + /** The seed number. */ + private long seed; + /** * Trains model based on the specified data. * @@ -110,8 +113,10 @@ private Vector calculateUpdates(Vector weights, Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation); Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation); + Random random = new Random(seed); + for (int i = 0; i < this.getAmountOfLocIterations(); i++) { - int randomIdx = ThreadLocalRandom.current().nextInt(amountOfObservation); + int randomIdx = random.nextInt(amountOfObservation); Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx); @@ -266,6 +271,25 @@ public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amount return this; } + /** + * Gets the seed number. + * + * @return The parameter value. + */ + public long getSeed() { + return seed; + } + + /** + * Set up the seed. + * + * @param seed The parameter value. + * @return Model with new seed parameter value. + */ + public SVMLinearBinaryClassificationTrainer withSeed(long seed) { + this.seed = seed; + return this; + } } /** This is a helper class to handle pair results which are returned from the calculation method. */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 6a6e1ed48259e..4b7cc95ebe1c8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -51,6 +51,9 @@ public class SVMLinearMultiClassClassificationTrainer /** Regularization parameter. */ private double lambda = 0.2; + /** The seed number. */ + private long seed; + /** * Trains model based on the specified data. * @@ -70,7 +73,8 @@ public class SVMLinearMultiClassClassificationTrainer SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() .withAmountOfIterations(this.amountOfIterations()) .withAmountOfLocIterations(this.amountOfLocIterations()) - .withLambda(this.lambda()); + .withLambda(this.lambda()) + .withSeed(this.seed); IgniteBiFunction lbTransformer = (k, v) -> { Double lb = lbExtractor.apply(k, v); @@ -182,4 +186,24 @@ public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int a this.amountOfLocIterations = amountOfLocIterations; return this; } + + /** + * Gets the seed number. + * + * @return The parameter value. + */ + public long getSeed() { + return seed; + } + + /** + * Set up the seed. + * + * @param seed The parameter value. + * @return Model with new seed parameter value. + */ + public SVMLinearMultiClassClassificationTrainer withSeed(long seed) { + this.seed = seed; + return this; + } } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java index aed638788e091..7289b1dfaf1f7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java @@ -62,9 +62,6 @@ public void testBinaryClassification() { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(NNStrategy.SIMPLE); - TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(550, 550)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-550, -550)), PRECISION); - Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates()); Assert.assertTrue(mdl.toString().contains(NNStrategy.SIMPLE.name())); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java deleted file mode 100644 index d227de7e329e0..0000000000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.Arrays; -import java.util.UUID; -import java.util.concurrent.ThreadLocalRandom; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; - -/** - * Tests for {@link SVMLinearBinaryClassificationTrainer} that require to start the whole Ignite infrastructure. - */ -public class SVMBinaryTrainerIntegrationTest extends GridCommonAbstractTest { - /** Fixed size of Dataset. */ - private static final int AMOUNT_OF_OBSERVATIONS = 1000; - - /** Fixed size of columns in Dataset. */ - private static final int AMOUNT_OF_FEATURES = 2; - - /** Precision in test checks. */ - private static final double PRECISION = 1e-2; - - /** Number of nodes in grid */ - private static final int NODE_COUNT = 3; - - /** Ignite instance. */ - private Ignite ignite; - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() { - stopAllGrids(); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - /* Grid instance. */ - ignite = grid(NODE_COUNT); - ignite.configuration().setPeerClassLoadingEnabled(true); - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - } - - /** - * Test trainer on classification model y = x. - */ - public void testTrainWithTheLinearlySeparableCase() { - IgniteCache data = ignite.getOrCreateCache(UUID.randomUUID().toString()); - - ThreadLocalRandom rndX = ThreadLocalRandom.current(); - ThreadLocalRandom rndY = ThreadLocalRandom.current(); - - for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) { - double x = rndX.nextDouble(-1000, 1000); - double y = rndY.nextDouble(-1000, 1000); - double[] vec = new double[AMOUNT_OF_FEATURES + 1]; - vec[0] = y - x > 0 ? 1 : -1; // assign label. - vec[1] = x; - vec[2] = y; - data.put(i, vec); - } - - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); - - SVMLinearBinaryClassificationModel mdl = trainer.fit( - ignite, - data, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); - TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); - } -} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index b7721774d4098..5630beea5f370 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -39,7 +39,8 @@ public void testTrainWithTheLinearlySeparableCase() { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); + SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() + .withSeed(1234L); SVMLinearBinaryClassificationModel mdl = trainer.fit( cacheMock, diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index f2328f8752a4b..7ea28c2493f26 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -42,7 +42,8 @@ public void testTrainWithTheLinearlySeparableCase() { SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer() .withLambda(0.3) .withAmountOfLocIterations(10) - .withAmountOfIterations(20); + .withAmountOfIterations(20) + .withSeed(1234L); SVMLinearMultiClassClassificationModel mdl = trainer.fit( cacheMock, diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java index 822ad184ea929..df7263f9d47f7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java @@ -28,7 +28,6 @@ SVMModelTest.class, SVMBinaryTrainerTest.class, SVMMultiClassTrainerTest.class, - SVMBinaryTrainerIntegrationTest.class }) public class SVMTestSuite { // No-op.