-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33609][ML] word2vec reduce broadcast size #30548
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature | |
import org.apache.hadoop.fs.Path | ||
|
||
import org.apache.spark.annotation.Since | ||
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} | ||
|
@@ -278,34 +279,45 @@ class Word2VecModel private[ml] ( | |
@Since("1.4.0") | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
private var bcModel: Broadcast[Word2VecModel] = _ | ||
|
||
/** | ||
* Transform a sentence column to a vector column to represent the whole sentence. The transform | ||
* is performed by averaging all word vectors it contains. | ||
*/ | ||
@Since("2.0.0") | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
val outputSchema = transformSchema(dataset.schema, logging = true) | ||
val vectors = wordVectors.getVectors | ||
.mapValues(vv => Vectors.dense(vv.map(_.toDouble))) | ||
.map(identity).toMap // mapValues doesn't return a serializable map (SI-7005) | ||
val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) | ||
val d = $(vectorSize) | ||
val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray) | ||
val word2Vec = udf { sentence: Seq[String] => | ||
|
||
if (bcModel == null) { | ||
bcModel = dataset.sparkSession.sparkContext.broadcast(this) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like you only use this.wordVectors below? maybe just broadcast that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, there are two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, right, I think I meant to say that you only use those two. is there any savings from just broadcasting those rather than the whole model? if not that's fine. |
||
} | ||
|
||
val size = $(vectorSize) | ||
val emptyVec = Vectors.sparse(size, Array.emptyIntArray, Array.emptyDoubleArray) | ||
val transformer = udf { sentence: Seq[String] => | ||
if (sentence.isEmpty) { | ||
emptyVec | ||
} else { | ||
val sum = Vectors.zeros(d) | ||
val wordIndices = bcModel.value.wordVectors.wordIndex | ||
val wordVectors = bcModel.value.wordVectors.wordVectors | ||
val array = Array.ofDim[Double](size) | ||
var count = 0L | ||
sentence.foreach { word => | ||
bVectors.value.get(word).foreach { v => | ||
BLAS.axpy(1.0, v, sum) | ||
wordIndices.get(word).foreach { index => | ||
val offset = index * size | ||
var i = 0 | ||
while (i < size) { array(i) += wordVectors(offset + i); i += 1 } | ||
} | ||
count += 1 | ||
} | ||
BLAS.scal(1.0 / sentence.size, sum) | ||
sum | ||
val vec = Vectors.dense(array) | ||
BLAS.scal(1.0 / count, vec) | ||
vec | ||
} | ||
} | ||
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))), | ||
|
||
dataset.withColumn($(outputCol), transformer(col($(inputCol))), | ||
outputSchema($(outputCol)).metadata) | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -502,19 +502,19 @@ class Word2VecModel private[spark] ( | |
private val vectorSize = wordVectors.length / numWords | ||
|
||
// wordList: Ordered list of words obtained from wordIndex. | ||
private val wordList: Array[String] = { | ||
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip | ||
wl.toArray | ||
private lazy val wordList: Array[String] = { | ||
wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray | ||
} | ||
|
||
// wordVecNorms: Array of length numWords, each value being the Euclidean norm | ||
// of the wordVector. | ||
private val wordVecNorms: Array[Float] = { | ||
val wordVecNorms = new Array[Float](numWords) | ||
private lazy val wordVecNorms: Array[Float] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How much does this save, if it only happens once and has to happen to use the model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this var There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK fair enough. There are use cases here that would never need this calculated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am wrong. this var |
||
val num = numWords | ||
val size = vectorSize | ||
val wordVecNorms = new Array[Float](num) | ||
var i = 0 | ||
while (i < numWords) { | ||
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid this slicing |
||
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) | ||
while (i < num) { | ||
wordVecNorms(i) = blas.snrm2(size, wordVectors, i * size, 1) | ||
i += 1 | ||
} | ||
wordVecNorms | ||
|
@@ -538,9 +538,13 @@ class Word2VecModel private[spark] ( | |
@Since("1.1.0") | ||
def transform(word: String): Vector = { | ||
wordIndex.get(word) match { | ||
case Some(ind) => | ||
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid this slicing |
||
Vectors.dense(vec.map(_.toDouble)) | ||
case Some(index) => | ||
val size = vectorSize | ||
val offset = index * size | ||
val array = Array.ofDim[Double](size) | ||
var i = 0 | ||
while (i < size) { array(i) = wordVectors(offset + i); i += 1 } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this actually more efficient than slice? Likewise above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess so, I will do a simple test. |
||
Vectors.dense(array) | ||
case None => | ||
throw new IllegalStateException(s"$word not in vocabulary") | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't suppose we have a way to clean this up after use - will just have to get GCed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. I followed the impl of CountVectorizer here.
Since other .ml impls do not use a mutable
var
for a broadcast variable like this, I will remove this var.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As to
CountVectorizer
, should we also remove thevar broadcastDict
in it? It looks like that other mllib impls do not use mutable broadcasted variable like that.