Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IGNITE-9393: Fixed bug in reduce function in dataset.compute #4628

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
/**
* Trains model based on the specified data.
*
* @param datasetBuilder Dataset builder.
* @param datasetBuilder Dataset builder.
* @param featureExtractor Feature extractor.
* @param lbExtractor Label extractor.
* @param lbExtractor Label extractor.
* @return Model.
*/
@Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
assert datasetBuilder != null;

PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
Expand All @@ -85,7 +85,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
(upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a);
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

centers = initClusterCentersRandomly(dataset, k);

boolean converged = false;
Expand Down Expand Up @@ -113,7 +120,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
centers[i] = newCentroids[i];
}
}
} catch (Exception e) {
}
catch (Exception e) {
throw new RuntimeException(e);
}
return new KMeansModel(centers, distance);
Expand All @@ -124,15 +132,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
*
* @param centers Current centers on the current iteration.
* @param dataset Dataset.
* @param cols Amount of columns.
* @param cols Amount of columns.
* @return Helper data to calculate the new centroids.
*/
private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
final Vector[] finalCenters = centers;

return dataset.compute(data -> {

TotalCostAndCounts res = new TotalCostAndCounts();

for (int i = 0; i < data.rowSize(); i++) {
Expand All @@ -147,20 +154,29 @@ private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,

int finalI = i;
res.sums.compute(centroidIdx,
(IgniteBiFunction<Integer, Vector, Vector>) (ind, v) -> v.plus(data.getRow(finalI).features()));
(IgniteBiFunction<Integer, Vector, Vector>)(ind, v) -> {
Vector features = data.getRow(finalI).features();
return v == null ? features : v.plus(features);
});

res.counts.merge(centroidIdx, 1,
(IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
(IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> a == null ? b : a.merge(b));
}, (a, b) -> {
if (a == null)
return b == null ? new TotalCostAndCounts() : b;
if (b == null)
return a;
return a.merge(b);
});
}

/**
* Find the closest cluster center index and distance to it from a given point.
*
* @param centers Centers to look in.
* @param pnt Point.
* @param pnt Point.
*/
private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
double bestDistance = Double.POSITIVE_INFINITY;
Expand All @@ -180,12 +196,11 @@ private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, Lab
* K cluster centers are initialized randomly.
*
* @param dataset The dataset to pick up random centers.
* @param k Amount of clusters.
* @param k Amount of clusters.
* @return K cluster centers.
*/
private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset,
int k) {

int k) {
Vector[] initCenters = new DenseVector[k];

// Gets k or less vectors from each partition.
Expand All @@ -211,12 +226,19 @@ private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorS

rndPnt.add(data.getRow(nextIdx));
}
} else // If it's not enough vectors to pick k vectors.
}
else // If it's not enough vectors to pick k vectors.
for (int i = 0; i < data.rowSize(); i++)
rndPnt.add(data.getRow(i));
}
return rndPnt;
}, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
}, (a, b) -> {
if (a == null)
return b == null ? new ArrayList<>() : b;
if (b == null)
return a;
return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
});

// Shuffle them.
Collections.shuffle(rndPnts);
Expand All @@ -228,7 +250,8 @@ private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorS
rndPnts.remove(rndPnt);
initCenters[i] = rndPnt.features();
}
} else
}
else
throw new RuntimeException("The KMeans Trainer required more than " + k + " vectors to find " + k + " clusters");

return initCenters;
Expand All @@ -245,7 +268,6 @@ public static class TotalCostAndCounts {
/** Count of points closest to the center with a given index. */
ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();


/** Count of points closest to the center with a given index. */
ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
(upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {

return dataset.compute(data -> {

CentroidStat res = new CentroidStat();

for (int i = 0; i < data.rowSize(); i++) {
Expand All @@ -171,15 +169,21 @@ private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
centroidStat.put(lb, 1);
res.centroidStat.put(centroidIdx, centroidStat);
} else {
int cnt = centroidStat.containsKey(lb) ? centroidStat.get(lb) : 0;
int cnt = centroidStat.getOrDefault(lb, 0);
centroidStat.put(lb, cnt + 1);
}

res.counts.merge(centroidIdx, 1,
(IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> a == null ? b : a.merge(b));
}, (a, b) -> {
if (a == null)
return b == null ? new CentroidStat() : b;
if (b == null)
return a;
return a.merge(b);
});

} catch (Exception e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.ignite.ml.knn.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -79,7 +80,13 @@ protected List<LabeledVector> findKNearestNeighbors(Vector v) {
List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, data);
return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
}, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
}, (a, b) -> {
if (a == null)
return b == null ? new ArrayList<>() : b;
if (b == null)
return a;
return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
});

LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ else if (b == null)
@Override protected int getColumns() {
return dataset.compute(
data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(),
(a, b) -> a == null ? b : a
(a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
}
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron,
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> a == null ? b : a);
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
Expand All @@ -100,7 +106,7 @@ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron,
seed
);

IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[]{lbExtractor.apply(k, v)};
IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)};

MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
* @param seed Seed for random generator.
*/
public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
int batchSize, int locIterations, long seed) {
int batchSize, int locIterations, long seed) {
this.updatesStgy = updatesStgy;
this.maxIterations = maxIterations;
this.batchSize = batchSize;
Expand All @@ -82,7 +82,13 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> a == null ? b : a);
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});

MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
Expand All @@ -100,7 +106,7 @@ public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron
seed
);

MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)});
MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)});

double[] params = mlp.parameters().getStorage().data();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
/**
* Trains model based on the specified data.
*
* @param datasetBuilder Dataset builder.
* @param datasetBuilder Dataset builder.
* @param featureExtractor Feature extractor.
* @param lbExtractor Label extractor.
* @param lbExtractor Label extractor.
* @return Model.
*/
@Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, Double> lbExtractor) {
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, Double> lbExtractor) {
List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);

LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
Expand All @@ -92,7 +92,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
}

/** Iterates among dataset and collects class labels. */
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Double> lbExtractor) {
assert datasetBuilder != null;

PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
Expand All @@ -108,14 +109,22 @@ private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuild

final double[] lbs = data.getY();

for (double lb : lbs) locClsLabels.add(lb);
for (double lb : lbs)
locClsLabels.add(lb);

return locClsLabels;
}, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()));
}, (a, b) -> {
if (a == null)
return b == null ? new HashSet<>() : b;
if (b == null)
return a;
return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
});

res.addAll(clsLabels);

} catch (Exception e) {
}
catch (Exception e) {
throw new RuntimeException(e);
}
return res;
Expand Down