Skip to content

Commit

Permalink
[SPARK-29967][ML][PYTHON] KMeans support instance weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Dec 2, 2019
1 parent e04a634 commit f6b44d0
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 46 deletions.
35 changes: 28 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Expand Up @@ -31,17 +31,18 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

/**
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure with HasWeightCol {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -313,12 +314,32 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
} else {
lit(1.0)
}

val instances: RDD[(OldVector, Double)] = dataset.select(
DatasetUtils.columnToVector(dataset, getFeaturesCol),
w).rdd.map {
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
}

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
Expand All @@ -327,7 +348,7 @@ class KMeans @Since("1.5.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
maxIter, seed, tol, weightCol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
Expand All @@ -336,7 +357,7 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.run(instances, Option(instr))
val parentModel = algo.runWithweight(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset),
Expand Down
Expand Up @@ -84,8 +84,8 @@ private[spark] abstract class DistanceMeasure extends Serializable {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
axpy(weight, point.vector, sum)
}

/**
Expand All @@ -100,6 +100,18 @@ private[spark] abstract class DistanceMeasure extends Serializable {
new VectorWithNorm(sum)
}

/**
* Returns a centroid for a cluster given its `sum` vector and the weightSum of points.
*
* @param sum the `sum` for a cluster
* @param weightSum the weightSum of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, weightSum: Double): VectorWithNorm = {
scal(1.0 / weightSum, sum)
new VectorWithNorm(sum)
}

/**
* Returns two new centroids symmetric to the specified centroid applying `noise` with the
* with the specified `level`.
Expand Down Expand Up @@ -249,9 +261,9 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
override def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
axpy(1.0 / point.norm, point.vector, sum)
axpy(weight / point.norm, point.vector, sum)
}

/**
Expand Down
61 changes: 36 additions & 25 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -209,11 +209,14 @@ class KMeans private (
*/
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
run(data, None)
val instances: RDD[(Vector, Double)] = data.map {
case (point) => (point, 1.0)
}
runWithweight(instances, None)
}

private[spark] def run(
data: RDD[Vector],
private[spark] def runWithweight(
data: RDD[(Vector, Double)],
instr: Option[Instrumentation]): KMeansModel = {

if (data.getStorageLevel == StorageLevel.NONE) {
Expand All @@ -222,12 +225,15 @@ class KMeans private (
}

// Compute squared norms and cache them.
val norms = data.map(Vectors.norm(_, 2.0))
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
val norms = data.map { case (v, _) =>
Vectors.norm(v, 2.0)
}

val zippedData = data.zip(norms).map { case ((v, w), norm) =>
(new VectorWithNorm(v, norm), w)
}
zippedData.persist()
val model = runAlgorithm(zippedData, instr)
val model = runAlgorithmWithWeight(zippedData, instr)
zippedData.unpersist()

// Warn at the end of the run as well, for increased visibility.
Expand All @@ -241,8 +247,8 @@ class KMeans private (
/**
* Implementation of K-Means algorithm.
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
private def runAlgorithmWithWeight(
data: RDD[(VectorWithNorm, Double)],
instr: Option[Instrumentation]): KMeansModel = {

val sc = data.sparkContext
Expand All @@ -251,14 +257,17 @@ class KMeans private (

val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)

val dataVectorWithNorm = data.map(d => d._1)
val weights = data.map(d => d._2)

val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
initRandom(dataVectorWithNorm)
} else {
initKMeansParallel(data, distanceMeasureInstance)
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
Expand All @@ -278,30 +287,32 @@ class KMeans private (
val bcCenters = sc.broadcast(centers)

// Find the new centers
val collected = data.mapPartitions { points =>
val collected = data.mapPartitions { pointsAndWeights =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size

val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
val counts = Array.fill(thisCenters.length)(0L)

points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
// clusterWeightSum is needed to calculate cluster center
// cluster center =
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.fill(thisCenters.length)(0.0)

pointsAndWeights.foreach { case (point, weight) =>
var (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
cost *= weight
costAccum.add(cost)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight)
clusterWeightSum(bestCenter) += weight
}

counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator
}.reduceByKey { case ((sum1, clusterWeightSum1), (sum2, clusterWeightSum2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
(sum1, clusterWeightSum1 + clusterWeightSum2)
}.collectAsMap()

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}

val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}
Expand Down

0 comments on commit f6b44d0

Please sign in to comment.