Skip to content

Commit

Permalink
[SPARK-24666][ML] Fix infinity vectors produced by Word2Vec when numI…
Browse files Browse the repository at this point in the history
…terations are large

### What changes were proposed in this pull request?

This patch adds normalization to word vectors when fitting dataset in Word2Vec.

### Why are the changes needed?

Running Word2Vec on some datasets, when numIterations is large, can produce infinity word vectors.

### Does this PR introduce any user-facing change?

Yes. After this patch, Word2Vec won't produce infinity word vectors.

### How was this patch tested?

Manually. This issue is not always reproducible on any dataset. The dataset known to reproduce it is too large (925M) to upload.

```scala
case class Sentences(name: String, words: Array[String])
val dataset = spark.read
  .option("header", "true").option("sep", "\t")
  .option("quote", "").option("nullValue", "\\N")
  .csv("/tmp/title.akas.tsv")
  .filter("region = 'US' or language = 'en'")
  .select("title")
  .as[String]
  .map(s => Sentences(s, s.split(' ')))
  .persist()

println("Training model...")
val word2Vec = new Word2Vec()
  .setInputCol("words")
  .setOutputCol("vector")
  .setVectorSize(64)
  .setWindowSize(4)
  .setNumPartitions(50)
  .setMinCount(5)
  .setMaxIter(30)
val model = word2Vec.fit(dataset)
model.getVectors.show()
```

Before:
```
Training model...
+-------------+--------------------+
|         word|              vector|
+-------------+--------------------+
|     Unspoken|[-Infinity,-Infin...|
|       Talent|[-Infinity,Infini...|
|    Hourglass|[2.02805806500023...|
|Nickelodeon's|[-4.2918617120906...|
|      Priests|[-1.3570403355926...|
|    Religion:|[-6.7049072282803...|
|           Bu|[5.05591774315586...|
|      Totoro:|[-1.0539840178632...|
|     Trouble,|[-3.5363592836003...|
|       Hatter|[4.90413981352826...|
|          '79|[7.50436471285412...|
|         Vile|[-2.9147142985312...|
|         9/11|[-Infinity,Infini...|
|      Santino|[1.30005911270850...|
|      Motives|[-1.2538958306253...|
|          '13|[-4.5040152427657...|
|       Fierce|[Infinity,Infinit...|
|       Stover|[-2.6326895394029...|
|          'It|[1.66574533864436...|
|        Butts|[Infinity,Infinit...|
+-------------+--------------------+
only showing top 20 rows
```

After:
```
Training model...
+-------------+--------------------+
|         word|              vector|
+-------------+--------------------+
|     Unspoken|[-0.0454501919448...|
|       Talent|[-0.2657704949378...|
|    Hourglass|[-0.1399687677621...|
|Nickelodeon's|[-0.1767119318246...|
|      Priests|[-0.0047509293071...|
|    Religion:|[-0.0411605164408...|
|           Bu|[0.11837736517190...|
|      Totoro:|[0.05258282646536...|
|     Trouble,|[0.09482011198997...|
|       Hatter|[0.06040831282734...|
|          '79|[0.04783720895648...|
|         Vile|[-0.0017210749210...|
|         9/11|[-0.0713915303349...|
|      Santino|[-0.0412711687386...|
|      Motives|[-0.0492418706417...|
|          '13|[-0.0073119504377...|
|       Fierce|[-0.0565455369651...|
|       Stover|[0.06938160210847...|
|          'It|[0.01117012929171...|
|        Butts|[0.05374567210674...|
+-------------+--------------------+
only showing top 20 rows
```

Closes #26722 from viirya/SPARK-24666-2.

Lead-authored-by: Liang-Chi Hsieh <liangchi@uber.com>
Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <liangchi@uber.com>
  • Loading branch information
viirya and viirya committed Dec 6, 2019
1 parent 7782b61 commit 755d889
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
17 changes: 14 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,20 @@ class Word2Vec extends Serializable with Logging {
}
}.flatten
}
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
// SPARK-24666: do normalization for aggregating weights from partitions.
// Original Word2Vec either single-thread or multi-thread which do Hogwild-style aggregation.
// Our approach needs to do extra normalization, otherwise adding weights continuously may
// cause overflow on float and lead to infinity/-infinity weights.
val synAgg = partial.mapPartitions { iter =>
iter.map { case (id, vec) =>
(id, (vec, 1))
}
}.reduceByKey { case ((v1, count1), (v2, count2)) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
(v1, count1 + count2)
}.map { case (id, (vec, count)) =>
blas.sscal(vectorSize, 1.0f / count, vec, 1)
(id, vec)
}.collect()
var i = 0
while (i < synAgg.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ class Word2VecSuite extends MLTest with DefaultReadWriteTest {
test("getVectors") {
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

val codes = Map(
"a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
"b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
"c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
)
val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }

val docDF = doc.zip(doc).toDF("text", "alsotext")

val model = new Word2Vec()
Expand Down

0 comments on commit 755d889

Please sign in to comment.