Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33609][ML] word2vec reduce broadcast size #30548

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 19 additions & 13 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,33 @@ class Word2VecModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity).toMap // mapValues doesn't return a serializable map (SI-7005)
val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray)
val word2Vec = udf { sentence: Seq[String] =>

val bcModel = dataset.sparkSession.sparkContext.broadcast(this.wordVectors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance this makes more sense. But, we can't call bcModel.destroy() at the end here anyway. So we have this broadcast we can't explicitly close no matter what. And now I guess, this would re-broadcast every time? that could be bad. What do you think? I know this is not consistent in the code either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And now I guess, this would re-broadcast every time? that could be bad. What do you think?

I agree. I perfer not using broadcasting in transform, but this may need more discussion. we can keep current behavior for now.

GBT models are also broadcasted in this way for performance since SPARK-7127.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but i'd back out this part of the change

val size = $(vectorSize)
val emptyVec = Vectors.sparse(size, Array.emptyIntArray, Array.emptyDoubleArray)
val transformer = udf { sentence: Seq[String] =>
if (sentence.isEmpty) {
emptyVec
} else {
val sum = Vectors.zeros(d)
val wordIndices = bcModel.value.wordIndex
val wordVectors = bcModel.value.wordVectors
val array = Array.ofDim[Double](size)
var count = 0
sentence.foreach { word =>
bVectors.value.get(word).foreach { v =>
BLAS.axpy(1.0, v, sum)
wordIndices.get(word).foreach { index =>
val offset = index * size
var i = 0
while (i < size) { array(i) += wordVectors(offset + i); i += 1 }
}
count += 1
}
BLAS.scal(1.0 / sentence.size, sum)
sum
val vec = Vectors.dense(array)
BLAS.scal(1.0 / count, vec)
vec
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))),

dataset.withColumn($(outputCol), transformer(col($(inputCol))),
outputSchema($(outputCol)).metadata)
}

Expand Down
27 changes: 12 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,22 +502,15 @@ class Word2VecModel private[spark] (
private val vectorSize = wordVectors.length / numWords

// wordList: Ordered list of words obtained from wordIndex.
private val wordList: Array[String] = {
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
wl.toArray
private lazy val wordList: Array[String] = {
wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray
}

// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
private val wordVecNorms: Array[Float] = {
val wordVecNorms = new Array[Float](numWords)
var i = 0
while (i < numWords) {
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid this slicing

wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
i += 1
}
wordVecNorms
private lazy val wordVecNorms: Array[Float] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does this save, if it only happens once and has to happen to use the model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this var wordVecNorms is only used in method findSynonyms in the .mllib.w2v; however, this findSynonyms is never used in the .ml side. So I think we can make it lazy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK fair enough. There are use cases here that would never need this calculated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

however, this findSynonyms is never used in the .ml side. So I think we can make it lazy.

I am wrong. this var wordVecNorms is used in methods findSynonyms and findSynonymsArray in the .ml side. Since it is not used in transform, so we can still mark it lazy

val size = vectorSize
Array.tabulate(numWords)(i => blas.snrm2(size, wordVectors, i * size, 1))
}

@Since("1.5.0")
Expand All @@ -538,9 +531,13 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def transform(word: String): Vector = {
wordIndex.get(word) match {
case Some(ind) =>
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid this slicing

Vectors.dense(vec.map(_.toDouble))
case Some(index) =>
val size = vectorSize
val offset = index * size
val array = Array.ofDim[Double](size)
var i = 0
while (i < size) { array(i) = wordVectors(offset + i); i += 1 }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually more efficient than slice? Likewise above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, I will do a simple test.

Vectors.dense(array)
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
Expand Down