Skip to content

Commit

Permalink
Code Review Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yunni committed Nov 15, 2016
1 parent c115ed3 commit 033ae5d
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
*
* Params for [[BucketedRandomProjectionLSH]].
*/
private[ml] trait BucketedRandomProjectionParams extends Params {
private[ml] trait BucketedRandomProjectionLSHParams extends Params {

/**
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
Expand Down Expand Up @@ -68,18 +68,18 @@ private[ml] trait BucketedRandomProjectionParams extends Params {
*/
@Experimental
@Since("2.1.0")
class BucketedRandomProjectionModel private[ml](
class BucketedRandomProjectionLSHModel private[ml](
override val uid: String,
@Since("2.1.0") private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionModel] with BucketedRandomProjectionParams {
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
val hashValues: Array[Double] = randUnitVectors.map({
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
})
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.grouped(1).map(Vectors.dense).toArray
}
}
Expand All @@ -100,7 +100,7 @@ class BucketedRandomProjectionModel private[ml](

@Since("2.1.0")
override def write: MLWriter = {
new BucketedRandomProjectionModel.BucketedRandomProjectionModelWriter(this)
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
}
}

Expand All @@ -111,8 +111,8 @@ class BucketedRandomProjectionModel private[ml](
* Euclidean distance metrics.
*
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
* distance space. The output will be vectors of configurable dimension. Hash value in the same
* dimension is calculated by the same hash function.
* distance space. The output will be vectors of configurable dimension. Hash values in the
* same dimension are calculated by the same hash function.
*
* References:
*
Expand All @@ -125,7 +125,8 @@ class BucketedRandomProjectionModel private[ml](
@Experimental
@Since("2.1.0")
class BucketedRandomProjectionLSH(override val uid: String)
extends LSH[BucketedRandomProjectionModel] with BucketedRandomProjectionParams with HasSeed {
extends LSH[BucketedRandomProjectionLSHModel]
with BucketedRandomProjectionLSHParams with HasSeed {

@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
Expand All @@ -138,7 +139,7 @@ class BucketedRandomProjectionLSH(override val uid: String)

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("random projection"))
this(Identifiable.randomUID("brp-lsh"))
}

/** @group setParam */
Expand All @@ -150,15 +151,17 @@ class BucketedRandomProjectionLSH(override val uid: String)
def setSeed(value: Long): this.type = set(seed, value)

@Since("2.1.0")
override protected[this] def createRawLSHModel(inputDim: Int): BucketedRandomProjectionModel = {
override protected[this] def createRawLSHModel(
inputDim: Int
): BucketedRandomProjectionLSHModel = {
val rand = new Random($(seed))
val randUnitVectors: Array[Vector] = {
Array.fill($(numHashTables)) {
val randArray = Array.fill(inputDim)(rand.nextGaussian())
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
}
}
new BucketedRandomProjectionModel(uid, randUnitVectors)
new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
}

@Since("2.1.0")
Expand All @@ -179,18 +182,18 @@ object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomP
}

@Since("2.1.0")
object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjectionModel] {
object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {

@Since("2.1.0")
override def read: MLReader[BucketedRandomProjectionModel] = {
new BucketedRandomProjectionModelReader
override def read: MLReader[BucketedRandomProjectionLSHModel] = {
new BucketedRandomProjectionLSHModelReader
}

@Since("2.1.0")
override def load(path: String): BucketedRandomProjectionModel = super.load(path)
override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)

private[BucketedRandomProjectionModel] class BucketedRandomProjectionModelWriter(
instance: BucketedRandomProjectionModel) extends MLWriter {
private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
instance: BucketedRandomProjectionLSHModel) extends MLWriter {

// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
private case class Data(randUnitVectors: Matrix)
Expand All @@ -208,21 +211,22 @@ object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjection
}
}

private class BucketedRandomProjectionModelReader
extends MLReader[BucketedRandomProjectionModel] {
private class BucketedRandomProjectionLSHModelReader
extends MLReader[BucketedRandomProjectionLSHModel] {

/** Checked against metadata when loading model */
private val className = classOf[BucketedRandomProjectionModel].getName
private val className = classOf[BucketedRandomProjectionLSHModel].getName

override def load(path: String): BucketedRandomProjectionModel = {
override def load(path: String): BucketedRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
.select("randUnitVectors")
.head()
val model = new BucketedRandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
randUnitVectors.rowIter.toArray)

DefaultParamsReader.getAndSetParams(model, metadata)
model
Expand Down
31 changes: 6 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import org.apache.spark.sql.types._
*/
private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
/**
* Param for the dimension of LSH OR-amplification.
* Param for the number of hash tables used in LSH OR-amplification.
*
* LSH OR-amplification can be used to reduce the false negative rate. The higher the dimension
* is, the lower the false negative rate.
* LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
* param lead to a reduced false negative rate, at the expense of added computational complexity.
* @group param
*/
final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " +
Expand Down Expand Up @@ -66,7 +66,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
self: T =>

/**
* The hash function of LSH, mapping an input feature to multiple vectors
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
*/
protected[ml] val hashFunction: Vector => Array[Vector]
Expand Down Expand Up @@ -99,26 +99,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
validateAndTransformSchema(schema)
}

/**
* Given a large dataset and an item, approximately find at most k items which have the closest
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
* transformed data when necessary.
*
* This method implements two ways of fetching k nearest neighbors:
* - Single-probe: Fast, return at most k elements (Probing only one buckets)
* - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key)
*
* Currently it is made private since more discussion is needed for Multi-probe
*
* @param dataset the dataset to search for nearest neighbors of the key
* @param key Feature vector representing the item to search for
* @param numNearestNeighbors The maximum number of nearest neighbors
* @param singleProbe True for using single-probe; false for multi-probe
* @param distCol Output column for storing the distance between each result row and the key
* @return A dataset containing at most k items closest to the key. A distCol is added to show
* the distance between each row and the key.
*/
// TODO: Fix the MultiProbe NN Search in SPARK-18454
private[feature] def approxNearestNeighbors(
dataset: Dataset[_],
key: Vector,
Expand Down Expand Up @@ -179,7 +160,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
* @return A dataset containing at most k items closest to the key. A distCol is added to show
* the distance between each row and the key.
*/
private[feature] def approxNearestNeighbors(
def approxNearestNeighbors(
dataset: Dataset[_],
key: Vector,
numNearestNeighbors: Int,
Expand Down
77 changes: 41 additions & 36 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,31 @@ import org.apache.spark.sql.types.StructType
* :: Experimental ::
*
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is
* a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`:
* `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries`
* a perfect hash function for a specific set `S` with cardinality equal to `numEntries`:
* `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries`
*
* @param numEntries The number of entries of the hash functions.
* @param randCoefficients An array of random coefficients, each used by one hash function.
*/
@Experimental
@Since("2.1.0")
class MinHashModel private[ml] (
class MinHashLSHModel private[ml](
override val uid: String,
@Since("2.1.0") private[ml] val numEntries: Int,
@Since("2.1.0") private[ml] val randCoefficients: Array[Int])
extends LSHModel[MinHashModel] {
private[ml] val numEntries: Int,
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
elems: Vector => {
require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
val elemsList = elems.toSparse.indices.toList
val hashValues = randCoefficients.map({ randCoefficient: Int =>
elemsList.map({ elem: Int =>
(1 + elem) * randCoefficient.toLong % MinHashLSH.prime % numEntries
}).min.toDouble
val hashValues = randCoefficients.map({ case (a: Int, b: Int) =>
elemsList.map { elem: Int =>
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME % numEntries
}.min.toDouble
})
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.grouped(1).map(Vectors.dense).toArray
}
}
Expand All @@ -74,7 +74,7 @@ class MinHashModel private[ml] (
@Since("2.1.0")
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
// TODO: This hashDistance function is controversial. Requires more discussion.
// TODO: This hashDistance function requires more discussion in SPARK-18454
x.zip(y).map(vectorPair =>
vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
).min
Expand All @@ -84,7 +84,7 @@ class MinHashModel private[ml] (
override def copy(extra: ParamMap): this.type = defaultCopy(extra)

@Since("2.1.0")
override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
}

/**
Expand All @@ -93,17 +93,17 @@ class MinHashModel private[ml] (
* LSH class for Jaccard distance.
*
* The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
* `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
* means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
* Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
* as binary "1" values.
* `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))`
* means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and
* 5. Also, any input vector must have at least 1 non-zero index, and all non-zero values are
* treated as binary "1" values.
*
* References:
* [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]]
*/
@Experimental
@Since("2.1.0")
class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSeed {
class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed {

@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
Expand All @@ -116,21 +116,23 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("min hash"))
this(Identifiable.randomUID("mh-lsh"))
}

/** @group setParam */
@Since("2.1.0")
def setSeed(value: Long): this.type = set(seed, value)

@Since("2.1.0")
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
require(inputDim <= MinHashLSH.prime / 2,
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.prime / 2}.")
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = {
require(inputDim <= MinHashLSH.HASH_PRIME,
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
val rand = new Random($(seed))
val numEntry = inputDim * 2
val randCoofs: Array[Int] = Array.fill($(numHashTables))(1 + rand.nextInt(MinHashLSH.prime - 1))
new MinHashModel(uid, numEntry, randCoofs)
val numEntry = inputDim
val randCoefs: Array[(Int, Int)] = Array.fill(2 * $(numHashTables)) {
(1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
}
new MinHashLSHModel(uid, numEntry, randCoefs)
}

@Since("2.1.0")
Expand All @@ -146,46 +148,49 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
@Since("2.1.0")
object MinHashLSH extends DefaultParamsReadable[MinHashLSH] {
// A large prime smaller than sqrt(2^63 − 1)
private[ml] val prime = 2038074743
private[ml] val HASH_PRIME = 2038074743

@Since("2.1.0")
override def load(path: String): MinHashLSH = super.load(path)
}

@Since("2.1.0")
object MinHashModel extends MLReadable[MinHashModel] {
object MinHashLSHModel extends MLReadable[MinHashLSHModel] {

@Since("2.1.0")
override def read: MLReader[MinHashModel] = new MinHashModelReader
override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader

@Since("2.1.0")
override def load(path: String): MinHashModel = super.load(path)
override def load(path: String): MinHashLSHModel = super.load(path)

private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel)
extends MLWriter {

private case class Data(numEntries: Int, randCoefficients: Array[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.numEntries, instance.randCoefficients)
val data = Data(instance.numEntries, instance.randCoefficients
.flatMap(tuple => Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class MinHashModelReader extends MLReader[MinHashModel] {
private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] {

/** Checked against metadata when loading model */
private val className = classOf[MinHashModel].getName
private val className = classOf[MinHashLSHModel].getName

override def load(path: String): MinHashModel = {
override def load(path: String): MinHashLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
val numEntries = data.getAs[Int](0)
val randCoefficients = data.getAs[Seq[Int]](1).toArray
val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
val randCoefficients = data.getAs[Seq[Int]](1).grouped(2)
.map(tuple => (tuple(0), tuple(1))).toArray
val model = new MinHashLSHModel(metadata.uid, numEntries, randCoefficients)

DefaultParamsReader.getAndSetParams(model, metadata)
model
Expand Down

0 comments on commit 033ae5d

Please sign in to comment.