From 2e7f1a3bd519d79ce9b08d388247e9a1d7f67635 Mon Sep 17 00:00:00 2001 From: Benjamin Radford Date: Thu, 9 Mar 2017 23:42:33 -0500 Subject: [PATCH 1/4] Added findAnalogies method to Word2VecModel --- .../apache/spark/mllib/feature/Word2Vec.scala | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 531c8b07910fc..b9d6fa845467c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -609,6 +609,71 @@ class Word2VecModel private[spark] ( filtered.take(num).toArray } + /** + * Find words similar to the words supplied to 'positive' and dissimilar + * to the words supplied to 'negative'. + * @param positive array of words similar to the results list + * @param negative array of words dissimilar to the results list + * @param num number of synonyms to find + * @return array of (word, cosineSimilarity) + */ + def findAnalogies(positive: Array[String] = Array(), + negative: Array[String] = Array(), + num: Int = 1): Array[(String, Double)] = { + require(num > 0, "Number of similar words should be > 0") + require(positive.length > 0 || negative.length > 0, + "Either positive or negative argument must be supplied") + + var positiveVectors = Array[Array[Double]]() + var negativeVectors = Array[Array[Double]]() + + for(pp <- positive) + positiveVectors :+= transform(pp).toArray + for(nn <- negative) + negativeVectors :+= transform(nn).toArray + // Normalize positive and negative vectors before summation + positiveVectors = if (positiveVectors.size > 0) { + positiveVectors.map(x => { + val sumsqr = x.map(y => y * y).reduce((a, b) => a + b) + x.map(y => y / math.pow(sumsqr, .5)) + }) + } else { + Array(Array.fill(vectorSize)(0.0)) + } + negativeVectors = if (negativeVectors.size > 0) { + negativeVectors.map(x => { + val sumsqr = x.map(y => y * y).reduce((a, b) => a + b) + x.map(y => y / math.pow(sumsqr, .5)) + }) + } else { + Array(Array.fill(vectorSize)(0.0)) + } + // Sum positive vectors + val positiveSum = if (positiveVectors.size > 1) { + positiveVectors.reduce((x, y) => { + x.zip(y).map(a => a._1 + a._2) + }) + } else { + positiveVectors(0) + } + // Sum negative vectors + val negativeSum = if (negativeVectors.size > 1) { + negativeVectors.reduce((x, y) => { + x.zip(y).map(a => a._1 + a._2) + }) + } else { + negativeVectors(0) + } + + // Subtract negative vectors from positive vectors + var cosVec = positiveSum.zip(negativeSum).map(a => a._1 - a._2) + val vecnorm = math.pow(cosVec.map(y => y * y).reduce((a, b) => a + b), 0.5) + cosVec = cosVec.map(x => x / vecnorm) + + // Find synonyms of calcualted vector + findSynonyms(Vectors.dense(cosVec), num) + } + /** * Returns a map of words to their vector representations. */ From 9aefebfcd2e6eaad117727901ad70d0d26b03a1a Mon Sep 17 00:00:00 2001 From: Benjamin Radford Date: Fri, 10 Mar 2017 00:16:46 -0500 Subject: [PATCH 2/4] Fixed comment indentation to conform to style guide. --- R/pkg/DESCRIPTION | 2 +- .../org/apache/spark/mllib/feature/Word2Vec.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cc471edc376b3..062187bc12673 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -54,5 +54,5 @@ Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 5.0.1 +RoxygenNote: 6.0.1 VignetteBuilder: knitr diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b9d6fa845467c..15939360a7825 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -610,13 +610,13 @@ class Word2VecModel private[spark] ( } /** - * Find words similar to the words supplied to 'positive' and dissimilar - * to the words supplied to 'negative'. - * @param positive array of words similar to the results list - * @param negative array of words dissimilar to the results list - * @param num number of synonyms to find - * @return array of (word, cosineSimilarity) - */ + * Find words similar to the words supplied to 'positive' and dissimilar + * to the words supplied to 'negative'. + * @param positive array of words similar to the results list + * @param negative array of words dissimilar to the results list + * @param num number of synonyms to find + * @return array of (word, cosineSimilarity) + */ def findAnalogies(positive: Array[String] = Array(), negative: Array[String] = Array(), num: Int = 1): Array[(String, Double)] = { From a309887ceb2f20e243ffb8bd72925c3b92b31324 Mon Sep 17 00:00:00 2001 From: Benjamin Radford Date: Fri, 10 Mar 2017 11:27:33 -0500 Subject: [PATCH 3/4] Streamlined findAnalogies method. Fixed indentation. Minimized code re-use. --- .../apache/spark/mllib/feature/Word2Vec.scala | 60 +++++-------------- 1 file changed, 16 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 15939360a7825..dd6ce3c6a6c92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -617,58 +617,30 @@ class Word2VecModel private[spark] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ - def findAnalogies(positive: Array[String] = Array(), - negative: Array[String] = Array(), - num: Int = 1): Array[(String, Double)] = { + def findAnalogies( + positive: Array[String] = Array(), + negative: Array[String] = Array(), + num: Int = 1): Array[(String, Double)] = { require(num > 0, "Number of similar words should be > 0") require(positive.length > 0 || negative.length > 0, "Either positive or negative argument must be supplied") - var positiveVectors = Array[Array[Double]]() - var negativeVectors = Array[Array[Double]]() + var vectors = Array[Array[Double]]() for(pp <- positive) - positiveVectors :+= transform(pp).toArray + vectors :+= transform(pp).toArray for(nn <- negative) - negativeVectors :+= transform(nn).toArray + vectors :+= transform(nn).toArray.map(x => -x) // Normalize positive and negative vectors before summation - positiveVectors = if (positiveVectors.size > 0) { - positiveVectors.map(x => { - val sumsqr = x.map(y => y * y).reduce((a, b) => a + b) - x.map(y => y / math.pow(sumsqr, .5)) - }) - } else { - Array(Array.fill(vectorSize)(0.0)) - } - negativeVectors = if (negativeVectors.size > 0) { - negativeVectors.map(x => { - val sumsqr = x.map(y => y * y).reduce((a, b) => a + b) - x.map(y => y / math.pow(sumsqr, .5)) - }) - } else { - Array(Array.fill(vectorSize)(0.0)) - } - // Sum positive vectors - val positiveSum = if (positiveVectors.size > 1) { - positiveVectors.reduce((x, y) => { - x.zip(y).map(a => a._1 + a._2) - }) - } else { - positiveVectors(0) - } - // Sum negative vectors - val negativeSum = if (negativeVectors.size > 1) { - negativeVectors.reduce((x, y) => { - x.zip(y).map(a => a._1 + a._2) - }) - } else { - negativeVectors(0) - } - - // Subtract negative vectors from positive vectors - var cosVec = positiveSum.zip(negativeSum).map(a => a._1 - a._2) - val vecnorm = math.pow(cosVec.map(y => y * y).reduce((a, b) => a + b), 0.5) - cosVec = cosVec.map(x => x / vecnorm) + vectors = vectors.map(x => { + val norm = blas.snrm2(vectorSize, x.map(_.toFloat), 1) + x.map(y => y / norm) + }) + val vectorSum = vectors.reduce((x, y) => { + x.zip(y).map(a => a._1 + a._2) + }) + val norm = blas.snrm2(vectorSize, vectorSum.map(_.toFloat), 1) + val cosVec = vectorSum.map(x => x / norm) // Find synonyms of calcualted vector findSynonyms(Vectors.dense(cosVec), num) From 2ea62591a38f46a1d06fb796f6a50a7b827b7682 Mon Sep 17 00:00:00 2001 From: benradford Date: Sat, 11 Mar 2017 00:12:21 -0500 Subject: [PATCH 4/4] Fixed Roxygen version 6.0.1 -> 5.0.1 --- R/pkg/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 062187bc12673..cc471edc376b3 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -54,5 +54,5 @@ Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 6.0.1 +RoxygenNote: 5.0.1 VignetteBuilder: knitr