From 739f5bd40d5cde7e69aeee2c9da4f5d287ac47fa Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 16 Jul 2015 21:03:05 -0700 Subject: [PATCH 1/2] do not assert on time taken by Thread.sleep() --- .../apache/spark/ml/util/stopwatches.scala | 4 +- .../apache/spark/ml/util/StopwatchSuite.scala | 73 ++++++++++++++++--- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 5fdf878a3df72..8d4174124b5c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -67,6 +67,8 @@ private[spark] abstract class Stopwatch extends Serializable { */ def elapsed(): Long + override def toString: String = s"$name: ${elapsed()}ms" + /** * Gets the current time in milliseconds. */ @@ -145,7 +147,7 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext override def toString: String = { stopwatches.values.toArray.sortBy(_.name) - .map(c => s" ${c.name}: ${c.elapsed()}ms") + .map(c => s" $c") .mkString("{\n", ",\n", "\n}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 8df6617fe0228..fcf3907da0a20 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.util +import java.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopwatchSuite._ + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { assert(sw.name === "sw") assert(sw.elapsed() === 0L) @@ -29,18 +33,27 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { sw.stop() } + val ubStart = now sw.start() - Thread.sleep(50) + val lbStart = now + runTask() + val lb = now - lbStart val duration = sw.stop() - assert(duration >= 50 && duration < 100) // using a loose upper bound + val ub = now - ubStart + assert(duration >= lb && duration <= ub) val elapsed = sw.elapsed() assert(elapsed === duration) + val ubStart2 = now sw.start() - Thread.sleep(50) + val lbStart2 = now + runTask() + val lb2 = now - lbStart2 val duration2 = sw.stop() - assert(duration2 >= 50 && duration2 < 100) + val ub2 = now - ubStart2 + assert(duration2 >= lb2 && duration2 <= ub2) val elapsed2 = sw.elapsed() assert(elapsed2 === duration + duration2) + assert(sw.toString === s"sw: ${elapsed2}ms") sw.start() assert(sw.isRunning) intercept[AssertionError] { @@ -61,14 +74,22 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) + val ubAcc = sc.accumulator(0L) + val lbAcc = sc.accumulator(0L) rdd.foreach { i => + val ubStart = now sw.start() - Thread.sleep(50) + val lbStart = now + runTask() + val lb = now - lbStart sw.stop() + val ub = now - ubStart + lbAcc += lb + ubAcc += ub } assert(!sw.isRunning) val elapsed = sw.elapsed() - assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + assert(elapsed >= lbAcc.value && elapsed <= ubAcc.value) } test("MultiStopwatch") { @@ -81,29 +102,57 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { sw("some") } assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + val localUbStart = now sw("local").start() + val localLbStart = now + val sparkUbStart = now sw("spark").start() - Thread.sleep(50) + val sparkLbStart = now + runTask() + val localLb = now - localLbStart sw("local").stop() - Thread.sleep(50) + val localUb = now - localUbStart + runTask() + val sparkLb = now - sparkLbStart sw("spark").stop() + val sparkUb = now - sparkUbStart val localElapsed = sw("local").elapsed() val sparkElapsed = sw("spark").elapsed() - assert(localElapsed >= 50 && localElapsed < 100) - assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(localElapsed >= localLb && localElapsed <= localUb) + assert(sparkElapsed >= sparkLb && sparkElapsed <= sparkUb) assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) + val lbAcc = sc.accumulator(0L) + val ubAcc = sc.accumulator(0L) rdd.foreach { i => sw("local").start() + val ubStart = now sw("spark").start() - Thread.sleep(50) + val lbStart = now + runTask() + val lb = now - lbStart sw("spark").stop() + val ub = now - ubStart sw("local").stop() + lbAcc += lb + ubAcc += ub } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) val sparkElapsed2 = sw("spark").elapsed() - assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + assert(sparkElapsed2 >= sparkElapsed + lbAcc.value + && sparkElapsed2 <= sparkElapsed + ubAcc.value) } } + +private object StopwatchSuite { + + /** Runs a task that takes a random time. */ + def runTask(): Unit = { + Thread.sleep(new Random().nextInt(10)) + } + + /** The current time in milliseconds. */ + def now: Long = System.currentTimeMillis() +} From 4b40faa0903497889b791939cb434c9de2cc3d2a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 16 Jul 2015 22:16:57 -0700 Subject: [PATCH 2/2] simplify tests --- .../apache/spark/ml/util/StopwatchSuite.scala | 89 ++++++------------- 1 file changed, 28 insertions(+), 61 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index fcf3907da0a20..9e6bc7193c13b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -33,24 +33,10 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { sw.stop() } - val ubStart = now - sw.start() - val lbStart = now - runTask() - val lb = now - lbStart - val duration = sw.stop() - val ub = now - ubStart - assert(duration >= lb && duration <= ub) + val duration = checkStopwatch(sw) val elapsed = sw.elapsed() assert(elapsed === duration) - val ubStart2 = now - sw.start() - val lbStart2 = now - runTask() - val lb2 = now - lbStart2 - val duration2 = sw.stop() - val ub2 = now - ubStart2 - assert(duration2 >= lb2 && duration2 <= ub2) + val duration2 = checkStopwatch(sw) val elapsed2 = sw.elapsed() assert(elapsed2 === duration + duration2) assert(sw.toString === s"sw: ${elapsed2}ms") @@ -74,22 +60,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) - val ubAcc = sc.accumulator(0L) - val lbAcc = sc.accumulator(0L) + val acc = sc.accumulator(0L) rdd.foreach { i => - val ubStart = now - sw.start() - val lbStart = now - runTask() - val lb = now - lbStart - sw.stop() - val ub = now - ubStart - lbAcc += lb - ubAcc += ub + acc += checkStopwatch(sw) } assert(!sw.isRunning) val elapsed = sw.elapsed() - assert(elapsed >= lbAcc.value && elapsed <= ubAcc.value) + assert(elapsed === acc.value) } test("MultiStopwatch") { @@ -102,57 +79,47 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { sw("some") } assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") - val localUbStart = now - sw("local").start() - val localLbStart = now - val sparkUbStart = now - sw("spark").start() - val sparkLbStart = now - runTask() - val localLb = now - localLbStart - sw("local").stop() - val localUb = now - localUbStart - runTask() - val sparkLb = now - sparkLbStart - sw("spark").stop() - val sparkUb = now - sparkUbStart + val localDuration = checkStopwatch(sw("local")) + val sparkDuration = checkStopwatch(sw("spark")) val localElapsed = sw("local").elapsed() val sparkElapsed = sw("spark").elapsed() - assert(localElapsed >= localLb && localElapsed <= localUb) - assert(sparkElapsed >= sparkLb && sparkElapsed <= sparkUb) + assert(localElapsed === localDuration) + assert(sparkElapsed === sparkDuration) assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) - val lbAcc = sc.accumulator(0L) - val ubAcc = sc.accumulator(0L) + val acc = sc.accumulator(0L) rdd.foreach { i => sw("local").start() - val ubStart = now - sw("spark").start() - val lbStart = now - runTask() - val lb = now - lbStart - sw("spark").stop() - val ub = now - ubStart + val duration = checkStopwatch(sw("spark")) sw("local").stop() - lbAcc += lb - ubAcc += ub + acc += duration } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) val sparkElapsed2 = sw("spark").elapsed() - assert(sparkElapsed2 >= sparkElapsed + lbAcc.value - && sparkElapsed2 <= sparkElapsed + ubAcc.value) + assert(sparkElapsed2 === sparkElapsed + acc.value) } } -private object StopwatchSuite { +private object StopwatchSuite extends SparkFunSuite { - /** Runs a task that takes a random time. */ - def runTask(): Unit = { + /** + * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and + * returns the duration reported by the stopwatch. + */ + def checkStopwatch(sw: Stopwatch): Long = { + val ubStart = now + sw.start() + val lbStart = now Thread.sleep(new Random().nextInt(10)) + val lb = now - lbStart + val duration = sw.stop() + val ub = now - ubStart + assert(duration >= lb && duration <= ub) + duration } /** The current time in milliseconds. */ - def now: Long = System.currentTimeMillis() + private def now: Long = System.currentTimeMillis() }