New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spark-18408][ML] API Improvements for LSH #15874
Conversation
…ni/spark into SPARK-18334-yunn-minhash-bug
Test build #68601 has finished for PR 15874 at commit
|
Thanks @Yunni, I can take a look at this today. I would prefer to separate the addition of "AND-amplification" into another PR since the other changes I believe we'd like to get into 2.1, whereas the "AND-amplification" is not as urgent. That will make this easier to review and get merged in a timely manner. I'm open to other arguments, though. |
Thanks, @sethah. I have reverted "AND-amplification" related changes. PTAL. |
Test build #68625 has finished for PR 15874 at commit
|
Can you please add "[ML]" to the PR title? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't make it through a full pass, but I'm leaving these comments here for now.
@@ -144,12 +152,12 @@ class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed { | |||
} | |||
|
|||
@Since("2.1.0") | |||
object MinHash extends DefaultParamsReadable[MinHash] { | |||
object MinHashLSH extends DefaultParamsReadable[MinHashLSH] { | |||
// A large prime smaller than sqrt(2^63 − 1) | |||
private[ml] val prime = 2038074743 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We typically use all caps for constants like these. I prefer MinHashLSH.HASH_PRIME
or MinHashLSH.PRIME_MODULUS
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -106,22 +123,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |||
* transformed data when necessary. | |||
* | |||
* This method implements two ways of fetching k nearest neighbors: | |||
* - Single Probing: Fast, return at most k elements (Probing only one buckets) | |||
* - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key) | |||
* - Single-probe: Fast, return at most k elements (Probing only one buckets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Probing only one bucket"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* - Single-probe: Fast, return at most k elements (Probing only one buckets) | ||
* - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key) | ||
* | ||
* Currently it is made private since more discussion is needed for Multi-probe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the point here. Are you trying to make the approxNearestNeighbors
method completely private? There is still a public overload of this method - which now shows up as the only method in the docs and just says "overloaded method for approxNearestNeighbors". This doc above does not show up.
As a general rule, we should always generate and closely inspect the docs to make sure that they are what we intend and that they make sense from an end user's perspective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant to make the approxNearestNeighbors
public in Line 163 LSH.scala. I copied all the docs there except the difference of SingleProbe/MultiProbe.
@@ -35,26 +35,26 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { | |||
/** | |||
* Param for the dimension of LSH OR-amplification. | |||
* | |||
* In this implementation, we use LSH OR-amplification to reduce the false negative rate. The | |||
* higher the dimension is, the lower the false negative rate. | |||
* LSH OR-amplification can be used to reduce the false negative rate. The higher the dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are still using the word "dimension" here. It might also be useful to add that reducing false negatives comes at the cost of added computation. How does this sound?
* Param for the number of hash tables used in LSH OR-amplification.
*
* LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
* param lead to a reduced false negative rate, at the expense of added computational complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds good.
@@ -66,10 +66,10 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] | |||
self: T => | |||
|
|||
/** | |||
* The hash function of LSH, mapping a predefined KeyType to a Vector | |||
* The hash function of LSH, mapping an input feature to multiple vectors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"mapping an input feature vector to multiple hash vectors."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@@ -46,21 +42,23 @@ import org.apache.spark.sql.types.StructType | |||
@Since("2.1.0") | |||
class MinHashModel private[ml] ( | |||
override val uid: String, | |||
@Since("2.1.0") val numEntries: Int, | |||
@Since("2.1.0") val randCoefficients: Array[Int]) | |||
@Since("2.1.0") private[ml] val numEntries: Int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no since tags for private values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
val rand = new Random($(seed)) | ||
val numEntry = inputDim * 2 | ||
val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1)) | ||
val randCoofs: Array[Int] = Array.fill($(numHashTables))(1 + rand.nextInt(MinHashLSH.prime - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
randCoefs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
// Since it's generated by hashing, it will be a pair of dense vectors. | ||
x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min | ||
// TODO: This hashDistance function is controversial. Requires more discussion. | ||
x.zip(y).map(vectorPair => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, I'm quite unsure, but this does not look to me like what what was discussed here. @jkbradley Can you confirm this is what you wanted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's still under discussion, I am not sure which hashDistance
to leave in the code. Do you just want me to change to the hashDistance
@jkbradley suggested?
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { | ||
transform(dataset) | ||
} else { | ||
dataset.toDF() | ||
} | ||
modelDataset.select( | ||
struct(col("*")).as(inputName), | ||
explode(vectorToMap(col($(outputCol)))).as(explodeCols)) | ||
struct(col("*")).as(inputName), posexplode(col($(outputCol))).as(explodeCols)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well here's a fun one. When I run this test:
test("memory leak test") {
val numDim = 50
val data = {
for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
yield Vectors.sparse(numDim, Seq((i, j.toDouble)))
}
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
// Project from 100 dimensional Euclidean Space to 10 dimensions
val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(10)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(2.5)
.setSeed(12345)
val model = brp.fit(df)
val joined = model.approxSimilarityJoin(df, df, Double.MaxValue, "distCol")
joined.show()
}
I get the following error:
[info] - BucketedRandomProjectionLSH with high dimension data: test of LSH property *** FAILED *** (7 seconds, 568 milliseconds)
[info] org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 (TID 205, localhost, executor driver): org.apache.spark.SparkException: Managed memory leak detected; size = 33816576 bytes, TID = 205
[info] at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:295)
[info] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
[info] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
[info] at java.lang.Thread.run(Thread.java:745)
Could you run the same test and see if you get an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not get the same error, and the result shows successfully. Could you provide me with the full stack of the Exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I still get it. Did you use the code above? It's not directly copy pasted from the existing tests.
- memory leak test *** FAILED *** (8 seconds, 938 milliseconds)
[info] org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 3.0 failed 1 times, most recent failure: Lost task 0.0 in stage 3.0 (TID 204, localhost, executor driver): org.apache.spark.SparkException: Managed memory leak detected; size = 33816576 bytes, TID = 204
[info] at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:295)
[info] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
[info] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
[info] at java.lang.Thread.run(Thread.java:745)
[info]
[info] Driver stacktrace:
[info] at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
[info] at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
[info] at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
[info] at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
[info] at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
[info] at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
[info] at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
[info] at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
[info] at scala.Option.foreach(Option.scala:257)
[info] at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
[info] at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
[info] at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
[info] at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
[info] at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
[info] at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
[info] at org.apache.spark.SparkContext.runJob(SparkContext.scala:1896)
[info] at org.apache.spark.SparkContext.runJob(SparkContext.scala:1909)
[info] at org.apache.spark.SparkContext.runJob(SparkContext.scala:1922)
[info] at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:333)
[info] at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
[info] at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2323)
[info] at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
[info] at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2717)
[info] at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2322)
[info] at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2329)
[info] at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2065)
[info] at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2064)
[info] at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2747)
[info] at org.apache.spark.sql.Dataset.head(Dataset.scala:2064)
[info] at org.apache.spark.sql.Dataset.take(Dataset.scala:2279)
[info] at org.apache.spark.sql.Dataset.showString(Dataset.scala:247)
[info] at org.apache.spark.sql.Dataset.show(Dataset.scala:596)
[info] at org.apache.spark.sql.Dataset.show(Dataset.scala:555)
[info] at org.apache.spark.sql.Dataset.show(Dataset.scala:564)
[info] at org.apache.spark.ml.feature.BucketedRandomProjectionLSHSuite$$anonfun$3.apply$mcV$sp(BucketedRandomProjectionLSHSuite.scala:74)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I copied your code to BucketedRandomProjectionLSHSuite.scala
and it runs fine for me with the following output:
+--------------------+--------------------+------------------+
| datasetA| datasetB| distCol|
+--------------------+--------------------+------------------+
|[(50,[0],[-2.0]),...|[(50,[0],[-1.0]),...| 1.0|
|[(50,[4],[-1.0]),...|[(50,[23],[-1.0])...|1.4142135623730951|
|[(50,[5],[-1.0]),...|[(50,[32],[-1.0])...|1.4142135623730951|
|[(50,[7],[1.0]),W...|[(50,[47],[1.0]),...|1.4142135623730951|
|[(50,[7],[2.0]),W...|[(50,[26],[-2.0])...|2.8284271247461903|
|[(50,[8],[-2.0]),...|[(50,[1],[-1.0]),...| 2.23606797749979|
|[(50,[8],[-1.0]),...|[(50,[23],[-2.0])...| 2.23606797749979|
|[(50,[10],[-1.0])...|[(50,[7],[2.0]),W...| 2.23606797749979|
|[(50,[10],[-1.0])...|[(50,[13],[2.0]),...| 2.23606797749979|
|[(50,[11],[-1.0])...|[(50,[39],[2.0]),...| 2.23606797749979|
|[(50,[12],[-2.0])...|[(50,[28],[1.0]),...| 2.23606797749979|
|[(50,[12],[-2.0])...|[(50,[29],[-1.0])...| 2.23606797749979|
|[(50,[13],[1.0]),...|[(50,[2],[-2.0]),...| 2.23606797749979|
|[(50,[14],[1.0]),...|[(50,[33],[2.0]),...| 2.23606797749979|
|[(50,[14],[2.0]),...|[(50,[28],[2.0]),...|2.8284271247461903|
|[(50,[15],[-1.0])...|[(50,[38],[-1.0])...|1.4142135623730951|
|[(50,[18],[1.0]),...|[(50,[8],[-1.0]),...|1.4142135623730951|
|[(50,[18],[1.0]),...|[(50,[12],[-2.0])...| 2.23606797749979|
|[(50,[18],[2.0]),...|[(50,[43],[1.0]),...| 2.23606797749979|
|[(50,[20],[1.0]),...|[(50,[25],[-1.0])...|1.4142135623730951|
+--------------------+--------------------+------------------+
only showing top 20 rows
Let me see if the test can pass jenkins or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at line 292 of Executor.scala
, it shows this is just a OOM exception of Dataframe. That's the reason why it behaves different on our machines and Jenkins.
model.approxSimilarityJoin(df, df, Double.MaxValue, "distCol")
returns near 40000 rows when threshold = Double.MaxValue
If you reduce the numDim
, the test will pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #15916
* This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean | ||
* distance metrics. | ||
* This [[BucketedRandomProjectionLSH]] implements Locality Sensitive Hashing functions for | ||
* Euclidean distance metrics. | ||
* | ||
* The input is dense or sparse vectors, each of which represents a point in the Euclidean | ||
* distance space. The output will be vectors of configurable dimension. Hash value in the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Hash values in the same dimension are calculated"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Test build #68678 has finished for PR 15874 at commit
|
Test build #68683 has finished for PR 15874 at commit
|
Test build #68689 has finished for PR 15874 at commit
|
I'll take a look |
override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = { | ||
override protected[this] def createRawLSHModel( | ||
inputDim: Int | ||
): BucketedRandomProjectionLSHModel = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: This should go on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* Reference: | ||
* [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]] | ||
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is | ||
* a perfect hash function for a specific set `S` with cardinality equal to `numEntries`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking more at the Wikipedia entry, I'm still doubtful about whether this is a perfect hash function. It looks like the first of 2 parts in the construction of a perfect hash function. I also still don't see why mentioning "perfect hash functions" will help users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, removed.
* [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]] | ||
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is | ||
* a perfect hash function for a specific set `S` with cardinality equal to `numEntries`: | ||
* `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries` | ||
* | ||
* @param numEntries The number of entries of the hash functions. | ||
* @param randCoefficients An array of random coefficients, each used by one hash function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
* Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated | ||
* as binary "1" values. | ||
* `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))` | ||
* means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the old terminology since all non-zero values are treated the same. How about "This set contains elements 2, 3, and 5." ?
model.set(model.bucketLength, 0.5) | ||
val res = model.hashFunction(Vectors.dense(1.23, 4.56)) | ||
assert(res.equals(Vectors.dense(9.0, 2.0))) | ||
assert(res(0).equals(Vectors.dense(9.0))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also assert res.length == 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
val key = Vectors.dense(1.2, 3.4) | ||
|
||
val brp = new BucketedRandomProjectionLSH() | ||
.setNumHashTables(20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to set some of these Params here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
.setSeed(12344) | ||
|
||
val data = { | ||
for (i <- 0 to 95) yield Vectors.sparse(Int.MaxValue, (i until i + 5).map((_, 1.0))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to "0 to 2." I like keeping tests minimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Other comments: MinHash Looking yet again at this, I think it's using a technically incorrect hash function. It is not a perfect hash function. It can hash 2 input indices to the same hash bucket. (As before, check out the Wikipedia page to see how it's missing the 2nd stage in the construction of a perfect hash function.) If we want to fix this, then we could alternatively precompute a random permutation of indices, which also serves as a perfect hash function. That said, perhaps it does not matter in practice. If numEntries (inputDim) is large enough, then the current hash function will probably behave similarly to a perfect hash function. approxNearestNeighbors This is still not what I proposed, even for single-probe queries. It will still have the potential to consider (and sort) a number of candidates much larger than numNearestNeighbors. Since we're running out of time, I'm fine with leaving it as is for now and just changing the behavior for the next release. However, can you please add a note to the method documentation that this method is experimental and will likely change behavior in the next release? Thanks! |
…8408-yunn-api-improvements
@Yunni Thanks for the updates! I don't think we should include AND-amplification for 2.1 since we're already in QA. But it'd be nice to get it in 2.2. Also, 2.2 will give us plenty of time to discuss distributed approxNearestNeighbors. FYI: I asked around about the managed memory leak warning/failure. It is usually just a warning, but some test suites are set to fail upon seeing that warning. That was apparently useful for debugging some memory leak bugs but is not cause to worry. I recommend we make tests small enough to avoid them for now. If the warning becomes an issue, we could configure ML suites to ignore the warning, or we could even downgrade the warning to a lower-priority log message for all of Spark. This LGTM. What does everyone think? For 2.1, the main thing I'd still like to do is to send a PR to clarify terminology. That could be done in [https://github.com//pull/15795] |
I will take a look. |
@jkbradley Awesome, thanks so much! :) Now that the API is finalized, I will work on the User Doc |
}).min.toDouble | ||
val hashValues = randCoefficients.map({ case (a: Int, b: Int) => | ||
elemsList.map { elem: Int => | ||
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME % numEntries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still looking at it, but I don't think this is correct. Why do we tack on % numEntries
here. Could you point me to a resource? The paper linked above (and many other references that I've seen) use (ax + b) mod p
where p is a large prime.
I see the formula listed under the wiki article for perfect hashing functions lists (kx mod p) mod n
, but that's not the full picture. They are referencing a paper which simply uses that formula as the first part of multilevel scheme.
If it's helpful - this seems to be the original paper on MinHash. The author mentions that
This is further explored in [5] where it is shown
that random linear transformations are likely to suffice in practice.
Reference 5 is here, which seems to be a more concise version of your reference. In that paper, they describe (ax + b mod p)
.
@jkbradley Thanks for checking that, that is the conclusion I drew as well. |
Hi @sethah, modulo the cardinality of hash universe does not really affect the independence since p is a mach larger prime. For example, in http://people.csail.mit.edu/mip/papers/kwise-lb/kwise-lb.pdf, they use "mod b". Since we don't care about the hash universe here, I am OK with changing to |
…m/Yunni/spark into SPARK-18408-yunn-api-improvements
Test build #68880 has finished for PR 15874 at commit
|
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the | ||
* transformed data when necessary. | ||
* | ||
* NOTE: This method is experimental and will likely change behavior in the next release. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: use @note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* | ||
* NOTE: This method is experimental and will likely change behavior in the next release. | ||
* | ||
* @param dataset the dataset to search for nearest neighbors of the key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Capitalize first words and add periods to all fields
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* @param key Feature vector representing the item to search for | ||
* @param numNearestNeighbors The maximum number of nearest neighbors | ||
* @param distCol Output column for storing the distance between each result row and the key | ||
* @return A dataset containing at most k items closest to the key. A distCol is added to show |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: A column "distCol" is added ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
* where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*` | ||
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is | ||
* picked from a hash family for a specific set `S` with cardinality equal to `numEntries`: | ||
* `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove the numEntries
part here if we have removed it from the code. I still haven't gotten around to properly digging into this. For now, I'd like to change the second sentence to: "Each hash function is picked from the following family of hash functions, where a_i
and b_i
are randomly chosen integers less than prime
:"
And I prefer this paper: "http://www.combinatorics.org/ojs/index.php/eljc/article/download/v7i1r26/pdf" as a reference because it is concise and easier to parse. That said, since it's a direct download link we could maybe not put the link in the doc, and just list the reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
elemsList.map({elem: Int => | ||
(1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries | ||
}).min.toDouble | ||
val hashValues = randCoefficients.map({ case (a: Int, b: Int) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: the "({" is redundant. Also, I don't think the type annotations are necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the parentheses.
}) | ||
Vectors.dense(hashValues) | ||
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450 | ||
hashValues.grouped(1).map(Vectors.dense).toArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not: hashValues.map(Vectors.dense(_))
? We can just add the grouping later when it's needed. Same for BRP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vectors.dense takes an array instead of a single number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is an alternate constructor which takes a single (or multiple values). I guess I just think the grouped(1)
is a bit confusing, not really an efficiency concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It's dense(firstValue: Double, otherValues: Double*)
.
val numEntry = inputDim * 2 | ||
val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1)) | ||
new MinHashModel(uid, numEntry, randCoofs) | ||
val randCoefs: Array[(Int, Int)] = Array.fill(2 * $(numHashTables)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it 2 * $(numHashTables)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this was an error before, we should have a unit test that catches this. Basically, the output of transform should be a vector of length equal to numHashTables
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests added in LSHTest.scala
def checkModelData( | ||
model: BucketedRandomProjectionLSHModel, | ||
model2: BucketedRandomProjectionLSHModel | ||
): Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move this up a line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
val key: Vector = Vectors.sparse(100, | ||
(0 until 100).filter(_.toString.contains("1")).map((_, 1.0))) | ||
|
||
val model = mh.fit(dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be easier to just create a model artificially, then test the edge case. That will speed up the tests. Same in other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Test build #69012 has finished for PR 15874 at commit
|
Test build #69020 has finished for PR 15874 at commit
|
Test build #69031 has finished for PR 15874 at commit
|
@sethah PTAL |
LGTM. I think we've made JIRAs for all of the follow-up items. Thanks! |
Thanks @sethah ! Your comment was very helpful and detailed :-) |
…8408-yunn-api-improvements
@jkbradley If you don't have more comments, can we merge this because I need to change the examples in #15795 ? |
Test build #69215 has finished for PR 15874 at commit
|
LGTM |
Well, I'm having trouble merging b/c of bad wifi during travel. Ping @yanboliang @MLnick @mengxr would one of you mind merging this with master and branch-2.1? @sethah and I have both given LGTMs. Thanks! |
## What changes were proposed in this pull request? (1) Change output schema to `Array of Vector` instead of `Vectors` (2) Use `numHashTables` as the dimension of Array (3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH` (4) Make `randUnitVectors/randCoefficients` private (5) Make Multi-Probe NN Search and `hashDistance` private for future discussion Saved for future PRs: (1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR. (2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one. ## How was this patch tested? Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation. Author: Yun Ni <yunn@uber.com> Author: Yunni <Euler57721@gmail.com> Closes #15874 from Yunni/SPARK-18408-yunn-api-improvements. (cherry picked from commit 05f7c6f) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
## What changes were proposed in this pull request? (1) Change output schema to `Array of Vector` instead of `Vectors` (2) Use `numHashTables` as the dimension of Array (3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH` (4) Make `randUnitVectors/randCoefficients` private (5) Make Multi-Probe NN Search and `hashDistance` private for future discussion Saved for future PRs: (1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR. (2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one. ## How was this patch tested? Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation. Author: Yun Ni <yunn@uber.com> Author: Yunni <Euler57721@gmail.com> Closes apache#15874 from Yunni/SPARK-18408-yunn-api-improvements.
## What changes were proposed in this pull request? (1) Change output schema to `Array of Vector` instead of `Vectors` (2) Use `numHashTables` as the dimension of Array (3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH` (4) Make `randUnitVectors/randCoefficients` private (5) Make Multi-Probe NN Search and `hashDistance` private for future discussion Saved for future PRs: (1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR. (2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one. ## How was this patch tested? Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation. Author: Yun Ni <yunn@uber.com> Author: Yunni <Euler57721@gmail.com> Closes apache#15874 from Yunni/SPARK-18408-yunn-api-improvements.
What changes were proposed in this pull request?
(1) Change output schema to
Array of Vector
instead ofVectors
(2) Use
numHashTables
as the dimension of Array(3) Rename
RandomProjection
toBucketedRandomProjectionLSH
,MinHash
toMinHashLSH
(4) Make
randUnitVectors/randCoefficients
private(5) Make Multi-Probe NN Search and
hashDistance
private for future discussionSaved for future PRs:
(1) AND-amplification and
numHashFunctions
as the dimension of Vector are saved for a future PR.(2)
hashDistance
and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one.How was this patch tested?
Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation.