diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index c3623a83fb961..ebfd1299431be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -431,19 +431,28 @@ class Word2Vec extends Serializable with Logging { class Word2VecModel private[mllib] ( model: Map[String, Array[Float]]) extends Serializable with Saveable { - // Maintain a ordered list of words based on the index in the initial model. + // wordList: Ordered list of words obtained from model. + // wordIndex: Maps each word to an index, which can retrieve the corresponding + // vector from wordVectors (see below) + // vectorSize: Dimension of each vector. + // numWords: Number of words. private val wordList: Array[String] = model.keys.toArray private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap - private val numDim = model.head._2.size + private val vectorSize = model.head._2.size private val numWords = wordIndex.size + // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word + // mapped with index i can be retrieved by the slice + // (ind * vectorSize, ind * vectorSize + vectorSize) + // wordVecNorms: Array of length numWords, each value being the Euclidean norm + // of the wordVector. private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { - val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray + val wordVectors = wordList.flatMap(word => model.get(word).get).toArray val wordVecNorms = new Array[Double](numWords) var i = 0 while (i < numWords) { val vec = model.get(wordList(i)).get - wordVecNorms(i) = blas.snrm2(numDim, vec, 1) + wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) i += 1 } (wordVectors, wordVecNorms) @@ -501,10 +510,10 @@ class Word2VecModel private[mllib] ( val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 - val beta: Float = 1 + val beta: Float = 0 blas.sgemv( - "T", numDim, numWords, alpha, wordVectors, numDim, fVector, 1, beta, cosineVec, 1) + "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) // Need not divide with the norm of the given vector since it is constant. val updatedCosines = new Array[Double](numWords) @@ -526,7 +535,7 @@ class Word2VecModel private[mllib] ( */ def getVectors: Map[String, Array[Float]] = { wordIndex.map { case (word, ind) => - (word, wordVectors.slice(numDim * ind, numDim * ind + numDim)) + (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) } } }