Skip to content

Commit

Permalink
Switch back to native blas calls
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Apr 21, 2015
1 parent da1642d commit 6b74c81
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,11 @@ class Word2VecModel private[mllib] (
// Maintain a ordered list of words based on the index in the initial model.
private val wordList: Array[String] = model.keys.toArray
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
private val numDim = model.head._2.size
private val numWords = wordIndex.size

private val (wordVectors: DenseMatrix, wordVecNorms: Array[Double]) = {
val numDim = model.head._2.size
val numWords = wordIndex.size
val flatVec = model.toSeq.flatMap { case(w, v) =>
v.map(_.toDouble)}.toArray
val wordVectors = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
Expand Down Expand Up @@ -500,9 +498,13 @@ class Word2VecModel private[mllib] (
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")

val numWords = wordVectors.numRows
val cosineVec = Vectors.zeros(numWords).asInstanceOf[DenseVector]
BLAS.gemv(1.0, wordVectors, new DenseVector(vector.toArray), 0.0, cosineVec)
val fVector = vector.toArray.map(_.toFloat)
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
val beta: Float = 1

blas.sgemv(
"T", numDim, numWords, alpha, wordVectors, numDim, fVector, 1, beta, cosineVec, 1)

// Need not divide with the norm of the given vector since it is constant.
val updatedCosines = new Array[Double](numWords)
Expand All @@ -523,11 +525,9 @@ class Word2VecModel private[mllib] (
* Returns a map of words to their vector representations.
*/
def getVectors: Map[String, Array[Float]] = {
val numDim = wordVectors.numCols
wordIndex.map { case (word, ind) =>
val startInd = numDim * ind
val endInd = startInd + numDim
(word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
(word, wordVectors.slice(numDim * ind, numDim * ind + numDim))
}
}
}

Expand Down

0 comments on commit 6b74c81

Please sign in to comment.