Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18409][ML] LSH approxNearestNeighbors should use approxQuantile instead of sort #26415

Closed
wants to merge 5 commits into from

Conversation

huaxingao
Copy link
Contributor

What changes were proposed in this pull request?

LSHModel.approxNearestNeighbors sorts the full dataset on the hashDistance in order to find a threshold. This PR uses approxQuantile instead.

Why are the changes needed?

To improve performance.

Does this PR introduce any user-facing change?

Yes.
Changed LSH to make it extend HasRelativeError
LSH and LSHModel have new APIs setRelativeError/getRelativeError

How was this patch tested?

Existing tests. Also added a couple doc test in python to test newly added getRelativeError

val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
val quantile = numNearestNeighbors.toDouble / modelDataset.count()
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistUDF(col($(outputCol))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hashDistCol was defined earlier as hashDistUDF(col($(outputCol))). We can use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. I will update the code.


// Filter the dataset where the hash value is less than the threshold.
modelDataset.filter(hashDistCol <= hashThreshold)
modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the threshold is approximate, we still need to put a limit here to return no more numNearestNeighbors items.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess maybe we don't need to put a limit here? There is a limit later in the code.

    val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
    modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I guess it's also possible we get too few nearest neighbors. This is probably especially likely as the quantile is small. It may be a good idea to request too many nearest neighbors, to make the likelihood of returning too few pretty small.

On that note, is it meaningful to expose relativeError to callers? they want a number of nearest neighbors, not more or less. This is a pretty internal implementation detail. How about simply setting some fixed value, plus oversampling, which should virtually always give enough results yet gets some efficiency gains? I don't know what that value is; might bear a little testing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huaxingao ah, I saw it. yea, I think we don't need more limit. Probably we need oversampling to get more neighbors in case of what @srowen said.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Thanks for your review. I will remove the exposure of relativeError. Will also do some tests to find out a good number to use for oversampling.

@SparkQA
Copy link

SparkQA commented Nov 6, 2019

Test build #113340 has finished for PR 26415 at commit 6d3f79e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class _LSHParams(HasInputCol, HasOutputCol, HasRelativeError):

@@ -71,6 +72,10 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group expertSetParam */
@Since("3.0.0")
def setRelativeError(value: Double): this.type = set(relativeError, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @srowen. Seems this relative error does not mean too much for users in LSH. If they want more or less NN, they can change numNearestNeighbors.

@zhengruifeng
Copy link
Contributor

zhengruifeng commented Nov 9, 2019

Maybe we can add a new param like method, it support serveral option:
1, exact, existing method
2, approx, using approxQuantile
3, stack,also an exact method, using org.apache.spark.util.BoundedPriorityQueue or org.apache.spark.ml.recommendation.TopByKeyAggregator, it only supports a relative small numNearestNeighbors (maybe <1000, this threshold is related to RAM config) to avoid OOM.
numNearestNeighbors is usually a small number, and it should be much faster than approach 1&2.

@zhengruifeng
Copy link
Contributor

In some case (i.e. small datasets), we may want a exact result. so we may need to keep current method.

@srowen
Copy link
Member

srowen commented Nov 9, 2019

I think that's too much complexity for the caller, and changes the API. How about: start with a quantile that should yield 2x the number of results. Use a fixed relative error that still achieves some good speedup over a sort. While not enough results, double the quantile.

I guess we need to check, if not already, that there are more items than nearest neighbors to begin with (i.e. can't ask for 10 nearest neighbors from 8 items). Also, cap quantile at 1 (in which case return all items anyway)

@huaxingao
Copy link
Contributor Author

Sounds good to me. @srowen

@@ -112,7 +113,8 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
numNearestNeighbors: Int,
singleProbe: Boolean,
distCol: String): Dataset[_] = {
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
require(numNearestNeighbors > 0 && numNearestNeighbors <= dataset.count(), "The number of" +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This count() is expensive. Hm, I guess we have to compute it at least once, but, save it. Otherwise it is computed like 3 times per loop below.

}
var quantile = requestedNum.toDouble / modelDataset.count()
var hashThreshold = modelDatasetWithDist.stat
.approxQuantile(distCol, Array(quantile), 0.001)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 0.001 enough to give a good speedup? this can be pretty loose.

.approxQuantile(distCol, Array(quantile), 0.001)

// Filter the dataset where the hash value is less than the threshold.
filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, just ask for progressive higher quantiles of the original dataset. You don't want to both double the number requested and filter out things that were under the previous quantile, I think? just one, and doubling the number requested is much cheaper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that I have to filter out to find out if I can get enough number of the nearest neighbors. If not, I go back to the loop to double the quantile.

I am debating if I should continue this PR. The purpose of this PR is to improve performance. If the first round of the loop doesn't get enough number of the nearest neighbors and we have to go into the loop multiple times, the performance could be worse than the original code.

In the doc of approxNearestNeighbors, it says Given a large dataset and an item, approximately find at most k items which have the closest distance to the item. If this is true, then I guess we can just use a quantile that should yield 2x the number of results. If we get less than k elements, that's OK. However, the original implementation returns exact k elements. I am not sure if we can change the original behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind my comment; you're not reusing that filtered subset in the next loop, oops. This is correct.
Yes you have to count the filtered result each time.

Certainly the hope is that the first pass will yield enough neighbors, and the loop is just there as a fallback. I think it can be a win if the relative error tolerance is loose enough that 1 or even 2 approx quantile checks is faster than a full sort, but, I don't know how it plays out at scale.

You raise a good point; the docs below do say 'at most k elements', although the current implementation will return exactly k (assuming there are at least k points in the input). It repeats that twice. Hm. I'd also support just making one pass and picking a larger multiple of the request number of neighbors. But I don't mind the approach you have here (modulo a few optimizations above).

We could also optimize the requestedNum > modelDataset.count() case by just returning the whole input in that case rather than continuing with another pass.

I wonder if it's reasonable to construct a simple synthetic large input in a test, and test out whether it seems to be faster at some scale, and how you have to set the relative error to get that speedup, and how likely it is that it needs even a second loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think over. Seems to me that it's safe not to have the second loop, and we might not need to double the quantile in the beginning.

If the relative error sets to zero, the exact quantile is computed. Filtering out at this exact quantile gives the exact k items of numNearestNeighbors we ask for. Of course, this could be very expensive. If users want to trade off accuracy for performance, they can increase the relative error. In this case, using approximate quantile as threshold may return less number of data than numNearestNeighbors, but this is OK since the docs says approximately find at most k items. I tried to do some tests but it seems hard to find a universal relative error that works for all cases. Since the approximate quantile has the range of

 floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)

In the extreme unlucky case, the approximate quantile is (p - err) * N) and the elements that meet the requirement happen to be in the range of (p - err) * N to p * N. It seems to me the appropriate value of relative error depends on the distribution of the dataset and the users may need to adjust their own relative error depends on their dataset.

So I guess 1) we don't need the loop 2) expose the relative error to user?
I am OK with either double the quantile or not double it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take back my last sentence. We should double the quantile to make it safer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do it in one pass, we need (p - err) * N >= M, where M is the number of nearest neighbors. We'd chose the quantile p to be some multiple k of the desired quantile M/N. I think that works out to needing k >= err * N / M + 1. So maybe it's a matter of fixing some err that gets a good speedup, like 0.1, and then just picking quantile p = err + M / N. Makes sense, really: you have to increase the desired quantile by err.

