Skip to content

Commit

Permalink
[SPARK-34189][ML] w2v findSynonyms optimization
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, use Guavaording instead of BoundedPriorityQueue;
2, use local variables;
3, avoid conversion: ml.vector -> mllib.vector

### Why are the changes needed?
this pr is about 30% faster than existing impl

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
existing testsuites

Closes #31276 from zhengruifeng/w2v_findSynonyms_opt.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
  • Loading branch information
zhengruifeng committed Jan 27, 2021
1 parent 3c68670 commit 2c4e4f8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -256,7 +255,7 @@ class Word2VecModel private[ml] (
*/
@Since("2.2.0")
def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
wordVectors.findSynonyms(vec, num)
wordVectors.findSynonyms(vec.toArray, num, None)
}

/**
Expand Down
93 changes: 45 additions & 48 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
Expand All @@ -37,7 +38,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -506,11 +506,14 @@ class Word2VecModel private[spark] (
wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray
}

// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
private lazy val wordVecNorms: Array[Float] = {
// wordVecInvNorms: Array of length numWords, each value being the inverse of
// Euclidean norm of the wordVector.
private lazy val wordVecInvNorms: Array[Float] = {
val size = vectorSize
Array.tabulate(numWords)(i => blas.snrm2(size, wordVectors, i * size, 1))
Array.tabulate(numWords) { i =>
val norm = blas.snrm2(size, wordVectors, i * size, 1)
if (norm != 0) 1 / norm else 0.0F
}
}

@Since("1.5.0")
Expand Down Expand Up @@ -552,7 +555,7 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector, num, Some(word))
findSynonyms(vector.toArray, num, Some(word))
}

/**
Expand All @@ -565,7 +568,7 @@ class Word2VecModel private[spark] (
*/
@Since("1.1.0")
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
findSynonyms(vector, num, None)
findSynonyms(vector.toArray, num, None)
}

/**
Expand All @@ -576,54 +579,48 @@ class Word2VecModel private[spark] (
* @param wordOpt optionally, a word to reject from the results list
* @return array of (word, cosineSimilarity)
*/
private def findSynonyms(
vector: Vector,
private[spark] def findSynonyms(
vector: Array[Double],
num: Int,
wordOpt: Option[String]): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val localVectorSize = vectorSize

val floatVec = vector.map(_.toFloat)
val vecNorm = blas.snrm2(localVectorSize, floatVec, 1)

val localWordList = wordList
val localNumWords = numWords
if (vecNorm == 0) {
Iterator.tabulate(num + 1)(i => (localWordList(i), 0.0))
.filterNot(t => wordOpt.contains(t._1))
.take(num)
.toArray
} else {
// Normalize input vector before blas.sgemv to avoid Inf value
blas.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1)

val cosineVec = Array.ofDim[Float](localNumWords)
blas.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize,
floatVec, 1, 0.0F, cosineVec, 1)

val localWordVecInvNorms = wordVecInvNorms
var i = 0
while (i < cosineVec.length) { cosineVec(i) *= localWordVecInvNorms(i); i += 1 }

val fVector = vector.toArray.map(_.toFloat)
val cosineVec = new Array[Float](numWords)
val alpha: Float = 1
val beta: Float = 0
// Normalize input vector before blas.sgemv to avoid Inf value
val vecNorm = blas.snrm2(vectorSize, fVector, 1)
if (vecNorm != 0.0f) {
blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)
}
blas.sgemv(
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)

var i = 0
while (i < numWords) {
val norm = wordVecNorms(i)
if (norm == 0.0f) {
cosineVec(i) = 0.0f
} else {
cosineVec(i) /= norm
val idxOrd = new GuavaOrdering[Int] {
override def compare(left: Int, right: Int): Int = {
Ordering[Float].compare(cosineVec(left), cosineVec(right))
}
}
i += 1
}

val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))

var j = 0
while (j < numWords) {
pq += Tuple2(wordList(j), cosineVec(j))
j += 1
idxOrd.greatestOf(Iterator.range(0, localNumWords).asJava, num + 1)
.iterator.asScala
.map(i => (localWordList(i), cosineVec(i).toDouble))
.filterNot(t => wordOpt.contains(t._1))
.take(num)
.toArray
}

val scored = pq.toSeq.sortBy(-_._2)

val filtered = wordOpt match {
case Some(w) => scored.filter(tup => w != tup._1)
case None => scored
}

filtered
.take(num)
.map { case (word, score) => (word, score.toDouble) }
.toArray
}

/**
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4688,7 +4688,7 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
+----+--------------------+
...
>>> model.findSynonymsArray("a", 2)
[('b', 0.015859870240092278), ('c', -0.5680795907974243)]
[('b', 0.015859...), ('c', -0.568079...)]
>>> from pyspark.sql.functions import format_number as fmt
>>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show()
+----+----------+
Expand Down

0 comments on commit 2c4e4f8

Please sign in to comment.