Skip to content

Commit

Permalink
[SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cau…
Browse files Browse the repository at this point in the history
…se by test dataset not deterministic)

## What changes were proposed in this pull request?

Fix NaiveBayes unit test occasionly fail:
Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset.
(If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure)

## How was this patch tested?

Manually run tests multiple times and check each time output models contains the same values.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19558 from WeichenXu123/fix_nb_test_seed.
  • Loading branch information
WeichenXu123 authored and jkbradley committed Oct 25, 2017
1 parent b377ef1 commit 841f1d7
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import scala.util.Random

import breeze.linalg.{DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
Expand Down Expand Up @@ -335,6 +335,7 @@ object NaiveBayesSuite {
val _pi = pi.map(math.exp)
val _theta = theta.map(row => row.map(math.exp))

implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed)
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
val xi = modelType match {
Expand Down

0 comments on commit 841f1d7

Please sign in to comment.