Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29967][ML][PYTHON] KMeans support instance weighting #26739

Closed
wants to merge 5 commits into from

Conversation

huaxingao
Copy link
Contributor

What changes were proposed in this pull request?

add weight support in KMeans

Why are the changes needed?

KMeans should support weighting

Does this PR introduce any user-facing change?

Yes. KMeans.setWeightCol

How was this patch tested?

Unit Tests

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have counts any more. Is it OK to remove this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just log the sum of weights? it keeps the same info in the unweighted case and it's still sort of meaningful as 'number of examples' in the weighted case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1, I guess we need to add a new var count: Long to get the total count of dataset, since in other algs like LinearSVC,LogisticRegression, instr.logNumExamples logs the unweighted count;
2, Since more and more algs support weightCol, I think we may added a new method like instr.logSumOfWeights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess maybe leave the code this way for now and open a separate PR later on to add method instr.logSumOfWeights and use it in all the algs that support weight?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK to add new instr.log in other PR.
Here I prefer to keep instr.logNumExamples log the unweighted count, in order to keep it in sync with other algs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK to add new instr.log in other PR.
Here I prefer to keep instr.logNumExamples log the unweighted count, in order to keep it in sync with other algs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks!
I also added logSumOfWeights. I will update other algs that has weightCol once this PR is merged.

@SparkQA
Copy link

SparkQA commented Dec 2, 2019

Test build #114735 has finished for PR 26739 at commit f6b44d0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks resaonable.

}

private[spark] def run(
data: RDD[Vector],
private[spark] def runWithweight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: runWithWeight

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just log the sum of weights? it keeps the same info in the unweighted case and it's still sort of meaningful as 'number of examples' in the weighted case

val clusterWeightSum = Array.fill(thisCenters.length)(0.0)

pointsAndWeights.foreach { case (point, weight) =>
var (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nit, but you can use val and then pass cost * weight to costAccum.add

@SparkQA
Copy link

SparkQA commented Dec 3, 2019

Test build #114794 has finished for PR 26739 at commit d26d83d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@huaxingao
Copy link
Contributor Author

@zhengruifeng

}

val instances: RDD[(OldVector, Double)] = dataset.select(
DatasetUtils.columnToVector(dataset, getFeaturesCol),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, why breaking this line?

dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w)
.rdd.map { ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the format.

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1, I guess we need to add a new var count: Long to get the total count of dataset, since in other algs like LinearSVC,LogisticRegression, instr.logNumExamples logs the unweighted count;
2, Since more and more algs support weightCol, I think we may added a new method like instr.logSumOfWeights

@@ -100,6 +100,18 @@ private[spark] abstract class DistanceMeasure extends Serializable {
new VectorWithNorm(sum)
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the above def centroid(sum: Vector, count: Long): VectorWithNorm still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is still used by BisecttingKMeans

// clusterWeightSum is needed to calculate cluster center
// cluster center =
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.fill(thisCenters.length)(0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, Array.ofDim[Double](thisCenters.length) or new Array[Double](thisCenters.length)

@SparkQA
Copy link

SparkQA commented Dec 4, 2019

Test build #114831 has finished for PR 26739 at commit f55917d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 4, 2019

Test build #114869 has finished for PR 26739 at commit c664833.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

}.collectAsMap()

if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
instr.foreach(_.logNumExamples(data.count()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, what about using a sc.longAccumulator to accumulate the count? like costAccum

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks!

@SparkQA
Copy link

SparkQA commented Dec 9, 2019

Test build #115046 has finished for PR 26739 at commit 2e9f683.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen srowen closed this in 1cac9b2 Dec 10, 2019
@srowen
Copy link
Member

srowen commented Dec 10, 2019

Merged to master

@huaxingao
Copy link
Contributor Author

Thanks! @srowen @zhengruifeng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants