Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29967][ML][PYTHON] KMeans support instance weighting #26739

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 27 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Expand Up @@ -31,17 +31,18 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

/**
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure with HasWeightCol {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -313,12 +314,31 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
} else {
lit(1.0)
}

val instances: RDD[(OldVector, Double)] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
}

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
Expand All @@ -327,7 +347,7 @@ class KMeans @Since("1.5.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
maxIter, seed, tol, weightCol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
Expand All @@ -336,7 +356,7 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.run(instances, Option(instr))
val parentModel = algo.runWithWeight(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset),
Expand Down
Expand Up @@ -123,6 +123,10 @@ private[spark] class Instrumentation private () extends Logging with MLEvents {
logNamedValue(Instrumentation.loggerTags.numExamples, num)
}

def logSumOfWeights(num: Double): Unit = {
logNamedValue(Instrumentation.loggerTags.sumOfWeights, num)
}

/**
* Logs the value with customized name field.
*/
Expand Down Expand Up @@ -179,6 +183,7 @@ private[spark] object Instrumentation {
val numExamples = "numExamples"
val meanOfLabels = "meanOfLabels"
val varianceOfLabels = "varianceOfLabels"
val sumOfWeights = "sumOfWeights"
}

def instrumented[T](body: (Instrumentation => T)): T = {
Expand Down
Expand Up @@ -84,8 +84,8 @@ private[spark] abstract class DistanceMeasure extends Serializable {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
axpy(weight, point.vector, sum)
}

/**
Expand All @@ -100,6 +100,18 @@ private[spark] abstract class DistanceMeasure extends Serializable {
new VectorWithNorm(sum)
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the above def centroid(sum: Vector, count: Long): VectorWithNorm still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is still used by BisecttingKMeans

* Returns a centroid for a cluster given its `sum` vector and the weightSum of points.
*
* @param sum the `sum` for a cluster
* @param weightSum the weightSum of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, weightSum: Double): VectorWithNorm = {
scal(1.0 / weightSum, sum)
new VectorWithNorm(sum)
}

/**
* Returns two new centroids symmetric to the specified centroid applying `noise` with the
* with the specified `level`.
Expand Down Expand Up @@ -249,9 +261,9 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
override def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
axpy(1.0 / point.norm, point.vector, sum)
axpy(weight / point.norm, point.vector, sum)
}

/**
Expand Down
65 changes: 41 additions & 24 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -209,11 +209,14 @@ class KMeans private (
*/
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
run(data, None)
val instances: RDD[(Vector, Double)] = data.map {
case (point) => (point, 1.0)
}
runWithWeight(instances, None)
}

private[spark] def run(
data: RDD[Vector],
private[spark] def runWithWeight(
data: RDD[(Vector, Double)],
instr: Option[Instrumentation]): KMeansModel = {

if (data.getStorageLevel == StorageLevel.NONE) {
Expand All @@ -222,12 +225,15 @@ class KMeans private (
}

// Compute squared norms and cache them.
val norms = data.map(Vectors.norm(_, 2.0))
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
val norms = data.map { case (v, _) =>
Vectors.norm(v, 2.0)
}

val zippedData = data.zip(norms).map { case ((v, w), norm) =>
(new VectorWithNorm(v, norm), w)
}
zippedData.persist()
val model = runAlgorithm(zippedData, instr)
val model = runAlgorithmWithWeight(zippedData, instr)
zippedData.unpersist()

// Warn at the end of the run as well, for increased visibility.
Expand All @@ -241,8 +247,8 @@ class KMeans private (
/**
* Implementation of K-Means algorithm.
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
private def runAlgorithmWithWeight(
data: RDD[(VectorWithNorm, Double)],
instr: Option[Instrumentation]): KMeansModel = {

val sc = data.sparkContext
Expand All @@ -251,14 +257,17 @@ class KMeans private (

val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)

val dataVectorWithNorm = data.map(d => d._1)
val weights = data.map(d => d._2)

val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
initRandom(dataVectorWithNorm)
} else {
initKMeansParallel(data, distanceMeasureInstance)
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
Expand All @@ -275,35 +284,43 @@ class KMeans private (
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator
val countAccum = sc.longAccumulator
val bcCenters = sc.broadcast(centers)

// Find the new centers
val collected = data.mapPartitions { points =>
val collected = data.mapPartitions { pointsAndWeights =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size

val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)

points.foreach { point =>
// clusterWeightSum is needed to calculate cluster center
// cluster center =
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.ofDim[Double](thisCenters.length)

pointsAndWeights.foreach { case (point, weight) =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
costAccum.add(cost * weight)
countAccum.add(1)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight)
clusterWeightSum(bestCenter) += weight
}

counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator
}.reduceByKey { case ((sum1, clusterWeightSum1), (sum2, clusterWeightSum2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
(sum1, clusterWeightSum1 + clusterWeightSum2)
}.collectAsMap()

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
instr.foreach(_.logNumExamples(countAccum.value))
instr.foreach(_.logSumOfWeights(collected.values.map(_._2).sum))
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have counts any more. Is it OK to remove this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just log the sum of weights? it keeps the same info in the unweighted case and it's still sort of meaningful as 'number of examples' in the weighted case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1, I guess we need to add a new var count: Long to get the total count of dataset, since in other algs like LinearSVC,LogisticRegression, instr.logNumExamples logs the unweighted count;
2, Since more and more algs support weightCol, I think we may added a new method like instr.logSumOfWeights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess maybe leave the code this way for now and open a separate PR later on to add method instr.logSumOfWeights and use it in all the algs that support weight?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK to add new instr.log in other PR.
Here I prefer to keep instr.logNumExamples log the unweighted count, in order to keep it in sync with other algs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK to add new instr.log in other PR.
Here I prefer to keep instr.logNumExamples log the unweighted count, in order to keep it in sync with other algs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks!
I also added logSumOfWeights. I will update other algs that has weightCol once this PR is merged.

val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
val newCenters = collected.mapValues { case (sum, weightSum) =>
distanceMeasureInstance.centroid(sum, weightSum)
}

bcCenters.destroy()
Expand Down