Skip to content

Commit

Permalink
Minor
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Apr 21, 2015
1 parent 6b74c81 commit ffc9240
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,19 +431,28 @@ class Word2Vec extends Serializable with Logging {
class Word2VecModel private[mllib] (
model: Map[String, Array[Float]]) extends Serializable with Saveable {

// Maintain a ordered list of words based on the index in the initial model.
// wordList: Ordered list of words obtained from model.
// wordIndex: Maps each word to an index, which can retrieve the corresponding
// vector from wordVectors (see below)
// vectorSize: Dimension of each vector.
// numWords: Number of words.
private val wordList: Array[String] = model.keys.toArray
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
private val numDim = model.head._2.size
private val vectorSize = model.head._2.size
private val numWords = wordIndex.size

// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
// mapped with index i can be retrieved by the slice
// (ind * vectorSize, ind * vectorSize + vectorSize)
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray
val wordVectors = wordList.flatMap(word => model.get(word).get).toArray
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
val vec = model.get(wordList(i)).get
wordVecNorms(i) = blas.snrm2(numDim, vec, 1)
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
i += 1
}
(wordVectors, wordVecNorms)
Expand Down Expand Up @@ -501,10 +510,10 @@ class Word2VecModel private[mllib] (
val fVector = vector.toArray.map(_.toFloat)
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
val beta: Float = 1
val beta: Float = 0

blas.sgemv(
"T", numDim, numWords, alpha, wordVectors, numDim, fVector, 1, beta, cosineVec, 1)
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)

// Need not divide with the norm of the given vector since it is constant.
val updatedCosines = new Array[Double](numWords)
Expand All @@ -526,7 +535,7 @@ class Word2VecModel private[mllib] (
*/
def getVectors: Map[String, Array[Float]] = {
wordIndex.map { case (word, ind) =>
(word, wordVectors.slice(numDim * ind, numDim * ind + numDim))
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
}
}
}
Expand Down

0 comments on commit ffc9240

Please sign in to comment.