From cee80c04bc709d9e4730ab1cad69b9b075738a75 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 21 Nov 2015 11:34:52 +0800 Subject: [PATCH 1/3] broadcast global table --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a47f27b0afb14..f353166b22732 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -315,13 +315,15 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount From b1c65a91b6925543b2dc63885c0368e806f4df9d Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 24 Nov 2015 18:25:24 +0800 Subject: [PATCH 2/3] create broadcast in iterations for W2V --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f353166b22732..f9a2bc8e6c8ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -315,10 +315,11 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - val bcSyn0Global = sc.broadcast(syn0Global) - val bcSyn1Global = sc.broadcast(syn1Global) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) From cf4b9e798b00ff438c886559ebc162989a68e2bd Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 26 Nov 2015 16:16:31 +0800 Subject: [PATCH 3/3] add broadcast unpersist --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f9a2bc8e6c8ca..655ac0bb5545b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -408,6 +408,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist()