Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Nov 19, 2019
1 parent a06c6f4 commit b56b489
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.ml.feature

import scala.util.Random

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{IntParam, ParamValidators}
Expand Down Expand Up @@ -113,7 +112,8 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
numNearestNeighbors: Int,
singleProbe: Boolean,
distCol: String): Dataset[_] = {
require(numNearestNeighbors > 0 && numNearestNeighbors <= dataset.count(), "The number of" +
val count = dataset.count()
require(numNearestNeighbors > 0 && numNearestNeighbors <= count, "The number of" +
" nearest neighbors cannot be less than 1 or greater than the number of elements in dataset")
// Get Hash Value of the key
val keyHash = hashFunction(key)
Expand All @@ -139,23 +139,17 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
val hashDistCol = hashDistUDF(col($(outputCol)))

// Compute threshold to get around k elements.
var approxQuantile = numNearestNeighbors.toDouble / count + 0.05 // relative error = 0.05
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
var filtered: DataFrame = null
var requestedNum = numNearestNeighbors
do {
requestedNum *= 2
if (requestedNum > modelDataset.count()) {
requestedNum = modelDataset.count().toInt
}
var quantile = requestedNum.toDouble / modelDataset.count()
var hashThreshold = modelDatasetWithDist.stat
.approxQuantile(distCol, Array(quantile), 0.001)

if (approxQuantile >= 1) {
modelDatasetWithDist
} else {
val hashThreshold = modelDatasetWithDist.stat
.approxQuantile(distCol, Array(approxQuantile), 0.05) // relative error = 0.05
// Filter the dataset where the hash value is less than the threshold.
filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
}
while (filtered.count() < numNearestNeighbors)
filtered
}

// Get the top k nearest neighbor by their distance to the key
Expand Down

0 comments on commit b56b489

Please sign in to comment.