How about something like that as a one-pass solution? It just also means checking that if p > 1 then just fall back to sort + take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

I did a small test to try to find a good value of relative error to use. I used existing BucketedRandomProjectionLSHSuite but made the dataset bigger: one with incremental values and one with random values.

val data1 = {
      for (i <- -200 until 200; j <- -200 until 200) yield Vectors.dense(i*10, j*10)
}
dataset1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys")
val data2 = {
      for (i <- -200 until 200; j <- -200 until 200) yield Vectors.dense(Random.nextInt, Random.nextInt)
}
dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")

So the dataset count is 160000 and I asked for 40000 nearest neighbors. I tested relative error 0.001, 0.005, 0.01, 0.05, 0.1 and 0.2 but didn't see much change in performance. Not sure if it is because the dataset is too small.

Dataset1 with incremental values

relative error 1st run 2nd run 3rd run 4th run 5th run average
0.001 7.61s 6.56s 6.39s 7.32s 7.49s 6.998s
0.005 6.39s 6.44s 6.62s 6.39s 7.54s 6.67s
0.01 6.56s 6.38s 7.34s 6.58s 6.68s 6.71s
0.05 6.51s 6.24s 7.24s 6.34s 6.54s 6.57s
0.1 6.28s 6.20s 6.34s 6.68s 7.07s 6.51s
0.2 6.39s 6.21s 6.25s 6.22s 6.30s 6.27s

Dataset2 with random values

relative error 1st run 2nd run 3rd run 4th run 5th run average
0.001 7.66s 6.77s 6.75s 7.78s 6.64s 7.11s
0.005 6.57s 6.61s 6.75s 7.42s 6.60s 6.79s
0.01 6.68s 7.44s 6.25s 6.69s 7.48s 6.91s
0.05 6.59s 6.54s 6.75s 6.62s 6.63s 6.62s
0.1 7.73s 6.58s 6.61s 6.68s 6.55s 6.83s
0.2 6.61s 6.62s 6.54s 6.51s 6.59s 6.57s

