Skip to content

Commit

Permalink
MAHOUT-1162: Adding BallKMeans and StreamingKMeans clustering algorithms
Browse files Browse the repository at this point in the history
Additionally, besides the new algorithms there are new utility methods in
ClusteringUtils for casting between Vector types and computing various
clustering quality metrics: Dunn Index, Davies-Bouldin Index and the
Adjusted Rand Index.



git-svn-id: https://svn.apache.org/repos/asf/mahout/trunk@1480954 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
Dan Filimon committed May 10, 2013
1 parent 1c33a8c commit 5a7100a
Show file tree
Hide file tree
Showing 7 changed files with 1,518 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ Mahout Change Log

Release 0.8 - unreleased

MAHOUT-1162: Adding BallKMeans and StreamingKMeans clustering algorithms (dfilimon)

MAHOUT-1205: ParallelALSFactorizationJob should leverage the distributed cache (ssc)

MAHOUT-1156: Adding nearest neighbor Searchers (dfilimon)
Expand Down
283 changes: 283 additions & 0 deletions core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
package org.apache.mahout.clustering;

import java.util.List;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.Searcher;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;

public class ClusteringUtils {
/**
* Computes the summaries for the distances in each cluster.
* @param datapoints iterable of datapoints.
* @param centroids iterable of Centroids.
* @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose
* index is i.
*/
public static List<OnlineSummarizer> summarizeClusterDistances(Iterable<? extends Vector> datapoints,
Iterable<? extends Vector> centroids,
DistanceMeasure distanceMeasure) {
UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
searcher.addAll(centroids);
List<OnlineSummarizer> summarizers = Lists.newArrayList();
if (searcher.size() == 0) {
return summarizers;
}
for (int i = 0; i < searcher.size(); ++i) {
summarizers.add(new OnlineSummarizer());
}
for (Vector v : datapoints) {
Centroid closest = (Centroid)searcher.search(v, 1).get(0).getValue();
OnlineSummarizer summarizer = summarizers.get(closest.getIndex());
summarizer.add(distanceMeasure.distance(v, closest));
}
return summarizers;
}

/**
* Adds up the distances from each point to its closest cluster and returns the sum.
* @param datapoints iterable of datapoints.
* @param centroids iterable of Centroids.
* @return the total cost described above.
*/
public static double totalClusterCost(Iterable<? extends Vector> datapoints, Iterable<? extends Vector> centroids) {
DistanceMeasure distanceMeasure = new EuclideanDistanceMeasure();
UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
searcher.addAll(centroids);
return totalClusterCost(datapoints, searcher);
}

/**
* Adds up the distances from each point to its closest cluster and returns the sum.
* @param datapoints iterable of datapoints.
* @param centroids searcher of Centroids.
* @return the total cost described above.
*/
public static double totalClusterCost(Iterable<? extends Vector> datapoints, Searcher centroids) {
double totalCost = 0;
for (Vector vector : datapoints) {
Centroid closest = (Centroid) centroids.searchFirst(vector, false).getValue();
totalCost += closest.getWeight();
}
return totalCost;
}

/**
* Estimates the distance cutoff. In StreamingKMeans, the distance between two vectors divided
* by this value is used as a probability threshold when deciding whether to form a new cluster
* or not.
* Small values (comparable to the minimum distance between two points) are preferred as they
* guarantee with high likelihood that all but very close points are put in separate clusters
* initially. The clusters themselves are actually collapsed periodically when their number goes
* over the maximum number of clusters and the distanceCutoff is increased.
* So, the returned value is only an initial estimate.
* @param data the datapoints whose distance is to be estimated.
* @param distanceMeasure the distance measure used to compute the distance between two points.
* @return the minimum distance between the first sampleLimit points
* @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
*/
public static double estimateDistanceCutoff(Iterable<? extends Vector> data,
DistanceMeasure distanceMeasure, int sampleLimit) {
Iterable<? extends Vector> limitedData = Iterables.limit(data, sampleLimit);
ProjectionSearch searcher = new ProjectionSearch(distanceMeasure, 3, 1);
searcher.add(limitedData.iterator().next());
double minDistance = Double.POSITIVE_INFINITY;
for (Vector vector : Iterables.skip(limitedData, 1)) {
double closest = searcher.searchFirst(vector, false).getWeight();
if (closest < minDistance) {
minDistance = closest;
}
searcher.add(vector);
}
return minDistance;
}

/**
* Computes the Davies-Bouldin Index for a given clustering.
* See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
* @param centroids list of centroids
* @param distanceMeasure distance measure for inter-cluster distances
* @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
* @return the Davies-Bouldin Index
*/
public static double daviesBouldinIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
List<OnlineSummarizer> clusterDistanceSummaries) {
Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
"Number of centroids and cluster summaries differ.");
int n = centroids.size();
double totalDBIndex = 0;
// The inner loop shouldn't be reduced for j = i + 1 to n because the computation of the Davies-Bouldin
// index is not really symmetric.
// For a given cluster i, we look for a cluster j that maximizes the ratio of the sum of average distances
// from points in cluster i to its center and and points in cluster j to its center to the distance between
// cluster i and cluster j.
// The maximization is the key issue, as the cluster that maximizes this ratio might be j for i but is NOT
// NECESSARILY i for j.
for (int i = 0; i < n; ++i) {
double averageDistanceI = clusterDistanceSummaries.get(i).getMean();
double maxDBIndex = 0;
for (int j = 0; j < n; ++j) {
if (i != j) {
double dbIndex = (averageDistanceI + clusterDistanceSummaries.get(j).getMean()) /
distanceMeasure.distance(centroids.get(i), centroids.get(j));
if (dbIndex > maxDBIndex) {
maxDBIndex = dbIndex;
}
}
}
totalDBIndex += maxDBIndex;
}
return totalDBIndex / n;
}

