diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 9b5f5a619e02c..0b9c1b570d943 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -285,27 +285,33 @@ class Word2VecModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) - val vectors = wordVectors.getVectors - .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) - .map(identity).toMap // mapValues doesn't return a serializable map (SI-7005) - val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) - val d = $(vectorSize) - val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray) - val word2Vec = udf { sentence: Seq[String] => + + val bcModel = dataset.sparkSession.sparkContext.broadcast(this.wordVectors) + val size = $(vectorSize) + val emptyVec = Vectors.sparse(size, Array.emptyIntArray, Array.emptyDoubleArray) + val transformer = udf { sentence: Seq[String] => if (sentence.isEmpty) { emptyVec } else { - val sum = Vectors.zeros(d) + val wordIndices = bcModel.value.wordIndex + val wordVectors = bcModel.value.wordVectors + val array = Array.ofDim[Double](size) + var count = 0 sentence.foreach { word => - bVectors.value.get(word).foreach { v => - BLAS.axpy(1.0, v, sum) + wordIndices.get(word).foreach { index => + val offset = index * size + var i = 0 + while (i < size) { array(i) += wordVectors(offset + i); i += 1 } } + count += 1 } - BLAS.scal(1.0 / sentence.size, sum) - sum + val vec = Vectors.dense(array) + BLAS.scal(1.0 / count, vec) + vec } } - dataset.withColumn($(outputCol), word2Vec(col($(inputCol))), + + dataset.withColumn($(outputCol), transformer(col($(inputCol))), outputSchema($(outputCol)).metadata) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index eeb583f84ca8b..8a6317a910146 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -502,22 +502,15 @@ class Word2VecModel private[spark] ( private val vectorSize = wordVectors.length / numWords // wordList: Ordered list of words obtained from wordIndex. - private val wordList: Array[String] = { - val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip - wl.toArray + private lazy val wordList: Array[String] = { + wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray } // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val wordVecNorms: Array[Float] = { - val wordVecNorms = new Array[Float](numWords) - var i = 0 - while (i < numWords) { - val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) - wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) - i += 1 - } - wordVecNorms + private lazy val wordVecNorms: Array[Float] = { + val size = vectorSize + Array.tabulate(numWords)(i => blas.snrm2(size, wordVectors, i * size, 1)) } @Since("1.5.0") @@ -538,9 +531,13 @@ class Word2VecModel private[spark] ( @Since("1.1.0") def transform(word: String): Vector = { wordIndex.get(word) match { - case Some(ind) => - val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) - Vectors.dense(vec.map(_.toDouble)) + case Some(index) => + val size = vectorSize + val offset = index * size + val array = Array.ofDim[Double](size) + var i = 0 + while (i < size) { array(i) = wordVectors(offset + i); i += 1 } + Vectors.dense(array) case None => throw new IllegalStateException(s"$word not in vocabulary") }