Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33609][ML] word2vec reduce broadcast size #30548

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 25 additions & 13 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
Expand Down Expand Up @@ -278,34 +279,45 @@ class Word2VecModel private[ml] (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

private var bcModel: Broadcast[Word2VecModel] = _
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't suppose we have a way to clean this up after use - will just have to get GCed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I followed the impl of CountVectorizer here.
Since other .ml impls do not use a mutable var for a broadcast variable like this, I will remove this var.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As to CountVectorizer, should we also remove the var broadcastDict in it? It looks like that other mllib impls do not use mutable broadcasted variable like that.


/**
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity).toMap // mapValues doesn't return a serializable map (SI-7005)
val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray)
val word2Vec = udf { sentence: Seq[String] =>

if (bcModel == null) {
bcModel = dataset.sparkSession.sparkContext.broadcast(this)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you only use this.wordVectors below? maybe just broadcast that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both wordVectors and wordIndex are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there are two wordVectors...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, right, I think I meant to say that you only use those two. is there any savings from just broadcasting those rather than the whole model? if not that's fine.

}

val size = $(vectorSize)
val emptyVec = Vectors.sparse(size, Array.emptyIntArray, Array.emptyDoubleArray)
val transformer = udf { sentence: Seq[String] =>
if (sentence.isEmpty) {
emptyVec
} else {
val sum = Vectors.zeros(d)
val wordIndices = bcModel.value.wordVectors.wordIndex
val wordVectors = bcModel.value.wordVectors.wordVectors
val array = Array.ofDim[Double](size)
var count = 0L
sentence.foreach { word =>
bVectors.value.get(word).foreach { v =>
BLAS.axpy(1.0, v, sum)
wordIndices.get(word).foreach { index =>
val offset = index * size
var i = 0
while (i < size) { array(i) += wordVectors(offset + i); i += 1 }
}
count += 1
}
BLAS.scal(1.0 / sentence.size, sum)
sum
val vec = Vectors.dense(array)
BLAS.scal(1.0 / count, vec)
vec
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))),

dataset.withColumn($(outputCol), transformer(col($(inputCol))),
outputSchema($(outputCol)).metadata)
}

Expand Down
26 changes: 15 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,19 +502,19 @@ class Word2VecModel private[spark] (
private val vectorSize = wordVectors.length / numWords

// wordList: Ordered list of words obtained from wordIndex.
private val wordList: Array[String] = {
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
wl.toArray
private lazy val wordList: Array[String] = {
wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray
}

// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
private val wordVecNorms: Array[Float] = {
val wordVecNorms = new Array[Float](numWords)
private lazy val wordVecNorms: Array[Float] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does this save, if it only happens once and has to happen to use the model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this var wordVecNorms is only used in method findSynonyms in the .mllib.w2v; however, this findSynonyms is never used in the .ml side. So I think we can make it lazy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK fair enough. There are use cases here that would never need this calculated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

however, this findSynonyms is never used in the .ml side. So I think we can make it lazy.

I am wrong. this var wordVecNorms is used in methods findSynonyms and findSynonymsArray in the .ml side. Since it is not used in transform, so we can still mark it lazy

val num = numWords
val size = vectorSize
val wordVecNorms = new Array[Float](num)
var i = 0
while (i < numWords) {
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid this slicing

wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
while (i < num) {
wordVecNorms(i) = blas.snrm2(size, wordVectors, i * size, 1)
i += 1
}
wordVecNorms
Expand All @@ -538,9 +538,13 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def transform(word: String): Vector = {
wordIndex.get(word) match {
case Some(ind) =>
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid this slicing

Vectors.dense(vec.map(_.toDouble))
case Some(index) =>
val size = vectorSize
val offset = index * size
val array = Array.ofDim[Double](size)
var i = 0
while (i < size) { array(i) = wordVectors(offset + i); i += 1 }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually more efficient than slice? Likewise above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, I will do a simple test.

Vectors.dense(array)
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
Expand Down