New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29967][ML][PYTHON] KMeans support instance weighting #26739
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,8 +84,8 @@ private[spark] abstract class DistanceMeasure extends Serializable { | |
* @param point a `VectorWithNorm` to be added to `sum` of a cluster | ||
* @param sum the `sum` for a cluster to be updated | ||
*/ | ||
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { | ||
axpy(1.0, point.vector, sum) | ||
def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = { | ||
axpy(weight, point.vector, sum) | ||
} | ||
|
||
/** | ||
|
@@ -100,6 +100,18 @@ private[spark] abstract class DistanceMeasure extends Serializable { | |
new VectorWithNorm(sum) | ||
} | ||
|
||
/** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It is still used by |
||
* Returns a centroid for a cluster given its `sum` vector and the weightSum of points. | ||
* | ||
* @param sum the `sum` for a cluster | ||
* @param weightSum the weightSum of points in the cluster | ||
* @return the centroid of the cluster | ||
*/ | ||
def centroid(sum: Vector, weightSum: Double): VectorWithNorm = { | ||
scal(1.0 / weightSum, sum) | ||
new VectorWithNorm(sum) | ||
} | ||
|
||
/** | ||
* Returns two new centroids symmetric to the specified centroid applying `noise` with the | ||
* with the specified `level`. | ||
|
@@ -249,9 +261,9 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { | |
* @param point a `VectorWithNorm` to be added to `sum` of a cluster | ||
* @param sum the `sum` for a cluster to be updated | ||
*/ | ||
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { | ||
override def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = { | ||
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") | ||
axpy(1.0 / point.norm, point.vector, sum) | ||
axpy(weight / point.norm, point.vector, sum) | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since | |
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml.util.Instrumentation | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.linalg.BLAS.axpy | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.storage.StorageLevel | ||
|
@@ -209,11 +209,14 @@ class KMeans private ( | |
*/ | ||
@Since("0.8.0") | ||
def run(data: RDD[Vector]): KMeansModel = { | ||
run(data, None) | ||
val instances: RDD[(Vector, Double)] = data.map { | ||
case (point) => (point, 1.0) | ||
} | ||
runWithweight(instances, None) | ||
} | ||
|
||
private[spark] def run( | ||
data: RDD[Vector], | ||
private[spark] def runWithweight( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: runWithWeight |
||
data: RDD[(Vector, Double)], | ||
instr: Option[Instrumentation]): KMeansModel = { | ||
|
||
if (data.getStorageLevel == StorageLevel.NONE) { | ||
|
@@ -222,12 +225,15 @@ class KMeans private ( | |
} | ||
|
||
// Compute squared norms and cache them. | ||
val norms = data.map(Vectors.norm(_, 2.0)) | ||
val zippedData = data.zip(norms).map { case (v, norm) => | ||
new VectorWithNorm(v, norm) | ||
val norms = data.map { case (v, _) => | ||
Vectors.norm(v, 2.0) | ||
} | ||
|
||
val zippedData = data.zip(norms).map { case ((v, w), norm) => | ||
(new VectorWithNorm(v, norm), w) | ||
} | ||
zippedData.persist() | ||
val model = runAlgorithm(zippedData, instr) | ||
val model = runAlgorithmWithWeight(zippedData, instr) | ||
zippedData.unpersist() | ||
|
||
// Warn at the end of the run as well, for increased visibility. | ||
|
@@ -241,8 +247,8 @@ class KMeans private ( | |
/** | ||
* Implementation of K-Means algorithm. | ||
*/ | ||
private def runAlgorithm( | ||
data: RDD[VectorWithNorm], | ||
private def runAlgorithmWithWeight( | ||
data: RDD[(VectorWithNorm, Double)], | ||
instr: Option[Instrumentation]): KMeansModel = { | ||
|
||
val sc = data.sparkContext | ||
|
@@ -251,14 +257,17 @@ class KMeans private ( | |
|
||
val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) | ||
|
||
val dataVectorWithNorm = data.map(d => d._1) | ||
val weights = data.map(d => d._2) | ||
|
||
val centers = initialModel match { | ||
case Some(kMeansCenters) => | ||
kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) | ||
case None => | ||
if (initializationMode == KMeans.RANDOM) { | ||
initRandom(data) | ||
initRandom(dataVectorWithNorm) | ||
} else { | ||
initKMeansParallel(data, distanceMeasureInstance) | ||
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance) | ||
} | ||
} | ||
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 | ||
|
@@ -278,30 +287,32 @@ class KMeans private ( | |
val bcCenters = sc.broadcast(centers) | ||
|
||
// Find the new centers | ||
val collected = data.mapPartitions { points => | ||
val collected = data.mapPartitions { pointsAndWeights => | ||
val thisCenters = bcCenters.value | ||
val dims = thisCenters.head.vector.size | ||
|
||
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) | ||
val counts = Array.fill(thisCenters.length)(0L) | ||
|
||
points.foreach { point => | ||
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) | ||
// clusterWeightSum is needed to calculate cluster center | ||
// cluster center = | ||
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ... | ||
val clusterWeightSum = Array.fill(thisCenters.length)(0.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, |
||
|
||
pointsAndWeights.foreach { case (point, weight) => | ||
var (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Total nit, but you can use val and then pass |
||
cost *= weight | ||
costAccum.add(cost) | ||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) | ||
counts(bestCenter) += 1 | ||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight) | ||
clusterWeightSum(bestCenter) += weight | ||
} | ||
|
||
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator | ||
}.reduceByKey { case ((sum1, count1), (sum2, count2)) => | ||
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0) | ||
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator | ||
}.reduceByKey { case ((sum1, clusterWeightSum1), (sum2, clusterWeightSum2)) => | ||
axpy(1.0, sum2, sum1) | ||
(sum1, count1 + count2) | ||
(sum1, clusterWeightSum1 + clusterWeightSum2) | ||
}.collectAsMap() | ||
|
||
if (iteration == 0) { | ||
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum)) | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just log the sum of weights? it keeps the same info in the unweighted case and it's still sort of meaningful as 'number of examples' in the weighted case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1, I guess we need to add a new var There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess maybe leave the code this way for now and open a separate PR later on to add method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am OK to add new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am OK to add new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. Thanks! |
||
val newCenters = collected.mapValues { case (sum, count) => | ||
distanceMeasureInstance.centroid(sum, count) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, why breaking this line?
dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w) .rdd.map { ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update the format.