Seems to me that it may not be good to have a fixed value for relative error. For example, 0.05 might be a good relative error for the case of getting 40000 nearest neighbors from 160000 data, but it's too big for the case of getting 400 nearest neighbors from 160000 data. I guess I will pick
err = 0.2 M / N. Since p = err + M / N, we actually have p = 1.2 M/N. Hope this makes sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I wonder how this compares to sort + take?

I think the only issue with making err a function of M/N is that you won't maybe get much speedup when M/N is small, and that's the common case. But I guess you're saying the speed difference isn't that big. Relative error could be smaller than 0.001 though... like 10 nearest neighbors out of 1M = 0.000001. I wonder if that's notably slower?

The downside to err + M/N is that you filter in a lot more elements, although you subsequently sort and take anyway; they won't come back to the driver.

It is a good question whether it overall speeds things up but I think it will if the LSH has a not-tiny relative error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is much faster than sort + take.
So I guess we will pick 0.01 as relative error? Or 0.05?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's start with 0.05. Doesn't look like much gain after that.

@SparkQA
Copy link

SparkQA commented Nov 14, 2019

Test build #113811 has finished for PR 26415 at commit a06c6f4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 19, 2019

Test build #114064 has finished for PR 26415 at commit b56b489.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that looks good pending tests. Just a few minor comments.

// Filter the dataset where the hash value is less than the threshold.
modelDataset.filter(hashDistCol <= hashThreshold)
// Compute threshold to get around k elements.
var approxQuantile = numNearestNeighbors.toDouble / count + 0.05 // relative error = 0.05
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can this be val? also do you want a val relativeError = 0.05 and reuse it?
Might also be worth a short comment saying that the "err + M/N" quantile should be guaranteed to give enough neighbors.

@SparkQA
Copy link

SparkQA commented Nov 19, 2019

Test build #114109 has finished for PR 26415 at commit ddc278b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen srowen closed this in 56a65b9 Nov 20, 2019
@srowen
Copy link
Member

srowen commented Nov 20, 2019

Merged to master

@huaxingao
Copy link
Contributor Author

Thank you all for the help!

@huaxingao huaxingao deleted the spark-18409 branch November 20, 2019 16:01
@@ -112,7 +112,9 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
numNearestNeighbors: Int,
singleProbe: Boolean,
distCol: String): Dataset[_] = {
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
val count = dataset.count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for late reply.
1, Since approxNearestNeighbors is for query with only one key, it is supposed to be called many times in practice. Is val count = dataset.count() too expensive between calls? Can it be precomputed somewhere?
2, Do we need numNearestNeighbors <= count? refer to scala's and RDD's behavior:

scala> Array(1).take(2)
res1: Array[Int] = Array(1)
scala> val rdd = sc.range(0, 1)
rdd: org.apache.spark.rdd.RDD[Long] = MapPartitionsRDD[1] at range at <console>:24

scala> rdd.count
res0: Long = 1

scala> rdd.take
take   takeAsync   takeOrdered   takeSample

scala> rdd.take(3)
res1: Array[Long] = Array(0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, but it's hardly the most expensive operation here, before or after. I'd expect a caller has to cache the input for multiple calls to make sense anyway. I think the count + approxQuantile still beats sort + take.

I also take the point about take() but relaxing the requirement won't remove the need for count() here. Does it make more sense semantically? You're right it would be more consistent with the original implementation. Hm. I guess I'm neutral on it, but it's a valid question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the doc
"@param numNearestNeighbors The maximum number of nearest neighbors."
It imply that the output should not contain exact numNearestNeighbors items.

1, if we do not require numNearestNeighbors <= count, then if singleProbe is true, this count job can be avoided.
2, if we found that numNearestNeighbors <= count, we can directly return modelDatasetWithDist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It implies it could have fewer, yeah. This would be the only case to return fewer, yes, and you'd just return the input. I don't think singleProbe is used (and could be removed). But the count() is still needed below in any event.

I mean, if the caller cares, they can check count() vs number of neighbors before calling it a bunch of times anyway.


// Filter the dataset where the hash value is less than the threshold.
modelDataset.filter(hashDistCol <= hashThreshold)
// Compute threshold to get around k elements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here can we add a check of numNearestNeighbors in the future?
1, if it is a small number, using TopByKeyAggregator will skip above count job and return a exact threshold;
2, otherwise, using approxQuantile

Similar logic can be found in PCA&GMM, with small numFeatures&k, the impls are differenct from those with large numbers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From ALS? OK I could see that being an efficient alternative, to get top k per partition and merge them to a final top k -- if k isn't big. The existing impl works on keyed data but a simplified version would work here. I think that could be a valid further improvement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this logic is quite similar to ALS.

@huaxingao
Copy link
Contributor Author

Thanks for the comments. I will have follow-up for improvement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
6 participants