Skip to content

Commit

Permalink
[SPARK-3578] Fix upper bound in GraphGenerators.sampleLogNormal
Browse files Browse the repository at this point in the history
GraphGenerators.sampleLogNormal is supposed to return an integer strictly less than maxVal. However, it violates this guarantee. It generates its return value as follows:

```scala
var X: Double = maxVal

while (X >= maxVal) {
  val Z = rand.nextGaussian()
  X = math.exp(mu + sigma*Z)
}
math.round(X.toFloat)
```

When X is sampled to be close to (but less than) maxVal, then it will pass the while loop condition, but the rounded result will be equal to maxVal, which will violate the guarantee. For example, if maxVal is 5 and X is 4.9, then X < maxVal, but `math.round(X.toFloat)` is 5.

This PR instead rounds X before checking the loop condition, guaranteeing that the condition will hold for the return value.

Author: Ankur Dave <ankurdave@gmail.com>

Closes apache#2439 from ankurdave/SPARK-3578 and squashes the following commits:

f6655e5 [Ankur Dave] Go back to math.floor
5900c22 [Ankur Dave] Round X in loop condition
6fd5fb1 [Ankur Dave] Run sampleLogNormal bounds check 1000 times
1638598 [Ankur Dave] Round down in sampleLogNormal to guarantee upper bound
  • Loading branch information
ankurdave authored and jegonzal committed Sep 22, 2014
1 parent 56dae30 commit f9d6220
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ object GraphGenerators {
val Z = rand.nextGaussian()
X = math.exp(mu + sigma*Z)
}
math.round(X.toFloat)
math.floor(X).toInt
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
val sigma = 1.3
val maxVal = 100

val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal)
assert(dstId < maxVal)
val trials = 1000
for (i <- 1 to trials) {
val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal)
assert(dstId < maxVal)
}

val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345)
val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345)
Expand Down

0 comments on commit f9d6220

Please sign in to comment.