/**
* Computes the Dunn Index of a given clustering. See http://en.wikipedia.org/wiki/Dunn_index
* @param centroids list of centroids
* @param distanceMeasure distance measure to compute inter-centroid distance with
* @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
* @return the Dunn Index
*/
public static double dunnIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
List<OnlineSummarizer> clusterDistanceSummaries) {
Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
"Number of centroids and cluster summaries differ.");
int n = centroids.size();
// Intra-cluster distances will come from the OnlineSummarizer, and will be the median distance (noting that
// the median for just one value is that value).
// A variety of metrics can be used for the intra-cluster distance including max distance between two points,
// mean distance, etc. Median distance was chosen as this is more robust to outliers and characterizes the
// distribution of distances (from a point to the center) better.
double maxIntraClusterDistance = 0;
for (OnlineSummarizer summarizer : clusterDistanceSummaries) {
if (summarizer.getCount() > 0) {
double intraClusterDistance;
if (summarizer.getCount() == 1) {
intraClusterDistance = summarizer.getMean();
} else {
intraClusterDistance = summarizer.getMedian();
}
if (maxIntraClusterDistance < intraClusterDistance) {
maxIntraClusterDistance = intraClusterDistance;
}
}
}
double minDunnIndex = Double.POSITIVE_INFINITY;
for (int i = 0; i < n; ++i) {
// Distances are symmetric, so d(i, j) = d(j, i).
for (int j = i + 1; j < n; ++j) {
double dunnIndex = distanceMeasure.distance(centroids.get(i), centroids.get(j));
if (minDunnIndex > dunnIndex) {
minDunnIndex = dunnIndex;
}
}
}
return minDunnIndex / maxIntraClusterDistance;
}

public static double choose2(double n) {
return n * (n - 1) / 2;
}

/**
* Creates a confusion matrix by searching for the closest cluster of both the row clustering and column clustering
* of a point and adding its weight to that cell of the matrix.
* It doesn't matter which clustering is the row clustering and which is the column clustering. If they're
* interchanged, the resulting matrix is the transpose of the original one.
* @param rowCentroids clustering one
* @param columnCentroids clustering two
* @param datapoints datapoints whose closest cluster we need to find
* @param distanceMeasure distance measure to use
* @return the confusion matrix
*/
public static Matrix getConfusionMatrix(List<? extends Vector> rowCentroids, List<? extends Vector> columnCentroids,
Iterable<? extends Vector> datapoints, DistanceMeasure distanceMeasure) {
Searcher rowSearcher = new BruteSearch(distanceMeasure);
rowSearcher.addAll(rowCentroids);
Searcher columnSearcher = new BruteSearch(distanceMeasure);
columnSearcher.addAll(columnCentroids);

int numRows = rowCentroids.size();
int numCols = columnCentroids.size();
Matrix confusionMatrix = new DenseMatrix(numRows, numCols);

for (Vector vector : datapoints) {
WeightedThing<Vector> closestRowCentroid = rowSearcher.search(vector, 1).get(0);
WeightedThing<Vector> closestColumnCentroid = columnSearcher.search(vector, 1).get(0);
int row = ((Centroid) closestRowCentroid.getValue()).getIndex();
int column = ((Centroid) closestColumnCentroid.getValue()).getIndex();
double vectorWeight;
if (vector instanceof WeightedVector) {
vectorWeight = ((WeightedVector) vector).getWeight();
} else {
vectorWeight = 1;
}
confusionMatrix.set(row, column, confusionMatrix.get(row, column) + vectorWeight);
}

return confusionMatrix;
}

/**
* Computes the Adjusted Rand Index for a given confusion matrix.
* @param confusionMatrix confusion matrix; not to be confused with the more restrictive ConfusionMatrix class
* @return the Adjusted Rand Index
*/
public static double getAdjustedRandIndex(Matrix confusionMatrix) {
int numRows = confusionMatrix.numRows();
int numCols = confusionMatrix.numCols();
double rowChoiceSum = 0;
double columnChoiceSum = 0;
double totalChoiceSum = 0;
double total = 0;
for (int i = 0; i < numRows; ++i) {
double rowSum = 0;
for (int j = 0; j < numCols; ++j) {
rowSum += confusionMatrix.get(i, j);
totalChoiceSum += choose2(confusionMatrix.get(i, j));
}
total += rowSum;
rowChoiceSum += choose2(rowSum);
}
for (int j = 0; j < numCols; ++j) {
double columnSum = 0;
for (int i = 0; i < numRows; ++i) {
columnSum += confusionMatrix.get(i, j);
}
columnChoiceSum += choose2(columnSum);
}
double rowColumnChoiceSumDivTotal = rowChoiceSum * columnChoiceSum / choose2(total);
return (totalChoiceSum - rowColumnChoiceSumDivTotal)
/ ((rowChoiceSum + columnChoiceSum) / 2 - rowColumnChoiceSumDivTotal);
}

/**
* Computes the total weight of the points in the given Vector iterable.
* @param data iterable of points
* @return total weight
*/
public static double totalWeight(Iterable<? extends Vector> data) {
double sum = 0;
for (Vector row : data) {
Preconditions.checkNotNull(row);
if (row instanceof WeightedVector) {
sum += ((WeightedVector)row).getWeight();
} else {
sum++;
}
}
return sum;
}
}

0 comments on commit 5a7100a

Please sign in to comment.