From 1a66c0b1879fffb9079f190ae6234c10f3a83e69 Mon Sep 17 00:00:00 2001 From: Yun Ni Date: Thu, 16 Feb 2017 12:54:22 -0800 Subject: [PATCH 01/61] Scala API Change for AND-amplification --- .../feature/BucketedRandomProjectionLSH.scala | 8 +++++--- .../org/apache/spark/ml/feature/LSH.scala | 19 ++++++++++++++++++- .../apache/spark/ml/feature/MinHashLSH.scala | 8 +++++--- .../BucketedRandomProjectionLSHSuite.scala | 9 ++++++++- .../spark/ml/feature/MinHashLSHSuite.scala | 15 +++++++++------ 5 files changed, 45 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index cbac16345a292..a0141bef935b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -79,8 +79,7 @@ class BucketedRandomProjectionLSHModel private[ml]( val hashValues: Array[Double] = randUnitVectors.map({ randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) }) - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) + hashValues.grouped($(numHashFunctions)).map(Vectors.dense).toArray } } @@ -137,6 +136,9 @@ class BucketedRandomProjectionLSH(override val uid: String) @Since("2.1.0") override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value) + @Since("2.2.0") + override def setNumHashFunctions(value: Int): this.type = super.setNumHashFunctions(value) + @Since("2.1.0") def this() = { this(Identifiable.randomUID("brp-lsh")) @@ -155,7 +157,7 @@ class BucketedRandomProjectionLSH(override val uid: String) inputDim: Int): BucketedRandomProjectionLSHModel = { val rand = new Random($(seed)) val randUnitVectors: Array[Vector] = { - Array.fill($(numHashTables)) { + Array.fill($(numHashTables) * $(numHashFunctions)) { val randArray = Array.fill(inputDim)(rand.nextGaussian()) Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201d..3a632b3a5baca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -43,10 +43,24 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { "tables, where increasing number of hash tables lowers the false negative rate, and " + "decreasing it improves the running performance", ParamValidators.gt(0)) + /** + * Param for the number of hash functions used in LSH AND-amplification. + * + * LSH AND-amplification can be used to reduce the false positive rate. Higher values for this + * param lead to a reduced false positive rate and lower computational complexity. + * @group param + */ + final val numHashFunctions: IntParam = new IntParam(this, "numHashFunctions", "number of hash " + + "functions, where increasing number of hash functions lowers the false positive rate, and " + + "decreasing it improves the false negative rate", ParamValidators.gt(0)) + /** @group getParam */ final def getNumHashTables: Int = $(numHashTables) - setDefault(numHashTables -> 1) + /** @group getParam */ + final def getNumHashFunctions: Int = $(numHashFunctions) + + setDefault(numHashTables -> 1, numHashFunctions -> 1) /** * Transform the Schema for LSH @@ -308,6 +322,9 @@ private[ml] abstract class LSH[T <: LSHModel[T]] /** @group setParam */ def setNumHashTables(value: Int): this.type = set(numHashTables, value) + /** @group setParam */ + def setNumHashFunctions(value: Int): this.type = set(numHashFunctions, value) + /** * Validate and create a new instance of concrete LSHModel. Because different LSHModel may have * different initial setting, developer needs to define how their LSHModel is created instead of diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 620e1fbb09ff7..4c8ab330d13ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -61,8 +61,7 @@ class MinHashLSHModel private[ml]( ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME }.min.toDouble } - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) + hashValues.grouped($(numHashFunctions)).map(Vectors.dense).toArray } } @@ -119,6 +118,9 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with Has @Since("2.1.0") override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value) + @Since("2.2.0") + override def setNumHashFunctions(value: Int): this.type = super.setNumHashFunctions(value) + @Since("2.1.0") def this() = { this(Identifiable.randomUID("mh-lsh")) @@ -133,7 +135,7 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with Has require(inputDim <= MinHashLSH.HASH_PRIME, s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.") val rand = new Random($(seed)) - val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables)) { + val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables) * $(numHashFunctions)) { (1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1)) } new MinHashLSHModel(uid, randCoefs) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a555c..9c5929fce156e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -52,6 +52,7 @@ class BucketedRandomProjectionLSHSuite test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) + assert(brp.getNumHashFunctions === 1.0) } test("read/write") { @@ -85,6 +86,7 @@ class BucketedRandomProjectionLSHSuite test("BucketedRandomProjectionLSH: randUnitVectors") { val brp = new BucketedRandomProjectionLSH() .setNumHashTables(20) + .setNumHashFunctions(10) .setInputCol("keys") .setOutputCol("values") .setBucketLength(1.0) @@ -119,6 +121,7 @@ class BucketedRandomProjectionLSHSuite // Project from 100 dimensional Euclidean Space to 10 dimensions val brp = new BucketedRandomProjectionLSH() .setNumHashTables(10) + .setNumHashFunctions(5) .setInputCol("keys") .setOutputCol("values") .setBucketLength(2.5) @@ -133,7 +136,8 @@ class BucketedRandomProjectionLSHSuite val key = Vectors.dense(1.2, 3.4) val brp = new BucketedRandomProjectionLSH() - .setNumHashTables(2) + .setNumHashTables(8) + .setNumHashFunctions(2) .setInputCol("keys") .setOutputCol("values") .setBucketLength(4.0) @@ -150,6 +154,7 @@ class BucketedRandomProjectionLSHSuite val brp = new BucketedRandomProjectionLSH() .setNumHashTables(20) + .setNumHashFunctions(10) .setInputCol("keys") .setOutputCol("values") .setBucketLength(1.0) @@ -182,6 +187,7 @@ class BucketedRandomProjectionLSHSuite val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(8) .setNumHashTables(2) .setInputCol("keys") .setOutputCol("values") @@ -200,6 +206,7 @@ class BucketedRandomProjectionLSHSuite val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(8) .setNumHashTables(2) .setInputCol("keys") .setOutputCol("values") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf82460f..8bac0b628136a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -44,8 +44,9 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } test("MinHashLSH: default params") { - val rp = new MinHashLSH - assert(rp.getNumHashTables === 1.0) + val mh = new MinHashLSH + assert(mh.getNumHashTables === 1.0) + assert(mh.getNumHashFunctions === 1.0) } test("read/write") { @@ -109,7 +110,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("approxNearestNeighbors for min hash") { val mh = new MinHashLSH() - .setNumHashTables(20) + .setNumHashTables(64) + .setNumHashFunctions(2) .setInputCol("keys") .setOutputCol("values") .setSeed(12345) @@ -119,8 +121,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20, singleProbe = true) - assert(precision >= 0.7) - assert(recall >= 0.7) + assert(precision >= 0.6) + assert(recall >= 0.6) } test("approxNearestNeighbors for numNeighbors <= 0") { @@ -149,7 +151,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") val mh = new MinHashLSH() - .setNumHashTables(20) + .setNumHashTables(64) + .setNumHashFunctions(2) .setInputCol("keys") .setOutputCol("values") .setSeed(12345) From afdebf68881bd5b711a62143d502249ac26b4e67 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 19 Feb 2017 04:24:11 -0800 Subject: [PATCH 02/61] [SPARK-19550][BUILD][WIP] Addendum: select Java 1.7 for scalac 2.10, still ## What changes were proposed in this pull request? Go back to selecting source/target 1.7 for Scala 2.10 builds, because the SBT-based build for 2.10 won't work otherwise. ## How was this patch tested? Existing tests, but, we need to verify this vs what the SBT build would exactly run on Jenkins Author: Sean Owen Closes #16983 from srowen/SPARK-19550.3. --- project/SparkBuild.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b48879faa4fb8..93a31897c9fc1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -234,7 +234,9 @@ object SparkBuild extends PomBuild { }, javacJVMVersion := "1.8", - scalacJVMVersion := "1.8", + // SBT Scala 2.10 build still doesn't support Java 8, because scalac 2.10 doesn't, but, + // it also doesn't touch Java 8 code and it's OK to emit Java 7 bytecode in this case + scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8"), javacOptions in Compile ++= Seq( "-encoding", "UTF-8", From 486607abd60844dbadde861015282139bf021c7c Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 19 Feb 2017 04:34:07 -0800 Subject: [PATCH 03/61] [SPARK-19450] Replace askWithRetry with askSync. ## What changes were proposed in this pull request? `askSync` is already added in `RpcEndpointRef` (see SPARK-19347 and https://github.com/apache/spark/pull/16690#issuecomment-276850068) and `askWithRetry` is marked as deprecated. As mentioned SPARK-18113(https://github.com/apache/spark/pull/16503#event-927953218): >askWithRetry is basically an unneeded API, and a leftover from the akka days that doesn't make sense anymore. It's prone to cause deadlocks (exactly because it's blocking), it imposes restrictions on the caller (e.g. idempotency) and other things that people generally don't pay that much attention to when using it. Since `askWithRetry` is just used inside spark and not in user logic. It might make sense to replace all of them with `askSync`. ## How was this patch tested? This PR doesn't change code logic, existing unit test can cover. Author: jinxing Closes #16790 from jinxing64/SPARK-19450. --- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/deploy/Client.scala | 2 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../deploy/master/ui/ApplicationPage.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../deploy/rest/StandaloneRestServer.scala | 6 +- .../spark/deploy/worker/ui/WorkerPage.scala | 4 +- .../CoarseGrainedExecutorBackend.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../org/apache/spark/rpc/RpcEndpointRef.scala | 60 ------------------- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 4 +- .../spark/storage/BlockManagerMaster.scala | 32 +++++----- .../StandaloneDynamicAllocationSuite.scala | 4 +- .../spark/deploy/client/AppClientSuite.scala | 2 +- .../spark/deploy/master/MasterSuite.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 21 ++++--- .../spark/storage/BlockManagerSuite.scala | 2 +- ...osCoarseGrainedSchedulerBackendSuite.scala | 2 +- .../spark/deploy/yarn/YarnAllocator.scala | 2 +- .../state/StateStoreCoordinator.scala | 8 +-- .../receiver/ReceiverSupervisorImpl.scala | 4 +- .../streaming/scheduler/ReceiverTracker.scala | 6 +- 24 files changed, 58 insertions(+), 119 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ca442b629fd9..4ef6656222455 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -99,7 +99,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker[T: ClassTag](message: Any): T = { try { - trackerEndpoint.askWithRetry[T](message) + trackerEndpoint.askSync[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 7e564061e69bb..e4d83893e740e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -605,7 +605,7 @@ class SparkContext(config: SparkConf) extends Logging { Some(Utils.getThreadDump()) } else { val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get - Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) + Some(endpointRef.askSync[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index a4de3d7eaf458..bf6093236d92b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -123,7 +123,7 @@ private class ClientEndpoint( Thread.sleep(5000) logInfo("... polling master for driver state") val statusResponse = - activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) + activeMasterEndpoint.askSync[DriverStatusResponse](RequestDriverStatus(driverId)) if (statusResponse.found) { logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c5f7c077fe202..816bf37e39fee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -1045,7 +1045,7 @@ private[deploy] object Master extends Logging { val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) - val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest) (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 18cff3125d6b4..946a92882141c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -34,7 +34,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val state = master.askWithRetry[MasterStateResponse](RequestMasterState) + val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index ebbbbd3b715b0..7dbe32975435d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -33,7 +33,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - master.askWithRetry[MasterStateResponse](RequestMasterState) + master.askSync[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index c19296c7b3e00..56620064c57fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -71,7 +71,7 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + val response = masterEndpoint.askSync[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion @@ -89,7 +89,7 @@ private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRe extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + val response = masterEndpoint.askSync[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse @@ -174,7 +174,7 @@ private[rest] class StandaloneSubmitRequestServlet( requestMessage match { case submitRequest: CreateSubmissionRequest => val driverDescription = buildDriverDescription(submitRequest) - val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + val response = masterEndpoint.askSync[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8ebcbcb6a1738..1ad973122b609 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -34,12 +34,12 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) + val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) + val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 4a38560d8deca..b376ecd301eab 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -199,7 +199,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig) + val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig) val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index db5d0d85ceb87..d762f11125516 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -677,7 +677,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index a5778876d4901..4d39f144dd198 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -92,64 +92,4 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) timeout.awaitResult(future) } - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a - * default timeout, throw a SparkException if this fails even after the default number of retries. - * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this - * method retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in a message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - @deprecated("use 'askSync' instead.", "2.2.0") - def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a - * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method - * retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in a message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @param timeout the timeout duration - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - @deprecated("use 'askSync' instead.", "2.2.0") - def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = ask[T](message, timeout) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("RpcEndpoint returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - - if (attempts < maxRetries) { - Thread.sleep(retryWaitMs) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 0b7d3716c19da..692ed8083475c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -232,7 +232,7 @@ class DAGScheduler( accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) - blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( + blockManagerMaster.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e006cc96569af..94abe30bb12f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -372,7 +372,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp try { if (driverEndpoint != null) { logInfo("Shutting down all executors") - driverEndpoint.askWithRetry[Boolean](StopExecutors) + driverEndpoint.askSync[Boolean](StopExecutors) } } catch { case e: Exception => @@ -384,7 +384,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp stopExecutors() try { if (driverEndpoint != null) { - driverEndpoint.askWithRetry[Boolean](StopDriver) + driverEndpoint.askSync[Boolean](StopDriver) } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7a600068912b1..3ca690db9e79f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -60,7 +60,7 @@ class BlockManagerMaster( maxMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") - val updatedId = driverEndpoint.askWithRetry[BlockManagerId]( + val updatedId = driverEndpoint.askSync[BlockManagerId]( RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId @@ -72,7 +72,7 @@ class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { - val res = driverEndpoint.askWithRetry[Boolean]( + val res = driverEndpoint.askSync[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logDebug(s"Updated info of block $blockId") res @@ -80,12 +80,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( GetLocationsMultipleBlockIds(blockIds)) } @@ -99,11 +99,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { - driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) + driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** @@ -111,12 +111,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) + driverEndpoint.askSync[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) @@ -128,7 +128,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) @@ -140,7 +140,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( + val future = driverEndpoint.askSync[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -159,11 +159,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askSync[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askSync[Array[StorageStatus]](GetStorageStatus) } /** @@ -184,7 +184,7 @@ class BlockManagerMaster( * master endpoint for a response to a prior message. */ val response = driverEndpoint. - askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + askSync[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip implicit val sameThread = ThreadUtils.sameThread val cbf = @@ -214,7 +214,7 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askSync[Future[Seq[BlockId]]](msg) timeout.awaitResult(future) } @@ -223,7 +223,7 @@ class BlockManagerMaster( * since they are not reported the master. */ def hasCachedBlocks(executorId: String): Boolean = { - driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) + driverEndpoint.askSync[Boolean](HasCachedBlocks(executorId)) } /** Stop the driver endpoint, called only on the Spark driver node */ @@ -237,7 +237,7 @@ class BlockManagerMaster( /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!driverEndpoint.askWithRetry[Boolean](message)) { + if (!driverEndpoint.askSync[Boolean](message)) { throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 54ea72737c5b2..9839dcf8535db 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -547,7 +547,7 @@ class StandaloneDynamicAllocationSuite /** Get the Master state */ private def getMasterState: MasterStateResponse = { - master.self.askWithRetry[MasterStateResponse](RequestMasterState) + master.self.askSync[MasterStateResponse](RequestMasterState) } /** Get the applications that are active from Master */ @@ -620,7 +620,7 @@ class StandaloneDynamicAllocationSuite when(endpointRef.address).thenReturn(mockAddress) val message = RegisterExecutor(id, endpointRef, "localhost", 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] - backend.driverEndpoint.askWithRetry[Boolean](message) + backend.driverEndpoint.askSync[Boolean](message) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index bc58fb2a362a4..936639b845789 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -171,7 +171,7 @@ class AppClientSuite /** Get the Master state */ private def getMasterState: MasterStateResponse = { - master.self.askWithRetry[MasterStateResponse](RequestMasterState) + master.self.askSync[MasterStateResponse](RequestMasterState) } /** Get the applications that are active from Master */ diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index da7253b2a56df..2127da48ece49 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -432,7 +432,7 @@ class MasterSuite extends SparkFunSuite val master = makeMaster() master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) eventually(timeout(10.seconds)) { - val masterState = master.self.askWithRetry[MasterStateResponse](RequestMasterState) + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index b4037d7a9c6e8..31d9dd3de8acc 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -118,8 +118,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") - val reply = newRpcEndpointRef.askWithRetry[String]("Echo") + val newRpcEndpointRef = rpcEndpointRef.askSync[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askSync[String]("Echo") assert("Echo" === reply) } @@ -132,7 +132,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { context.reply(msg) } }) - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } @@ -150,7 +150,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } finally { anotherEnv.shutdown() @@ -177,14 +177,13 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { - // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause - val e = intercept[SparkException] { - rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) + val e = intercept[RpcTimeoutException] { + rpcEndpointRef.askSync[String]("hello", new RpcTimeout(1 millis, shortProp)) } // The SparkException cause should be a RpcTimeoutException with message indicating the // controlling timeout property - assert(e.getCause.isInstanceOf[RpcTimeoutException]) - assert(e.getCause.getMessage.contains(shortProp)) + assert(e.isInstanceOf[RpcTimeoutException]) + assert(e.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -677,7 +676,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } finally { localEnv.shutdown() @@ -894,7 +893,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val ref = anotherEnv.setupEndpointRef(env.address, "SPARK-14699") // Make sure the connect is set up - assert(ref.askWithRetry[String]("hello") === "hello") + assert(ref.askSync[String]("hello") === "hello") anotherEnv.shutdown() anotherEnv.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 705c355234425..64a67b4c4cbab 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -394,7 +394,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askWithRetry[Boolean]( + val reregister = !master.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index a674da4066456..cdb3b68489654 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -442,7 +442,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.stop() // Any method of the backend involving sending messages to the driver endpoint should not // be called after the backend is stopped. - verify(driverEndpoint, never()).askWithRetry(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) + verify(driverEndpoint, never()).askSync(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) } test("mesos supports spark.executor.uri") { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a76dbd1bf0e3..abd2de75c6450 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -101,7 +101,7 @@ private[yarn] class YarnAllocator( * @see SPARK-12864 */ private var executorIdCounter: Int = - driverRef.askWithRetry[Int](RetrieveLastAllocatedExecutorId) + driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) // Queue to store the timestamp of failed executors private val failedExecutorsTimeStamps = new Queue[Long]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 267d17623d5e5..d0f81887e62d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -88,21 +88,21 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { /** Verify whether the given executor has the active instance of a state store */ private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId)) + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) } /** Get the location of the state store */ private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId)) + rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) } /** Deactivate instances related to a set of operator */ private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation)) + rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) } private[state] def stop(): Unit = { - rpcEndpointRef.askWithRetry[Boolean](StopCoordinator) + rpcEndpointRef.askSync[Boolean](StopCoordinator) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 722024b8a6d57..f5c8a88f42af6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -188,13 +188,13 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, host, executorId, endpoint) - trackerEndpoint.askWithRetry[Boolean](msg) + trackerEndpoint.askSync[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - trackerEndpoint.askWithRetry[Boolean](DeregisterReceiver(streamId, message, errorString)) + trackerEndpoint.askSync[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8f55d982a904c..bd7ab0b9bf5eb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -170,7 +170,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false trackerState = Stopping if (!skipReceiverLaunch) { // Send the stop signal to all the receivers - endpoint.askWithRetry[Boolean](StopAllReceivers) + endpoint.askSync[Boolean](StopAllReceivers) // Wait for the Spark job that runs the receivers to be over // That is, for the receivers to quit gracefully. @@ -183,7 +183,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } // Check if all the receivers have been deregistered or not - val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + val receivers = endpoint.askSync[Seq[Int]](AllReceiverIds) if (receivers.nonEmpty) { logWarning("Not all of the receivers have deregistered, " + receivers) } else { @@ -249,7 +249,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false */ def allocatedExecutors(): Map[Int, Option[String]] = synchronized { if (isTrackerStarted) { - endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { + endpoint.askSync[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { _.runningExecutor.map { _.executorId } From 96c9392e24f301e85da79c0a072248a7145c4772 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 19 Feb 2017 09:37:56 -0800 Subject: [PATCH 04/61] [SPARK-19533][EXAMPLES] Convert Java tests to use lambdas, Java 8 features ## What changes were proposed in this pull request? Convert Java tests to use lambdas, Java 8 features. ## How was this patch tested? Jenkins tests. Author: Sean Owen Closes #16961 from srowen/SPARK-19533. --- .../apache/spark/examples/JavaLogQuery.java | 21 +-- .../apache/spark/examples/JavaPageRank.java | 49 ++----- .../apache/spark/examples/JavaSparkPi.java | 20 +-- .../spark/examples/JavaStatusTrackerDemo.java | 5 +- .../org/apache/spark/examples/JavaTC.java | 8 +- .../apache/spark/examples/JavaWordCount.java | 27 +--- .../spark/examples/ml/JavaALSExample.java | 7 +- ...lectionViaTrainValidationSplitExample.java | 3 - .../examples/ml/JavaTokenizerExample.java | 13 +- .../examples/ml/JavaVectorSlicerExample.java | 7 +- .../mllib/JavaAssociationRulesExample.java | 6 +- ...avaBinaryClassificationMetricsExample.java | 33 ++--- .../mllib/JavaBisectingKMeansExample.java | 7 +- .../mllib/JavaChiSqSelectorExample.java | 38 ++--- ...JavaDecisionTreeClassificationExample.java | 26 +--- .../JavaDecisionTreeRegressionExample.java | 33 ++--- .../mllib/JavaElementwiseProductExample.java | 27 +--- .../mllib/JavaGaussianMixtureExample.java | 19 +-- ...GradientBoostingClassificationExample.java | 21 +-- ...JavaGradientBoostingRegressionExample.java | 30 +--- .../mllib/JavaIsotonicRegressionExample.java | 39 ++--- .../examples/mllib/JavaKMeansExample.java | 19 +-- .../examples/mllib/JavaLBFGSExample.java | 23 +-- .../JavaLatentDirichletAllocationExample.java | 28 ++-- .../JavaLinearRegressionWithSGDExample.java | 47 +++--- ...avaLogisticRegressionWithLBFGSExample.java | 14 +- ...ulticlassClassificationMetricsExample.java | 13 +- .../examples/mllib/JavaNaiveBayesExample.java | 19 +-- .../JavaPowerIterationClusteringExample.java | 6 +- ...JavaRandomForestClassificationExample.java | 23 +-- .../JavaRandomForestRegressionExample.java | 37 ++--- .../mllib/JavaRankingMetricsExample.java | 135 ++++++------------ .../mllib/JavaRecommendationExample.java | 58 +++----- .../mllib/JavaRegressionMetricsExample.java | 31 ++-- .../examples/mllib/JavaSVMWithSGDExample.java | 13 +- .../examples/mllib/JavaSimpleFPGrowth.java | 12 +- .../mllib/JavaStreamingTestExample.java | 40 ++---- .../sql/JavaSQLDataSourceExample.java | 8 +- .../examples/sql/JavaSparkSQLExample.java | 60 +++----- .../sql/hive/JavaSparkHiveExample.java | 9 +- .../JavaStructuredKafkaWordCount.java | 10 +- .../JavaStructuredNetworkWordCount.java | 11 +- ...avaStructuredNetworkWordCountWindowed.java | 16 +-- .../streaming/JavaCustomReceiver.java | 34 +---- .../streaming/JavaDirectKafkaWordCount.java | 31 +--- .../streaming/JavaFlumeEventCount.java | 8 +- .../streaming/JavaKafkaWordCount.java | 33 +---- .../streaming/JavaNetworkWordCount.java | 25 +--- .../examples/streaming/JavaQueueStream.java | 24 +--- .../JavaRecoverableNetworkWordCount.java | 91 ++++-------- .../streaming/JavaSqlNetworkWordCount.java | 51 +++---- .../JavaStatefulNetworkWordCount.java | 30 +--- 52 files changed, 380 insertions(+), 1018 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 7775443861661..cf12de390f608 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -17,18 +17,16 @@ package org.apache.spark.examples; -import com.google.common.collect.Lists; import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.sql.SparkSession; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -40,7 +38,7 @@ */ public final class JavaLogQuery { - public static final List exampleApacheLogs = Lists.newArrayList( + public static final List exampleApacheLogs = Arrays.asList( "10.10.10.10 - \"FRED\" [18/Jan/2013:17:56:07 +1100] \"GET http://images.com/2013/Generic.jpg " + "HTTP/1.1\" 304 315 \"http://referall.com/\" \"Mozilla/4.0 (compatible; MSIE 7.0; " + "Windows NT 5.1; GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; " + @@ -109,19 +107,10 @@ public static void main(String[] args) { JavaRDD dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs); - JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() { - @Override - public Tuple2, Stats> call(String s) { - return new Tuple2<>(extractKey(s), extractStats(s)); - } - }); + JavaPairRDD, Stats> extracted = + dataSet.mapToPair(s -> new Tuple2<>(extractKey(s), extractStats(s))); - JavaPairRDD, Stats> counts = extracted.reduceByKey(new Function2() { - @Override - public Stats call(Stats stats, Stats stats2) { - return stats.merge(stats2); - } - }); + JavaPairRDD, Stats> counts = extracted.reduceByKey(Stats::merge); List, Stats>> output = counts.collect(); for (Tuple2 t : output) { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index bcc493bdcb225..b5b4703932f0f 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -19,7 +19,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Iterator; import java.util.regex.Pattern; import scala.Tuple2; @@ -28,10 +27,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.sql.SparkSession; /** @@ -90,52 +86,35 @@ public static void main(String[] args) throws Exception { JavaRDD lines = spark.read().textFile(args[0]).javaRDD(); // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - String[] parts = SPACES.split(s); - return new Tuple2<>(parts[0], parts[1]); - } - }).distinct().groupByKey().cache(); + JavaPairRDD> links = lines.mapToPair(s -> { + String[] parts = SPACES.split(s); + return new Tuple2<>(parts[0], parts[1]); + }).distinct().groupByKey().cache(); // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - JavaPairRDD ranks = links.mapValues(new Function, Double>() { - @Override - public Double call(Iterable rs) { - return 1.0; - } - }); + JavaPairRDD ranks = links.mapValues(rs -> 1.0); // Calculates and updates URL ranks continuously using PageRank algorithm. for (int current = 0; current < Integer.parseInt(args[1]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD contribs = links.join(ranks).values() - .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { - @Override - public Iterator> call(Tuple2, Double> s) { - int urlCount = Iterables.size(s._1); - List> results = new ArrayList<>(); - for (String n : s._1) { - results.add(new Tuple2<>(n, s._2() / urlCount)); - } - return results.iterator(); + .flatMapToPair(s -> { + int urlCount = Iterables.size(s._1()); + List> results = new ArrayList<>(); + for (String n : s._1) { + results.add(new Tuple2<>(n, s._2() / urlCount)); } - }); + return results.iterator(); + }); // Re-calculates URL ranks based on neighbor contributions. - ranks = contribs.reduceByKey(new Sum()).mapValues(new Function() { - @Override - public Double call(Double sum) { - return 0.15 + sum * 0.85; - } - }); + ranks = contribs.reduceByKey(new Sum()).mapValues(sum -> 0.15 + sum * 0.85); } // Collects all URL ranks and dump them to console. List> output = ranks.collect(); for (Tuple2 tuple : output) { - System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); + System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 89855e81f1f7a..cb4b26569088a 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -19,8 +19,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.sql.SparkSession; import java.util.ArrayList; @@ -49,19 +47,11 @@ public static void main(String[] args) throws Exception { JavaRDD dataSet = jsc.parallelize(l, slices); - int count = dataSet.map(new Function() { - @Override - public Integer call(Integer integer) { - double x = Math.random() * 2 - 1; - double y = Math.random() * 2 - 1; - return (x * x + y * y <= 1) ? 1 : 0; - } - }).reduce(new Function2() { - @Override - public Integer call(Integer integer, Integer integer2) { - return integer + integer2; - } - }); + int count = dataSet.map(integer -> { + double x = Math.random() * 2 - 1; + double y = Math.random() * 2 - 1; + return (x * x + y * y <= 1) ? 1 : 0; + }).reduce((integer, integer2) -> integer + integer2); System.out.println("Pi is roughly " + 4.0 * count / n); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index 6f899c772eb98..b0ebedfed6a8b 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -25,7 +25,6 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.SparkSession; - import java.util.Arrays; import java.util.List; @@ -50,11 +49,11 @@ public static void main(String[] args) throws Exception { .appName(APP_NAME) .getOrCreate(); - final JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); // Example of implementing a progress reporter for a simple job. JavaRDD rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( - new IdentityWithDelay()); + new IdentityWithDelay<>()); JavaFutureAction> jobFuture = rdd.collectAsync(); while (!jobFuture.isDone()) { Thread.sleep(1000); // 1 second diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index f12ca77ed1eb0..bde30b84d6cf3 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -80,13 +80,7 @@ public static void main(String[] args) { // the graph to obtain the path (x, z). // Because join() joins on keys, the edges are stored in reversed order. - JavaPairRDD edges = tc.mapToPair( - new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 e) { - return new Tuple2<>(e._2(), e._1()); - } - }); + JavaPairRDD edges = tc.mapToPair(e -> new Tuple2<>(e._2(), e._1())); long oldCount; long nextCount = tc.count(); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 8f18604c0750c..f1ce1e958580f 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -21,13 +21,9 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.sql.SparkSession; import java.util.Arrays; -import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -48,28 +44,11 @@ public static void main(String[] args) throws Exception { JavaRDD lines = spark.read().textFile(args[0]).javaRDD(); - JavaRDD words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String s) { - return Arrays.asList(SPACE.split(s)).iterator(); - } - }); + JavaRDD words = lines.flatMap(s -> Arrays.asList(SPACE.split(s)).iterator()); - JavaPairRDD ones = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }); + JavaPairRDD ones = words.mapToPair(s -> new Tuple2<>(s, 1)); - JavaPairRDD counts = ones.reduceByKey( - new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaPairRDD counts = ones.reduceByKey((i1, i2) -> i1 + i2); List> output = counts.collect(); for (Tuple2 tuple : output) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 739558e81ffb0..33ba668b32fc2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -25,7 +25,6 @@ import java.io.Serializable; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.recommendation.ALS; import org.apache.spark.ml.recommendation.ALSModel; @@ -88,11 +87,7 @@ public static void main(String[] args) { // $example on$ JavaRDD ratingsRDD = spark .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD() - .map(new Function() { - public Rating call(String str) { - return Rating.parseRating(str); - } - }); + .map(Rating::parseRating); Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class); Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); Dataset training = splits[0]; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java index 0f96293f0348b..9a4722b90cf1b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -32,9 +32,6 @@ /** * Java example demonstrating model selection using TrainValidationSplit. * - * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} - * using linear regression. - * * Run with * {{{ * bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index 004e9b12f6260..3f809eba7fffb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -69,20 +69,17 @@ public static void main(String[] args) { .setOutputCol("words") .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - spark.udf().register("countTokens", new UDF1, Integer>() { - @Override - public Integer call(WrappedArray words) { - return words.size(); - } - }, DataTypes.IntegerType); + spark.udf().register("countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); Dataset tokenized = tokenizer.transform(sentenceDataFrame); tokenized.select("sentence", "words") - .withColumn("tokens", callUDF("countTokens", col("words"))).show(false); + .withColumn("tokens", callUDF("countTokens", col("words"))) + .show(false); Dataset regexTokenized = regexTokenizer.transform(sentenceDataFrame); regexTokenized.select("sentence", "words") - .withColumn("tokens", callUDF("countTokens", col("words"))).show(false); + .withColumn("tokens", callUDF("countTokens", col("words"))) + .show(false); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java index 1922514c87dff..1ae48be2660bc 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -20,10 +20,9 @@ import org.apache.spark.sql.SparkSession; // $example on$ +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; - import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; @@ -43,14 +42,14 @@ public static void main(String[] args) { .getOrCreate(); // $example on$ - Attribute[] attrs = new Attribute[]{ + Attribute[] attrs = { NumericAttribute.defaultAttr().withName("f1"), NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); - List data = Lists.newArrayList( + List data = Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java index 189560e3fe1f1..5f43603f4ff5c 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -38,9 +38,9 @@ public static void main(String[] args) { // $example on$ JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( - new FreqItemset(new String[] {"a"}, 15L), - new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 12L) + new FreqItemset<>(new String[] {"a"}, 15L), + new FreqItemset<>(new String[] {"b"}, 35L), + new FreqItemset<>(new String[] {"a", "b"}, 12L) )); AssociationRules arules = new AssociationRules() diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java index 12aa14f7107f7..b9d0313c6bb56 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -21,7 +21,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; @@ -46,7 +45,7 @@ public static void main(String[] args) { JavaRDD test = splits[1]; // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(2) .run(training.rdd()); @@ -54,15 +53,8 @@ public static void main(String[] args) { model.clearThreshold(); // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - @Override - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); + JavaPairRDD predictionAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. BinaryClassificationMetrics metrics = @@ -73,32 +65,25 @@ public Tuple2 call(LabeledPoint p) { System.out.println("Precision by threshold: " + precision.collect()); // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + JavaRDD recall = metrics.recallByThreshold().toJavaRDD(); System.out.println("Recall by threshold: " + recall.collect()); // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + JavaRDD f1Score = metrics.fMeasureByThreshold().toJavaRDD(); System.out.println("F1 Score by threshold: " + f1Score.collect()); - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + JavaRDD f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); System.out.println("F2 Score by threshold: " + f2Score.collect()); // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); + JavaRDD prc = metrics.pr().toJavaRDD(); System.out.println("Precision-recall curve: " + prc.collect()); // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - @Override - public Double call(Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); + JavaRDD thresholds = precision.map(t -> Double.parseDouble(t._1().toString())); // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); + JavaRDD roc = metrics.roc().toJavaRDD(); System.out.println("ROC curve: " + roc.collect()); // AUPRC diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java index c600094947d5a..f878b55a98adf 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -17,10 +17,9 @@ package org.apache.spark.examples.mllib; -import java.util.ArrayList; - // $example on$ -import com.google.common.collect.Lists; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; @@ -41,7 +40,7 @@ public static void main(String[] args) { JavaSparkContext sc = new JavaSparkContext(sparkConf); // $example on$ - ArrayList localData = Lists.newArrayList( + List localData = Arrays.asList( Vectors.dense(0.1, 0.1), Vectors.dense(0.3, 0.3), Vectors.dense(10.1, 10.1), Vectors.dense(10.3, 10.3), Vectors.dense(20.1, 20.1), Vectors.dense(20.3, 20.3), diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java index ad44acb4cd6e3..ce354af2b5793 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java @@ -19,10 +19,8 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.VoidFunction; // $example on$ import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.feature.ChiSqSelector; import org.apache.spark.mllib.feature.ChiSqSelectorModel; import org.apache.spark.mllib.linalg.Vectors; @@ -42,41 +40,25 @@ public static void main(String[] args) { // Discretize data in 16 equal bins since ChiSqSelector requires categorical features // Although features are doubles, the ChiSqSelector treats each unique value as a category - JavaRDD discretizedData = points.map( - new Function() { - @Override - public LabeledPoint call(LabeledPoint lp) { - final double[] discretizedFeatures = new double[lp.features().size()]; - for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); - } - return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); - } + JavaRDD discretizedData = points.map(lp -> { + double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } - ); + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + }); // Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) - final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); + ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); // Filter the top 50 features from each feature vector - JavaRDD filteredData = discretizedData.map( - new Function() { - @Override - public LabeledPoint call(LabeledPoint lp) { - return new LabeledPoint(lp.label(), transformer.transform(lp.features())); - } - } - ); + JavaRDD filteredData = discretizedData.map(lp -> + new LabeledPoint(lp.label(), transformer.transform(lp.features()))); // $example off$ System.out.println("filtered data: "); - filteredData.foreach(new VoidFunction() { - @Override - public void call(LabeledPoint labeledPoint) throws Exception { - System.out.println(labeledPoint.toString()); - } - }); + filteredData.foreach(System.out::println); jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java index 66387b9df51c7..032c168b946d6 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -27,8 +27,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; @@ -53,31 +51,21 @@ public static void main(String[] args) { // Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; + int numClasses = 2; Map categoricalFeaturesInfo = new HashMap<>(); String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; + int maxDepth = 5; + int maxBins = 32; // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testErr = + predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java index 904e7f7e9505e..f222c38fc82b6 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -27,9 +27,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; @@ -56,34 +53,20 @@ public static void main(String[] args) { // Empty categoricalFeaturesInfo indicates all features are continuous. Map categoricalFeaturesInfo = new HashMap<>(); String impurity = "variance"; - Integer maxDepth = 5; - Integer maxBins = 32; + int maxDepth = 5; + int maxBins = 32; // Train a DecisionTree model. - final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testMSE = predictionAndLabel.mapToDouble(pl -> { + double diff = pl._1() - pl._2(); + return diff * diff; + }).mean(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java index c8ce6ab284b07..2d45c6166fee3 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java @@ -25,12 +25,10 @@ import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.feature.ElementwiseProduct; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; // $example off$ -import org.apache.spark.api.java.function.VoidFunction; public class JavaElementwiseProductExample { public static void main(String[] args) { @@ -43,35 +41,18 @@ public static void main(String[] args) { JavaRDD data = jsc.parallelize(Arrays.asList( Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); - final ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); // Batch transform and per-row transform give the same results: JavaRDD transformedData = transformer.transform(data); - JavaRDD transformedData2 = data.map( - new Function() { - @Override - public Vector call(Vector v) { - return transformer.transform(v); - } - } - ); + JavaRDD transformedData2 = data.map(transformer::transform); // $example off$ System.out.println("transformedData: "); - transformedData.foreach(new VoidFunction() { - @Override - public void call(Vector vector) throws Exception { - System.out.println(vector.toString()); - } - }); + transformedData.foreach(System.out::println); System.out.println("transformedData2: "); - transformedData2.foreach(new VoidFunction() { - @Override - public void call(Vector vector) throws Exception { - System.out.println(vector.toString()); - } - }); + transformedData2.foreach(System.out::println); jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java index 3124411c8227c..5792e5a71cb09 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java @@ -22,7 +22,6 @@ // $example on$ import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.GaussianMixture; import org.apache.spark.mllib.clustering.GaussianMixtureModel; import org.apache.spark.mllib.linalg.Vector; @@ -39,18 +38,14 @@ public static void main(String[] args) { // Load and parse data String path = "data/mllib/gmm_data.txt"; JavaRDD data = jsc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) { - values[i] = Double.parseDouble(sarray[i]); - } - return Vectors.dense(values); - } + JavaRDD parsedData = data.map(s -> { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); } - ); + return Vectors.dense(values); + }); parsedData.cache(); // Cluster the data into two classes using GaussianMixture diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java index 213949e525dc2..521ee96fbdf4b 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java @@ -27,8 +27,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; @@ -61,24 +59,13 @@ public static void main(String[] args) { Map categoricalFeaturesInfo = new HashMap<>(); boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); + GradientBoostedTreesModel model = GradientBoostedTrees.train(trainingData, boostingStrategy); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testErr = + predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification GBT model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java index 78db442dbc99d..b345d19f59ab6 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java @@ -24,12 +24,9 @@ import scala.Tuple2; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; @@ -60,30 +57,15 @@ public static void main(String[] args) { Map categoricalFeaturesInfo = new HashMap<>(); boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); + GradientBoostedTreesModel model = GradientBoostedTrees.train(trainingData, boostingStrategy); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testMSE = predictionAndLabel.mapToDouble(pl -> { + double diff = pl._1() - pl._2(); + return diff * diff; + }).mean(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression GBT model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java index a30b5f1f73eaf..adebafe4b89d7 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java @@ -20,9 +20,6 @@ import scala.Tuple2; import scala.Tuple3; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaRDD; @@ -42,14 +39,8 @@ public static void main(String[] args) { jsc.sc(), "data/mllib/sample_isotonic_regression_libsvm_data.txt").toJavaRDD(); // Create label, feature, weight tuples from input data with weight set to default value 1.0. - JavaRDD> parsedData = data.map( - new Function>() { - public Tuple3 call(LabeledPoint point) { - return new Tuple3<>(new Double(point.label()), - new Double(point.features().apply(0)), 1.0); - } - } - ); + JavaRDD> parsedData = data.map(point -> + new Tuple3<>(point.label(), point.features().apply(0), 1.0)); // Split data into training (60%) and test (40%) sets. JavaRDD>[] splits = @@ -59,29 +50,17 @@ public Tuple3 call(LabeledPoint point) { // Create isotonic regression model from training data. // Isotonic parameter defaults to true so it is only shown for demonstration - final IsotonicRegressionModel model = - new IsotonicRegression().setIsotonic(true).run(training); + IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); // Create tuples of predicted and real labels. - JavaPairRDD predictionAndLabel = test.mapToPair( - new PairFunction, Double, Double>() { - @Override - public Tuple2 call(Tuple3 point) { - Double predictedLabel = model.predict(point._2()); - return new Tuple2<>(predictedLabel, point._1()); - } - } - ); + JavaPairRDD predictionAndLabel = test.mapToPair(point -> + new Tuple2<>(model.predict(point._2()), point._1())); // Calculate mean squared error between predicted and real labels. - Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( - new Function, Object>() { - @Override - public Object call(Tuple2 pl) { - return Math.pow(pl._1() - pl._2(), 2); - } - } - ).rdd()).mean(); + double meanSquaredError = predictionAndLabel.mapToDouble(pl -> { + double diff = pl._1() - pl._2(); + return diff * diff; + }).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); // Save and load model diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java index 2d89c768fcfca..f17275617ad5e 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java @@ -22,7 +22,6 @@ // $example on$ import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; @@ -39,18 +38,14 @@ public static void main(String[] args) { // Load and parse data String path = "data/mllib/kmeans_data.txt"; JavaRDD data = jsc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) { - values[i] = Double.parseDouble(sarray[i]); - } - return Vectors.dense(values); - } + JavaRDD parsedData = data.map(s -> { + String[] sarray = s.split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); } - ); + return Vectors.dense(values); + }); parsedData.cache(); // Cluster the data into two classes using KMeans diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java index f6f91f486fef7..3fdc03a92ad7a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -23,7 +23,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; import org.apache.spark.mllib.linalg.Vector; @@ -50,12 +49,8 @@ public static void main(String[] args) { JavaRDD test = data.subtract(trainingInit); // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); + JavaPairRDD training = data.mapToPair(p -> + new Tuple2<>(p.label(), MLUtils.appendBias(p.features()))); training.cache(); // Run training algorithm to build the model. @@ -77,7 +72,7 @@ public Tuple2 call(LabeledPoint p) { Vector weightsWithIntercept = result._1(); double[] loss = result._2(); - final LogisticRegressionModel model = new LogisticRegressionModel( + LogisticRegressionModel model = new LogisticRegressionModel( Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); @@ -85,13 +80,8 @@ public Tuple2 call(LabeledPoint p) { model.clearThreshold(); // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); + JavaPairRDD scoreAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. BinaryClassificationMetrics metrics = @@ -99,8 +89,9 @@ public Tuple2 call(LabeledPoint p) { double auROC = metrics.areaUnderROC(); System.out.println("Loss of each step in training process"); - for (double l : loss) + for (double l : loss) { System.out.println(l); + } System.out.println("Area under ROC = " + auROC); // $example off$ diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java index 578564eeb23dd..887edf8c21210 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java @@ -25,7 +25,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.DistributedLDAModel; import org.apache.spark.mllib.clustering.LDA; import org.apache.spark.mllib.clustering.LDAModel; @@ -44,28 +43,17 @@ public static void main(String[] args) { // Load and parse the data String path = "data/mllib/sample_lda_data.txt"; JavaRDD data = jsc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) { - values[i] = Double.parseDouble(sarray[i]); - } - return Vectors.dense(values); - } + JavaRDD parsedData = data.map(s -> { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); } - ); + return Vectors.dense(values); + }); // Index documents with unique IDs JavaPairRDD corpus = - JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 doc_id) { - return doc_id.swap(); - } - } - ) - ); + JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(Tuple2::swap)); corpus.cache(); // Cluster the documents into three topics using LDA diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java index 9ca9a7847c463..324a781c1a44a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java @@ -23,9 +23,8 @@ // $example on$ import scala.Tuple2; -import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LinearRegressionModel; @@ -44,43 +43,31 @@ public static void main(String[] args) { // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) { - v[i] = Double.parseDouble(features[i]); - } - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } + JavaRDD parsedData = data.map(line -> { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) { + v[i] = Double.parseDouble(features[i]); } - ); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + }); parsedData.cache(); // Building the model int numIterations = 100; double stepSize = 0.00000001; - final LinearRegressionModel model = + LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2<>(prediction, point.label()); - } - } - ); - double MSE = new JavaDoubleRDD(valuesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - return Math.pow(pair._1() - pair._2(), 2.0); - } - } - ).rdd()).mean(); + JavaPairRDD valuesAndPreds = parsedData.mapToPair(point -> + new Tuple2<>(model.predict(point.features()), point.label())); + + double MSE = valuesAndPreds.mapToDouble(pair -> { + double diff = pair._1() - pair._2(); + return diff * diff; + }).mean(); System.out.println("training Mean Squared Error = " + MSE); // Save and load model diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java index 7fc371ec0f990..26b8a6e9fa3ad 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -23,8 +23,8 @@ // $example on$ import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.evaluation.MulticlassMetrics; @@ -49,19 +49,13 @@ public static void main(String[] args) { JavaRDD test = splits[1]; // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(10) .run(training.rdd()); // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); + JavaPairRDD predictionAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 2d12bdd2a6440..03670383b794f 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -21,7 +21,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.evaluation.MulticlassMetrics; @@ -46,19 +45,13 @@ public static void main(String[] args) { JavaRDD test = splits[1]; // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(3) .run(training.rdd()); // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); + JavaPairRDD predictionAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java index f4ec04b0c677c..d80dbe80000b3 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java @@ -19,8 +19,6 @@ // $example on$ import scala.Tuple2; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -41,20 +39,11 @@ public static void main(String[] args) { JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4}); JavaRDD training = tmp[0]; // training set JavaRDD test = tmp[1]; // test set - final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); + NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); JavaPairRDD predictionAndLabel = - test.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - double accuracy = predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return pl._1().equals(pl._2()); - } - }).count() / (double) test.count(); + test.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double accuracy = + predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) test.count(); // Save and load model model.save(jsc.sc(), "target/tmp/myNaiveBayesModel"); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java index 91c3bd72da3a7..5155f182ba20e 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -17,9 +17,9 @@ package org.apache.spark.examples.mllib; -import scala.Tuple3; +import java.util.Arrays; -import com.google.common.collect.Lists; +import scala.Tuple3; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -39,7 +39,7 @@ public static void main(String[] args) { @SuppressWarnings("unchecked") // $example on$ - JavaRDD> similarities = sc.parallelize(Lists.newArrayList( + JavaRDD> similarities = sc.parallelize(Arrays.asList( new Tuple3<>(0L, 1L, 0.9), new Tuple3<>(1L, 2L, 0.9), new Tuple3<>(2L, 3L, 0.9), diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java index 24af5d0180ce4..6998ce2156c25 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java @@ -19,6 +19,7 @@ // $example on$ import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -26,8 +27,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; @@ -50,7 +49,7 @@ public static void main(String[] args) { // Train a RandomForest model. // Empty categoricalFeaturesInfo indicates all features are continuous. Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap<>(); + Map categoricalFeaturesInfo = new HashMap<>(); Integer numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "gini"; @@ -58,25 +57,15 @@ public static void main(String[] args) { Integer maxBins = 32; Integer seed = 12345; - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, + RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testErr = + predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java index afa9045878db3..4a0f55f529801 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java @@ -23,12 +23,9 @@ import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; @@ -52,37 +49,23 @@ public static void main(String[] args) { // Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. Map categoricalFeaturesInfo = new HashMap<>(); - Integer numTrees = 3; // Use more in practice. + int numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; + int maxDepth = 4; + int maxBins = 32; + int seed = 12345; // Train a RandomForest model. - final RandomForestModel model = RandomForest.trainRegressor(trainingData, + RandomForestModel model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); // Evaluate model on test instances and compute test error JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testMSE = predictionAndLabel.mapToDouble(pl -> { + double diff = pl._1() - pl._2(); + return diff * diff; + }).mean(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 54dfc404ca6e9..bd49f059b29fd 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -23,7 +23,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.evaluation.RegressionMetrics; import org.apache.spark.mllib.evaluation.RankingMetrics; import org.apache.spark.mllib.recommendation.ALS; @@ -39,93 +38,61 @@ public static void main(String[] args) { // $example on$ String path = "data/mllib/sample_movielens_data.txt"; JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - @Override - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double - .parseDouble(parts[2]) - 2.5); - } - } - ); + JavaRDD ratings = data.map(line -> { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + }); ratings.cache(); // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); // Get top 10 recommendations for every user and scale ratings from 0 to 1 JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - @Override - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2<>(t._1(), scaledRatings); + JavaRDD> userRecsScaled = userRecs.map(t -> { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); } - } - ); + return new Tuple2<>(t._1(), scaledRatings); + }); JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - @Override - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); + JavaRDD binarizedRatings = ratings.map(r -> { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; } - } - ); + return new Rating(r.user(), r.product(), binaryRating); + }); // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - @Override - public Object call(Rating r) { - return r.user(); - } - } - ); + JavaPairRDD> userMovies = binarizedRatings.groupBy(Rating::user); // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - @Override - public List call(Iterable docs) { - List products = new ArrayList<>(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } + JavaPairRDD> userMoviesList = userMovies.mapValues(docs -> { + List products = new ArrayList<>(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); } - return products; } - } - ); + return products; + }); // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - @Override - public List call(Rating[] docs) { - List products = new ArrayList<>(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; + JavaPairRDD> userRecommendedList = userRecommended.mapValues(docs -> { + List products = new ArrayList<>(); + for (Rating r : docs) { + products.add(r.product()); } - } - ); + return products; + }); JavaRDD, List>> relevantDocs = userMoviesList.join( userRecommendedList).values(); @@ -143,33 +110,15 @@ public List call(Rating[] docs) { System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - @Override - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); + JavaRDD> userProducts = + ratings.map(r -> new Tuple2<>(r.user(), r.product())); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - @Override - public Tuple2, Object> call(Rating r) { - return new Tuple2, Object>( - new Tuple2<>(r.user(), r.product()), r.rating()); - } - } - )); + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(r -> + new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))); JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - @Override - public Tuple2, Object> call(Rating r) { - return new Tuple2, Object>( - new Tuple2<>(r.user(), r.product()), r.rating()); - } - } + JavaPairRDD.fromJavaRDD(ratings.map(r -> + new Tuple2, Object>(new Tuple2<>(r.user(), r.product()), r.rating()) )).join(predictions).values(); // Create regression metrics object diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index f69aa4b75a56c..1ee68da35e81a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -21,7 +21,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.recommendation.ALS; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; @@ -37,15 +36,12 @@ public static void main(String[] args) { // Load and parse the data String path = "data/mllib/als/test.data"; JavaRDD data = jsc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); + JavaRDD ratings = data.map(s -> { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), + Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + }); // Build the recommendation model using ALS int rank = 10; @@ -53,37 +49,19 @@ public Rating call(String s) { MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); + JavaRDD> userProducts = + ratings.map(r -> new Tuple2<>(r.user(), r.product())); JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD() + .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating())) + ); + JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD( + ratings.map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))) + .join(predictions).values(); + double MSE = ratesAndPreds.mapToDouble(pair -> { + double err = pair._1() - pair._2(); + return err * err; + }).mean(); System.out.println("Mean Squared Error = " + MSE); // Save and load model diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index b3e5c04759575..7bb9993b84168 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -21,7 +21,6 @@ import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LinearRegressionModel; @@ -38,34 +37,24 @@ public static void main(String[] args) { // Load and parse the data String path = "data/mllib/sample_linear_regression_data.txt"; JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) { - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - } - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } + JavaRDD parsedData = data.map(line -> { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) { + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); } - ); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + }); parsedData.cache(); // Building the model int numIterations = 100; - final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); + JavaPairRDD valuesAndPreds = parsedData.mapToPair(point -> + new Tuple2<>(model.predict(point.features()), point.label())); // Instantiate metrics object RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java index 720b167b2cadf..866a221fdb592 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java @@ -24,7 +24,6 @@ import scala.Tuple2; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.SVMModel; import org.apache.spark.mllib.classification.SVMWithSGD; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; @@ -50,20 +49,14 @@ public static void main(String[] args) { // Run training algorithm to build the model. int numIterations = 100; - final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); + SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); // Clear the default threshold. model.clearThreshold(); // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - } - ); + JavaRDD> scoreAndLabels = test.map(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. BinaryClassificationMetrics metrics = diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java index 7f4fe600422b2..f9198e75c2ff5 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java @@ -23,9 +23,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -// $example off$ -import org.apache.spark.api.java.function.Function; -// $example on$ import org.apache.spark.mllib.fpm.AssociationRules; import org.apache.spark.mllib.fpm.FPGrowth; import org.apache.spark.mllib.fpm.FPGrowthModel; @@ -42,14 +39,7 @@ public static void main(String[] args) { // $example on$ JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); - JavaRDD> transactions = data.map( - new Function>() { - public List call(String line) { - String[] parts = line.split(" "); - return Arrays.asList(parts); - } - } - ); + JavaRDD> transactions = data.map(line -> Arrays.asList(line.split(" "))); FPGrowth fpg = new FPGrowth() .setMinSupport(0.2) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java index cfaa577b51161..4be702c2ba6ad 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java @@ -17,10 +17,6 @@ package org.apache.spark.examples.mllib; - -import org.apache.spark.api.java.function.VoidFunction; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; // $example on$ import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.StreamingTest; @@ -75,16 +71,12 @@ public static void main(String[] args) throws Exception { ssc.checkpoint(Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark").toString()); // $example on$ - JavaDStream data = ssc.textFileStream(dataDir).map( - new Function() { - @Override - public BinarySample call(String line) { - String[] ts = line.split(","); - boolean label = Boolean.parseBoolean(ts[0]); - double value = Double.parseDouble(ts[1]); - return new BinarySample(label, value); - } - }); + JavaDStream data = ssc.textFileStream(dataDir).map(line -> { + String[] ts = line.split(","); + boolean label = Boolean.parseBoolean(ts[0]); + double value = Double.parseDouble(ts[1]); + return new BinarySample(label, value); + }); StreamingTest streamingTest = new StreamingTest() .setPeacePeriod(0) @@ -98,21 +90,11 @@ public BinarySample call(String line) { // Stop processing if test becomes significant or we time out timeoutCounter = numBatchesTimeout; - out.foreachRDD(new VoidFunction>() { - @Override - public void call(JavaRDD rdd) { - timeoutCounter -= 1; - - boolean anySignificant = !rdd.filter(new Function() { - @Override - public Boolean call(StreamingTestResult v) { - return v.pValue() < 0.05; - } - }).isEmpty(); - - if (timeoutCounter <= 0 || anySignificant) { - rdd.context().stop(); - } + out.foreachRDD(rdd -> { + timeoutCounter -= 1; + boolean anySignificant = !rdd.filter(v -> v.pValue() < 0.05).isEmpty(); + if (timeoutCounter <= 0 || anySignificant) { + rdd.context().stop(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index b687fae5a1da0..adb96dd8bf00c 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -139,11 +139,9 @@ private static void runBasicParquetExample(SparkSession spark) { // Parquet files can also be used to create a temporary view and then used in SQL statements parquetFileDF.createOrReplaceTempView("parquetFile"); Dataset namesDF = spark.sql("SELECT name FROM parquetFile WHERE age BETWEEN 13 AND 19"); - Dataset namesDS = namesDF.map(new MapFunction() { - public String call(Row row) { - return "Name: " + row.getString(0); - } - }, Encoders.STRING()); + Dataset namesDS = namesDF.map( + (MapFunction) row -> "Name: " + row.getString(0), + Encoders.STRING()); namesDS.show(); // +------------+ // | value| diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java index c5770d147a6b5..8605852d0881c 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -227,12 +227,9 @@ private static void runDatasetCreationExample(SparkSession spark) { // Encoders for most common types are provided in class Encoders Encoder integerEncoder = Encoders.INT(); Dataset primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder); - Dataset transformedDS = primitiveDS.map(new MapFunction() { - @Override - public Integer call(Integer value) throws Exception { - return value + 1; - } - }, integerEncoder); + Dataset transformedDS = primitiveDS.map( + (MapFunction) value -> value + 1, + integerEncoder); transformedDS.collect(); // Returns [2, 3, 4] // DataFrames can be converted to a Dataset by providing a class. Mapping based on name @@ -255,15 +252,12 @@ private static void runInferSchemaExample(SparkSession spark) { JavaRDD peopleRDD = spark.read() .textFile("examples/src/main/resources/people.txt") .javaRDD() - .map(new Function() { - @Override - public Person call(String line) throws Exception { - String[] parts = line.split(","); - Person person = new Person(); - person.setName(parts[0]); - person.setAge(Integer.parseInt(parts[1].trim())); - return person; - } + .map(line -> { + String[] parts = line.split(","); + Person person = new Person(); + person.setName(parts[0]); + person.setAge(Integer.parseInt(parts[1].trim())); + return person; }); // Apply a schema to an RDD of JavaBeans to get a DataFrame @@ -276,12 +270,9 @@ public Person call(String line) throws Exception { // The columns of a row in the result can be accessed by field index Encoder stringEncoder = Encoders.STRING(); - Dataset teenagerNamesByIndexDF = teenagersDF.map(new MapFunction() { - @Override - public String call(Row row) throws Exception { - return "Name: " + row.getString(0); - } - }, stringEncoder); + Dataset teenagerNamesByIndexDF = teenagersDF.map( + (MapFunction) row -> "Name: " + row.getString(0), + stringEncoder); teenagerNamesByIndexDF.show(); // +------------+ // | value| @@ -290,12 +281,9 @@ public String call(Row row) throws Exception { // +------------+ // or by field name - Dataset teenagerNamesByFieldDF = teenagersDF.map(new MapFunction() { - @Override - public String call(Row row) throws Exception { - return "Name: " + row.getAs("name"); - } - }, stringEncoder); + Dataset teenagerNamesByFieldDF = teenagersDF.map( + (MapFunction) row -> "Name: " + row.getAs("name"), + stringEncoder); teenagerNamesByFieldDF.show(); // +------------+ // | value| @@ -324,12 +312,9 @@ private static void runProgrammaticSchemaExample(SparkSession spark) { StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows - JavaRDD rowRDD = peopleRDD.map(new Function() { - @Override - public Row call(String record) throws Exception { - String[] attributes = record.split(","); - return RowFactory.create(attributes[0], attributes[1].trim()); - } + JavaRDD rowRDD = peopleRDD.map((Function) record -> { + String[] attributes = record.split(","); + return RowFactory.create(attributes[0], attributes[1].trim()); }); // Apply the schema to the RDD @@ -343,12 +328,9 @@ public Row call(String record) throws Exception { // The results of SQL queries are DataFrames and support all the normal RDD operations // The columns of a row in the result can be accessed by field index or by field name - Dataset namesDS = results.map(new MapFunction() { - @Override - public String call(Row row) throws Exception { - return "Name: " + row.getString(0); - } - }, Encoders.STRING()); + Dataset namesDS = results.map( + (MapFunction) row -> "Name: " + row.getString(0), + Encoders.STRING()); namesDS.show(); // +-------------+ // | value| diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 2fe1307d8efbe..47638565b1663 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -90,12 +90,9 @@ public static void main(String[] args) { Dataset sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key"); // The items in DaraFrames are of type Row, which lets you to access each column by ordinal. - Dataset stringsDS = sqlDF.map(new MapFunction() { - @Override - public String call(Row row) throws Exception { - return "Key: " + row.get(0) + ", Value: " + row.get(1); - } - }, Encoders.STRING()); + Dataset stringsDS = sqlDF.map( + (MapFunction) row -> "Key: " + row.get(0) + ", Value: " + row.get(1), + Encoders.STRING()); stringsDS.show(); // +--------------------+ // | value| diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java index 0f45cfeca4429..4e02719e043ad 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java @@ -25,7 +25,6 @@ import org.apache.spark.sql.streaming.StreamingQuery; import java.util.Arrays; -import java.util.Iterator; /** * Consumes messages from one or more topics in Kafka and does wordcount. @@ -78,12 +77,9 @@ public static void main(String[] args) throws Exception { .as(Encoders.STRING()); // Generate running word count - Dataset wordCounts = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(x.split(" ")).iterator(); - } - }, Encoders.STRING()).groupBy("value").count(); + Dataset wordCounts = lines.flatMap( + (FlatMapFunction) x -> Arrays.asList(x.split(" ")).iterator(), + Encoders.STRING()).groupBy("value").count(); // Start running the query that prints the running counts to the console StreamingQuery query = wordCounts.writeStream() diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java index 5f342e1ead6ca..3af786978b167 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java @@ -21,7 +21,6 @@ import org.apache.spark.sql.streaming.StreamingQuery; import java.util.Arrays; -import java.util.Iterator; /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. @@ -61,13 +60,9 @@ public static void main(String[] args) throws Exception { .load(); // Split the lines into words - Dataset words = lines.as(Encoders.STRING()) - .flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(x.split(" ")).iterator(); - } - }, Encoders.STRING()); + Dataset words = lines.as(Encoders.STRING()).flatMap( + (FlatMapFunction) x -> Arrays.asList(x.split(" ")).iterator(), + Encoders.STRING()); // Generate running word count Dataset wordCounts = words.groupBy("value").count(); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java index 172d053c29a1f..93ec5e2695157 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java @@ -18,13 +18,11 @@ import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.*; -import org.apache.spark.sql.functions; import org.apache.spark.sql.streaming.StreamingQuery; import scala.Tuple2; import java.sql.Timestamp; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; /** @@ -86,16 +84,12 @@ public static void main(String[] args) throws Exception { // Split the lines into words, retaining timestamps Dataset words = lines .as(Encoders.tuple(Encoders.STRING(), Encoders.TIMESTAMP())) - .flatMap( - new FlatMapFunction, Tuple2>() { - @Override - public Iterator> call(Tuple2 t) { - List> result = new ArrayList<>(); - for (String word : t._1.split(" ")) { - result.add(new Tuple2<>(word, t._2)); - } - return result.iterator(); + .flatMap((FlatMapFunction, Tuple2>) t -> { + List> result = new ArrayList<>(); + for (String word : t._1.split(" ")) { + result.add(new Tuple2<>(word, t._2)); } + return result.iterator(); }, Encoders.tuple(Encoders.STRING(), Encoders.TIMESTAMP()) ).toDF("word", "timestamp"); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index e20b94d5b03f2..47692ec982890 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -20,9 +20,6 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; @@ -38,7 +35,6 @@ import java.net.Socket; import java.nio.charset.StandardCharsets; import java.util.Arrays; -import java.util.Iterator; import java.util.regex.Pattern; /** @@ -74,23 +70,9 @@ public static void main(String[] args) throws Exception { // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); + JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1)) + .reduceByKey((i1, i2) -> i1 + i2); wordCounts.print(); ssc.start(); @@ -108,15 +90,13 @@ public JavaCustomReceiver(String host_ , int port_) { port = port_; } + @Override public void onStart() { // Start the thread that receives data over a connection - new Thread() { - @Override public void run() { - receive(); - } - }.start(); + new Thread(this::receive).start(); } + @Override public void onStop() { // There is nothing much to do as the thread calling receive() // is designed to stop by itself isStopped() returns false @@ -127,13 +107,13 @@ private void receive() { try { Socket socket = null; BufferedReader reader = null; - String userInput = null; try { // connect to the server socket = new Socket(host, port); reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); // Until stopped or connection broken continue reading + String userInput; while (!isStopped() && (userInput = reader.readLine()) != null) { System.out.println("Received data '" + userInput + "'"); store(userInput); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index ed118f86c058b..5e5ae6213d5d9 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -20,7 +20,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Arrays; -import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.regex.Pattern; @@ -30,7 +29,6 @@ import kafka.serializer.StringDecoder; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.*; import org.apache.spark.streaming.api.java.*; import org.apache.spark.streaming.kafka.KafkaUtils; import org.apache.spark.streaming.Durations; @@ -82,31 +80,10 @@ public static void main(String[] args) throws Exception { ); // Get the lines, split them into words, count the words and print - JavaDStream lines = messages.map(new Function, String>() { - @Override - public String call(Tuple2 tuple2) { - return tuple2._2(); - } - }); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey( - new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaDStream lines = messages.map(Tuple2::_2); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); + JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1)) + .reduceByKey((i1, i2) -> i1 + i2); wordCounts.print(); // Start the computation diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index 33c0a2df2fe43..0c651049d0ffa 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -18,7 +18,6 @@ package org.apache.spark.examples.streaming; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import org.apache.spark.streaming.flume.FlumeUtils; @@ -62,12 +61,7 @@ public static void main(String[] args) throws Exception { flumeStream.count(); - flumeStream.count().map(new Function() { - @Override - public String call(Long in) { - return "Received " + in + " flume events."; - } - }).print(); + flumeStream.count().map(in -> "Received " + in + " flume events.").print(); ssc.start(); ssc.awaitTermination(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 8a5fd53372041..ce5acdca92666 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -18,7 +18,6 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; -import java.util.Iterator; import java.util.Map; import java.util.HashMap; import java.util.regex.Pattern; @@ -26,10 +25,6 @@ import scala.Tuple2; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -78,32 +73,12 @@ public static void main(String[] args) throws Exception { JavaPairReceiverInputDStream messages = KafkaUtils.createStream(jssc, args[0], args[1], topicMap); - JavaDStream lines = messages.map(new Function, String>() { - @Override - public String call(Tuple2 tuple2) { - return tuple2._2(); - } - }); + JavaDStream lines = messages.map(Tuple2::_2); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1)) + .reduceByKey((i1, i2) -> i1 + i2); wordCounts.print(); jssc.start(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 7a8fe99f48f27..b217672def88e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -18,15 +18,11 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; -import java.util.Iterator; import java.util.regex.Pattern; import scala.Tuple2; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -66,24 +62,9 @@ public static void main(String[] args) throws Exception { // Replication necessary in distributed scenario for fault tolerance. JavaReceiverInputDStream lines = ssc.socketTextStream( args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); + JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1)) + .reduceByKey((i1, i2) -> i1 + i2); wordCounts.print(); ssc.start(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java index 62413b4606ff2..e86f8ab38a74f 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java @@ -17,19 +17,15 @@ package org.apache.spark.examples.streaming; - +import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Queue; import scala.Tuple2; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -49,14 +45,14 @@ public static void main(String[] args) throws Exception { // Create the queue through which RDDs can be pushed to // a QueueInputDStream - Queue> rddQueue = new LinkedList<>(); // Create and push some RDDs into the queue - List list = Lists.newArrayList(); + List list = new ArrayList<>(); for (int i = 0; i < 1000; i++) { list.add(i); } + Queue> rddQueue = new LinkedList<>(); for (int i = 0; i < 30; i++) { rddQueue.add(ssc.sparkContext().parallelize(list)); } @@ -64,19 +60,9 @@ public static void main(String[] args) throws Exception { // Create the QueueInputDStream and use it do some processing JavaDStream inputStream = ssc.queueStream(rddQueue); JavaPairDStream mappedStream = inputStream.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i % 10, 1); - } - }); + i -> new Tuple2<>(i % 10, 1)); JavaPairDStream reducedStream = mappedStream.reduceByKey( - new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + (i1, i2) -> i1 + i2); reducedStream.print(); ssc.start(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index acbc34524328b..45a876decff8b 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -18,10 +18,8 @@ package org.apache.spark.examples.streaming; import java.io.File; -import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; -import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -30,12 +28,10 @@ import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; @@ -120,7 +116,7 @@ private static JavaStreamingContext createContext(String ip, // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint System.out.println("Creating new context"); - final File outputFile = new File(outputPath); + File outputFile = new File(outputPath); if (outputFile.exists()) { outputFile.delete(); } @@ -132,52 +128,31 @@ private static JavaStreamingContext createContext(String ip, // Create a socket stream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); + JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1)) + .reduceByKey((i1, i2) -> i1 + i2); + + wordCounts.foreachRDD((rdd, time) -> { + // Get or register the blacklist Broadcast + Broadcast> blacklist = + JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + LongAccumulator droppedWordsCounter = + JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(wordCount -> { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; } - }); - - wordCounts.foreachRDD(new VoidFunction2, Time>() { - @Override - public void call(JavaPairRDD rdd, Time time) throws IOException { - // Get or register the blacklist Broadcast - final Broadcast> blacklist = - JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); - // Get or register the droppedWordsCounter Accumulator - final LongAccumulator droppedWordsCounter = - JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); - // Use blacklist to drop words and use droppedWordsCounter to count them - String counts = rdd.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 wordCount) { - if (blacklist.value().contains(wordCount._1())) { - droppedWordsCounter.add(wordCount._2()); - return false; - } else { - return true; - } - } - }).collect().toString(); - String output = "Counts at time " + time + " " + counts; - System.out.println(output); - System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); - System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(output + "\n", outputFile, Charset.defaultCharset()); - } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + System.out.println(output); + System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); + System.out.println("Appending to " + outputFile.getAbsolutePath()); + Files.append(output + "\n", outputFile, Charset.defaultCharset()); }); return ssc; @@ -198,19 +173,15 @@ public static void main(String[] args) throws Exception { System.exit(1); } - final String ip = args[0]; - final int port = Integer.parseInt(args[1]); - final String checkpointDirectory = args[2]; - final String outputPath = args[3]; + String ip = args[0]; + int port = Integer.parseInt(args[1]); + String checkpointDirectory = args[2]; + String outputPath = args[3]; // Function to create JavaStreamingContext without any output operations // (used to detect the new context) - Function0 createContextFunc = new Function0() { - @Override - public JavaStreamingContext call() { - return createContext(ip, port, checkpointDirectory, outputPath); - } - }; + Function0 createContextFunc = + () -> createContext(ip, port, checkpointDirectory, outputPath); JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, createContextFunc); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index b8e9e125ba596..948d1a2111780 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -18,20 +18,15 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; -import java.util.Iterator; import java.util.regex.Pattern; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.VoidFunction2; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -48,7 +43,6 @@ * and then run the example * `$ bin/run-example org.apache.spark.examples.streaming.JavaSqlNetworkWordCount localhost 9999` */ - public final class JavaSqlNetworkWordCount { private static final Pattern SPACE = Pattern.compile(" "); @@ -70,39 +64,28 @@ public static void main(String[] args) throws Exception { // Replication necessary in distributed scenario for fault tolerance. JavaReceiverInputDStream lines = ssc.socketTextStream( args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD(new VoidFunction2, Time>() { - @Override - public void call(JavaRDD rdd, Time time) { - SparkSession spark = JavaSparkSessionSingleton.getInstance(rdd.context().getConf()); + words.foreachRDD((rdd, time) -> { + SparkSession spark = JavaSparkSessionSingleton.getInstance(rdd.context().getConf()); - // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame - JavaRDD rowRDD = rdd.map(new Function() { - @Override - public JavaRecord call(String word) { - JavaRecord record = new JavaRecord(); - record.setWord(word); - return record; - } - }); - Dataset wordsDataFrame = spark.createDataFrame(rowRDD, JavaRecord.class); + // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame + JavaRDD rowRDD = rdd.map(word -> { + JavaRecord record = new JavaRecord(); + record.setWord(word); + return record; + }); + Dataset wordsDataFrame = spark.createDataFrame(rowRDD, JavaRecord.class); - // Creates a temporary view using the DataFrame - wordsDataFrame.createOrReplaceTempView("words"); + // Creates a temporary view using the DataFrame + wordsDataFrame.createOrReplaceTempView("words"); - // Do word count on table using SQL and print it - Dataset wordCountsDataFrame = - spark.sql("select word, count(*) as total from words group by word"); - System.out.println("========= " + time + "========="); - wordCountsDataFrame.show(); - } + // Do word count on table using SQL and print it + Dataset wordCountsDataFrame = + spark.sql("select word, count(*) as total from words group by word"); + System.out.println("========= " + time + "========="); + wordCountsDataFrame.show(); }); ssc.start(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index ed36df852ace6..9d8bd7fd11ebd 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -18,7 +18,6 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; -import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -72,32 +71,17 @@ public static void main(String[] args) throws Exception { JavaReceiverInputDStream lines = ssc.socketTextStream( args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(SPACE.split(x)).iterator(); - } - }); + JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator()); - JavaPairDStream wordsDstream = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2<>(s, 1); - } - }); + JavaPairDStream wordsDstream = words.mapToPair(s -> new Tuple2<>(s, 1)); // Update the cumulative count function Function3, State, Tuple2> mappingFunc = - new Function3, State, Tuple2>() { - @Override - public Tuple2 call(String word, Optional one, - State state) { - int sum = one.orElse(0) + (state.exists() ? state.get() : 0); - Tuple2 output = new Tuple2<>(word, sum); - state.update(sum); - return output; - } + (word, one, state) -> { + int sum = one.orElse(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2<>(word, sum); + state.update(sum); + return output; }; // DStream made of get cumulative counts that get updated in every batch From ec0ea7f6ae4bbf803a22102faffdb88d0398ebed Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 19 Feb 2017 09:42:50 -0800 Subject: [PATCH 05/61] [SPARK-19534][TESTS] Convert Java tests to use lambdas, Java 8 features ## What changes were proposed in this pull request? Convert tests to use Java 8 lambdas, and modest related fixes to surrounding code. ## How was this patch tested? Jenkins tests Author: Sean Owen Closes #16964 from srowen/SPARK-19534. --- .../spark/network/TransportContext.java | 6 +- .../spark/network/util/MapConfigProvider.java | 8 +- .../network/ChunkFetchIntegrationSuite.java | 37 +- .../RequestTimeoutIntegrationSuite.java | 3 +- .../network/TransportClientFactorySuite.java | 51 +- .../TransportResponseHandlerSuite.java | 14 +- .../network/crypto/AuthIntegrationSuite.java | 19 +- .../spark/network/sasl/SparkSaslSuite.java | 65 +-- .../util/TransportFrameDecoderSuite.java | 44 +- .../network/sasl/SaslIntegrationSuite.java | 34 +- .../ExternalShuffleBlockHandlerSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 6 +- .../ExternalShuffleIntegrationSuite.java | 13 +- .../shuffle/OneForOneBlockFetcherSuite.java | 78 ++- .../shuffle/RetryingBlockFetcherSuite.java | 64 +-- .../unsafe/sort/UnsafeExternalSorter.java | 1 - .../org/apache/spark/JavaJdbcRDDSuite.java | 26 +- .../sort/UnsafeShuffleWriterSuite.java | 65 +-- .../map/AbstractBytesToBytesMapSuite.java | 25 +- .../sort/UnsafeExternalSorterSuite.java | 25 +- .../org/apache/spark/Java8RDDAPISuite.java | 7 +- .../test/org/apache/spark/JavaAPISuite.java | 492 ++++------------ .../kafka010/JavaConsumerStrategySuite.java | 24 +- .../SparkSubmitCommandBuilderSuite.java | 2 +- .../SparkSubmitOptionParserSuite.java | 8 +- .../apache/spark/ml/feature/JavaPCASuite.java | 35 +- .../classification/JavaNaiveBayesSuite.java | 10 +- .../clustering/JavaBisectingKMeansSuite.java | 4 +- .../spark/mllib/clustering/JavaLDASuite.java | 40 +- .../mllib/fpm/JavaAssociationRulesSuite.java | 6 +- .../regression/JavaLinearRegressionSuite.java | 11 +- .../mllib/tree/JavaDecisionTreeSuite.java | 15 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../sql/Java8DatasetAggregatorSuite.java | 16 +- .../spark/sql/JavaApplySchemaSuite.java | 22 +- .../apache/spark/sql/JavaDataFrameSuite.java | 47 +- .../spark/sql/JavaDatasetAggregatorSuite.java | 49 +- .../sql/JavaDatasetAggregatorSuiteBase.java | 14 +- .../apache/spark/sql/JavaDatasetSuite.java | 147 ++--- .../org/apache/spark/sql/JavaUDFSuite.java | 37 +- .../streaming/JavaMapWithStateSuite.java | 81 +-- .../spark/streaming/JavaReceiverAPISuite.java | 24 +- .../streaming/JavaWriteAheadLogSuite.java | 10 +- .../apache/spark/streaming/Java8APISuite.java | 21 +- .../apache/spark/streaming/JavaAPISuite.java | 526 ++++-------------- 45 files changed, 662 insertions(+), 1574 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 37ba543380f07..965c4ae307667 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,9 +17,9 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.List; -import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import io.netty.handler.timeout.IdleStateHandler; @@ -100,7 +100,7 @@ public TransportClientFactory createClientFactory(List } public TransportClientFactory createClientFactory() { - return createClientFactory(Lists.newArrayList()); + return createClientFactory(new ArrayList<>()); } /** Create a server which will attempt to bind to a specific port. */ @@ -120,7 +120,7 @@ public TransportServer createServer(List bootstraps) { } public TransportServer createServer() { - return createServer(0, Lists.newArrayList()); + return createServer(0, new ArrayList<>()); } public TransportChannelHandler initializePipeline(SocketChannel channel) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index b6667998b5b9d..9cfee7f08d155 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -17,22 +17,20 @@ package org.apache.spark.network.util; -import com.google.common.collect.Maps; - import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.NoSuchElementException; /** ConfigProvider based on a Map (copied in the constructor). */ public class MapConfigProvider extends ConfigProvider { - public static final MapConfigProvider EMPTY = new MapConfigProvider( - Collections.emptyMap()); + public static final MapConfigProvider EMPTY = new MapConfigProvider(Collections.emptyMap()); private final Map config; public MapConfigProvider(Map config) { - this.config = Maps.newHashMap(config); + this.config = new HashMap<>(config); } @Override diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 5bb8819132e3d..824482af08dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.RandomAccessFile; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; @@ -29,7 +30,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.common.io.Closeables; import org.junit.AfterClass; @@ -179,49 +179,49 @@ public void onFailure(int chunkIndex, Throwable e) { @Test public void fetchBufferChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchFileChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchNonExistentChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(12345)); + FetchResult res = fetchChunks(Arrays.asList(12345)); assertTrue(res.successChunks.isEmpty()); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertEquals(Sets.newHashSet(12345), res.failedChunks); assertTrue(res.buffers.isEmpty()); } @Test public void fetchBothChunks() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchChunkAndNonExistent() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); + assertEquals(Sets.newHashSet(12345), res.failedChunks); + assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } - private void assertBufferListsEqual(List list0, List list1) + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); for (int i = 0; i < list0.size(); i ++) { @@ -229,7 +229,8 @@ private void assertBufferListsEqual(List list0, List configMap = Maps.newHashMap(); + Map configMap = new HashMap<>(); configMap.put("spark.shuffle.io.connectionTimeout", "10s"); conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 205ab88c84313..e95d25fe6ae91 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -19,19 +19,20 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; -import com.google.common.collect.Maps; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; import org.apache.spark.network.client.TransportClient; @@ -71,39 +72,36 @@ public void tearDown() { * * If concurrent is true, create multiple threads to create clients in parallel. */ - private void testClientReuse(final int maxConnections, boolean concurrent) + private void testClientReuse(int maxConnections, boolean concurrent) throws IOException, InterruptedException { - Map configMap = Maps.newHashMap(); + Map configMap = new HashMap<>(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); - final TransportClientFactory factory = context.createClientFactory(); - final Set clients = Collections.synchronizedSet( + TransportClientFactory factory = context.createClientFactory(); + Set clients = Collections.synchronizedSet( new HashSet()); - final AtomicInteger failed = new AtomicInteger(); + AtomicInteger failed = new AtomicInteger(); Thread[] attempts = new Thread[maxConnections * 10]; // Launch a bunch of threads to create new clients. for (int i = 0; i < attempts.length; i++) { - attempts[i] = new Thread() { - @Override - public void run() { - try { - TransportClient client = - factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertTrue(client.isActive()); - clients.add(client); - } catch (IOException e) { - failed.incrementAndGet(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + attempts[i] = new Thread(() -> { + try { + TransportClient client = + factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(client.isActive()); + clients.add(client); + } catch (IOException e) { + failed.incrementAndGet(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - }; + }); if (concurrent) { attempts[i].start(); @@ -113,8 +111,8 @@ public void run() { } // Wait until all the threads complete. - for (int i = 0; i < attempts.length; i++) { - attempts[i].join(); + for (Thread attempt : attempts) { + attempt.join(); } Assert.assertEquals(0, failed.get()); @@ -150,7 +148,7 @@ public void returnDifferentClientsForDifferentServers() throws IOException, Inte TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); assertTrue(c2.isActive()); - assertTrue(c1 != c2); + assertNotSame(c1, c2); factory.close(); } @@ -167,7 +165,7 @@ public void neverReturnInactiveClients() throws IOException, InterruptedExceptio assertFalse(c1.isActive()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertFalse(c1 == c2); + assertNotSame(c1, c2); assertTrue(c2.isActive()); factory.close(); } @@ -207,8 +205,7 @@ public Iterable> getAll() { } }); TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); - TransportClientFactory factory = context.createClientFactory(); - try { + try (TransportClientFactory factory = context.createClientFactory()) { TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertTrue(c1.isActive()); long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds @@ -216,8 +213,6 @@ public Iterable> getAll() { Thread.sleep(10); } assertFalse(c1.isActive()); - } finally { - factory.close(); } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 128f7cba74350..4477c9a935f21 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; @@ -54,7 +52,7 @@ public void handleSuccessfulFetch() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onSuccess(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -67,7 +65,7 @@ public void handleFailedFetch() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); - verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -84,9 +82,9 @@ public void clearAllOutstandingRequests() throws Exception { handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); - verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + verify(callback, times(1)).onSuccess(eq(0), any()); + verify(callback, times(1)).onFailure(eq(1), any()); + verify(callback, times(1)).onFailure(eq(2), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -118,7 +116,7 @@ public void handleFailedRPC() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new RpcFailure(12345, "oh no")); - verify(callback, times(1)).onFailure((Throwable) any()); + verify(callback, times(1)).onFailure(any()); assertEquals(0, handler.numOutstandingRequests()); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 21609d5aa2a20..8751944a1c2a3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.network.crypto; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.List; import java.util.Map; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import io.netty.channel.Channel; import org.junit.After; import org.junit.Test; @@ -163,20 +163,17 @@ void createServer(String secret) throws Exception { } void createServer(String secret, boolean enableAes) throws Exception { - TransportServerBootstrap introspector = new TransportServerBootstrap() { - @Override - public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { - AuthTestCtx.this.serverChannel = channel; - if (rpcHandler instanceof AuthRpcHandler) { - AuthTestCtx.this.authRpcHandler = (AuthRpcHandler) rpcHandler; - } - return rpcHandler; + TransportServerBootstrap introspector = (channel, rpcHandler) -> { + this.serverChannel = channel; + if (rpcHandler instanceof AuthRpcHandler) { + this.authRpcHandler = (AuthRpcHandler) rpcHandler; } + return rpcHandler; }; SecretKeyHolder keyHolder = createKeyHolder(secret); TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder) : new SaslServerBootstrap(conf, keyHolder); - this.server = ctx.createServer(Lists.newArrayList(auth, introspector)); + this.server = ctx.createServer(Arrays.asList(auth, introspector)); } void createClient(String secret) throws Exception { @@ -186,7 +183,7 @@ void createClient(String secret) throws Exception { void createClient(String secret, boolean enableAes) throws Exception { TransportConf clientConf = enableAes ? conf : new TransportConf("rpc", MapConfigProvider.EMPTY); - List bootstraps = Lists.newArrayList( + List bootstraps = Arrays.asList( new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret))); this.client = ctx.createClientFactory(bootstraps) .createClient(TestUtils.getLocalHost(), server.getPort()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 87129b900bf0b..6f15718bd8705 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -23,6 +23,7 @@ import java.io.File; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -35,7 +36,6 @@ import javax.security.sasl.SaslException; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.io.ByteStreams; import com.google.common.io.Files; import io.netty.buffer.ByteBuf; @@ -45,8 +45,6 @@ import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; @@ -137,18 +135,15 @@ public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Throwable { + private static void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", JavaUtils.bytesToString(message)); - cb.onSuccess(JavaUtils.stringToBytes("Pong")); - return null; - } - }) + doAnswer(invocation -> { + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); + return null; + }) .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); @@ -255,21 +250,17 @@ public void testEncryptedMessageChunking() throws Exception { @Test public void testFileRegionEncryption() throws Exception { - final Map testConf = ImmutableMap.of( + Map testConf = ImmutableMap.of( "spark.network.sasl.maxEncryptedBlockSize", "1k"); - final AtomicReference response = new AtomicReference<>(); - final File file = File.createTempFile("sasltest", ".txt"); + AtomicReference response = new AtomicReference<>(); + File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { - @Override - public ManagedBuffer answer(InvocationOnMock invocation) { - return new FileSegmentManagedBuffer(conf, file, 0, file.length()); - } - }); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation -> + new FileSegmentManagedBuffer(conf, file, 0, file.length())); RpcHandler rpcHandler = mock(RpcHandler.class); when(rpcHandler.getStreamManager()).thenReturn(sm); @@ -280,18 +271,15 @@ public ManagedBuffer answer(InvocationOnMock invocation) { ctx = new SaslTestCtx(rpcHandler, true, false, testConf); - final CountDownLatch lock = new CountDownLatch(1); + CountDownLatch lock = new CountDownLatch(1); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - response.set((ManagedBuffer) invocation.getArguments()[1]); - response.get().retain(); - lock.countDown(); - return null; - } - }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + doAnswer(invocation -> { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + lock.countDown(); + return null; + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); ctx.client.fetchChunk(0, 0, callback); lock.await(10, TimeUnit.SECONDS); @@ -388,7 +376,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap()); + this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap()); } SaslTestCtx( @@ -416,7 +404,7 @@ private static class SaslTestCtx { checker)); try { - List clientBootstraps = Lists.newArrayList(); + List clientBootstraps = new ArrayList<>(); clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder)); if (disableClientEncryption) { clientBootstraps.add(new EncryptionDisablerBootstrap()); @@ -467,11 +455,6 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) ctx.write(msg, promise); } - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - super.handlerRemoved(ctx); - } - @Override public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { channel.pipeline().addFirst("encryptionChecker", this); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index d4de4a941d480..b53e41303751c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -28,8 +28,6 @@ import io.netty.channel.ChannelHandlerContext; import org.junit.AfterClass; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -52,7 +50,7 @@ public void testFrameDecoding() throws Exception { @Test public void testInterception() throws Exception { - final int interceptedReads = 3; + int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); ChannelHandlerContext ctx = mockChannelHandlerContext(); @@ -84,22 +82,19 @@ public void testInterception() throws Exception { public void testRetainedFrames() throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); - final AtomicInteger count = new AtomicInteger(); - final List retained = new ArrayList<>(); + AtomicInteger count = new AtomicInteger(); + List retained = new ArrayList<>(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); - } else { - buf.release(); - } - return null; + when(ctx.fireChannelRead(any())).thenAnswer(in -> { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); } + return null; }); ByteBuf data = createAndFeedFrames(100, decoder, ctx); @@ -150,12 +145,6 @@ public void testEmptyFrame() throws Exception { testInvalidFrame(8); } - @Test(expected = IllegalArgumentException.class) - public void testLargeFrame() throws Exception { - // Frame length includes the frame size field, so need to add a few more bytes. - testInvalidFrame(Integer.MAX_VALUE + 9); - } - /** * Creates a number of randomly sized frames and feed them to the given decoder, verifying * that the frames were read. @@ -210,13 +199,10 @@ private void testInvalidFrame(long size) throws Exception { private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); - return null; - } + when(ctx.fireChannelRead(any())).thenAnswer(in -> { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; }); return ctx; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 52f50a3409b98..c0e170e5b9353 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -38,7 +38,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; @@ -105,8 +104,7 @@ public void afterEach() { @Test public void testGoodClient() throws IOException, InterruptedException { clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; @@ -120,8 +118,7 @@ public void testBadClient() { when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); try { // Bootstrap should fail on startup. @@ -134,8 +131,7 @@ public void testBadClient() { @Test public void testNoSaslClient() throws IOException, InterruptedException { - clientFactory = context.createClientFactory( - Lists.newArrayList()); + clientFactory = context.createClientFactory(new ArrayList<>()); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { @@ -159,15 +155,11 @@ public void testNoSaslServer() { RpcHandler handler = new TestRpcHandler(); TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - TransportServer server = context.createServer(); - try { + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + try (TransportServer server = context.createServer()) { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); - } finally { - server.close(); } } @@ -191,14 +183,13 @@ public void testAppIsolation() throws Exception { try { // Create a client, and make a request to fetch blocks from a different app. clientFactory = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); client1 = clientFactory.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final AtomicReference exception = new AtomicReference<>(); + AtomicReference exception = new AtomicReference<>(); - final CountDownLatch blockFetchLatch = new CountDownLatch(1); + CountDownLatch blockFetchLatch = new CountDownLatch(1); BlockFetchingListener listener = new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { @@ -235,12 +226,11 @@ public void onBlockFetchFailure(String blockId, Throwable t) { // Create a second client, authenticated with a different app ID, and try to read from // the stream created for the previous app. clientFactory2 = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); client2 = clientFactory2.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final CountDownLatch chunkReceivedLatch = new CountDownLatch(1); + CountDownLatch chunkReceivedLatch = new CountDownLatch(1); ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { @@ -284,7 +274,7 @@ public StreamManager getStreamManager() { } } - private void checkSecurityException(Throwable t) { + private static void checkSecurityException(Throwable t) { assertNotNull("No exception was caught.", t); assertTrue("Expected SecurityException.", t.getMessage().contains(SecurityException.class.getName())); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index c036bc2e8d256..e47a72c9d16cc 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -93,7 +93,7 @@ public void testOpenShuffleBlocks() { ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onFailure(any()); StreamHandle handle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 7757500b41a6f..47c087088a8a2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -60,12 +60,10 @@ public void noCleanupAndCleanup() throws IOException { public void cleanupUsesExecutor() throws IOException { TestShuffleDataContext dataContext = createSomeData(); - final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + AtomicBoolean cleanupCalled = new AtomicBoolean(false); // Executor which does nothing to ensure we're actually using it. - Executor noThreadExecutor = new Executor() { - @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } - }; + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 88de6fb83c637..b8ae04eefb972 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; @@ -29,7 +30,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.junit.After; import org.junit.AfterClass; @@ -173,7 +173,7 @@ public void testFetchOneSort() throws Exception { FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks[0])); exec0Fetch.releaseBuffers(); } @@ -185,7 +185,7 @@ public void testFetchThreeSort() throws Exception { assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), exec0Fetch.successBlocks); assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks)); exec0Fetch.releaseBuffers(); } @@ -241,7 +241,7 @@ public void testFetchNoServer() throws Exception { assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + private static void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); @@ -249,7 +249,7 @@ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInf executorId, executorInfo); } - private void assertBufferListsEqual(List list0, List list1) + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); for (int i = 0; i < list0.size(); i ++) { @@ -257,7 +257,8 @@ private void assertBufferListsEqual(List list0, List list } } - private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + private static void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) + throws Exception { ByteBuffer nio0 = buffer0.nioByteBuffer(); ByteBuffer nio1 = buffer1.nioByteBuffer(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 2590b9ce4c1f1..3e51fea3cf0e5 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -25,8 +25,6 @@ import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -85,8 +83,8 @@ public void testFailure() { // Each failure will cause a failure to be invoked in all remaining block fetches. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any()); } @Test @@ -100,15 +98,15 @@ public void testFailureAndSuccess() { // We may call both success and failure for the same block. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); - verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any()); } @Test public void testEmptyBlockFetch() { try { - fetchBlocks(Maps.newLinkedHashMap()); + fetchBlocks(Maps.newLinkedHashMap()); fail(); } catch (IllegalArgumentException e) { assertEquals("Zero-sized blockIds array", e.getMessage()); @@ -123,52 +121,46 @@ public void testEmptyBlockFetch() { * * If a block's buffer is "null", an exception will be thrown instead. */ - private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + private static BlockFetchingListener fetchBlocks(LinkedHashMap blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); - // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( - (ByteBuffer) invocationOnMock.getArguments()[0]); - RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); - assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); - return null; - } + // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 + doAnswer(invocationOnMock -> { + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); + return null; }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. - final AtomicInteger expectedChunkIndex = new AtomicInteger(0); - final Iterator blockIterator = blocks.values().iterator(); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - try { - long streamId = (Long) invocation.getArguments()[0]; - int myChunkIndex = (Integer) invocation.getArguments()[1]; - assertEquals(123, streamId); - assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); - - ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; - ManagedBuffer result = blockIterator.next(); - if (result != null) { - callback.onSuccess(myChunkIndex, result); - } else { - callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); - } - } catch (Exception e) { - e.printStackTrace(); - fail("Unexpected failure"); + AtomicInteger expectedChunkIndex = new AtomicInteger(0); + Iterator blockIterator = blocks.values().iterator(); + doAnswer(invocation -> { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); } - return null; + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); } - }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + return null; + }).when(client).fetchChunk(anyLong(), anyInt(), any()); fetcher.start(); return listener; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 6db71eea6e8b5..a530e16734db4 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -28,7 +28,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -84,7 +83,7 @@ public void testUnrecoverableFailure() throws IOException, InterruptedException performInteractions(interactions, listener); - verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchFailure(eq("b0"), any()); verify(listener).onBlockFetchSuccess("b1", block1); verifyNoMoreInteractions(listener); } @@ -190,7 +189,7 @@ public void testThreeIOExceptions() throws IOException, InterruptedException { performInteractions(interactions, listener); verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); verifyNoMoreInteractions(listener); } @@ -220,7 +219,7 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException performInteractions(interactions, listener); verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); verifyNoMoreInteractions(listener); } @@ -249,40 +248,37 @@ private static void performInteractions(List> inte Stubber stub = null; // Contains all blockIds that are referenced across all interactions. - final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + LinkedHashSet blockIds = Sets.newLinkedHashSet(); - for (final Map interaction : interactions) { + for (Map interaction : interactions) { blockIds.addAll(interaction.keySet()); - Answer answer = new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - try { - // Verify that the RetryingBlockFetcher requested the expected blocks. - String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; - String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); - assertArrayEquals(desiredBlockIds, requestedBlockIds); - - // Now actually invoke the success/failure callbacks on each block. - BlockFetchingListener retryListener = - (BlockFetchingListener) invocationOnMock.getArguments()[1]; - for (Map.Entry block : interaction.entrySet()) { - String blockId = block.getKey(); - Object blockValue = block.getValue(); - - if (blockValue instanceof ManagedBuffer) { - retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); - } else if (blockValue instanceof Exception) { - retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); - } else { - fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); - } + Answer answer = invocationOnMock -> { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); } - return null; - } catch (Throwable e) { - e.printStackTrace(); - throw e; } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; } }; @@ -295,7 +291,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } assertNotNull(stub); - stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + stub.when(fetchStarter).createAndStart(any(), anyObject()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 189d607fa6c5f..29aca04a3d11b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -37,7 +37,6 @@ import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; /** diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java index 7fe452a48d89b..a6589d2898144 100644 --- a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -20,14 +20,11 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.rdd.JdbcRDD; import org.junit.After; import org.junit.Assert; @@ -89,30 +86,13 @@ public void tearDown() throws SQLException { public void testJavaJdbcRDD() throws Exception { JavaRDD rdd = JdbcRDD.create( sc, - new JdbcRDD.ConnectionFactory() { - @Override - public Connection getConnection() throws SQLException { - return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb"); - } - }, + () -> DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb"), "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", 1, 100, 1, - new Function() { - @Override - public Integer call(ResultSet r) throws Exception { - return r.getInt(1); - } - } + r -> r.getInt(1) ).cache(); Assert.assertEquals(100, rdd.count()); - Assert.assertEquals( - Integer.valueOf(10100), - rdd.reduce(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - })); + Assert.assertEquals(Integer.valueOf(10100), rdd.reduce((i1, i2) -> i1 + i2)); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 088b68132d905..24a55df84a240 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -34,8 +34,6 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; import org.apache.spark.ShuffleDependency; @@ -119,9 +117,7 @@ public void setUp() throws IOException { any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( (File) args[1], @@ -132,33 +128,24 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; - File tmp = (File) invocationOnMock.getArguments()[3]; - mergedOutputFile.delete(); - tmp.renameTo(mergedOutputFile); - return null; - } + doAnswer(invocationOnMock -> { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + File tmp = (File) invocationOnMock.getArguments()[3]; + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); + return null; }).when(shuffleBlockResolver) .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer( - InvocationOnMock invocationOnMock) throws Throwable { - TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); + when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); @@ -243,7 +230,7 @@ class BadRecords extends scala.collection.AbstractIterator writer = createWriter(true); - writer.write(Iterators.>emptyIterator()); + writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); @@ -259,7 +246,7 @@ public void writeWithoutSpilling() throws Exception { // In this example, each partition should have exactly one record: final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITITONS; i++) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } final UnsafeShuffleWriter writer = createWriter(true); writer.write(dataToWrite.iterator()); @@ -315,7 +302,7 @@ private void testMergingSpills( final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } writer.insertRecordIntoSorter(dataToWrite.get(0)); writer.insertRecordIntoSorter(dataToWrite.get(1)); @@ -424,7 +411,7 @@ public void writeEnoughDataToTriggerSpill() throws Exception { final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10]; for (int i = 0; i < 10 + 1; i++) { - dataToWrite.add(new Tuple2(i, bigByteArray)); + dataToWrite.add(new Tuple2<>(i, bigByteArray)); } writer.write(dataToWrite.iterator()); assertEquals(2, spillFilesCreated.size()); @@ -458,7 +445,7 @@ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exc final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); writer.stop(true); @@ -478,7 +465,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); - dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( @@ -491,15 +478,15 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4]; new Random(42).nextBytes(atMaxRecordSize); - dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + dataToWrite.add(new Tuple2<>(2, ByteBuffer.wrap(atMaxRecordSize))); // Inserting a record that's larger than the max record size final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()]; new Random(42).nextBytes(exceedsMaxRecordSize); - dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + dataToWrite.add(new Tuple2<>(3, ByteBuffer.wrap(exceedsMaxRecordSize))); writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( @@ -511,10 +498,10 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { @Test public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { final UnsafeShuffleWriter writer = createWriter(false); - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2<>(1, 1)); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); writer.stop(false); assertSpillFilesWereCleanedUp(); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 26568146bf4d7..03cec8ed81b72 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -22,7 +22,6 @@ import java.nio.ByteBuffer; import java.util.*; -import scala.Tuple2; import scala.Tuple2$; import org.junit.After; @@ -31,8 +30,6 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -88,25 +85,18 @@ public void setup() { spillFilesCreated.clear(); MockitoAnnotations.initMocks(this); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) - throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } + when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( @@ -118,8 +108,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); } @After diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index fbbe530a132e1..771d39016c188 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -23,7 +23,6 @@ import java.util.LinkedList; import java.util.UUID; -import scala.Tuple2; import scala.Tuple2$; import org.junit.After; @@ -31,8 +30,6 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -96,25 +93,18 @@ public void setUp() { taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) - throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } + when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( @@ -126,8 +116,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); } @After diff --git a/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java index e22ad89c1d6ea..1d2b05ebc2503 100644 --- a/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java @@ -64,12 +64,7 @@ public void tearDown() { public void foreachWithAnonymousClass() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - foreachCalls++; - } - }); + rdd.foreach(s -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 80aab100aced4..512149127d72f 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -31,7 +31,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.*; import org.apache.spark.Accumulator; @@ -208,7 +207,7 @@ public void sortByKey() { assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator - sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); + sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); @@ -266,13 +265,7 @@ public void sortBy() { JavaRDD> rdd = sc.parallelize(pairs); // compare on first value - JavaRDD> sortedRDD = - rdd.sortBy(new Function, Integer>() { - @Override - public Integer call(Tuple2 t) { - return t._1(); - } - }, true, 2); + JavaRDD> sortedRDD = rdd.sortBy(Tuple2::_1, true, 2); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); @@ -280,12 +273,7 @@ public Integer call(Tuple2 t) { assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value - sortedRDD = rdd.sortBy(new Function, Integer>() { - @Override - public Integer call(Tuple2 t) { - return t._2(); - } - }, true, 2); + sortedRDD = rdd.sortBy(Tuple2::_2, true, 2); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); @@ -294,28 +282,20 @@ public Integer call(Tuple2 t) { @Test public void foreach() { - final LongAccumulator accum = sc.sc().longAccumulator(); + LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - accum.add(1); - } - }); + rdd.foreach(s -> accum.add(1)); assertEquals(2, accum.value().intValue()); } @Test public void foreachPartition() { - final LongAccumulator accum = sc.sc().longAccumulator(); + LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreachPartition(new VoidFunction>() { - @Override - public void call(Iterator iter) { - while (iter.hasNext()) { - iter.next(); - accum.add(1); - } + rdd.foreachPartition(iter -> { + while (iter.hasNext()) { + iter.next(); + accum.add(1); } }); assertEquals(2, accum.value().intValue()); @@ -361,12 +341,7 @@ public void lookup() { @Test public void groupBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function isOdd = new Function() { - @Override - public Boolean call(Integer x) { - return x % 2 == 0; - } - }; + Function isOdd = x -> x % 2 == 0; JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); assertEquals(2, oddsAndEvens.count()); assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens @@ -383,12 +358,7 @@ public void groupByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function, Boolean> areOdd = - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 x) { - return (x._1() % 2 == 0) && (x._2() % 2 == 0); - } - }; + x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd); assertEquals(2, oddsAndEvens.count()); @@ -406,13 +376,7 @@ public Boolean call(Tuple2 x) { public void keyByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function, String> sumToString = - new Function, String>() { - @Override - public String call(Tuple2 x) { - return String.valueOf(x._1() + x._2()); - } - }; + Function, String> sumToString = x -> String.valueOf(x._1() + x._2()); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD> keyed = pairRDD.keyBy(sumToString); assertEquals(7, keyed.count()); @@ -516,25 +480,14 @@ public void leftOuterJoin() { rdd1.leftOuterJoin(rdd2).collect(); assertEquals(5, joined.size()); Tuple2>> firstUnmatched = - rdd1.leftOuterJoin(rdd2).filter( - new Function>>, Boolean>() { - @Override - public Boolean call(Tuple2>> tup) { - return !tup._2()._2().isPresent(); - } - }).first(); + rdd1.leftOuterJoin(rdd2).filter(tup -> !tup._2()._2().isPresent()).first(); assertEquals(3, firstUnmatched._1().intValue()); } @Test public void foldReduce() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; int sum = rdd.fold(0, add); assertEquals(33, sum); @@ -546,12 +499,7 @@ public Integer call(Integer a, Integer b) { @Test public void treeReduce() { JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeReduce(add, depth); assertEquals(-5, sum); @@ -561,12 +509,7 @@ public Integer call(Integer a, Integer b) { @Test public void treeAggregate() { JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeAggregate(0, add, add, depth); assertEquals(-5, sum); @@ -584,21 +527,15 @@ public void aggregateByKey() { new Tuple2<>(5, 1), new Tuple2<>(5, 3)), 2); - Map> sets = pairs.aggregateByKey(new HashSet(), - new Function2, Integer, Set>() { - @Override - public Set call(Set a, Integer b) { - a.add(b); - return a; - } - }, - new Function2, Set, Set>() { - @Override - public Set call(Set a, Set b) { - a.addAll(b); - return a; - } - }).collectAsMap(); + Map> sets = pairs.aggregateByKey(new HashSet(), + (a, b) -> { + a.add(b); + return a; + }, + (a, b) -> { + a.addAll(b); + return a; + }).collectAsMap(); assertEquals(3, sets.size()); assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); @@ -616,13 +553,7 @@ public void foldByKey() { new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD sums = rdd.foldByKey(0, - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); assertEquals(1, sums.lookup(1).get(0).intValue()); assertEquals(2, sums.lookup(2).get(0).intValue()); assertEquals(3, sums.lookup(3).get(0).intValue()); @@ -639,13 +570,7 @@ public void reduceByKey() { new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD counts = rdd.reduceByKey( - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); assertEquals(1, counts.lookup(1).get(0).intValue()); assertEquals(2, counts.lookup(2).get(0).intValue()); assertEquals(3, counts.lookup(3).get(0).intValue()); @@ -655,12 +580,7 @@ public Integer call(Integer a, Integer b) { assertEquals(2, localCounts.get(2).intValue()); assertEquals(3, localCounts.get(3).intValue()); - localCounts = rdd.reduceByKeyLocally(new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + localCounts = rdd.reduceByKeyLocally((a, b) -> a + b); assertEquals(1, localCounts.get(1).intValue()); assertEquals(2, localCounts.get(2).intValue()); assertEquals(3, localCounts.get(3).intValue()); @@ -692,20 +612,8 @@ public void isEmpty() { assertTrue(sc.emptyRDD().isEmpty()); assertTrue(sc.parallelize(new ArrayList()).isEmpty()); assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty()); - assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter( - new Function() { - @Override - public Boolean call(Integer i) { - return i < 0; - } - }).isEmpty()); - assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter( - new Function() { - @Override - public Boolean call(Integer i) { - return i > 1; - } - }).isEmpty()); + assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter(i -> i < 0).isEmpty()); + assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter(i -> i > 1).isEmpty()); } @Test @@ -721,12 +629,7 @@ public void javaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaDoubleRDD distinct = rdd.distinct(); assertEquals(5, distinct.count()); - JavaDoubleRDD filter = rdd.filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 2.0; - } - }); + JavaDoubleRDD filter = rdd.filter(x -> x > 2.0); assertEquals(3, filter.count()); JavaDoubleRDD union = rdd.union(rdd); assertEquals(12, union.count()); @@ -764,7 +667,7 @@ public void javaDoubleRDDHistoGram() { // SPARK-5744 assertArrayEquals( new long[] {0}, - sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0})); + sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); } private static class DoubleComparator implements Comparator, Serializable { @@ -833,12 +736,7 @@ public void reduce() { @Test public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); - double sum = rdd.reduce(new Function2() { - @Override - public Double call(Double v1, Double v2) { - return v1 + v2; - } - }); + double sum = rdd.reduce((v1, v2) -> v1 + v2); assertEquals(10.0, sum, 0.001); } @@ -859,27 +757,11 @@ public void aggregate() { @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { - @Override - public double call(Integer x) { - return x.doubleValue(); - } - }).cache(); + JavaDoubleRDD doubles = rdd.mapToDouble(Integer::doubleValue).cache(); doubles.collect(); - JavaPairRDD pairs = rdd.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2<>(x, x); - } - }).cache(); + JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)).cache(); pairs.collect(); - JavaRDD strings = rdd.map(new Function() { - @Override - public String call(Integer x) { - return x.toString(); - } - }).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -887,39 +769,27 @@ public String call(Integer x) { public void flatMap() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(x.split(" ")).iterator(); - } - }); + JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" ")).iterator()); assertEquals("Hello", words.first()); assertEquals(11, words.count()); - JavaPairRDD pairsRDD = rdd.flatMapToPair( - new PairFlatMapFunction() { - @Override - public Iterator> call(String s) { - List> pairs = new LinkedList<>(); - for (String word : s.split(" ")) { - pairs.add(new Tuple2<>(word, word)); - } - return pairs.iterator(); + JavaPairRDD pairsRDD = rdd.flatMapToPair(s -> { + List> pairs = new LinkedList<>(); + for (String word : s.split(" ")) { + pairs.add(new Tuple2<>(word, word)); } + return pairs.iterator(); } ); assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); assertEquals(11, pairsRDD.count()); - JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { - @Override - public Iterator call(String s) { - List lengths = new LinkedList<>(); - for (String word : s.split(" ")) { - lengths.add((double) word.length()); - } - return lengths.iterator(); + JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { + List lengths = new LinkedList<>(); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); } + return lengths.iterator(); }); assertEquals(5.0, doubles.first(), 0.01); assertEquals(11, pairsRDD.count()); @@ -937,37 +807,23 @@ public void mapsFromPairsToPairs() { // Regression test for SPARK-668: JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterator> call(Tuple2 item) { - return Collections.singletonList(item.swap()).iterator(); - } - }); + item -> Collections.singletonList(item.swap()).iterator()); swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + pairRDD.mapToPair(Tuple2::swap).collect(); } @Test public void mapPartitions() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitions( - new FlatMapFunction, Integer>() { - @Override - public Iterator call(Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum).iterator(); + JavaRDD partitionSums = rdd.mapPartitions(iter -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); } - }); + return Collections.singletonList(sum).iterator(); + }); assertEquals("[3, 7]", partitionSums.collect().toString()); } @@ -975,17 +831,13 @@ public Iterator call(Iterator iter) { @Test public void mapPartitionsWithIndex() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitionsWithIndex( - new Function2, Iterator>() { - @Override - public Iterator call(Integer index, Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum).iterator(); + JavaRDD partitionSums = rdd.mapPartitionsWithIndex((index, iter) -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); } - }, false); + return Collections.singletonList(sum).iterator(); + }, false); assertEquals("[3, 7]", partitionSums.collect().toString()); } @@ -1124,21 +976,12 @@ public void sequenceFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, - Text.class).mapToPair(new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(pair._1().get(), pair._2().toString()); - } - }); + Text.class).mapToPair(pair -> new Tuple2<>(pair._1().get(), pair._2().toString())); assertEquals(pairs, readRDD.collect()); } @@ -1179,12 +1022,7 @@ public void binaryFilesCaching() throws Exception { channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); - readRDD.foreach(new VoidFunction>() { - @Override - public void call(Tuple2 pair) { - pair._2().toArray(); // force the file to read - } - }); + readRDD.foreach(pair -> pair._2().toArray()); // force the file to read List> result = readRDD.collect(); for (Tuple2 res : result) { @@ -1229,23 +1067,13 @@ public void writeWithNewAPIHadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsNewAPIHadoopFile( - outputDir, IntWritable.class, Text.class, + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @SuppressWarnings("unchecked") @@ -1259,22 +1087,13 @@ public void readWithNewAPIHadoopFile() throws IOException { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, Text.class, Job.getInstance().getConfiguration()); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @Test @@ -1315,21 +1134,12 @@ public void hadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @SuppressWarnings("unchecked") @@ -1343,34 +1153,19 @@ public void hadoopFileCompressed() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, - DefaultCodec.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @Test public void zip() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { - @Override - public double call(Integer x) { - return x.doubleValue(); - } - }); + JavaDoubleRDD doubles = rdd.mapToDouble(Integer::doubleValue); JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } @@ -1380,12 +1175,7 @@ public void zipPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); FlatMapFunction2, Iterator, Integer> sizesFn = - new FlatMapFunction2, Iterator, Integer>() { - @Override - public Iterator call(Iterator i, Iterator s) { - return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); - } - }; + (i, s) -> Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); @@ -1396,22 +1186,12 @@ public Iterator call(Iterator i, Iterator s) { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - intAccum.add(x); - } - }); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - doubleAccum.add((double) x); - } - }); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(x -> doubleAccum.add((double) x)); assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type @@ -1432,13 +1212,8 @@ public Float zero(Float initialValue) { } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - floatAccum.add((float) x); - } - }); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + rdd.foreach(x -> floatAccum.add((float) x)); assertEquals((Float) 25.0f, floatAccum.value()); // Test the setValue method @@ -1449,12 +1224,7 @@ public void call(Integer x) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(new Function() { - @Override - public String call(Integer t) { - return t.toString(); - } - }).collect(); + List> s = rdd.keyBy(Object::toString).collect(); assertEquals(new Tuple2<>("1", 1), s.get(0)); assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -1487,26 +1257,10 @@ public void checkpointAndRestore() { @Test public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); - Function keyFunction = new Function() { - @Override - public Integer call(Integer v1) { - return v1 % 3; - } - }; - Function createCombinerFunction = new Function() { - @Override - public Integer call(Integer v1) { - return v1; - } - }; + Function keyFunction = v1 -> v1 % 3; + Function createCombinerFunction = v1 -> v1; - Function2 mergeValueFunction = - new Function2() { - @Override - public Integer call(Integer v1, Integer v2) { - return v1 + v2; - } - }; + Function2 mergeValueFunction = (v1, v2) -> v1 + v2; JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); @@ -1534,20 +1288,8 @@ public Integer call(Integer v1, Integer v2) { @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i, i % 2); - } - }); - JavaPairRDD rdd3 = rdd2.mapToPair( - new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2<>(in._2(), in._1()); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); + JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); assertEquals(Arrays.asList( new Tuple2<>(1, 1), new Tuple2<>(0, 2), @@ -1561,13 +1303,7 @@ public Tuple2 call(Tuple2 in) { public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i, i % 2); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); List[] parts = rdd1.collectPartitions(new int[] {0}); assertEquals(Arrays.asList(1, 2), parts[0]); @@ -1623,13 +1359,7 @@ public void countApproxDistinctByKey() { public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 JavaRDD rdd = sc.parallelize(Arrays.asList(1)); - JavaPairRDD pairRDD = rdd.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2<>(x, new int[]{x}); - } - }); + JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } @@ -1651,13 +1381,7 @@ public void collectAsMapAndSerialize() throws Exception { @SuppressWarnings("unchecked") public void sampleByKey() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i % 2, 1); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); Map fractions = new HashMap<>(); fractions.put(0, 0.5); fractions.put(1, 1.0); @@ -1677,13 +1401,7 @@ public Tuple2 call(Integer i) { @SuppressWarnings("unchecked") public void sampleByKeyExact() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i % 2, 1); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); Map fractions = new HashMap<>(); fractions.put(0, 0.5); fractions.put(1, 1.0); @@ -1754,14 +1472,7 @@ public void takeAsync() throws Exception { public void foreachAsync() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.foreachAsync( - new VoidFunction() { - @Override - public void call(Integer integer) { - // intentionally left blank. - } - } - ); + JavaFutureAction future = rdd.foreachAsync(integer -> {}); future.get(); assertFalse(future.isCancelled()); assertTrue(future.isDone()); @@ -1784,11 +1495,8 @@ public void countAsync() throws Exception { public void testAsyncActionCancellation() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { - @Override - public void call(Integer integer) throws InterruptedException { - Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. - } + JavaFutureAction future = rdd.foreachAsync(integer -> { + Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. }); future.cancel(true); assertTrue(future.isCancelled()); @@ -1805,7 +1513,7 @@ public void call(Integer integer) throws InterruptedException { public void testAsyncActionErrorWrapping() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); + JavaFutureAction future = rdd.map(new BuggyMapFunction<>()).countAsync(); try { future.get(2, TimeUnit.SECONDS); fail("Expected future.get() for failed job to throw ExcecutionException"); diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java index ba57b6beb247d..938cc8ddfb5d9 100644 --- a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java @@ -59,39 +59,39 @@ public Object apply(Long x) { ); final ConsumerStrategy sub1 = - ConsumerStrategies.Subscribe(sTopics, sKafkaParams, sOffsets); + ConsumerStrategies.Subscribe(sTopics, sKafkaParams, sOffsets); final ConsumerStrategy sub2 = - ConsumerStrategies.Subscribe(sTopics, sKafkaParams); + ConsumerStrategies.Subscribe(sTopics, sKafkaParams); final ConsumerStrategy sub3 = - ConsumerStrategies.Subscribe(topics, kafkaParams, offsets); + ConsumerStrategies.Subscribe(topics, kafkaParams, offsets); final ConsumerStrategy sub4 = - ConsumerStrategies.Subscribe(topics, kafkaParams); + ConsumerStrategies.Subscribe(topics, kafkaParams); Assert.assertEquals( sub1.executorKafkaParams().get("bootstrap.servers"), sub3.executorKafkaParams().get("bootstrap.servers")); final ConsumerStrategy psub1 = - ConsumerStrategies.SubscribePattern(pat, sKafkaParams, sOffsets); + ConsumerStrategies.SubscribePattern(pat, sKafkaParams, sOffsets); final ConsumerStrategy psub2 = - ConsumerStrategies.SubscribePattern(pat, sKafkaParams); + ConsumerStrategies.SubscribePattern(pat, sKafkaParams); final ConsumerStrategy psub3 = - ConsumerStrategies.SubscribePattern(pat, kafkaParams, offsets); + ConsumerStrategies.SubscribePattern(pat, kafkaParams, offsets); final ConsumerStrategy psub4 = - ConsumerStrategies.SubscribePattern(pat, kafkaParams); + ConsumerStrategies.SubscribePattern(pat, kafkaParams); Assert.assertEquals( psub1.executorKafkaParams().get("bootstrap.servers"), psub3.executorKafkaParams().get("bootstrap.servers")); final ConsumerStrategy asn1 = - ConsumerStrategies.Assign(sParts, sKafkaParams, sOffsets); + ConsumerStrategies.Assign(sParts, sKafkaParams, sOffsets); final ConsumerStrategy asn2 = - ConsumerStrategies.Assign(sParts, sKafkaParams); + ConsumerStrategies.Assign(sParts, sKafkaParams); final ConsumerStrategy asn3 = - ConsumerStrategies.Assign(parts, kafkaParams, offsets); + ConsumerStrategies.Assign(parts, kafkaParams, offsets); final ConsumerStrategy asn4 = - ConsumerStrategies.Assign(parts, kafkaParams); + ConsumerStrategies.Assign(parts, kafkaParams); Assert.assertEquals( asn1.executorKafkaParams().get("bootstrap.servers"), diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index d569b6688deca..2e050f8413074 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -217,7 +217,7 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = - newCommandBuilder(Collections.emptyList()); + newCommandBuilder(Collections.emptyList()); launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); launcher.master = "yarn"; diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java index 3bc35da7cc27c..9ff7aceb581f4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -44,7 +44,7 @@ public void testAllOptions() { count++; verify(parser).handle(eq(optNames[0]), eq(value)); verify(parser, times(count)).handle(anyString(), anyString()); - verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); } } @@ -54,9 +54,9 @@ public void testAllOptions() { parser.parse(Arrays.asList(name)); count++; switchCount++; - verify(parser, times(switchCount)).handle(eq(switchNames[0]), same((String) null)); + verify(parser, times(switchCount)).handle(eq(switchNames[0]), same(null)); verify(parser, times(count)).handle(anyString(), any(String.class)); - verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); } } } @@ -80,7 +80,7 @@ public void testEqualSeparatedOption() { List args = Arrays.asList(parser.MASTER + "=" + parser.MASTER); parser.parse(args); verify(parser).handle(eq(parser.MASTER), eq(parser.MASTER)); - verify(parser).handleExtraArgs(eq(Collections.emptyList())); + verify(parser).handleExtraArgs(eq(Collections.emptyList())); } private static class DummyParser extends SparkSubmitOptionParser { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 8c0338e2844f0..683ceffeaed0e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -21,16 +21,14 @@ import java.util.Arrays; import java.util.List; -import scala.Tuple2; - import org.junit.Assert; import org.junit.Test; import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; @@ -69,35 +67,22 @@ public void testPCA() { JavaRDD dataRDD = jsc.parallelize(points, 2); RowMatrix mat = new RowMatrix(dataRDD.map( - new Function() { - public org.apache.spark.mllib.linalg.Vector call(Vector vector) { - return new org.apache.spark.mllib.linalg.DenseVector(vector.toArray()); - } - } + (Vector vector) -> (org.apache.spark.mllib.linalg.Vector) new DenseVector(vector.toArray()) ).rdd()); Matrix pc = mat.computePrincipalComponents(3); mat.multiply(pc).rows().toJavaRDD(); - JavaRDD expected = mat.multiply(pc).rows().toJavaRDD().map( - new Function() { - public Vector call(org.apache.spark.mllib.linalg.Vector vector) { - return vector.asML(); - } - } - ); + JavaRDD expected = mat.multiply(pc).rows().toJavaRDD() + .map(org.apache.spark.mllib.linalg.Vector::asML); - JavaRDD featuresExpected = dataRDD.zip(expected).map( - new Function, VectorPair>() { - public VectorPair call(Tuple2 pair) { - VectorPair featuresExpected = new VectorPair(); - featuresExpected.setFeatures(pair._1()); - featuresExpected.setExpected(pair._2()); - return featuresExpected; - } - } - ); + JavaRDD featuresExpected = dataRDD.zip(expected).map(pair -> { + VectorPair featuresExpected1 = new VectorPair(); + featuresExpected1.setFeatures(pair._1()); + featuresExpected1.setExpected(pair._2()); + return featuresExpected1; + }); Dataset df = spark.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 6ded42e928250..65db3d014fdcd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -25,7 +25,6 @@ import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; @@ -42,7 +41,7 @@ public class JavaNaiveBayesSuite extends SharedSparkSession { new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0)) ); - private int validatePrediction(List points, NaiveBayesModel model) { + private static int validatePrediction(List points, NaiveBayesModel model) { int correct = 0; for (LabeledPoint p : points) { if (model.predict(p.features()) == p.label()) { @@ -80,12 +79,7 @@ public void runUsingStaticMethods() { public void testPredictJavaRDD() { JavaRDD examples = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model = NaiveBayes.train(examples.rdd()); - JavaRDD vectors = examples.map(new Function() { - @Override - public Vector call(LabeledPoint v) throws Exception { - return v.features(); - } - }); + JavaRDD vectors = examples.map(LabeledPoint::features); JavaRDD predictions = model.predict(vectors); // Should be able to get the first prediction. predictions.first(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java index 3d62b273d2210..b4196c6ecdf72 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering; -import com.google.common.collect.Lists; +import java.util.Arrays; import org.junit.Assert; import org.junit.Test; @@ -31,7 +31,7 @@ public class JavaBisectingKMeansSuite extends SharedSparkSession { @Test public void twoDimensionalData() { - JavaRDD points = jsc.parallelize(Lists.newArrayList( + JavaRDD points = jsc.parallelize(Arrays.asList( Vectors.dense(4, -1), Vectors.dense(4, 1), Vectors.sparse(2, new int[]{0}, new double[]{1.0}) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 08d6713ab2bc3..38ee2507f2e1c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import scala.Tuple2; import scala.Tuple3; @@ -30,7 +31,6 @@ import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -39,7 +39,7 @@ public class JavaLDASuite extends SharedSparkSession { @Override public void setUp() throws IOException { super.setUp(); - ArrayList> tinyCorpus = new ArrayList<>(); + List> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(), LDASuite.tinyCorpus()[i]._2())); @@ -53,7 +53,7 @@ public void localLDAModel() { Matrix topics = LDASuite.tinyTopics(); double[] topicConcentration = new double[topics.numRows()]; Arrays.fill(topicConcentration, 1.0D / topics.numRows()); - LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); + LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1.0, 100.0); // Check: basic parameters assertEquals(model.k(), tinyK); @@ -87,17 +87,17 @@ public void distributedLDAModel() { // Check: basic parameters LocalLDAModel localModel = model.toLocal(); - assertEquals(model.k(), k); - assertEquals(localModel.k(), k); - assertEquals(model.vocabSize(), tinyVocabSize); - assertEquals(localModel.vocabSize(), tinyVocabSize); - assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + assertEquals(k, model.k()); + assertEquals(k, localModel.k()); + assertEquals(tinyVocabSize, model.vocabSize()); + assertEquals(tinyVocabSize, localModel.vocabSize()); + assertEquals(localModel.topicsMatrix(), model.topicsMatrix()); // Check: topic summaries Tuple2[] roundedTopicSummary = model.describeTopics(); - assertEquals(roundedTopicSummary.length, k); + assertEquals(k, roundedTopicSummary.length); Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); - assertEquals(roundedLocalTopicSummary.length, k); + assertEquals(k, roundedLocalTopicSummary.length); // Check: log probabilities assertTrue(model.logLikelihood() < 0.0); @@ -107,12 +107,8 @@ public void distributedLDAModel() { JavaPairRDD topicDistributions = model.javaTopicDistributions(); // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs // over topics. Compare it against nonEmptyCorpus instead of corpus - JavaPairRDD nonEmptyCorpus = corpus.filter( - new Function, Boolean>() { - public Boolean call(Tuple2 tuple2) { - return Vectors.norm(tuple2._2(), 1.0) != 0.0; - } - }); + JavaPairRDD nonEmptyCorpus = + corpus.filter(tuple2 -> Vectors.norm(tuple2._2(), 1.0) != 0.0); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments @@ -155,14 +151,14 @@ public void onlineOptimizerCompatibility() { LDAModel model = lda.run(corpus); // Check: basic parameters - assertEquals(model.k(), k); - assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(k, model.k()); + assertEquals(tinyVocabSize, model.vocabSize()); // Check: topic summaries Tuple2[] roundedTopicSummary = model.describeTopics(); - assertEquals(roundedTopicSummary.length, k); + assertEquals(k, roundedTopicSummary.length); Tuple2[] roundedLocalTopicSummary = model.describeTopics(); - assertEquals(roundedLocalTopicSummary.length, k); + assertEquals(k, roundedLocalTopicSummary.length); } @Test @@ -177,7 +173,7 @@ public void localLdaMethods() { double logPerplexity = toyModel.logPerplexity(pairedDocs); // check: logLikelihood. - ArrayList> docsSingleWord = new ArrayList<>(); + List> docsSingleWord = new ArrayList<>(); docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); JavaPairRDD single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); @@ -190,6 +186,6 @@ public void localLdaMethods() { LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; private LocalLDAModel toyModel = LDASuite.toyModel(); - private ArrayList> toyData = LDASuite.javaToyData(); + private List> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 3451e0773759b..15de566c886de 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -31,9 +31,9 @@ public void runAssociationRules() { @SuppressWarnings("unchecked") JavaRDD> freqItemsets = jsc.parallelize(Arrays.asList( - new FreqItemset(new String[]{"a"}, 15L), - new FreqItemset(new String[]{"b"}, 35L), - new FreqItemset(new String[]{"a", "b"}, 12L) + new FreqItemset<>(new String[]{"a"}, 15L), + new FreqItemset<>(new String[]{"b"}, 35L), + new FreqItemset<>(new String[]{"a", "b"}, 12L) )); JavaRDD> results = (new AssociationRules()).run(freqItemsets); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index a46b1321b3ca2..86c723aa00746 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -24,13 +24,13 @@ import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; public class JavaLinearRegressionSuite extends SharedSparkSession { - int validatePrediction(List validationData, LinearRegressionModel model) { + private static int validatePrediction( + List validationData, LinearRegressionModel model) { int numAccurate = 0; for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); @@ -87,12 +87,7 @@ public void testPredictJavaRDD() { LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); - JavaRDD vectors = testRDD.map(new Function() { - @Override - public Vector call(LabeledPoint v) throws Exception { - return v.features(); - } - }); + JavaRDD vectors = testRDD.map(LabeledPoint::features); JavaRDD predictions = model.predict(vectors); // Should be able to get the first prediction. predictions.first(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 1dcbbcaa0223c..0f71deb9ea528 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -25,8 +25,6 @@ import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; @@ -35,7 +33,7 @@ public class JavaDecisionTreeSuite extends SharedSparkSession { - int validatePrediction(List validationData, DecisionTreeModel model) { + private static int validatePrediction(List validationData, DecisionTreeModel model) { int numCorrect = 0; for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); @@ -63,7 +61,7 @@ public void runDTUsingConstructor() { DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); - Assert.assertTrue(numCorrect == rdd.count()); + Assert.assertEquals(numCorrect, rdd.count()); } @Test @@ -82,15 +80,10 @@ public void runDTUsingStaticMethods() { DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); // java compatibility test - JavaRDD predictions = model.predict(rdd.map(new Function() { - @Override - public Vector call(LabeledPoint v1) { - return v1.features(); - } - })); + JavaRDD predictions = model.predict(rdd.map(LabeledPoint::features)); int numCorrect = validatePrediction(arr, model); - Assert.assertTrue(numCorrect == rdd.count()); + Assert.assertEquals(numCorrect, rdd.count()); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 06cd9ea2d242c..bf8717483575f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -157,7 +157,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont // to the accumulator. So we can check if the row groups are filtered or not in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = (Option>) taskContext.taskMetrics() + Option> accu = taskContext.taskMetrics() .lookForAccumulatorByName("numRowGroups"); if (accu.isDefined()) { ((LongAccumulator)accu.get()).add((long)blocks.size()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java index 8b8a403e2b197..6ffccee52c0fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java @@ -35,27 +35,35 @@ public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase public void testTypedAggregationAverage() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2))); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), + agged.collectAsList()); } @Test public void testTypedAggregationCount() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(typed.count(v -> v)); - Assert.assertEquals(Arrays.asList(tuple2("a", 2L), tuple2("b", 1L)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)), + agged.collectAsList()); } @Test public void testTypedAggregationSumDouble() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(typed.sum(v -> (double)v._2())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)), + agged.collectAsList()); } @Test public void testTypedAggregationSumLong() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(typed.sumLong(v -> (long)v._2())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3L), tuple2("b", 3L)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)), + agged.collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 573d0e3594363..bf8ff61eae39e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -30,7 +30,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -95,12 +94,7 @@ public void applySchema() { personList.add(person2); JavaRDD rowRDD = jsc.parallelize(personList).map( - new Function() { - @Override - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); + person -> RowFactory.create(person.getName(), person.getAge())); List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); @@ -131,12 +125,7 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = jsc.parallelize(personList).map( - new Function() { - @Override - public Row call(Person person) { - return RowFactory.create(person.getName(), person.getAge()); - } - }); + person -> RowFactory.create(person.getName(), person.getAge())); List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); @@ -146,12 +135,7 @@ public Row call(Person person) { Dataset df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List actual = spark.sql("SELECT * FROM people").toJavaRDD() - .map(new Function() { - @Override - public String call(Row row) { - return row.getString(0) + "_" + row.get(1); - } - }).collect(); + .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List expected = new ArrayList<>(2); expected.add("Michael_29"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c44fc3d393862..c3b94a44c2e91 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -189,7 +189,7 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } - // Java.math.BigInteger is equavient to Spark Decimal(38,0) + // Java.math.BigInteger is equivalent to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } @@ -231,13 +231,10 @@ public void testCreateStructTypeFromList(){ Assert.assertEquals(0, schema2.fieldIndex("id")); } - private static final Comparator crosstabRowComparator = new Comparator() { - @Override - public int compare(Row row1, Row row2) { - String item1 = row1.getString(0); - String item2 = row2.getString(0); - return item1.compareTo(item2); - } + private static final Comparator crosstabRowComparator = (row1, row2) -> { + String item1 = row1.getString(0); + String item2 = row2.getString(0); + return item1.compareTo(item2); }; @Test @@ -249,7 +246,7 @@ public void testCrosstab() { Assert.assertEquals("1", columnNames[1]); Assert.assertEquals("2", columnNames[2]); List rows = crosstab.collectAsList(); - Collections.sort(rows, crosstabRowComparator); + rows.sort(crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); @@ -284,7 +281,7 @@ public void testCovariance() { @Test public void testSampleBy() { Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); @@ -296,7 +293,7 @@ public void testSampleBy() { public void pivot() { Dataset df = spark.table("courseSales"); List actual = df.groupBy("year") - .pivot("course", Arrays.asList("dotNET", "Java")) + .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); Assert.assertEquals(2012, actual.get(0).getInt(0)); @@ -352,24 +349,24 @@ public void testCountMinSketch() { Dataset df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); - Assert.assertEquals(sketch1.totalCount(), 1000); - Assert.assertEquals(sketch1.depth(), 10); - Assert.assertEquals(sketch1.width(), 20); + Assert.assertEquals(1000, sketch1.totalCount()); + Assert.assertEquals(10, sketch1.depth()); + Assert.assertEquals(20, sketch1.width()); CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); - Assert.assertEquals(sketch2.totalCount(), 1000); - Assert.assertEquals(sketch2.depth(), 10); - Assert.assertEquals(sketch2.width(), 20); + Assert.assertEquals(1000, sketch2.totalCount()); + Assert.assertEquals(10, sketch2.depth()); + Assert.assertEquals(20, sketch2.width()); CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); - Assert.assertEquals(sketch3.totalCount(), 1000); - Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); - Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + Assert.assertEquals(1000, sketch3.totalCount()); + Assert.assertEquals(0.001, sketch3.relativeError(), 1.0e-4); + Assert.assertEquals(0.99, sketch3.confidence(), 5.0e-3); CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); - Assert.assertEquals(sketch4.totalCount(), 1000); - Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); - Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + Assert.assertEquals(1000, sketch4.totalCount()); + Assert.assertEquals(0.001, sketch4.relativeError(), 1.0e-4); + Assert.assertEquals(0.99, sketch4.confidence(), 5.0e-3); } @Test @@ -389,13 +386,13 @@ public void testBloomFilter() { } BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); - Assert.assertTrue(filter3.bitSize() == 64 * 5); + Assert.assertEquals(64 * 5, filter3.bitSize()); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter3.mightContain(i)); } BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); - Assert.assertTrue(filter4.bitSize() == 64 * 5); + Assert.assertEquals(64 * 5, filter4.bitSize()); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter4.mightContain(i * 3)); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java index fe863715162f5..d3769a74b9789 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java @@ -24,7 +24,6 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; @@ -41,7 +40,9 @@ public void testTypedAggregationAnonClass() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(new IntSumOf().toColumn()); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)), + agged.collectAsList()); Dataset> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); @@ -87,48 +88,36 @@ public Encoder outputEncoder() { @Test public void testTypedAggregationAverage() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.avg( - new MapFunction, Double>() { - public Double call(Tuple2 value) throws Exception { - return (double)(value._2() * 2); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + Dataset> agged = grouped.agg(typed.avg(value -> (double)(value._2() * 2))); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), + agged.collectAsList()); } @Test public void testTypedAggregationCount() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.count( - new MapFunction, Object>() { - public Object call(Tuple2 value) throws Exception { - return value; - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + Dataset> agged = grouped.agg(typed.count(value -> value)); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)), + agged.collectAsList()); } @Test public void testTypedAggregationSumDouble() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.sum( - new MapFunction, Double>() { - public Double call(Tuple2 value) throws Exception { - return (double)value._2(); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + Dataset> agged = grouped.agg(typed.sum(value -> (double) value._2())); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)), + agged.collectAsList()); } @Test public void testTypedAggregationSumLong() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = grouped.agg(typed.sumLong( - new MapFunction, Long>() { - public Long call(Tuple2 value) throws Exception { - return (long)value._2(); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + Dataset> agged = grouped.agg(typed.sumLong(value -> (long) value._2())); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)), + agged.collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java index 8fc4eff55ddd0..e62db7d2cff61 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java @@ -52,23 +52,13 @@ public void tearDown() { spark = null; } - protected Tuple2 tuple2(T1 t1, T2 t2) { - return new Tuple2<>(t1, t2); - } - protected KeyValueGroupedDataset> generateGroupedDataset() { Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Arrays.asList(new Tuple2<>("a", 1), new Tuple2<>("a", 2), new Tuple2<>("b", 3)); Dataset> ds = spark.createDataset(data, encoder); - return ds.groupByKey( - new MapFunction, String>() { - @Override - public String call(Tuple2 value) throws Exception { - return value._1(); - } - }, + return ds.groupByKey((MapFunction, String>) value -> value._1(), Encoders.STRING()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a94a37cb21b3f..577672ca8e083 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -96,12 +96,7 @@ public void testToLocalIterator() { @Test public void testTypedFilterPreservingSchema() { Dataset ds = spark.range(10); - Dataset ds2 = ds.filter(new FilterFunction() { - @Override - public boolean call(Long value) throws Exception { - return value > 3; - } - }); + Dataset ds2 = ds.filter((FilterFunction) value -> value > 3); Assert.assertEquals(ds.schema(), ds2.schema()); } @@ -111,44 +106,28 @@ public void testCommonOperation() { Dataset ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset filtered = ds.filter(new FilterFunction() { - @Override - public boolean call(String v) throws Exception { - return v.startsWith("h"); - } - }); + Dataset filtered = ds.filter((FilterFunction) v -> v.startsWith("h")); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map(new MapFunction() { - @Override - public Integer call(String v) throws Exception { - return v.length(); - } - }, Encoders.INT()); + Dataset mapped = ds.map((MapFunction) v -> v.length(), Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { - @Override - public Iterator call(Iterator it) { - List ls = new LinkedList<>(); - while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); - } - return ls.iterator(); + Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { + List ls = new LinkedList<>(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase(Locale.ENGLISH)); } + return ls.iterator(); }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); - Dataset flatMapped = ds.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String s) { - List ls = new LinkedList<>(); - for (char c : s.toCharArray()) { - ls.add(String.valueOf(c)); - } - return ls.iterator(); + Dataset flatMapped = ds.flatMap((FlatMapFunction) s -> { + List ls = new LinkedList<>(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); } + return ls.iterator(); }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), @@ -157,16 +136,11 @@ public Iterator call(String s) { @Test public void testForeach() { - final LongAccumulator accum = jsc.sc().longAccumulator(); + LongAccumulator accum = jsc.sc().longAccumulator(); List data = Arrays.asList("a", "b", "c"); Dataset ds = spark.createDataset(data, Encoders.STRING()); - ds.foreach(new ForeachFunction() { - @Override - public void call(String s) throws Exception { - accum.add(1); - } - }); + ds.foreach((ForeachFunction) s -> accum.add(1)); Assert.assertEquals(3, accum.value().intValue()); } @@ -175,12 +149,7 @@ public void testReduce() { List data = Arrays.asList(1, 2, 3); Dataset ds = spark.createDataset(data, Encoders.INT()); - int reduced = ds.reduce(new ReduceFunction() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 + v2; - } - }); + int reduced = ds.reduce((ReduceFunction) (v1, v2) -> v1 + v2); Assert.assertEquals(6, reduced); } @@ -189,52 +158,38 @@ public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset grouped = ds.groupByKey( - new MapFunction() { - @Override - public Integer call(String v) throws Exception { - return v.length(); - } - }, + (MapFunction) v -> v.length(), Encoders.INT()); - Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { - @Override - public String call(Integer key, Iterator values) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return sb.toString(); + Dataset mapped = grouped.mapGroups((MapGroupsFunction) (key, values) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); } + return sb.toString(); }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); Dataset flatMapped = grouped.flatMapGroups( - new FlatMapGroupsFunction() { - @Override - public Iterator call(Integer key, Iterator values) { + (FlatMapGroupsFunction) (key, values) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return Collections.singletonList(sb.toString()).iterator(); - } - }, + }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); Dataset mapped2 = grouped.mapGroupsWithState( - new MapGroupsWithStateFunction() { - @Override - public String call(Integer key, Iterator values, KeyedState s) { + (MapGroupsWithStateFunction) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return sb.toString(); - } }, Encoders.LONG(), Encoders.STRING()); @@ -242,27 +197,19 @@ public String call(Integer key, Iterator values, KeyedState s) { Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); Dataset flatMapped2 = grouped.flatMapGroupsWithState( - new FlatMapGroupsWithStateFunction() { - @Override - public Iterator call(Integer key, Iterator values, KeyedState s) { + (FlatMapGroupsWithStateFunction) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return Collections.singletonList(sb.toString()).iterator(); - } - }, + }, Encoders.LONG(), Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); - Dataset> reduced = grouped.reduceGroups(new ReduceFunction() { - @Override - public String call(String v1, String v2) throws Exception { - return v1 + v2; - } - }); + Dataset> reduced = grouped.reduceGroups((ReduceFunction) (v1, v2) -> v1 + v2); Assert.assertEquals( asSet(tuple2(1, "a"), tuple2(3, "foobar")), @@ -271,29 +218,21 @@ public String call(String v1, String v2) throws Exception { List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset grouped2 = ds2.groupByKey( - new MapFunction() { - @Override - public Integer call(Integer v) throws Exception { - return v / 2; - } - }, + (MapFunction) v -> v / 2, Encoders.INT()); Dataset cogrouped = grouped.cogroup( grouped2, - new CoGroupFunction() { - @Override - public Iterator call(Integer key, Iterator left, Iterator right) { - StringBuilder sb = new StringBuilder(key.toString()); - while (left.hasNext()) { - sb.append(left.next()); - } - sb.append("#"); - while (right.hasNext()) { - sb.append(right.next()); - } - return Collections.singletonList(sb.toString()).iterator(); + (CoGroupFunction) (key, left, right) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); } + return Collections.singletonList(sb.toString()).iterator(); }, Encoders.STRING()); @@ -703,11 +642,11 @@ public void testJavaBeanEncoder() { obj1.setD(new String[]{"hello", null}); obj1.setE(Arrays.asList("a", "b")); obj1.setF(Arrays.asList(100L, null, 200L)); - Map map1 = new HashMap(); + Map map1 = new HashMap<>(); map1.put(1, "a"); map1.put(2, "b"); obj1.setG(map1); - Map nestedMap1 = new HashMap(); + Map nestedMap1 = new HashMap<>(); nestedMap1.put("x", "1"); nestedMap1.put("y", "2"); Map, Map> complexMap1 = new HashMap<>(); @@ -721,11 +660,11 @@ public void testJavaBeanEncoder() { obj2.setD(new String[]{null, "world"}); obj2.setE(Arrays.asList("x", "y")); obj2.setF(Arrays.asList(300L, null, 400L)); - Map map2 = new HashMap(); + Map map2 = new HashMap<>(); map2.put(3, "c"); map2.put(4, "d"); obj2.setG(map2); - Map nestedMap2 = new HashMap(); + Map nestedMap2 = new HashMap<>(); nestedMap2.put("q", "1"); nestedMap2.put("w", "2"); Map, Map> complexMap2 = new HashMap<>(); @@ -1328,7 +1267,7 @@ public NestedComplicatedJavaBean build() { @Test public void test() { /* SPARK-15285 Large numbers of Nested JavaBeans generates more than 64KB java bytecode */ - List data = new ArrayList(); + List data = new ArrayList<>(); data.add(NestedComplicatedJavaBean.newBuilder().build()); NestedComplicatedJavaBean obj3 = new NestedComplicatedJavaBean(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bbaac5a33975b..250fa674d8ecc 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.sql.types.DataTypes; @@ -54,16 +53,7 @@ public void tearDown() { @SuppressWarnings("unchecked") @Test public void udf1Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - - spark.udf().register("stringLengthTest", new UDF1() { - @Override - public Integer call(String str) { - return str.length(); - } - }, DataTypes.IntegerType); + spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); @@ -72,18 +62,8 @@ public Integer call(String str) { @SuppressWarnings("unchecked") @Test public void udf2Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", - // (String str1, String str2) -> str1.length() + str2.length, - // DataType.IntegerType); - - spark.udf().register("stringLengthTest", new UDF2() { - @Override - public Integer call(String str1, String str2) { - return str1.length() + str2.length(); - } - }, DataTypes.IntegerType); + spark.udf().register("stringLengthTest", + (String str1, String str2) -> str1.length() + str2.length(), DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); @@ -91,8 +71,8 @@ public Integer call(String str1, String str2) { public static class StringLengthTest implements UDF2 { @Override - public Integer call(String str1, String str2) throws Exception { - return new Integer(str1.length() + str2.length()); + public Integer call(String str1, String str2) { + return str1.length() + str2.length(); } } @@ -113,12 +93,7 @@ public void udf3Test() { @SuppressWarnings("unchecked") @Test public void udf4Test() { - spark.udf().register("inc", new UDF1() { - @Override - public Long call(Long i) { - return i + 1; - } - }, DataTypes.LongType); + spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); // This tests when Java UDFs are required to be the semantically same (See SPARK-9435). diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index 9b7701003d8d0..cb8ed83e5a49d 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -27,9 +27,6 @@ import scala.Tuple2; import com.google.common.collect.Sets; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.util.ManualClock; import org.junit.Assert; @@ -53,18 +50,14 @@ public void testAPI() { JavaPairDStream wordsDstream = null; Function4, State, Optional> mappingFunc = - new Function4, State, Optional>() { - @Override - public Optional call( - Time time, String word, Optional one, State state) { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return Optional.of(2.0); - } + (time, word, one, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); }; JavaMapWithStateDStream stateDstream = @@ -78,17 +71,14 @@ public Optional call( stateDstream.stateSnapshots(); Function3, State, Double> mappingFunc2 = - new Function3, State, Double>() { - @Override - public Double call(String key, Optional one, State state) { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return 2.0; - } + (key, one, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; }; JavaMapWithStateDStream stateDstream2 = @@ -136,13 +126,10 @@ public void testBasicFunction() { ); Function3, State, Integer> mappingFunc = - new Function3, State, Integer>() { - @Override - public Integer call(String key, Optional value, State state) { - int sum = value.orElse(0) + (state.exists() ? state.get() : 0); - state.update(sum); - return sum; - } + (key, value, state) -> { + int sum = value.orElse(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; }; testOperation( inputData, @@ -159,29 +146,15 @@ private void testOperation( int numBatches = expectedOutputs.size(); JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); JavaMapWithStateDStream mapWithStateDStream = - JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { - @Override - public Tuple2 call(K x) { - return new Tuple2<>(x, 1); - } - })).mapWithState(mapWithStateSpec); - - final List> collectedOutputs = + JavaPairDStream.fromJavaDStream(inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec); + + List> collectedOutputs = Collections.synchronizedList(new ArrayList>()); - mapWithStateDStream.foreachRDD(new VoidFunction>() { - @Override - public void call(JavaRDD rdd) { - collectedOutputs.add(Sets.newHashSet(rdd.collect())); - } - }); - final List>> collectedStateSnapshots = + mapWithStateDStream.foreachRDD(rdd -> collectedOutputs.add(Sets.newHashSet(rdd.collect()))); + List>> collectedStateSnapshots = Collections.synchronizedList(new ArrayList>>()); - mapWithStateDStream.stateSnapshots().foreachRDD(new VoidFunction>() { - @Override - public void call(JavaPairRDD rdd) { - collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); - } - }); + mapWithStateDStream.stateSnapshots().foreachRDD(rdd -> + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()))); BatchCounter batchCounter = new BatchCounter(ssc.ssc()); ssc.start(); ((ManualClock) ssc.ssc().scheduler().clock()) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 091ccbfd85cad..91560472446a9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -58,24 +58,16 @@ public void testReceiver() throws InterruptedException { TestServer server = new TestServer(0); server.start(); - final AtomicLong dataCounter = new AtomicLong(0); + AtomicLong dataCounter = new AtomicLong(0); try { JavaStreamingContext ssc = new JavaStreamingContext("local[2]", "test", new Duration(200)); JavaReceiverInputDStream input = ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); - JavaDStream mapped = input.map(new Function() { - @Override - public String call(String v1) { - return v1 + "."; - } - }); - mapped.foreachRDD(new VoidFunction>() { - @Override - public void call(JavaRDD rdd) { - long count = rdd.count(); - dataCounter.addAndGet(count); - } + JavaDStream mapped = input.map((Function) v1 -> v1 + "."); + mapped.foreachRDD((VoidFunction>) rdd -> { + long count = rdd.count(); + dataCounter.addAndGet(count); }); ssc.start(); @@ -110,11 +102,7 @@ private static class JavaSocketReceiver extends Receiver { @Override public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); + new Thread(this::receive).start(); } @Override diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index f02fa87f6194b..3f4e6ddb216ec 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -23,7 +23,6 @@ import java.util.Iterator; import java.util.List; -import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; import org.apache.spark.network.util.JavaUtils; @@ -81,12 +80,7 @@ public ByteBuffer read(WriteAheadLogRecordHandle handle) { @Override public Iterator readAll() { - return Iterators.transform(records.iterator(), new Function() { - @Override - public ByteBuffer apply(Record input) { - return input.buffer; - } - }); + return Iterators.transform(records.iterator(), input -> input.buffer); } @Override @@ -114,7 +108,7 @@ public void testCustomWAL() { String data1 = "data1"; WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); + Assert.assertEquals(data1, JavaUtils.bytesToString(wal.read(handle))); wal.write(JavaUtils.stringToBytes("data2"), 1235); wal.write(JavaUtils.stringToBytes("data3"), 1236); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 646cb97066f35..9948a4074cdc7 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -28,7 +28,6 @@ import org.apache.spark.streaming.Time; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.junit.Assert; import org.junit.Test; @@ -101,7 +100,7 @@ public void testMapPartitions() { while (in.hasNext()) { out = out + in.next().toUpperCase(); } - return Lists.newArrayList(out).iterator(); + return Arrays.asList(out).iterator(); }); JavaTestUtils.attachTestOutputStream(mapped); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -240,7 +239,7 @@ public void testTransformWith() { JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res : result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -315,7 +314,7 @@ public void testStreamingContextTransform() { JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1)); - List> listOfDStreams1 = Arrays.>asList(stream1, stream2); + List> listOfDStreams1 = Arrays.asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles JavaDStream transformed1 = ssc.transform( @@ -325,7 +324,7 @@ public void testStreamingContextTransform() { }); List> listOfDStreams2 = - Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); + Arrays.asList(stream1, stream2, pairStream1.toJavaDStream()); JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, (List> listOfRDDs, Time time) -> { @@ -358,7 +357,7 @@ public void testFlatMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream flatMapped = stream.flatMap( - s -> Lists.newArrayList(s.split("(?!^)")).iterator()); + s -> Arrays.asList(s.split("(?!^)")).iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -401,7 +400,7 @@ public void testPairFlatMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair(s -> { - List> out = Lists.newArrayList(); + List> out = new ArrayList<>(); for (String letter : s.split("(?!^)")) { out.add(new Tuple2<>(s.length(), letter)); } @@ -420,7 +419,7 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - expected.forEach(list -> Collections.sort(list)); + expected.forEach(Collections::sort); List> sortedActual = new ArrayList<>(); actual.forEach(list -> { List sortedList = new ArrayList<>(list); @@ -491,7 +490,7 @@ public void testPairMap() { // Maps pair -> pair of different type JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapToPair(x -> x.swap()); + JavaPairDStream reversed = pairStream.mapToPair(Tuple2::swap); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -543,7 +542,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream reversed = pairStream.map(in -> in._2()); + JavaDStream reversed = pairStream.map(Tuple2::_2); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -629,7 +628,7 @@ public void testCombineByKey() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey(i -> i, + JavaPairDStream combined = pairStream.combineByKey(i -> i, (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); JavaTestUtils.attachTestOutputStream(combined); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 8d24104d7870b..b966cbdca076d 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -33,7 +33,6 @@ import scala.Tuple2; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; @@ -123,12 +122,7 @@ public void testMap() { Arrays.asList(9,4)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) { - return s.length(); - } - }); + JavaDStream letterCount = stream.map(String::length); JavaTestUtils.attachTestOutputStream(letterCount); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -194,12 +188,7 @@ public void testFilter() { Arrays.asList("yankees")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream filtered = stream.filter(new Function() { - @Override - public Boolean call(String s) { - return s.contains("a"); - } - }); + JavaDStream filtered = stream.filter(s -> s.contains("a")); JavaTestUtils.attachTestOutputStream(filtered); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -276,17 +265,13 @@ public void testMapPartitions() { Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions( - new FlatMapFunction, String>() { - @Override - public Iterator call(Iterator in) { - StringBuilder out = new StringBuilder(); - while (in.hasNext()) { - out.append(in.next().toUpperCase(Locale.ENGLISH)); - } - return Arrays.asList(out.toString()).iterator(); - } - }); + JavaDStream mapped = stream.mapPartitions(in -> { + StringBuilder out = new StringBuilder(); + while (in.hasNext()) { + out.append(in.next().toUpperCase(Locale.ENGLISH)); + } + return Arrays.asList(out.toString()).iterator(); + }); JavaTestUtils.attachTestOutputStream(mapped); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -416,18 +401,7 @@ public void testTransform() { Arrays.asList(9,10,11)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = stream.transform( - new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD in) { - return in.map(new Function() { - @Override - public Integer call(Integer i) { - return i + 2; - } - }); - } - }); + JavaDStream transformed = stream.transform(in -> in.map(i -> i + 2)); JavaTestUtils.attachTestOutputStream(transformed); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -448,71 +422,21 @@ public void testVariousTransform() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - stream.transform( - new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD in) { - return null; - } - } - ); + stream.transform(in -> null); - stream.transform( - new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) { - return null; - } - } - ); + stream.transform((in, time) -> null); - stream.transformToPair( - new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) { - return null; - } - } - ); + stream.transformToPair(in -> null); - stream.transformToPair( - new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) { - return null; - } - } - ); + stream.transformToPair((in, time) -> null); - pairStream.transform( - new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) { - return null; - } - } - ); + pairStream.transform(in -> null); - pairStream.transform( - new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) { - return null; - } - } - ); + pairStream.transform((in, time) -> null); - pairStream.transformToPair( - new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) { - return null; - } - } - ); + pairStream.transformToPair(in -> null); - pairStream.transformToPair( - new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, - Time time) { - return null; - } - } - ); + pairStream.transformToPair((in, time) -> null); } @@ -558,19 +482,7 @@ public void testTransformWith() { JavaPairDStream> joined = pairStream1.transformWithToPair( pairStream2, - new Function3< - JavaPairRDD, - JavaPairRDD, - Time, - JavaPairRDD>>() { - @Override - public JavaPairRDD> call( - JavaPairRDD rdd1, - JavaPairRDD rdd2, - Time time) { - return rdd1.join(rdd2); - } - } + (rdd1, rdd2, time) -> rdd1.join(rdd2) ); JavaTestUtils.attachTestOutputStream(joined); @@ -603,100 +515,21 @@ public void testVariousTransformWith() { JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - stream1.transformWith( - stream2, - new Function3, JavaRDD, Time, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { - return null; - } - } - ); + stream1.transformWith(stream2, (rdd1, rdd2, time) -> null); - stream1.transformWith( - pairStream1, - new Function3, JavaPairRDD, Time, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, - Time time) { - return null; - } - } - ); + stream1.transformWith(pairStream1, (rdd1, rdd2, time) -> null); - stream1.transformWithToPair( - stream2, - new Function3, JavaRDD, Time, JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, - Time time) { - return null; - } - } - ); + stream1.transformWithToPair(stream2, (rdd1, rdd2, time) -> null); - stream1.transformWithToPair( - pairStream1, - new Function3, JavaPairRDD, Time, - JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaRDD rdd1, - JavaPairRDD rdd2, - Time time) { - return null; - } - } - ); + stream1.transformWithToPair(pairStream1, (rdd1, rdd2, time) -> null); - pairStream1.transformWith( - stream2, - new Function3, JavaRDD, Time, JavaRDD>() { - @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, - Time time) { - return null; - } - } - ); + pairStream1.transformWith(stream2, (rdd1, rdd2, time) -> null); - pairStream1.transformWith( - pairStream1, - new Function3, JavaPairRDD, Time, - JavaRDD>() { - @Override - public JavaRDD call(JavaPairRDD rdd1, - JavaPairRDD rdd2, - Time time) { - return null; - } - } - ); + pairStream1.transformWith(pairStream1, (rdd1, rdd2, time) -> null); - pairStream1.transformWithToPair( - stream2, - new Function3, JavaRDD, Time, - JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD rdd1, - JavaRDD rdd2, - Time time) { - return null; - } - } - ); + pairStream1.transformWithToPair(stream2, (rdd1, rdd2, time) -> null); - pairStream1.transformWithToPair( - pairStream2, - new Function3, JavaPairRDD, Time, - JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD rdd1, - JavaPairRDD rdd2, - Time time) { - return null; - } - } - ); + pairStream1.transformWithToPair(pairStream2, (rdd1, rdd2, time) -> null); } @SuppressWarnings("unchecked") @@ -727,44 +560,32 @@ public void testStreamingContextTransform(){ JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1)); - List> listOfDStreams1 = Arrays.>asList(stream1, stream2); + List> listOfDStreams1 = Arrays.asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles ssc.transform( listOfDStreams1, - new Function2>, Time, JavaRDD>() { - @Override - public JavaRDD call(List> listOfRDDs, Time time) { - Assert.assertEquals(2, listOfRDDs.size()); - return null; - } + (listOfRDDs, time) -> { + Assert.assertEquals(2, listOfRDDs.size()); + return null; } ); List> listOfDStreams2 = - Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); + Arrays.asList(stream1, stream2, pairStream1.toJavaDStream()); JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, - new Function2>, Time, JavaPairRDD>>() { - @Override - public JavaPairRDD> call(List> listOfRDDs, - Time time) { - Assert.assertEquals(3, listOfRDDs.size()); - JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); - JavaRDD rdd2 = (JavaRDD)listOfRDDs.get(1); - JavaRDD> rdd3 = - (JavaRDD>)listOfRDDs.get(2); - JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); - PairFunction mapToTuple = - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i, i); - } - }; - return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); - } + (listOfRDDs, time) -> { + Assert.assertEquals(3, listOfRDDs.size()); + JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); + JavaRDD rdd2 = (JavaRDD)listOfRDDs.get(1); + JavaRDD> rdd3 = + (JavaRDD>)listOfRDDs.get(2); + JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); + PairFunction mapToTuple = + (PairFunction) i -> new Tuple2<>(i, i); + return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); } ); JavaTestUtils.attachTestOutputStream(transformed2); @@ -787,12 +608,7 @@ public void testFlatMap() { Arrays.asList("a","t","h","l","e","t","i","c","s")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(x.split("(?!^)")).iterator(); - } - }); + JavaDStream flatMapped = stream.flatMap(x -> Arrays.asList(x.split("(?!^)")).iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -811,25 +627,13 @@ public void testForeachRDD() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output - stream.foreachRDD(new VoidFunction>() { - @Override - public void call(JavaRDD rdd) { - accumRdd.add(1); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer i) { - accumEle.add(1); - } - }); - } + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(i -> accumEle.add(1)); }); // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java - stream.foreachRDD(new VoidFunction2, Time>() { - @Override - public void call(JavaRDD rdd, Time time) { - } - }); + stream.foreachRDD((rdd, time) -> {}); JavaTestUtils.runStreams(ssc, 2, 2); @@ -873,16 +677,12 @@ public void testPairFlatMap() { new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMapToPair( - new PairFlatMapFunction() { - @Override - public Iterator> call(String in) { - List> out = new ArrayList<>(); - for (String letter: in.split("(?!^)")) { - out.add(new Tuple2<>(in.length(), letter)); - } - return out.iterator(); + JavaPairDStream flatMapped = stream.flatMapToPair(in -> { + List> out = new ArrayList<>(); + for (String letter : in.split("(?!^)")) { + out.add(new Tuple2<>(in.length(), letter)); } + return out.iterator(); }); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -949,21 +749,10 @@ public void testPairFilter() { Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = stream.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String in) { - return new Tuple2<>(in, in.length()); - } - }); + JavaPairDStream pairStream = + stream.mapToPair(in -> new Tuple2<>(in, in.length())); - JavaPairDStream filtered = pairStream.filter( - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 in) { - return in._1().contains("a"); - } - }); + JavaPairDStream filtered = pairStream.filter(in -> in._1().contains("a")); JavaTestUtils.attachTestOutputStream(filtered); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1014,13 +803,7 @@ public void testPairMap() { // Maps pair -> pair of different type JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapToPair( - new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 in) { - return in.swap(); - } - }); + JavaPairDStream reversed = pairStream.mapToPair(Tuple2::swap); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1048,18 +831,14 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapPartitionsToPair( - new PairFlatMapFunction>, Integer, String>() { - @Override - public Iterator> call(Iterator> in) { - List> out = new LinkedList<>(); - while (in.hasNext()) { - Tuple2 next = in.next(); - out.add(next.swap()); - } - return out.iterator(); - } - }); + JavaPairDStream reversed = pairStream.mapPartitionsToPair(in -> { + List> out = new LinkedList<>(); + while (in.hasNext()) { + Tuple2 next = in.next(); + out.add(next.swap()); + } + return out.iterator(); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1079,13 +858,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) { - return in._2(); - } - }); + JavaDStream reversed = pairStream.map(in -> in._2()); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1119,17 +892,13 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream flatMapped = pairStream.flatMapToPair( - new PairFlatMapFunction, Integer, String>() { - @Override - public Iterator> call(Tuple2 in) { - List> out = new LinkedList<>(); - for (Character s : in._1().toCharArray()) { - out.add(new Tuple2<>(in._2(), s.toString())); - } - return out.iterator(); - } - }); + JavaPairDStream flatMapped = pairStream.flatMapToPair(in -> { + List> out = new LinkedList<>(); + for (Character s : in._1().toCharArray()) { + out.add(new Tuple2<>(in._2(), s.toString())); + } + return out.iterator(); + }); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1216,12 +985,7 @@ public void testCombineByKey() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream combined = pairStream.combineByKey( - new Function() { - @Override - public Integer call(Integer i) { - return i; - } - }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); + i -> i, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); JavaTestUtils.attachTestOutputStream(combined); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1345,20 +1109,16 @@ public void testUpdateStateByKey() { JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream updated = pairStream.updateStateByKey( - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out += state.get(); - } - for (Integer v : values) { - out += v; - } - return Optional.of(out); - } - }); + JavaPairDStream updated = pairStream.updateStateByKey((values, state) -> { + int out = 0; + if (state.isPresent()) { + out += state.get(); + } + for (Integer v : values) { + out += v; + } + return Optional.of(out); + }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1389,20 +1149,16 @@ public void testUpdateStateByKeyWithInitial() { JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream updated = pairStream.updateStateByKey( - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out += state.get(); - } - for (Integer v : values) { - out += v; - } - return Optional.of(out); - } - }, new HashPartitioner(1), initialRDD); + JavaPairDStream updated = pairStream.updateStateByKey((values, state) -> { + int out = 0; + if (state.isPresent()) { + out += state.get(); + } + for (Integer v : values) { + out += v; + } + return Optional.of(out); + }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1500,13 +1256,7 @@ public void testPairTransform() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream sorted = pairStream.transformToPair( - new Function, JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD in) { - return in.sortByKey(); - } - }); + JavaPairDStream sorted = pairStream.transformToPair(in -> in.sortByKey()); JavaTestUtils.attachTestOutputStream(sorted); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1537,18 +1287,7 @@ public void testPairToNormalRDDTransform() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream firstParts = pairStream.transform( - new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaPairRDD in) { - return in.map(new Function, Integer>() { - @Override - public Integer call(Tuple2 in2) { - return in2._1(); - } - }); - } - }); + JavaDStream firstParts = pairStream.transform(in -> in.map(in2 -> in2._1())); JavaTestUtils.attachTestOutputStream(firstParts); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1575,12 +1314,7 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(new Function() { - @Override - public String call(String s) { - return s.toUpperCase(Locale.ENGLISH); - } - }); + JavaPairDStream mapped = pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1616,16 +1350,12 @@ public void testFlatMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream flatMapped = pairStream.flatMapValues( - new Function>() { - @Override - public Iterable call(String in) { - List out = new ArrayList<>(); - out.add(in + "1"); - out.add(in + "2"); - return out; - } - }); + JavaPairDStream flatMapped = pairStream.flatMapValues(in -> { + List out = new ArrayList<>(); + out.add(in + "1"); + out.add(in + "2"); + return out; + }); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1795,12 +1525,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { ssc.checkpoint(tempDir.getAbsolutePath()); JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) { - return s.length(); - } - }); + JavaDStream letterCount = stream.map(String::length); JavaCheckpointTestUtils.attachTestOutputStream(letterCount); List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); @@ -1822,7 +1547,7 @@ public Integer call(String s) { public void testContextGetOrCreate() throws InterruptedException { ssc.stop(); - final SparkConf conf = new SparkConf() + SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") .set("newContext", "true"); @@ -1835,13 +1560,10 @@ public void testContextGetOrCreate() throws InterruptedException { // Function to create JavaStreamingContext without any output operations // (used to detect the new context) - final AtomicBoolean newContextCreated = new AtomicBoolean(false); - Function0 creatingFunc = new Function0() { - @Override - public JavaStreamingContext call() { - newContextCreated.set(true); - return new JavaStreamingContext(conf, Seconds.apply(1)); - } + AtomicBoolean newContextCreated = new AtomicBoolean(false); + Function0 creatingFunc = () -> { + newContextCreated.set(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); }; newContextCreated.set(false); @@ -1912,18 +1634,15 @@ public void testSocketString() { ssc.socketStream( "localhost", 12345, - new Function>() { - @Override - public Iterable call(InputStream in) throws IOException { - List out = new ArrayList<>(); - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(in, StandardCharsets.UTF_8))) { - for (String line; (line = reader.readLine()) != null;) { - out.add(line); - } + in -> { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(in, StandardCharsets.UTF_8))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); } - return out; } + return out; }, StorageLevel.MEMORY_ONLY()); } @@ -1952,21 +1671,10 @@ public void testFileStream() throws IOException { LongWritable.class, Text.class, TextInputFormat.class, - new Function() { - @Override - public Boolean call(Path v1) { - return Boolean.TRUE; - } - }, + v1 -> Boolean.TRUE, true); - JavaDStream test = inputStream.map( - new Function, String>() { - @Override - public String call(Tuple2 v1) { - return v1._2().toString(); - } - }); + JavaDStream test = inputStream.map(v1 -> v1._2().toString()); JavaTestUtils.attachTestOutputStream(test); List> result = JavaTestUtils.runStreams(ssc, 1, 1); From 405ec0f0e4e12c3c9832c1052258472839898387 Mon Sep 17 00:00:00 2001 From: windpiger Date: Sun, 19 Feb 2017 16:50:16 -0800 Subject: [PATCH 06/61] [SPARK-19598][SQL] Remove the alias parameter in UnresolvedRelation ## What changes were proposed in this pull request? Remove the alias parameter in `UnresolvedRelation`, and use `SubqueryAlias` to replace it. This can simplify some `match case` situations. For example, the broadcast hint pull request can have one fewer case https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala#L57-L61 ## How was this patch tested? add some unit tests Author: windpiger Closes #16956 from windpiger/removeUnresolveTableAlias. --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++---- .../sql/catalyst/analysis/ResolveHints.scala | 6 ++-- .../sql/catalyst/analysis/unresolved.scala | 5 +-- .../sql/catalyst/catalog/SessionCatalog.scala | 12 +++---- .../spark/sql/catalyst/dsl/package.scala | 10 ++---- .../sql/catalyst/parser/AstBuilder.scala | 16 ++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 32 ++++++++++++------- .../catalog/SessionCatalogSuite.scala | 22 ------------- .../sql/execution/datasources/rules.scala | 3 +- .../org/apache/spark/sql/JoinSuite.scala | 4 +-- .../benchmark/TPCDSQueryBenchmark.scala | 4 +-- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 12 files changed, 51 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd517a98aca1c..39a276284c35e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -180,12 +180,8 @@ class Analyzer( def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan transformDown { case u : UnresolvedRelation => - val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) - .map(_._2).map { relation => - val withAlias = u.alias.map(SubqueryAlias(_, relation, None)) - withAlias.getOrElse(relation) - } - substituted.getOrElse(u) + cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) + .map(_._2).getOrElse(u) case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { @@ -623,7 +619,7 @@ class Analyzer( val tableIdentWithDb = u.tableIdentifier.copy( database = u.tableIdentifier.database.orElse(defaultDatabase)) try { - catalog.lookupRelation(tableIdentWithDb, u.alias) + catalog.lookupRelation(tableIdentWithDb) } catch { case _: NoSuchTableException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 2124177461b3b..70438eb5912b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -54,10 +54,8 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case r: UnresolvedRelation => - val alias = r.alias.getOrElse(r.tableIdentifier.table) - if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan - + case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => + BroadcastHint(plan) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => BroadcastHint(plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 36ed9ba50372b..262b894e2a0a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -37,10 +37,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str /** * Holds the name of a relation that has yet to be looked up in a catalog. */ -case class UnresolvedRelation( - tableIdentifier: TableIdentifier, - alias: Option[String] = None) extends LeafNode { - +case class UnresolvedRelation(tableIdentifier: TableIdentifier) extends LeafNode { /** Returns a `.` separated name for this relation. */ def tableName: String = tableIdentifier.unquotedString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index dd0c5cb7066f5..73ef0e6a1869e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -572,16 +572,14 @@ class SessionCatalog( * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. * * @param name The name of the table/view that we look up. - * @param alias The alias name of the table/view that we look up. */ - def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { + def lookupRelation(name: TableIdentifier): LogicalPlan = { synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - val relationAlias = alias.getOrElse(table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(relationAlias, viewDef, None) + SubqueryAlias(table, viewDef, None) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -594,12 +592,12 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(relationAlias, child, Some(name.copy(table = table, database = Some(db)))) + SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) } else { - SubqueryAlias(relationAlias, SimpleCatalogRelation(metadata), None) + SubqueryAlias(table, SimpleCatalogRelation(metadata), None) } } else { - SubqueryAlias(relationAlias, tempTables(table), None) + SubqueryAlias(table, tempTables(table), None) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 66e52ca68af19..3c531323397e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -280,11 +280,10 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore - def table(ref: String): LogicalPlan = - UnresolvedRelation(TableIdentifier(ref), None) + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) def table(db: String, ref: String): LogicalPlan = - UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + UnresolvedRelation(TableIdentifier(ref, Option(db))) implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { def select(exprs: Expression*): LogicalPlan = { @@ -369,10 +368,7 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def as(alias: String): LogicalPlan = logicalPlan match { - case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) - case plan => SubqueryAlias(alias, plan, None) - } + def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None) def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bbb9922c187de..08a6dd136b857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -179,7 +179,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } InsertIntoTable( - UnresolvedRelation(tableIdent, None), + UnresolvedRelation(tableIdent), partitionKeys, query, ctx.OVERWRITE != null, @@ -645,17 +645,21 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { - UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) } /** * Create an aliased table reference. This is typically used in FROM clauses. */ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { - val table = UnresolvedRelation( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.strictIdentifier).map(_.getText)) - table.optionalMap(ctx.sample)(withSample) + val table = UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) + + val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { + case Some(strictIdentifier) => + SubqueryAlias(strictIdentifier, table, None) + case _ => table + } + tableWithAlias.optionalMap(ctx.sample)(withSample) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 81a97dc1ff3f2..786e0f49b4b25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -61,23 +61,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation)) assertAnalysisError( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Seq("cannot resolve")) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( - TableIdentifier("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), Project(testRelation.output, testRelation), caseSensitive = false) } @@ -166,12 +166,12 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("resolve relations") { - assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq()) - checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) + assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe")), Seq()) + checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE")), testRelation) checkAnalysis( - UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) + UnresolvedRelation(TableIdentifier("tAbLe")), testRelation, caseSensitive = false) checkAnalysis( - UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) + UnresolvedRelation(TableIdentifier("TaBlE")), testRelation, caseSensitive = false) } test("divide should be casted into fractional types") { @@ -429,4 +429,14 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { assertAnalysisSuccess(r1) assertAnalysisSuccess(r2) } + + test("resolve as with an already existed alias") { + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl2.a")), + SubqueryAlias("tbl", testRelation, None).as("tbl2")), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index db73f03c8bb73..44434324d3770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -444,28 +444,6 @@ class SessionCatalogSuite extends PlanTest { == SubqueryAlias("tbl1", SimpleCatalogRelation(metastoreTable1), None)) } - test("lookup table relation with alias") { - val catalog = new SessionCatalog(newBasicCatalog()) - val alias = "monster" - val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) - val relation = SubqueryAlias("tbl1", SimpleCatalogRelation(tableMetadata), None) - val relationWithAlias = - SubqueryAlias(alias, - SimpleCatalogRelation(tableMetadata), None) - assert(catalog.lookupRelation( - TableIdentifier("tbl1", Some("db2")), alias = None) == relation) - assert(catalog.lookupRelation( - TableIdentifier("tbl1", Some("db2")), alias = Some(alias)) == relationWithAlias) - } - - test("lookup view with view name in alias") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tmpView = Range(1, 10, 2, 10) - catalog.createTempView("vw1", tmpView, overrideIfExists = false) - val plan = catalog.lookupRelation(TableIdentifier("vw1"), Option("range")) - assert(plan == SubqueryAlias("range", tmpView, None)) - } - test("look up view relation") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1c3e7c6d52239..e7a59d4ad4dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -52,8 +52,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } - val plan = LogicalRelation(dataSource.resolveRelation()) - u.alias.map(a => SubqueryAlias(a, plan, None)).getOrElse(plan) + LogicalRelation(dataSource.resolveRelation()) } catch { case _: ClassNotFoundException => u case e: Exception => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f780fc0ec013c..2e006735d123e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -364,8 +364,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") - val left = UnresolvedRelation(TableIdentifier("left"), None) - val right = UnresolvedRelation(TableIdentifier("right"), None) + val left = UnresolvedRelation(TableIdentifier("left")) + val right = UnresolvedRelation(TableIdentifier("right")) checkAnswer( left.join(right, $"left.N" === $"right.N", "full"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 3988d9750b585..239822b72034a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -73,13 +73,13 @@ object TPCDSQueryBenchmark { // per-row processing time for those cases. val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.logical.map { - case ur @ UnresolvedRelation(t: TableIdentifier, _) => + case ur @ UnresolvedRelation(t: TableIdentifier) => queryRelations.add(t.table) case lp: LogicalPlan => lp.expressions.foreach { _ foreach { case subquery: SubqueryExpression => subquery.plan.foreach { - case ur @ UnresolvedRelation(t: TableIdentifier, _) => + case ur @ UnresolvedRelation(t: TableIdentifier) => queryRelations.add(t.table) case _ => } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3267c237c865a..fd139119472db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -483,7 +483,7 @@ private[hive] class TestHiveQueryExecution( // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } + logical.collect { case UnresolvedRelation(tableIdent) => tableIdent.table } val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) From 88ca174a23a3b114c11e244951db3de82c8d69e3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 19 Feb 2017 18:13:12 -0800 Subject: [PATCH 07/61] [SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter ## What changes were proposed in this pull request? In `FileFormatWriter`, we will sort the input rows by partition columns and bucket id and sort columns, if we want to write data out partitioned or bucketed. However, if the data is already sorted, we will sort it again, which is unnecssary. This PR removes the sorting logic in `FileFormatWriter` and use `SortExec` instead. We will not add `SortExec` if the data is already sorted. ## How was this patch tested? I did a micro benchmark manually ``` val df = spark.range(10000000).select($"id", $"id" % 10 as "part").sort("part") spark.time(df.write.partitionBy("part").parquet("/tmp/test")) ``` The result was about 6.4 seconds before this PR, and is 5.7 seconds afterwards. close https://github.com/apache/spark/pull/16724 Author: Wenchen Fan Closes #16898 from cloud-fan/writer. --- .../datasources/FileFormatWriter.scala | 189 +++++++++--------- 1 file changed, 90 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index be13cbc51a9d3..644358493e2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ @@ -64,9 +63,9 @@ object FileFormatWriter extends Logging { val serializableHadoopConf: SerializableConfiguration, val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], val dataColumns: Seq[Attribute], - val bucketSpec: Option[BucketSpec], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) @@ -108,9 +107,21 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) + val allColumns = queryExecution.logical.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val bucketIdExpression = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) @@ -119,23 +130,45 @@ object FileFormatWriter extends Logging { uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = queryExecution.logical.output, - partitionColumns = partitionColumns, + allColumns = allColumns, dataColumns = dataColumns, - bucketSpec = bucketSpec, + partitionColumns = partitionColumns, + bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) ) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + // the sort order doesn't matter + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) try { - val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, + val rdd = if (orderingMatched) { + queryExecution.toRdd + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = queryExecution.executedPlan).execute() + } + + val ret = sparkSession.sparkContext.runJob(rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -189,7 +222,7 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) val writeTask = - if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { new DynamicPartitionWriteTask(description, taskAttemptContext, committer) @@ -287,31 +320,16 @@ object FileFormatWriter extends Logging { * multiple directories (partitions) or files (bucketing). */ private class DynamicPartitionWriteTask( - description: WriteJobDescription, + desc: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ - private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ + private def partitionPathExpression: Seq[Expression] = { + desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, @@ -325,35 +343,46 @@ object FileFormatWriter extends Logging { } /** - * Open and returns a new OutputWriter given a partition key and optional bucket id. + * Opens a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param key vaues for fields consisting of partition keys for the current row - * @param partString a function that projects the partition values into a string + * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the + * current row. + * @param getPartitionPath a function that projects the partition values into a path string. * @param fileCounter the number of files that have been written in the past for this specific * partition. This is used to limit the max number of records written for a * single file. The value should start from 0. + * @param updatedPartitions the set of updated partition paths, we should add the new partition + * path of this writer to it. */ private def newOutputWriter( - key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { - val partDir = - if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) + partColsAndBucketId: InternalRow, + getPartitionPath: UnsafeProjection, + fileCounter: Int, + updatedPartitions: mutable.Set[String]): Unit = { + val partDir = if (desc.partitionColumns.isEmpty) { + None + } else { + Option(getPartitionPath(partColsAndBucketId).getString(0)) + } + partDir.foreach(updatedPartitions.add) - // If the bucket spec is defined, the bucket column is right after the partition columns - val bucketId = if (description.bucketSpec.isDefined) { - BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) + // If the bucketId expression is defined, the bucketId column is right after the partition + // columns. + val bucketId = if (desc.bucketIdExpression.isDefined) { + BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) } else { "" } // This must be in a form that matches our bucketing format. See BucketingUtils. val ext = f"$bucketId.c$fileCounter%03d" + - description.outputWriterFactory.getFileExtension(taskAttemptContext) + desc.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => - description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) case _ => None } @@ -363,80 +392,42 @@ object FileFormatWriter extends Logging { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - currentWriter = description.outputWriterFactory.newInstance( + currentWriter = desc.outputWriterFactory.newInstance( path = path, - dataSchema = description.dataColumns.toStructType, + dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } override def execute(iter: Iterator[InternalRow]): Set[String] = { - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) - - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(description.dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iter.hasNext) { - val currentRow = iter.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + val getPartitionColsAndBucketId = UnsafeProjection.create( + desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. + val getPartPath = UnsafeProjection.create( + Seq(Concat(partitionPathExpression)), desc.partitionColumns) - val sortedIterator = sorter.sortedIterator() + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 - var currentKey: UnsafeRow = null + var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - // See a new key - write to a new partition (new file). - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") + for (row <- iter) { + val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) + if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + currentPartColsAndBucketId = nextPartColsAndBucketId.copy() + logDebug(s"Writing partition: $currentPartColsAndBucketId") recordsInFile = 0 fileCounter = 0 releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) - val partitionPath = getPartitionStringFunc(currentKey).getString(0) - if (partitionPath.nonEmpty) { - updatedPartitions.add(partitionPath) - } - } else if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) + } else if (desc.maxRecordsPerFile > 0 && + recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. recordsInFile = 0 @@ -445,10 +436,10 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - currentWriter.write(sortedIterator.getValue) + currentWriter.write(getOutputRow(row)) recordsInFile += 1 } releaseResources() From fecfd217a79e9e518efecb6af868aa90100ed50d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 20 Feb 2017 09:02:09 -0800 Subject: [PATCH 08/61] [SPARK-19646][CORE][STREAMING] binaryRecords replicates records in scala API ## What changes were proposed in this pull request? Use `BytesWritable.copyBytes`, not `getBytes`, because `getBytes` returns the underlying array, which may be reused when repeated reads don't need a different size, as is the case with binaryRecords APIs ## How was this patch tested? Existing tests Author: Sean Owen Closes #16974 from srowen/SPARK-19646. --- .../scala/org/apache/spark/SparkContext.scala | 5 +- .../scala/org/apache/spark/FileSuite.scala | 178 ++++-------------- .../spark/streaming/StreamingContext.scala | 5 +- .../spark/streaming/InputStreamsSuite.scala | 21 ++- 4 files changed, 53 insertions(+), 156 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e4d83893e740e..17194b9f06d35 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -961,12 +961,11 @@ class SparkContext(config: SparkConf) extends Logging { classOf[LongWritable], classOf[BytesWritable], conf = conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() assert(bytes.length == recordLength, "Byte array does not have correct length") bytes } - data } /** diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 6538507d407e0..a2d3177c5c711 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.nio.ByteBuffer import java.util.zip.GZIPOutputStream import scala.io.Source @@ -30,7 +31,6 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel @@ -237,24 +237,26 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } - test("binary file input as byte array") { - sc = new SparkContext("local", "test") + private def writeBinaryData(testOutput: Array[Byte], testOutputCopies: Int): File = { val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) + val file = new FileOutputStream(outFile) val channel = file.getChannel - channel.write(bbuf) + for (i <- 0 until testOutputCopies) { + // Shift values by i so that they're different in the output + val alteredOutput = testOutput.map(b => (b + i).toByte) + channel.write(ByteBuffer.wrap(alteredOutput)) + } channel.close() file.close() + outFile + } - val inRdd = sc.binaryFiles(outFileName) - val (infile: String, indata: PortableDataStream) = inRdd.collect.head - + test("binary file input as byte array") { + sc = new SparkContext("local", "test") + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) + val (infile, indata) = inRdd.collect().head // Make sure the name and array match assert(infile.contains(outFile.toURI.getPath)) // a prefix may get added assert(indata.toArray === testOutput) @@ -262,159 +264,55 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { test("portabledatastream caching tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).cache() - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).cache() + inRdd.foreach(_._2.toArray()) // force the file to read // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream persist disk storage") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY) - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - - // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).persist(StorageLevel.DISK_ONLY) + inRdd.foreach(_._2.toArray()) // force the file to read + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream flatmap tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) val numOfCopies = 3 - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName) - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val copyRdd = mappedRdd.flatMap { - curData: (String, PortableDataStream) => - for (i <- 1 to numOfCopies) yield (i, curData._2) - } - - val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() - - // Try reading the output back as an object file + val copyRdd = inRdd.flatMap(curData => (0 until numOfCopies).map(_ => curData._2)) + val copyArr = copyRdd.collect() assert(copyArr.length == numOfCopies) - copyArr.foreach{ - cEntry: (Int, PortableDataStream) => - assert(cEntry._2.toArray === testOutput) + for (i <- copyArr.indices) { + assert(copyArr(i).toArray === testOutput) } - } test("fixed record length binary file as byte array") { - // a fixed length of 6 bytes - sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, testOutput.length) - // make sure there are enough elements + val outFile = writeBinaryData(testOutput, testOutputCopies) + val inRdd = sc.binaryRecords(outFile.getAbsolutePath, testOutput.length) assert(inRdd.count == testOutputCopies) - - // now just compare the first one - val indata: Array[Byte] = inRdd.collect.head - assert(indata === testOutput) + val inArr = inRdd.collect() + for (i <- inArr.indices) { + assert(inArr(i) === testOutput.map(b => (b + i).toByte)) + } } test ("negative binary record length should raise an exception") { - // a fixed length of 6 bytes sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, -1) - + val outFile = writeBinaryData(Array[Byte](1, 2, 3, 4, 5, 6), 1) intercept[SparkException] { - inRdd.count + sc.binaryRecords(outFile.getAbsolutePath, -1).count() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 0a4c141e5be38..a34f6c73fea86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -435,13 +435,12 @@ class StreamingContext private[streaming] ( conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() require(bytes.length == recordLength, "Byte array does not have correct length. " + s"${bytes.length} did not equal recordLength: $recordLength") bytes } - data } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 6fb50a4052712..b5d36a36513ab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -84,7 +84,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray + val output = outputQueue.asScala.flatten.toArray assert(output.length === expectedOutput.size) for (i <- output.indices) { assert(output(i) === expectedOutput(i)) @@ -155,14 +155,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // not enough to trigger a batch clock.advance(batchDuration.milliseconds / 2) - val input = Seq(1, 2, 3, 4, 5) - input.foreach { i => + val numCopies = 3 + val input = Array[Byte](1, 2, 3, 4, 5) + for (i <- 0 until numCopies) { Thread.sleep(batchDuration.milliseconds) val file = new File(testDir, i.toString) - Files.write(Array[Byte](i.toByte), file) + Files.write(input.map(b => (b + i).toByte), file) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) - logInfo("Created file " + file) + logInfo(s"Created file $file") // Advance the clock after creating the file to avoid a race when // setting its modification time clock.advance(batchDuration.milliseconds) @@ -170,10 +171,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(batchCounter.getNumCompletedBatches === i) } } - - val expectedOutput = input.map(i => i.toByte) - val obtainedOutput = outputQueue.asScala.flatten.toList.map(i => i(0).toByte) - assert(obtainedOutput.toSeq === expectedOutput) + val obtainedOutput = outputQueue.asScala.map(_.flatten).toSeq + for (i <- obtainedOutput.indices) { + assert(obtainedOutput(i) === input.map(b => (b + i).toByte)) + } } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -258,7 +259,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false val outputQueue = new ConcurrentLinkedQueue[Seq[Long]] - def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) + def output: Iterable[Long] = outputQueue.asScala.flatten // set up the network stream using the test receiver withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => From 121f2bf63eb1f99199e11e43bb55dc0452f9b791 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 20 Feb 2017 09:04:22 -0800 Subject: [PATCH 09/61] [SPARK-15453][SQL][FOLLOW-UP] FileSourceScanExec to extract `outputOrdering` information ### What changes were proposed in this pull request? `outputOrdering` is also dependent on whether the bucket has more than one files. The test cases fail when we try to move them to sql/core. This PR is to fix the test cases introduced in https://github.com/apache/spark/pull/14864 and add a test case to verify [the related logics](https://github.com/tejasapatil/spark/blob/070c24994747c0479fb2520774ede27ff1cf8cac/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L197-L206). ### How was this patch tested? N/A Author: Xiao Li Closes #16994 from gatorsmile/bucketingTS. --- .../spark/sql/sources/BucketedReadSuite.scala | 229 +++++++++++------- 1 file changed, 137 insertions(+), 92 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index d9ddcbd57ca83..4fc72b9e47597 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -227,6 +227,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + case class BucketedTableTestSpec( + bucketSpec: Option[BucketSpec], + numPartitions: Int = 10, + expectedShuffle: Boolean = true, + expectedSort: Boolean = true) + /** * A helper method to test the bucket read functionality using join. It will save `df1` and `df2` * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join @@ -234,14 +240,15 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet * exists as user expected according to the `shuffleLeft` and `shuffleRight`. */ private def testBucketing( - bucketSpecLeft: Option[BucketSpec], - bucketSpecRight: Option[BucketSpec], + bucketedTableTestSpecLeft: BucketedTableTestSpec, + bucketedTableTestSpecRight: BucketedTableTestSpec, joinType: String = "inner", - joinCondition: (DataFrame, DataFrame) => Column, - shuffleLeft: Boolean, - shuffleRight: Boolean, - sortLeft: Boolean = true, - sortRight: Boolean = true): Unit = { + joinCondition: (DataFrame, DataFrame) => Column): Unit = { + val BucketedTableTestSpec(bucketSpecLeft, numPartitionsLeft, shuffleLeft, sortLeft) = + bucketedTableTestSpecLeft + val BucketedTableTestSpec(bucketSpecRight, numPartitionsRight, shuffleRight, sortRight) = + bucketedTableTestSpecRight + withTable("bucketed_table1", "bucketed_table2") { def withBucket( writer: DataFrameWriter[Row], @@ -263,8 +270,10 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet }.getOrElse(writer) } - withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") - withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") + withBucket(df1.repartition(numPartitionsLeft).write.format("parquet"), bucketSpecLeft) + .saveAsTable("bucketed_table1") + withBucket(df2.repartition(numPartitionsRight).write.format("parquet"), bucketSpecRight) + .saveAsTable("bucketed_table2") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { @@ -291,10 +300,10 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet // check existence of sort assert( joinOperator.left.find(_.isInstanceOf[SortExec]).isDefined == sortLeft, - s"expected sort in plan to be $shuffleLeft but found\n${joinOperator.left}") + s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}") assert( joinOperator.right.find(_.isInstanceOf[SortExec]).isDefined == sortRight, - s"expected sort in plan to be $shuffleRight but found\n${joinOperator.right}") + s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") } } } @@ -305,138 +314,174 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("avoid shuffle when join 2 bucketed tables") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when join bucketed table and non-bucketed table") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = None, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when 2 bucketed tables have different bucket number") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) - val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketSpecRight = Some(BucketSpec(5, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) - val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketSpecRight = Some(BucketSpec(8, Seq("j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpecLeft, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpecRight, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i")), - shuffleLeft = false, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i")) ) } test("shuffle when join keys are not equal to bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("j")), - shuffleLeft = true, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("j")) ) } test("shuffle when join 2 bucketed tables with bucketing disabled") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) + val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = true) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = true, - shuffleRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } } - test("avoid shuffle and sort when bucket and sort columns are join keys") { + test("check sort and shuffle when bucket and sort columns are join keys") { + // In case of bucketing, its possible to have multiple files belonging to the + // same bucket in a given relation. Each of these files are locally sorted + // but those files combined together are not globally sorted. Given that, + // the RDD partition will not be sorted even if the relation has sort columns set + // Therefore, we still need to keep the Sort in both sides. val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + + val bucketedTableTestSpecLeft1 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + val bucketedTableTestSpecRight1 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft1, + bucketedTableTestSpecRight = bucketedTableTestSpecRight1, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft2 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight2 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft2, + bucketedTableTestSpecRight = bucketedTableTestSpecRight2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft3 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + val bucketedTableTestSpecRight3 = BucketedTableTestSpec( + bucketSpec, numPartitions = 50, expectedShuffle = false, expectedSort = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft3, + bucketedTableTestSpecRight = bucketedTableTestSpecRight3, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft4 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight4 = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft4, + bucketedTableTestSpecRight = bucketedTableTestSpecRight4, + joinCondition = joinCondition(Seq("i", "j")) ) } test("avoid shuffle and sort when sort columns are a super set of join keys") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i")) ) } test("only sort one side when sort columns are different") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } test("only sort one side when sort columns are same but their ordering is different") { - val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) - val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) + val bucketSpecLeft = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) + val bucketSpecRight = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpecLeft, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpecRight, numPartitions = 1, expectedShuffle = false, expectedSort = true) testBucketing( - bucketSpecLeft = bucketSpec1, - bucketSpecRight = bucketSpec2, - joinCondition = joinCondition(Seq("i", "j")), - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = true + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) ) } @@ -470,20 +515,20 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("SPARK-17698 Join predicates should not contain filter clauses") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i"))) + val bucketedTableTestSpecLeft = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec( + bucketSpec, numPartitions = 1, expectedShuffle = false, expectedSort = false) testBucketing( - bucketSpecLeft = bucketSpec, - bucketSpecRight = bucketSpec, + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, joinType = "fullouter", joinCondition = (left: DataFrame, right: DataFrame) => { val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _) val filterLeft = left("i") === Literal("1") val filterRight = right("i") === Literal("1") joinPredicates && filterLeft && filterRight - }, - shuffleLeft = false, - shuffleRight = false, - sortLeft = false, - sortRight = false + } ) } From 2e95b165a1a8f15141e6b47004e9920c5016a21e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 20 Feb 2017 12:21:07 -0800 Subject: [PATCH 10/61] [SPARK-19669][SQL] Open up visibility for sharedState, sessionState, and a few other functions ## What changes were proposed in this pull request? To ease debugging, most of Spark SQL internals have public level visibility. Two of the most important internal states, sharedState and sessionState, however, are package private. It would make more sense to open these up as well with clear documentation that they are internal. In addition, users currently have way to set active/default SparkSession, but no way to actually get them back. We should open those up as well. ## How was this patch tested? N/A - only visibility change. Author: Reynold Xin Closes #17002 from rxin/SPARK-19669. --- .../org/apache/spark/internal/Logging.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 29 +++++++++++++++---- .../spark/sql/internal/SharedState.scala | 4 +-- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index 013cd1c1bc037..c7f2847731fcb 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. */ -private[spark] trait Logging { +trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1975a56cafe85..72af55c1fa147 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -95,18 +95,28 @@ class SparkSession private( /** * State shared across sessions, including the `SparkContext`, cached data, listener, * and a catalog that interacts with external systems. + * + * This is internal to Spark and there is no guarantee on interface stability. + * + * @since 2.2.0 */ + @InterfaceStability.Unstable @transient - private[sql] lazy val sharedState: SharedState = { + lazy val sharedState: SharedState = { existingSharedState.getOrElse(new SharedState(sparkContext)) } /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. + * + * This is internal to Spark and there is no guarantee on interface stability. + * + * @since 2.2.0 */ + @InterfaceStability.Unstable @transient - private[sql] lazy val sessionState: SessionState = { + lazy val sessionState: SessionState = { SparkSession.reflect[SessionState, SparkSession]( SparkSession.sessionStateClassName(sparkContext.conf), self) @@ -613,7 +623,6 @@ class SparkSession private( * * @since 2.1.0 */ - @InterfaceStability.Stable def time[T](f: => T): T = { val start = System.nanoTime() val ret = f @@ -928,9 +937,19 @@ object SparkSession { defaultSession.set(null) } - private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + /** + * Returns the active SparkSession for the current thread, returned by the builder. + * + * @since 2.2.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) - private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the default SparkSession that is returned by the builder. + * + * @since 2.2.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** A global SQL listener used for the SQL UI. */ private[sql] val sqlListener = new AtomicReference[SQLListener]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 8de95fe64e663..7ce9938f0d075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -39,7 +39,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // Load hive-site.xml into hadoopConf and determine the warehouse path we want to use, based on // the config from both hive and Spark SQL. Finally set the warehouse config value to sparkConf. - val warehousePath = { + val warehousePath: String = { val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") if (configFile != null) { sparkContext.hadoopConfiguration.addResource(configFile) @@ -103,7 +103,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A manager for global temporary views. */ - val globalTempViewManager = { + val globalTempViewManager: GlobalTempViewManager = { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. From d305b042a82e61b22de15f0432127cd0d1ad96b5 Mon Sep 17 00:00:00 2001 From: windpiger Date: Mon, 20 Feb 2017 19:20:23 -0800 Subject: [PATCH 11/61] [SPARK-19669][HOTFIX][SQL] sessionState access privileges compiled failed in TestSQLContext ## What changes were proposed in this pull request? In [SPARK-19669](https://github.com/apache/spark/commit/0733a54a4517b82291efed9ac7f7407d9044593c) change the sessionState access privileges from private to public, this lead to the compile failed in TestSQLContext this pr is a hotfix for this. ## How was this patch tested? N/A Author: windpiger Closes #17008 from windpiger/hotfixcompile. --- .../test/scala/org/apache/spark/sql/test/TestSQLContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 2f247ca3e8b7f..8ab6db175da5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,7 +35,7 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { new SQLConf { clear() From cd7971a5fe266c03bb412dc6f9d3e816f9d45725 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Feb 2017 21:25:21 -0800 Subject: [PATCH 12/61] [SPARK-19508][CORE] Improve error message when binding service fails ## What changes were proposed in this pull request? Utils provides a helper function to bind service on port. This function can bind the service to a random free port. However, if the binding fails on a random free port, the retrying and final exception messages look confusing. 17/02/06 16:25:43 WARN Utils: Service 'sparkDriver' could not bind on port 0. Attempting port 1. 17/02/06 16:25:43 WARN Utils: Service 'sparkDriver' could not bind on port 0. Attempting port 1. 17/02/06 16:25:43 WARN Utils: Service 'sparkDriver' could not bind on port 0. Attempting port 1. 17/02/06 16:25:43 WARN Utils: Service 'sparkDriver' could not bind on port 0. Attempting port 1. 17/02/06 16:25:43 WARN Utils: Service 'sparkDriver' could not bind on port 0. Attempting port 1. ... 17/02/06 16:25:43 ERROR SparkContext: Error initializing SparkContext. java.net.BindException: Can't assign requested address: Service 'sparkDriver' failed after 16 retries (starting from 0)! Consider explicitly setting the appropriate port for the service 'sparkDriver' (for example spark.ui.port for SparkUI) to an available port or increasing spark.port.maxRetries. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #16851 from viirya/better-log-message. --- .../scala/org/apache/spark/util/Utils.scala | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1e6e9a223e295..55382899a34d7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2210,17 +2210,32 @@ private[spark] object Utils extends Logging { } catch { case e: Exception if isBindCollision(e) => if (offset >= maxRetries) { - val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + - s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + - s"the appropriate port for the service$serviceString (for example spark.ui.port " + - s"for SparkUI) to an available port or increasing spark.port.maxRetries." + val exceptionMessage = if (startPort == 0) { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (on a random free port)! " + + s"Consider explicitly setting the appropriate binding address for " + + s"the service$serviceString (for example spark.driver.bindAddress " + + s"for SparkDriver) to the correct binding address." + } else { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + + s"the appropriate port for the service$serviceString (for example spark.ui.port " + + s"for SparkUI) to an available port or increasing spark.port.maxRetries." + } val exception = new BindException(exceptionMessage) // restore original stack trace exception.setStackTrace(e.getStackTrace) throw exception } - logWarning(s"Service$serviceString could not bind on port $tryPort. " + - s"Attempting port ${tryPort + 1}.") + if (startPort == 0) { + // As startPort 0 is for a random free port, it is most possibly binding address is + // not correct. + logWarning(s"Service$serviceString could not bind on a random free port. " + + "You may check whether configuring an appropriate binding address.") + } else { + logWarning(s"Service$serviceString could not bind on port $tryPort. " + + s"Attempting port ${tryPort + 1}.") + } } } // Should never happen From 75f07ace355bf2ed1296dd009450c84583bc9b22 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 20 Feb 2017 21:26:54 -0800 Subject: [PATCH 13/61] [SPARK-18922][TESTS] Fix new test failures on Windows due to path and resource not closed ## What changes were proposed in this pull request? This PR proposes to fix new test failures on WIndows as below: **Before** ``` KafkaRelationSuite: - test late binding start offsets *** FAILED *** (7 seconds, 679 milliseconds) Cause: java.nio.file.FileSystemException: C:\projects\spark\target\tmp\spark-4c4b0cd1-4cb7-4908-949d-1b0cc8addb50\topic-4-0\00000000000000000000.log -> C:\projects\spark\target\tmp\spark-4c4b0cd1-4cb7-4908-949d-1b0cc8addb50\topic-4-0\00000000000000000000.log.deleted: The process cannot access the file because it is being used by another process. KafkaSourceSuite: - deserialization of initial offset with Spark 2.1.0 *** FAILED *** (3 seconds, 542 milliseconds) java.io.IOException: Failed to delete: C:\projects\spark\target\tmp\spark-97ef64fc-ae61-4ce3-ac59-287fd38bd824 - deserialization of initial offset written by Spark 2.1.0 *** FAILED *** (60 milliseconds) java.nio.file.InvalidPathException: Illegal char <:> at index 2: /C:/projects/spark/external/kafka-0-10-sql/target/scala-2.11/test-classes/kafka-source-initial-offset-version-2.1.0.b HiveDDLSuite: - partitioned table should always put partition columns at the end of table schema *** FAILED *** (657 milliseconds) org.apache.spark.sql.AnalysisException: Path does not exist: file:/C:projectsspark arget mpspark-f1b83d09-850a-4bba-8e43-a2a28dfaa757; DDLSuite: - create a data source table without schema *** FAILED *** (94 milliseconds) org.apache.spark.sql.AnalysisException: Path does not exist: file:/C:projectsspark arget mpspark-a3f3c161-afae-4d6f-9182-e8642f77062b; - SET LOCATION for managed table *** FAILED *** (219 milliseconds) org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree: Exchange SinglePartit +- *HashAggregate(keys=[], functions=[partial_count(1)], output=[count#99367L]) +- *FileScan parquet default.tbl[] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/C:projectsspark arget mpspark-15be2f2f-4ea9-4c47-bfee-1b7b49363033], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<> - insert data to a data source table which has a not existed location should succeed *** FAILED *** (16 milliseconds) org.apache.spark.sql.AnalysisException: Path does not exist: file:/C:projectsspark arget mpspark-34987671-e8d1-4624-ba5b-db1012e1246b; - insert into a data source table with no existed partition location should succeed *** FAILED *** (16 milliseconds) org.apache.spark.sql.AnalysisException: Path does not exist: file:/C:projectsspark arget mpspark-4c6ccfbf-4091-4032-9fbc-3d40c58267d5; - read data from a data source table which has a not existed location should succeed *** FAILED *** (0 milliseconds) - read data from a data source table with no existed partition location should succeed *** FAILED *** (0 milliseconds) org.apache.spark.sql.AnalysisException: Path does not exist: file:/C:projectsspark arget mpspark-6af39e37-abd1-44e8-ac68-e2dfcf67a2f3; InputOutputMetricsSuite: - output metrics on records written *** FAILED *** (0 milliseconds) java.lang.IllegalArgumentException: Wrong FS: file://C:\projects\spark\target\tmp\spark-cd69ee77-88f2-4202-bed6-19c0ee05ef55\InputOutputMetricsSuite, expected: file:/// - output metrics on records written - new Hadoop API *** FAILED *** (16 milliseconds) java.lang.IllegalArgumentException: Wrong FS: file://C:\projects\spark\target\tmp\spark-b69e8fcb-047b-4de8-9cdf-5f026efb6762\InputOutputMetricsSuite, expected: file:/// ``` **After** ``` KafkaRelationSuite: - test late binding start offsets !!! CANCELED !!! (62 milliseconds) KafkaSourceSuite: - deserialization of initial offset with Spark 2.1.0 (5 seconds, 341 milliseconds) - deserialization of initial offset written by Spark 2.1.0 (910 milliseconds) HiveDDLSuite: - partitioned table should always put partition columns at the end of table schema (2 seconds) DDLSuite: - create a data source table without schema (828 milliseconds) - SET LOCATION for managed table (406 milliseconds) - insert data to a data source table which has a not existed location should succeed (406 milliseconds) - insert into a data source table with no existed partition location should succeed (453 milliseconds) - read data from a data source table which has a not existed location should succeed (94 milliseconds) - read data from a data source table with no existed partition location should succeed (265 milliseconds) InputOutputMetricsSuite: - output metrics on records written (172 milliseconds) - output metrics on records written - new Hadoop API (297 milliseconds) ``` ## How was this patch tested? Fixed tests in `InputOutputMetricsSuite`, `KafkaRelationSuite`, `KafkaSourceSuite`, `DDLSuite.scala` and `HiveDDLSuite`. Manually tested via AppVeyor as below: `InputOutputMetricsSuite`: https://ci.appveyor.com/project/spark-test/spark/build/633-20170219-windows-test/job/ex8nvwa6tsh7rmto `KafkaRelationSuite`: https://ci.appveyor.com/project/spark-test/spark/build/633-20170219-windows-test/job/h8dlcowew52y8ncw `KafkaSourceSuite`: https://ci.appveyor.com/project/spark-test/spark/build/634-20170219-windows-test/job/9ybgjl7yeubxcre4 `DDLSuite`: https://ci.appveyor.com/project/spark-test/spark/build/635-20170219-windows-test `HiveDDLSuite`: https://ci.appveyor.com/project/spark-test/spark/build/633-20170219-windows-test/job/up6o9n47er087ltb Author: hyukjinkwon Closes #16999 from HyukjinKwon/windows-fix. --- .../metrics/InputOutputMetricsSuite.scala | 4 +-- .../sql/kafka010/KafkaRelationSuite.scala | 4 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 20 ++++++----- .../sql/execution/command/DDLSuite.scala | 33 ++++++++++--------- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- 5 files changed, 35 insertions(+), 28 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index becf3829e7248..5d522189a0c29 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -259,7 +259,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext test("output metrics on records written") { val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath + val filePath = file.toURI.toURL.toString val records = runAndReturnRecordsWritten { sc.parallelize(1 to numRecords).saveAsTextFile(filePath) @@ -269,7 +269,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext test("output metrics on records written - new Hadoop API") { val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath + val filePath = file.toURI.toURL.toString val records = runAndReturnRecordsWritten { sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 673d60ff6f87a..68bc3e3e2e9a8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { @@ -147,6 +148,9 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon } test("test late binding start offsets") { + // Kafka fails to remove the logs on Windows. See KAFKA-1194. + assume(!Utils.isWindows) + var kafkaUtils: KafkaTestUtils = null try { /** diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 4f82b133cb4c8..534fb77c9ce18 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -161,11 +162,12 @@ class KafkaSourceSuite extends KafkaSourceTest { // Make sure Spark 2.1.0 will throw an exception when reading the new log intercept[java.lang.IllegalArgumentException] { // Simulate how Spark 2.1.0 reads the log - val in = new FileInputStream(metadataPath.getAbsolutePath + "/0") - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) + Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => + val length = in.read() + val bytes = new Array[Byte](length) + in.read(bytes) + KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) + } } } } @@ -181,13 +183,13 @@ class KafkaSourceSuite extends KafkaSourceTest { "subscribe" -> topic ) - val from = Paths.get( - getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").getPath) + val from = new File( + getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") Files.copy(from, to) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) + val source = provider.createSource( + spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) val deserializedOffset = source.getOffset.get val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) assert(referenceOffset == deserializedOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e1a3b247fd4fc..b44f20e367f0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1520,7 +1520,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tab1 USING json") }.getMessage assert(e.contains("Unable to infer schema for JSON. It must be specified manually")) - sql(s"CREATE TABLE tab2 using json location '${tempDir.getCanonicalPath}'") + sql(s"CREATE TABLE tab2 using json location '${tempDir.toURI}'") checkAnswer(spark.table("tab2"), Row("a", "b")) } } @@ -1814,7 +1814,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - sql(s"ALTER TABLE tbl SET LOCATION '${dir.getCanonicalPath}'") + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") spark.catalog.refreshTable("tbl") // SET LOCATION won't move data from previous table path to new table path. assert(spark.table("tbl").count() == 0) @@ -1836,15 +1836,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert data to a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "$path") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == path) dir.delete val tableLocFile = new File(table.location.stripPrefix("file:")) @@ -1859,8 +1859,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(tableLocFile.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) - val newDir = dir.getAbsolutePath.stripSuffix("/") + "/x" - val newDirFile = new File(newDir) + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI.toString spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) @@ -1878,16 +1878,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert into a data source table with no existed partition location should succeed") { withTable("t") { withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "$path" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == path) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1906,25 +1906,26 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("read data from a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "$path") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - val expectedPath = dir.getAbsolutePath.stripSuffix("/") - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == path) dir.delete() checkAnswer(spark.table("t"), Nil) - val newDir = dir.getAbsolutePath.stripSuffix("/") + "/x" + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI.toString spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table1.location == newDir) - assert(!new File(newDir).exists()) + assert(!newDirFile.exists()) checkAnswer(spark.table("t"), Nil) } } @@ -1938,7 +1939,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "${dir.toURI}" """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c04b9ee0f2cd5..792ac1e259494 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1570,7 +1570,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.getCanonicalPath}'") + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } From 66434d6c304541ff6e0d649058a0cd80cc515401 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 21 Feb 2017 09:38:14 -0800 Subject: [PATCH 14/61] [SPARK-19337][ML][DOC] Documentation and examples for LinearSVC ## What changes were proposed in this pull request? Documentation and examples (Java, scala, python, R) for LinearSVC ## How was this patch tested? local doc generation Author: Yuhao Yang Closes #16968 from hhbyyh/mlsvmdoc. --- docs/ml-classification-regression.md | 44 +++++++++++++++ .../examples/ml/JavaLinearSVCExample.java | 54 +++++++++++++++++++ examples/src/main/python/ml/linearsvc.py | 46 ++++++++++++++++ .../spark/examples/ml/LinearSVCExample.scala | 52 ++++++++++++++++++ 4 files changed, 196 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java create mode 100644 examples/src/main/python/ml/linearsvc.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 782ee58188934..37862f82c3386 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -363,6 +363,50 @@ Refer to the [R API docs](api/R/spark.mlp.html) for more details. +## Linear Support Vector Machine + +A [support vector machine](https://en.wikipedia.org/wiki/Support_vector_machine) constructs a hyperplane +or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, +regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has +the largest distance to the nearest training-data points of any class (so-called functional margin), +since in general the larger the margin the lower the generalization error of the classifier. LinearSVC +in Spark ML supports binary classification with linear SVM. Internally, it optimizes the +[Hinge Loss](https://en.wikipedia.org/wiki/Hinge_loss) using OWLQN optimizer. + + +**Examples** + +
+ +
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.LinearSVC) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LinearSVCExample.scala %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/LinearSVC.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLinearSVCExample.java %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.LinearSVC) for more details. + +{% include_example python/ml/linearsvc.py %} +
+ +
+ +Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. + +{% include_example r/ml/svmLinear.R %} +
+ +
## One-vs-Rest classifier (a.k.a. One-vs-All) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java new file mode 100644 index 0000000000000..a18ed1d0b48fa --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.classification.LinearSVC; +import org.apache.spark.ml.classification.LinearSVCModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off$ + +public class JavaLinearSVCExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaLinearSVCExample") + .getOrCreate(); + + // $example on$ + // Load training data + Dataset training = spark.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LinearSVC lsvc = new LinearSVC() + .setMaxIter(10) + .setRegParam(0.1); + + // Fit the model + LinearSVCModel lsvcModel = lsvc.fit(training); + + // Print the coefficients and intercept for LinearSVC + System.out.println("Coefficients: " + + lsvcModel.coefficients() + " Intercept: " + lsvcModel.intercept()); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/linearsvc.py b/examples/src/main/python/ml/linearsvc.py new file mode 100644 index 0000000000000..18cbf87a10695 --- /dev/null +++ b/examples/src/main/python/ml/linearsvc.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LinearSVC +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("linearSVC Example")\ + .getOrCreate() + + # $example on$ + # Load training data + training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lsvc = LinearSVC(maxIter=10, regParam=0.1) + + # Fit the model + lsvcModel = lsvc.fit(training) + + # Print the coefficients and intercept for linearsSVC + print("Coefficients: " + str(lsvcModel.coefficients)) + print("Intercept: " + str(lsvcModel.intercept)) + + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala new file mode 100644 index 0000000000000..5f43e65712b5d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearSVCExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LinearSVC +// $example off$ +import org.apache.spark.sql.SparkSession + +object LinearSVCExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("LinearSVCExample") + .getOrCreate() + + // $example on$ + // Load training data + val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lsvc = new LinearSVC() + .setMaxIter(10) + .setRegParam(0.1) + + // Fit the model + val lsvcModel = lsvc.fit(training) + + // Print the coefficients and intercept for linear svc + println(s"Coefficients: ${lsvcModel.coefficients} Intercept: ${lsvcModel.intercept}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 1d84b30c8147e7c084db20cfafe7557ab5aae0fc Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 21 Feb 2017 09:57:40 -0800 Subject: [PATCH 15/61] [SPARK-19626][YARN] Using the correct config to set credentials update time ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/14065, we introduced a configurable credential manager for Spark running on YARN. Also two configs `spark.yarn.credentials.renewalTime` and `spark.yarn.credentials.updateTime` were added, one is for the credential renewer and the other updater. But now we just query `spark.yarn.credentials.renewalTime` by mistake during CREDENTIALS UPDATING, where should be actually `spark.yarn.credentials.updateTime` . This PR fixes this mistake. ## How was this patch tested? existing test cc jerryshao vanzin Author: Kent Yao Closes #16955 from yaooqinn/cred_update. --- .../apache/spark/deploy/yarn/security/CredentialUpdater.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 5df4fbd9c1537..2fdb70a73c754 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -55,7 +55,7 @@ private[spark] class CredentialUpdater( /** Start the credential updater task */ def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) val remainingTime = startTime - System.currentTimeMillis() if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) From 0bd7ec5b891942892f729bb91d436f7e08b565cb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 21 Feb 2017 16:14:34 -0800 Subject: [PATCH 16/61] [SPARK-19652][UI] Do auth checks for REST API access. The REST API has a security filter that performs auth checks based on the UI root's security manager. That works fine when the UI root is the app's UI, but not when it's the history server. In the SHS case, all users would be allowed to see all applications through the REST API, even if the UI itself wouldn't be available to them. This change adds auth checks for each app access through the API too, so that only authorized users can see the app's data. The change also modifies the existing security filter to use `HttpServletRequest.getRemoteUser()`, which is used in other places. That is not necessarily the same as the principal's name; for example, when using Hadoop's SPNEGO auth filter, the remote user strips the realm information, which then matches the user name registered as the owner of the application. I also renamed the UIRootFromServletContext trait to a more generic name since I'm using it to store more context information now. Tested manually with an authentication filter enabled. Author: Marcelo Vanzin Closes #16978 from vanzin/SPARK-19652. --- .../scala/org/apache/spark/TestUtils.scala | 6 +- .../spark/status/api/v1/ApiRootResource.scala | 78 +++++++++++-------- .../spark/status/api/v1/SecurityFilter.scala | 6 +- .../org/apache/spark/ui/JettyUtils.scala | 4 +- .../deploy/history/HistoryServerSuite.scala | 62 ++++++++++++++- project/MimaExcludes.scala | 4 + .../api/v1/streaming/ApiStreamingApp.scala | 8 +- 7 files changed, 123 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 109104f0a537b..3f912dc191515 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -200,9 +200,13 @@ private[spark] object TestUtils { /** * Returns the response code from an HTTP(S) URL. */ - def httpResponseCode(url: URL, method: String = "GET"): Int = { + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) + headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } // Disable cert and host name validation for HTTPS tests. if (connection.isInstanceOf[HttpsURLConnection]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 17bc04303fa8b..67ccf43afa44a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1 import java.util.zip.ZipOutputStream import javax.servlet.ServletContext +import javax.servlet.http.HttpServletRequest import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class ApiRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJobs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs") def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs/{jobId: \\d+}") def getJob(@PathParam("appId") appId: String): OneJobResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneJobResource(ui) } } @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJob( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneJobResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneJobResource(ui) } } @Path("applications/{appId}/executors") def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ExecutorListResource(ui) } } @Path("applications/{appId}/allexecutors") def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllExecutorListResource(ui) } } @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ExecutorListResource(ui) } } @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getAllExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllExecutorListResource(ui) } } - @Path("applications/{appId}/stages") def getStages(@PathParam("appId") appId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllStagesResource(ui) } } @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStages( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") def getStage(@PathParam("appId") appId: String): OneStageResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneStageResource(ui) } } @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStage( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneStageResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneStageResource(ui) } } @Path("applications/{appId}/storage/rdd") def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllRDDResource(ui) } } @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdds( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllRDDResource(ui) } } @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneRDDResource(ui) } } @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdd( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneRDDResource(ui) } } @@ -234,19 +234,6 @@ private[spark] trait UIRoot { .status(Response.Status.SERVICE_UNAVAILABLE) .build() } - - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - getSparkUI(appKey) match { - case Some(ui) => - f(ui) - case None => throw new NotFoundException("no such app: " + appId) - } - } def securityManager: SecurityManager } @@ -263,13 +250,38 @@ private[v1] object UIRootFromServletContext { } } -private[v1] trait UIRootFromServletContext { +private[v1] trait ApiRequestContext { + @Context + protected var servletContext: ServletContext = _ + @Context - var servletContext: ServletContext = _ + protected var httpRequest: HttpServletRequest = _ def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) + + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + uiRoot.getSparkUI(appKey) match { + case Some(ui) => + val user = httpRequest.getRemoteUser() + if (!ui.securityManager.checkUIViewPermissions(user)) { + throw new ForbiddenException(raw"""user "$user" is not authorized""") + } + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } + } +private[v1] class ForbiddenException(msg: String) extends WebApplicationException( + Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + private[v1] class NotFoundException(msg: String) extends WebApplicationException( new NoSuchElementException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index b4a991eda35f3..1cd37185d6601 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response import javax.ws.rs.ext.Provider @Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { +private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { override def filter(req: ContainerRequestContext): Unit = { - val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + val user = httpRequest.getRemoteUser() if (!uiRoot.securityManager.checkUIViewPermissions(user)) { req.abortWith( Response .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user"is not authorized""") + .entity(raw"""user "$user" is not authorized""") .build() ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 7909821db954b..bdbdba5780856 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging { response.setHeader("X-Frame-Options", xFrameOptionsValue) response.getWriter.print(servletParams.extractFn(result)) } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setStatus(HttpServletResponse.SC_FORBIDDEN) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + response.sendError(HttpServletResponse.SC_FORBIDDEN, "User is not authorized to access this page.") } } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index b2eded43ba71f..dcf83cb530a91 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} import scala.concurrent.duration._ import scala.language.postfixOps @@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var server: HistoryServer = null private var port: Int = -1 - def init(): Unit = { + def init(extraConf: (String, String)*): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() val securityManager = HistoryServer.createSecurityManager(conf) @@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } + test("ui and api authorization checks") { + val appId = "app-20161115172038-0000" + val owner = "jose" + val admin = "root" + val other = "alice" + + stop() + init( + "spark.ui.filters" -> classOf[FakeAuthFilter].getName(), + "spark.history.ui.acls.enable" -> "true", + "spark.history.ui.admin.acls" -> admin) + + val tests = Seq( + (owner, HttpServletResponse.SC_OK), + (admin, HttpServletResponse.SC_OK), + (other, HttpServletResponse.SC_FORBIDDEN), + // When the remote user is null, the code behaves as if auth were disabled. + (null, HttpServletResponse.SC_OK)) + + val port = server.boundPort + val testUrls = Seq( + s"http://localhost:$port/api/v1/applications/$appId/jobs", + s"http://localhost:$port/history/$appId/jobs/") + + tests.foreach { case (user, expectedCode) => + testUrls.foreach { url => + val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil + val sc = TestUtils.httpResponseCode(new URL(url), headers = headers) + assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)") + } + } + } + def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) } @@ -648,3 +683,26 @@ object HistoryServerSuite { } } } + +/** + * A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header. + */ +class FakeAuthFilter extends Filter { + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val wrapped = new HttpServletRequestWrapper(hreq) { + override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER) + } + chain.doFilter(wrapped, res) + } + +} + +object FakeAuthFilter { + val FAKE_HTTP_USER = "HTTP_USER" +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9d359427f27a6..511686fb4f37f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( + // [SPARK-19652][UI] Do auth checks for REST API access. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"), + // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala index e64830a9459b1..aea75d5a9c8d0 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -19,14 +19,14 @@ package org.apache.spark.status.api.v1.streaming import javax.ws.rs.{Path, PathParam} -import org.apache.spark.status.api.v1.UIRootFromServletContext +import org.apache.spark.status.api.v1.ApiRequestContext @Path("/v1") -private[v1] class ApiStreamingApp extends UIRootFromServletContext { +private[v1] class ApiStreamingApp extends ApiRequestContext { @Path("applications/{appId}/streaming") def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ApiStreamingRootResource(ui) } } @@ -35,7 +35,7 @@ private[v1] class ApiStreamingApp extends UIRootFromServletContext { def getStreamingRoot( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ApiStreamingRootResource(ui) } } From 0167cc39597216c5f0860ff1fca1e14b55011b2c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 21 Feb 2017 19:30:36 -0800 Subject: [PATCH 17/61] [SPARK-19670][SQL][TEST] Enable Bucketed Table Reading and Writing Testing Without Hive Support ### What changes were proposed in this pull request? Bucketed table reading and writing does not need Hive support. We can move the test cases from `sql/hive` to `sql/core`. After this PR, we can improve the test case coverage. Bucket table reading and writing can be tested with and without Hive support. ### How was this patch tested? N/A Author: Xiao Li Closes #17004 from gatorsmile/mvTestCaseForBuckets. --- .../spark/sql/sources/BucketedReadSuite.scala | 30 ++++++++++------ .../sql/sources/BucketedWriteSuite.scala | 34 +++++++++++++------ .../BucketedReadWithHiveSupportSuite.scala | 28 +++++++++++++++ .../BucketedWriteWithHiveSupportSuite.scala | 30 ++++++++++++++++ 4 files changed, 101 insertions(+), 21 deletions(-) rename sql/{hive => core}/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala (95%) rename sql/{hive => core}/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala (88%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala similarity index 95% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 4fc72b9e47597..9b65419dba234 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -29,17 +29,25 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet -class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class BucketedReadWithoutHiveSupportSuite extends BucketedReadSuite with SharedSQLContext { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + } +} + + +abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { import testImplicits._ - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - private val nullDF = (for { + private lazy val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private lazy val nullDF = (for { i <- 0 to 50 s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") @@ -224,8 +232,10 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } - private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") - private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + private lazy val df1 = + (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + private lazy val df2 = + (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") case class BucketedTableTestSpec( bucketSpec: Option[BucketSpec], @@ -535,7 +545,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val warehouseFilePath = new URI(hiveContext.sparkSession.getWarehousePath).getPath + val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath val tableDir = new File(warehouseFilePath, "bucketed_table") Utils.deleteRecursively(tableDir) df1.write.parquet(tableDir.getAbsolutePath) @@ -553,9 +563,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - checkAnswer(hiveContext.table("bucketed_table").select("j"), df1.select("j")) + checkAnswer(spark.table("bucketed_table").select("j"), df1.select("j")) - checkAnswer(hiveContext.table("bucketed_table").groupBy("j").agg(max("k")), + checkAnswer(spark.table("bucketed_table").groupBy("j").agg(max("k")), df1.groupBy("j").agg(max("k"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala similarity index 88% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 61cef2a8008f2..9082261af7b00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,19 +20,29 @@ package org.apache.spark.sql.sources import java.io.File import java.net.URI -import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class BucketedWriteWithoutHiveSupportSuite extends BucketedWriteSuite with SharedSQLContext { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + } + + override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "json") +} + +abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { import testImplicits._ + protected def fileFormatsToTest: Seq[String] + test("bucketed by non-existing column") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) @@ -76,11 +86,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(e.getMessage == "'insertInto' does not support bucketing right now;") } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private lazy val df = { + (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + } def tableDir: File = { val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") - new File(URI.create(hiveContext.sessionState.catalog.hiveDefaultTableFilePath(identifier))) + new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier))) } /** @@ -141,7 +153,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -157,7 +169,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data with sortBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -190,7 +202,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data without partitionBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -203,7 +215,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } test("write bucketed data without partitionBy with sortBy") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) @@ -219,7 +231,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data with bucketing disabled") { // The configuration BUCKETING_ENABLED does not affect the writing path withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - for (source <- Seq("parquet", "json", "orc")) { + for (source <- fileFormatsToTest) { withTable("bucketed_table") { df.write .format(source) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala new file mode 100644 index 0000000000000..f277f99805a4a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +class BucketedReadWithHiveSupportSuite extends BucketedReadSuite with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala new file mode 100644 index 0000000000000..454e2f65d5d88 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +class BucketedWriteWithHiveSupportSuite extends BucketedWriteSuite with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") + } + + override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc") +} From 40f926b851da0f12eb44016b79a49adfff33aeef Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 22 Feb 2017 16:33:14 +0200 Subject: [PATCH 18/61] [SPARK-19694][ML] Add missing 'setTopicDistributionCol' for LDAModel ## What changes were proposed in this pull request? Add missing 'setTopicDistributionCol' for LDAModel ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #17021 from zhengruifeng/lda_outputCol. --- mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index bbcef3502d1dc..55720e2d613d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -437,6 +437,9 @@ abstract class LDAModel private[ml] ( @Since("1.6.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) + @Since("2.2.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + /** @group setParam */ @Since("1.6.0") def setSeed(value: Long): this.type = set(seed, value) From d5c014dfd80e069552980fdf5ee71c57497837b9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 22 Feb 2017 16:36:03 +0200 Subject: [PATCH 19/61] [SPARK-19679][ML] Destroy broadcasted object without blocking ## What changes were proposed in this pull request? Destroy broadcasted object without blocking use `find mllib -name '*.scala' | xargs -i bash -c 'egrep "destroy" -n {} && echo {}'` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #17016 from zhengruifeng/destroy_without_block. --- .../org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala | 2 +- .../main/scala/org/apache/spark/mllib/optimization/LBFGS.scala | 2 +- .../org/apache/spark/mllib/tree/model/treeEnsembleModels.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index f3bace8181570..4c525c0714ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -226,7 +226,7 @@ private[spark] object GradientBoostedTrees extends Logging { (a, b) => treesIndices.map(idx => a(idx) + b(idx))) .map(_ / dataCount) - broadcastTrees.destroy() + broadcastTrees.destroy(blocking = false) evaluation.toArray } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 7a714db853353..efedebe301387 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -261,7 +261,7 @@ object LBFGS extends Logging { val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp) // broadcasted model is not needed anymore - bcW.destroy() + bcW.destroy(blocking = false) /** * regVal is sum of weight squares if it's L2 updater; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index fc1d4125a5649..b1e82656a2405 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -167,7 +167,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( (a, b) => treesIndices.map(idx => a(idx) + b(idx))) .map(_ / dataCount) - broadcastTrees.destroy() + broadcastTrees.destroy(blocking = false) evaluation.toArray } From 022d91956626eac8b5efa0514ac8d2dceb54aafe Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Wed, 22 Feb 2017 15:42:40 +0100 Subject: [PATCH 20/61] [SPARK-13721][SQL] Make GeneratorOuter unresolved. ## What changes were proposed in this pull request? This is a small change to make GeneratorOuter always unresolved. It is mostly no-op change but makes it more clear since GeneratorOuter shouldn't survive analysis phase. This requires also handling in ResolveAliases rule. ## How was this patch tested? Existing generator tests. Author: Bogdan Raducanu Author: Reynold Xin Closes #17026 from bogdanrdc/PR16958. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/expressions/generators.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 39a276284c35e..c477cb48d0b07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -222,6 +222,7 @@ class Analyzer( expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() @@ -1665,7 +1666,6 @@ class Analyzer( var resolvedGenerator: Generate = null val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => // It's a sanity check, this should not happen as the previous case will throw // exception earlier. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1b98c30d3760b..e84796f2edad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -173,7 +173,6 @@ case class Stack(children: Seq[Expression]) extends Generator { } } - /** * Only support code generation when stack produces 50 rows or less. */ @@ -204,6 +203,10 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Wrapper around another generator to specify outer behavior. This is used to implement functions + * such as explode_outer. This expression gets replaced during analysis. + */ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generator { final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") @@ -212,7 +215,10 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") override def elementSchema: StructType = child.elementSchema + + override lazy val resolved: Boolean = false } + /** * A base class for [[Explode]] and [[PosExplode]]. */ From 93c647770b5e02ce238fc58d9a7fe271086b9751 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Wed, 22 Feb 2017 11:32:36 -0500 Subject: [PATCH 21/61] [SPARK-19405][STREAMING] Support for cross-account Kinesis reads via STS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add dependency on aws-java-sdk-sts - Replace SerializableAWSCredentials with new SerializableCredentialsProvider interface - Make KinesisReceiver take SerializableCredentialsProvider as argument and pass credential provider to KCL - Add new implementations of KinesisUtils.createStream() that take STS arguments - Make JavaKinesisStreamSuite test the entire KinesisUtils Java API - Update KCL/AWS SDK dependencies to 1.7.x/1.11.x ## What changes were proposed in this pull request? [JIRA link with detailed description.](https://issues.apache.org/jira/browse/SPARK-19405) * Replace SerializableAWSCredentials with new SerializableKCLAuthProvider class that takes 5 optional config params for configuring AWS auth and returns the appropriate credential provider object * Add new public createStream() APIs for specifying these parameters in KinesisUtils ## How was this patch tested? * Manually tested using explicit keypair and instance profile to read data from Kinesis stream in separate account (difficult to write a test orchestrating creation and assumption of IAM roles across separate accounts) * Expanded JavaKinesisStreamSuite to test the entire Java API in KinesisUtils ## License acknowledgement This contribution is my original work and that I license the work to the project under the project’s open source license. Author: Budde Closes #16744 from budde/master. --- external/kinesis-asl/pom.xml | 5 + .../streaming/JavaKinesisWordCountASL.java | 2 +- .../streaming/KinesisExampleUtils.scala | 35 ++++ .../streaming/KinesisWordCountASL.scala | 2 +- .../kinesis/KinesisBackedBlockRDD.scala | 8 +- .../kinesis/KinesisCheckpointer.scala | 2 +- .../kinesis/KinesisInputDStream.scala | 7 +- .../streaming/kinesis/KinesisReceiver.scala | 51 ++--- .../kinesis/KinesisRecordProcessor.scala | 2 +- .../streaming/kinesis/KinesisTestUtils.scala | 14 +- .../streaming/kinesis/KinesisUtils.scala | 192 ++++++++++++++++-- .../SerializableCredentialsProvider.scala | 85 ++++++++ .../kinesis/JavaKinesisStreamSuite.java | 35 +++- .../kinesis/KinesisReceiverSuite.scala | 25 ++- .../kinesis/KinesisStreamSuite.scala | 9 +- pom.xml | 4 +- python/pyspark/streaming/kinesis.py | 12 +- 17 files changed, 407 insertions(+), 83 deletions(-) create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index b2bac7c938ab5..daa79e79163b9 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -58,6 +58,11 @@ amazon-kinesis-client ${aws.kinesis.client.version} + + com.amazonaws + aws-java-sdk-sts + ${aws.java.sdk.version} + com.amazonaws amazon-kinesis-producer diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index d40bd3ff560d6..d1274a687fc70 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -127,7 +127,7 @@ public static void main(String[] args) throws Exception { // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl); // Setup the Spark config and StreamingContext SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala new file mode 100644 index 0000000000000..2eebd6130d4da --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import scala.collection.JavaConverters._ + +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.AmazonKinesis + +private[streaming] object KinesisExampleUtils { + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index a70c13d7d68a8..f14117b708a0d 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging { // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl) // Setup the SparkConfig and StreamingContext val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 45dc3c388cb8d..23c4d99e50f51 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val awsCredentialsOption: Option[SerializableAWSCredentials] = None + val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = awsCredentialsOption.getOrElse { - new DefaultAWSCredentialsProviderChain().getCredentials() - } + val credentials = kinesisCredsProvider.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) @@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator( private var lastSeqNumber: String = null private var internalIterator: Iterator[Record] = null - client.setEndpoint(endpointUrl, "kinesis", regionId) + client.setEndpoint(endpointUrl) override protected def getNext(): Record = { var nextRecord: Record = null diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index c445c15a5f644..5fb83b26f8382 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -21,7 +21,7 @@ import java.util.concurrent._ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import org.apache.spark.internal.Logging import org.apache.spark.streaming.Duration diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 5223c81a8e0e0..fbc6b99443ed7 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials] + kinesisCredsProvider: SerializableCredentialsProvider ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - awsCredentialsOption = awsCredentialsOption) + kinesisCredsProvider = kinesisCredsProvider) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) + checkpointAppName, checkpointInterval, storageLevel, messageHandler, + kinesisCredsProvider) } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 393e56a39320c..13fc54e531dda 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -private[kinesis] -case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) - extends AWSCredentials { - override def getAWSAccessKeyId: String = accessKeyId - override def getAWSSecretKey: String = secretKey -} - /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: @@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies - * the credentials + * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to + * generate the AWSCredentialsProvider instance used for KCL + * authorization. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials]) + kinesisCredsProvider: SerializableCredentialsProvider) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) - // KCL config instance - val awsCredProvider = resolveAWSCredentialsProvider() - val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) - .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) - .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) + val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( + checkpointAppName, + streamName, + kinesisCredsProvider.provider, + workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) /* * RecordProcessorFactory creates impls of IRecordProcessor. @@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T]( } } - /** - * If AWS credential is provided, return a AWSCredentialProvider returning that credential. - * Otherwise, return the DefaultAWSCredentialsProviderChain. - */ - private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { - awsCredentialsOption match { - case Some(awsCredentials) => - logInfo("Using provided AWS credentials") - new AWSCredentialsProvider { - override def getCredentials: AWSCredentials = awsCredentials - override def refresh(): Unit = { } - } - case None => - logInfo("Using DefaultAWSCredentialsProviderChain") - new DefaultAWSCredentialsProviderChain() - } - } - - /** * Class to handle blocks generated by this receiver's block generator. Specifically, in * the context of the Kinesis Receiver, this handler does the following. diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 73ccc4ad23f6d..8c6a399dd763e 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f183ef00b33cd..73ac7a3cd2355 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB -import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient} import com.amazonaws.services.kinesis.model._ import org.apache.spark.internal.Logging @@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl) private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils { val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } + lazy val shouldRunTests = { val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") if (isEnvSet) { diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index b2daffa34ccbf..2d777982e760c 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -73,7 +73,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, None) + cleanedHandler, DefaultCredentialsProvider) } } @@ -123,9 +123,80 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + cleanedHandler, kinesisCredsProvider) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = STSCredentialsProvider( + stsRoleArn = stsAssumeRoleArn, + stsSessionName = stsSessionName, + stsExternalId = Option(stsExternalId), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey)) + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, kinesisCredsProvider) } } @@ -169,7 +240,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, None) + defaultMessageHandler, DefaultCredentialsProvider) } } @@ -213,9 +284,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + defaultMessageHandler, kinesisCredsProvider) } } @@ -319,6 +393,68 @@ object KinesisUtils { awsAccessKeyId, awsSecretKey) } + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -404,10 +540,6 @@ object KinesisUtils { defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } - private def getRegionByEndpoint(endpointUrl: String): String = { - RegionUtils.getRegionByEndpoint(endpointUrl).getName() - } - private def validateRegion(regionName: String): String = { Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") @@ -439,6 +571,7 @@ private class KinesisUtilsPythonHelper { } } + // scalastyle:off def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -449,22 +582,43 @@ private class KinesisUtilsPythonHelper { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = { + // scalastyle:on + if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) + && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) { + throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " + + "must all be defined or all be null") + } + + if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + stsAssumeRoleArn, stsSessionName, stsExternalId) + } else { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + } + + // Throw IllegalArgumentException unless both values are null or neither are. + private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) { if (awsAccessKeyId == null && awsSecretKey != null) { throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") } if (awsAccessKeyId != null && awsSecretKey == null) { throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") } - if (awsAccessKeyId == null && awsSecretKey == null) { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) - } else { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - awsAccessKeyId, awsSecretKey) - } } - } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala new file mode 100644 index 0000000000000..aa6fe12edf74e --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentialsProvider + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentialsProvider( + awsAccessKeyId: String, + awsSecretKey: String) extends SerializableCredentialsProvider with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultAWSCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentialsProvider( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index f078973c6c285..26b1fda2ff511 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -36,7 +36,7 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { @Test public void testKinesisStream() { String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); - String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); // Tests the API, does not actually test data receiving JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", @@ -45,6 +45,17 @@ dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration( ssc.stop(); } + @Test + public void testAwsCreds() { + String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); + + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), + StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey"); + ssc.stop(); + } private static Function handler = new Function() { @Override @@ -62,4 +73,26 @@ public void testCustomHandler() { ssc.stop(); } + + @Test + public void testCustomHandlerAwsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey"); + + ssc.stop(); + } + + @Test + public void testCustomHandlerAwsStsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", "fakeSTSExternalId"); + + ssc.stop(); + } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 800502a77d120..deb411d73e588 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -22,7 +22,7 @@ import java.util.Arrays import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} @@ -62,9 +62,26 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of SerializableAWSCredentials") { - Utils.deserialize[SerializableAWSCredentials]( - Utils.serialize(new SerializableAWSCredentials("x", "y"))) + test("check serializability of credential provider classes") { + Utils.deserialize[BasicCredentialsProvider]( + Utils.serialize(BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y"))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId")))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId"), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y")))) } test("process records including store and set checkpointer") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 404b673c01171..387a96f26b305 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -49,7 +49,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Dummy parameters for API testing private val dummyEndpointUrl = defaultEndpointUrl - private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl) private val dummyAWSAccessKey = "dummyAccessKey" private val dummyAWSSecretKey = "dummySecretKey" @@ -138,8 +138,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.awsCredentialsOption === - Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + awsAccessKeyId = dummyAWSAccessKey, + awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } val partitions = nonEmptyRDD.partitions.map { @@ -201,7 +202,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) stream shouldBe a [ReceiverInputDStream[_]] diff --git a/pom.xml b/pom.xml index 60e4c7269eafd..c1174593c1922 100644 --- a/pom.xml +++ b/pom.xml @@ -145,7 +145,9 @@ 1.7.7 hadoop2 0.9.3 - 1.6.2 + 1.7.3 + + 1.11.76 0.10.2 diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 3a8d8b819fd37..b839859c45252 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -37,7 +37,8 @@ class KinesisUtils(object): def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel=StorageLevel.MEMORY_AND_DISK_2, - awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder, + stsAssumeRoleArn=None, stsSessionName=None, stsExternalId=None): """ Create an input stream that pulls messages from a Kinesis stream. This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -67,6 +68,12 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, :param awsSecretKey: AWS SecretKey (default is None. If None, will use DefaultAWSCredentialsProviderChain) :param decoder: A function used to decode value (default is utf8_decoder) + :param stsAssumeRoleArn: ARN of IAM role to assume when using STS sessions to read from + the Kinesis stream (default is None). + :param stsSessionName: Name to uniquely identify STS sessions used to read from Kinesis + stream, if STS is being used (default is None). + :param stsExternalId: External ID that can be used to validate against the assumed IAM + role's trust policy, if STS is being used (default is None). :return: A DStream object """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) @@ -81,7 +88,8 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, raise jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, jduration, jlevel, - awsAccessKeyId, awsSecretKey) + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, + stsSessionName, stsExternalId) stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v)) From 7c32b6904f951eac1db4bd08a7bd6d70a692ea82 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 22 Feb 2017 11:50:24 -0800 Subject: [PATCH 22/61] [SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs ## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #16945 from wangmiao1981/svc. --- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib_classification.R | 13 ++++++---- R/pkg/R/mllib_regression.R | 24 ++++++++++++------- .../testthat/test_mllib_classification.R | 10 +++++++- .../ml/r/AFTSurvivalRegressionWrapper.scala | 6 ++++- .../GeneralizedLinearRegressionWrapper.scala | 4 +++- .../ml/r/IsotonicRegressionWrapper.scala | 3 ++- .../ml/r/LogisticRegressionWrapper.scala | 7 ++++-- 8 files changed, 50 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 11940d356039e..647cbbdd825e3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa10f..05bb95266173a 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 96ee220bc4113..ac0578c4ab259 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) + tol, as.integer(maxIter), weightCol, regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) + weightCol) new("IsotonicRegressionModel", jobj = jobj) }) @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { + function(data, formula, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) + "fit", formula, data@sdf, as.integer(aggregationDepth)) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 620f528f2e6c8..459254d271a58 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -211,7 +211,15 @@ test_that("spark.logit", { df <- createDataFrame(data) model <- spark.logit(df, label ~ feature) prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) }) test_that("spark.mlp", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index bd965acf56944..0bf543d88894e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg } - def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + def fit( + formula: String, + data: DataFrame, + aggregationDepth: Int): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) .setFeaturesCol(rFormula.getFeaturesCol) + .setAggregationDepth(aggregationDepth) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 78f401f29b004..cbd6cd1c7933c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) - .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + + if (weightCol != null) glr.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 48632316f3950..d31ebb46afb97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper val isotonicRegression = new IsotonicRegression() .setIsotonic(isotonic) .setFeatureIndex(featureIndex) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) + if (weightCol != null) isotonicRegression.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 645bc7247f30f..c96f99cb83434 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String + weightCol: String, + aggregationDepth: Int ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper .setFitIntercept(fitIntercept) .setFamily(family) .setStandardization(standardization) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + .setAggregationDepth(aggregationDepth) if (thresholds.length > 1) { lr.setThresholds(thresholds) @@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper lr.setThreshold(thresholds(0)) } + if (weightCol != null) lr.setWeightCol(weightCol) + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL) From c282cf264b84d3db9ba6678dd3b0128d94186d92 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Feb 2017 12:42:23 -0800 Subject: [PATCH 23/61] [SPARK-19666][SQL] Skip a property without getter in Java schema inference and allow empty bean in encoder creation ## What changes were proposed in this pull request? This PR proposes to fix two. **Skip a property without a getter in beans** Currently, if we use a JavaBean without the getter as below: ```java public static class BeanWithoutGetter implements Serializable { private String a; public void setA(String a) { this.a = a; } } BeanWithoutGetter bean = new BeanWithoutGetter(); List data = Arrays.asList(bean); spark.createDataFrame(data, BeanWithoutGetter.class).show(); ``` - Before It throws an exception as below: ``` java.lang.NullPointerException at org.spark_project.guava.reflect.TypeToken.method(TypeToken.java:465) at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:126) at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:125) ``` - After ``` ++ || ++ || ++ ``` **Supports empty bean in encoder creation** ```java public static class EmptyBean implements Serializable {} EmptyBean bean = new EmptyBean(); List data = Arrays.asList(bean); spark.createDataset(data, Encoders.bean(EmptyBean.class)).show(); ``` - Before throws an exception as below: ``` java.lang.UnsupportedOperationException: Cannot infer type for class EmptyBean because it is not bean-compliant at org.apache.spark.sql.catalyst.JavaTypeInference$.org$apache$spark$sql$catalyst$JavaTypeInference$$serializerFor(JavaTypeInference.scala:436) at org.apache.spark.sql.catalyst.JavaTypeInference$.serializerFor(JavaTypeInference.scala:341) ``` - After ``` ++ || ++ || ++ ``` ## How was this patch tested? Unit test in `JavaDataFrameSuite`. Author: hyukjinkwon Closes #17013 from HyukjinKwon/SPARK-19666. --- .../sql/catalyst/JavaTypeInference.scala | 45 +++++++++---------- .../org/apache/spark/sql/SQLContext.scala | 6 +-- .../org/apache/spark/sql/SparkSession.scala | 7 +-- .../apache/spark/sql/JavaDataFrameSuite.java | 17 +++++++ .../apache/spark/sql/JavaDatasetSuite.java | 11 +++++ 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8b53d988cbc59..e9d9508e5adfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -117,11 +117,10 @@ object JavaTypeInference { val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) - case _ => + case other => // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. - val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(returnType) @@ -131,10 +130,15 @@ object JavaTypeInference { } } - private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors - .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + .filter(_.getReadMethod != null) + } + + private def getJavaBeanReadableAndWritableProperties( + beanClass: Class[_]): Array[PropertyDescriptor] = { + getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null) } private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { @@ -298,9 +302,7 @@ object JavaTypeInference { keyData :: valueData :: Nil) case other => - val properties = getJavaBeanProperties(other) - assert(properties.length > 0) - + val properties = getJavaBeanReadableAndWritableProperties(other) val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType @@ -417,21 +419,16 @@ object JavaTypeInference { ) case other => - val properties = getJavaBeanProperties(other) - if (properties.length > 0) { - CreateNamedStruct(properties.flatMap { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val fieldValue = Invoke( - inputObject, - p.getReadMethod.getName, - inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil - }) - } else { - throw new UnsupportedOperationException( - s"Cannot infer type for class ${other.getName} because it is not bean-compliant") - } + val properties = getJavaBeanReadableAndWritableProperties(other) + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil + }) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbe55090ea113..234ef2dffc6bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1090,14 +1090,14 @@ object SQLContext { */ private[sql] def beansToRows( data: Iterator[_], - beanInfo: BeanInfo, + beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) } - data.map{ element => + data.map { element => new GenericInternalRow( methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } ): InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 72af55c1fa147..afc1827e7eece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.Introspector import java.io.Closeable import java.util.concurrent.atomic.AtomicReference @@ -347,8 +346,7 @@ class SparkSession private( val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) + SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) } @@ -374,8 +372,7 @@ class SparkSession private( */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { val attrSeq = getSchema(beanClass) - val beanInfo = Introspector.getBeanInfo(beanClass) - val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq) Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c3b94a44c2e91..a8f814bfae530 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -397,4 +397,21 @@ public void testBloomFilter() { Assert.assertTrue(filter4.mightContain(i * 3)); } } + + public static class BeanWithoutGetter implements Serializable { + private String a; + + public void setA(String a) { + this.a = a; + } + } + + @Test + public void testBeanWithoutGetter() { + BeanWithoutGetter bean = new BeanWithoutGetter(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataFrame(data, BeanWithoutGetter.class); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 577672ca8e083..4581c6ebe9ef8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1276,4 +1276,15 @@ public void test() { spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class)); ds.collectAsList(); } + + public static class EmptyBean implements Serializable {} + + @Test + public void testEmptyBean() { + EmptyBean bean = new EmptyBean(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataset(data, Encoders.bean(EmptyBean.class)); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } } From 8812b971b73bdcca250c2cee3dca0d7093c8e3e4 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 22 Feb 2017 14:37:53 -0800 Subject: [PATCH 24/61] [SPARK-19554][UI,YARN] Allow SHS URL to be used for tracking in YARN RM. Allow an application to use the History Server URL as the tracking URL in the YARN RM, so there's still a link to the web UI somewhere in YARN even if the driver's UI is disabled. This is useful, for example, if an admin wants to disable the driver UI by default for applications, since it's harder to secure it (since it involves non trivial ssl certificate and auth management that admins may not want to expose to user apps). This needs to be opt-in, because of the way the YARN proxy works, so a new configuration was added to enable the option. The YARN RM will proxy requests to live AMs instead of redirecting the client, so pages in the SHS UI will not render correctly since they'll reference invalid paths in the RM UI. The proxy base support in the SHS cannot be used since that would prevent direct access to the SHS. So, to solve this problem, for the feature to work end-to-end, a new YARN-specific filter was added that detects whether the requests come from the proxy and redirects the client appropriatly. The SHS admin has to add this filter manually if they want the feature to work. Tested with new unit test, and by running with the documented configuration set in a test cluster. Also verified the driver UI is used when it's enabled. Author: Marcelo Vanzin Closes #16946 from vanzin/SPARK-19554. --- docs/running-on-yarn.md | 15 ++++ .../spark/deploy/yarn/ApplicationMaster.scala | 7 +- .../deploy/yarn/YarnProxyRedirectFilter.scala | 81 +++++++++++++++++++ .../spark/deploy/yarn/YarnRMClient.scala | 8 +- .../org/apache/spark/deploy/yarn/config.scala | 7 ++ .../yarn/YarnProxyRedirectFilterSuite.scala | 55 +++++++++++++ 6 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cf95b95afd2ee..e9ddaa76a797f 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -604,3 +604,18 @@ spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spn Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log will include a list of all tokens obtained, and their expiry details + +## Using the Spark History Server to replace the Spark Web UI + +It is possible to use the Spark History Server application page as the tracking URL for running +applications when the application UI is disabled. This may be desirable on secure clusters, or to +reduce the memory usage of the Spark driver. To set up tracking through the Spark History Server, +do the following: + +- On the application side, set spark.yarn.historyServer.allowTracking=true in Spark's + configuration. This will tell Spark to use the history server's URL as the tracking URL if + the application's UI is disabled. +- On the Spark History Server, add org.apache.spark.deploy.yarn.YarnProxyRedirectFilter + to the list of filters in the spark.ui.filters configuration. + +Be aware that the history server information may not be up-to-date with the application's state. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9df43aea3f3d5..864c834d110fd 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -332,7 +332,7 @@ private[spark] class ApplicationMaster( _sparkConf: SparkConf, _rpcEnv: RpcEnv, driverRef: RpcEndpointRef, - uiAddress: String, + uiAddress: Option[String], securityMgr: SecurityManager) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() @@ -408,8 +408,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl).getOrElse(""), - securityMgr) + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -435,7 +434,7 @@ private[spark] class ApplicationMaster( clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), + registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), securityMgr) // In client mode the actor will stop the reporter thread. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala new file mode 100644 index 0000000000000..ae625df75362a --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilter.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import org.apache.spark.internal.Logging + +/** + * A filter to be used in the Spark History Server for redirecting YARN proxy requests to the + * main SHS address. This is useful for applications that are using the history server as the + * tracking URL, since the SHS-generated pages cannot be rendered in that case without extra + * configuration to set up a proxy base URI (meaning the SHS cannot be ever used directly). + */ +class YarnProxyRedirectFilter extends Filter with Logging { + + import YarnProxyRedirectFilter._ + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + + // The YARN proxy will send a request with the "proxy-user" cookie set to the YARN's client + // user name. We don't expect any other clients to set this cookie, since the SHS does not + // use cookies for anything. + Option(hreq.getCookies()).flatMap(_.find(_.getName() == COOKIE_NAME)) match { + case Some(_) => + doRedirect(hreq, res.asInstanceOf[HttpServletResponse]) + + case _ => + chain.doFilter(req, res) + } + } + + private def doRedirect(req: HttpServletRequest, res: HttpServletResponse): Unit = { + val redirect = req.getRequestURL().toString() + + // Need a client-side redirect instead of an HTTP one, otherwise the YARN proxy itself + // will handle the redirect and get into an infinite loop. + val content = s""" + | + | + | Spark History Server Redirect + | + | + | + |

The requested page can be found at: $redirect.

+ | + | + """.stripMargin + + logDebug(s"Redirecting YARN proxy request to $redirect.") + res.setStatus(HttpServletResponse.SC_OK) + res.setContentType("text/html") + res.getWriter().write(content) + } + +} + +private[spark] object YarnProxyRedirectFilter { + val COOKIE_NAME = "proxy-user" +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 163dfb5a605c8..53fb467f6408d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,7 +55,7 @@ private[spark] class YarnRMClient extends Logging { driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, - uiAddress: String, + uiAddress: Option[String], uiHistoryAddress: String, securityMgr: SecurityManager, localResources: Map[String, LocalResource] @@ -65,9 +65,13 @@ private[spark] class YarnRMClient extends Logging { amClient.start() this.uiHistoryAddress = uiHistoryAddress + val trackingUrl = uiAddress.getOrElse { + if (sparkConf.get(ALLOW_HISTORY_SERVER_TRACKING_URL)) uiHistoryAddress else "" + } + logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index f19a5b22a757d..d8c96c35ca71c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -82,6 +82,13 @@ package object config { .stringConf .createOptional + private[spark] val ALLOW_HISTORY_SERVER_TRACKING_URL = + ConfigBuilder("spark.yarn.historyServer.allowTracking") + .doc("Allow using the History Server URL for the application as the tracking URL for the " + + "application when the Web UI is not enabled.") + .booleanConf + .createWithDefault(false) + /* File distribution. */ private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive") diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala new file mode 100644 index 0000000000000..54dbe9d50a68f --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnProxyRedirectFilterSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{PrintWriter, StringWriter} +import javax.servlet.FilterChain +import javax.servlet.http.{Cookie, HttpServletRequest, HttpServletResponse} + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite + +class YarnProxyRedirectFilterSuite extends SparkFunSuite { + + test("redirect proxied requests, pass-through others") { + val requestURL = "http://example.com:1234/foo?" + val filter = new YarnProxyRedirectFilter() + val cookies = Array(new Cookie(YarnProxyRedirectFilter.COOKIE_NAME, "dr.who")) + + val req = mock(classOf[HttpServletRequest]) + + // First request mocks a YARN proxy request (with the cookie set), second one has no cookies. + when(req.getCookies()).thenReturn(cookies, null) + when(req.getRequestURL()).thenReturn(new StringBuffer(requestURL)) + + val res = mock(classOf[HttpServletResponse]) + when(res.getWriter()).thenReturn(new PrintWriter(new StringWriter())) + + val chain = mock(classOf[FilterChain]) + + // First request is proxied. + filter.doFilter(req, res, chain) + verify(chain, never()).doFilter(req, res) + + // Second request is not, so should invoke the filter chain. + filter.doFilter(req, res, chain) + verify(chain, times(1)).doFilter(req, res) + } + +} From 0c2003b106c27cc4a68026b54ce16ab28f1e5b6e Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 22 Feb 2017 17:26:56 -0800 Subject: [PATCH 25/61] [SPARK-19658][SQL] Set NumPartitions of RepartitionByExpression In Parser ### What changes were proposed in this pull request? Currently, if `NumPartitions` is not set in RepartitionByExpression, we will set it using `spark.sql.shuffle.partitions` during Planner. However, this is not following the general resolution process. This PR is to set it in `Parser` and then `Optimizer` can use the value for plan optimization. ### How was this patch tested? Added a test case. Author: Xiao Li Closes #16988 from gatorsmile/resolveRepartition. --- .../spark/sql/catalyst/dsl/package.scala | 4 +-- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 16 ++++++++-- .../plans/logical/basicLogicalOperators.scala | 9 ++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 10 +++--- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +-- .../scala/org/apache/spark/sql/Dataset.scala | 5 +-- .../spark/sql/execution/SparkSqlParser.scala | 17 ++++++++-- .../spark/sql/execution/SparkStrategies.scala | 6 ++-- .../sql/execution/SparkSqlParserSuite.scala | 32 +++++++++++++++++-- 10 files changed, 74 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3c531323397e4..c062e4e84bcdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -373,8 +373,8 @@ package object dsl { def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) - def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan = - RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n)) + def distribute(exprs: Expression*)(n: Int): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan, numPartitions = n) def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0c13e3e93a42c..af846a09a8d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -578,7 +578,7 @@ object CollapseRepartition extends Rule[LogicalPlan] { RepartitionByExpression(exprs, child, numPartitions) // Case 3 case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = Some(numPartitions)) + r.copy(numPartitions = numPartitions) // Case 3 case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => RepartitionByExpression(exprs, child, numPartitions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 08a6dd136b857..926a37b363f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Sort(sort.asScala.map(visitSortItem), global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) + withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... DISTRIBUTE BY ... Sort( sort.asScala.map(visitSortItem), global = false, - RepartitionByExpression(expressionList(distributeBy), query)) + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { // CLUSTER BY ... val expressions = expressionList(clusterBy) Sort( expressions.map(SortOrder(_, Ascending)), global = false, - RepartitionByExpression(expressions, query)) + withRepartitionByExpression(ctx, expressions, query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // [EMPTY] query @@ -273,6 +273,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + /** * Create a logical plan using a query specification. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index af57632516790..d17d12cd83091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -844,18 +844,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * information about the number of partitions during execution. Used when a specific ordering or * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. - * If `numPartitions` is not specified, the number of partitions will be the number set by - * `spark.sql.shuffle.partitions`. */ case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Option[Int] = None) extends UnaryNode { + numPartitions: Int) extends UnaryNode { - numPartitions match { - case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") - case None => // Ok - } + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 786e0f49b4b25..01737e0a17341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,11 +21,12 @@ import java.util.TimeZone import org.scalatest.ShouldMatchers -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, Inner} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("pull out nondeterministic expressions from RepartitionByExpression") { - val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10) val projected = Alias(Rand(33), "_nondeterministic")() val expected = Project(testRelation.output, RepartitionByExpression(Seq(projected.toAttribute), - Project(testRelation.output :+ projected, testRelation))) + Project(testRelation.output :+ projected, testRelation), + numPartitions = 10)) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 2c1425242620e..67d5d2202b680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest { val orderSortDistrClusterClauses = Seq( ("", basePlan), (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)()), - (" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc)) + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)) ) orderSortDistrClusterClauses.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 38a24cc8ed8c2..1ebc53d2bb84e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2410,7 +2410,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } /** @@ -2425,7 +2425,8 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + RepartitionByExpression( + partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d50800235264f..1340aebc1ddd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -22,16 +22,17 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -1441,4 +1442,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { reader, writer, schemaLess) } + + /** + * Create a clause for DISTRIBUTE BY. + */ + override protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, conf.numShufflePartitions) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 557181ebd9590..0e3d5595df94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -332,8 +332,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions: Int = self.numPartitions - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil @@ -414,9 +412,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, nPartitions) => + case logical.RepartitionByExpression(expressions, child, numPartitions) => exchange.ShuffleExchange(HashPartitioning( - expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil + expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 15e490fb30a27..bb6c486e880a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -36,7 +38,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends PlanTest { - private lazy val parser = new SparkSqlParser(new SQLConf) + val newConf = new SQLConf + private lazy val parser = new SparkSqlParser(newConf) /** * Normalizes plans: @@ -251,4 +254,29 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("t"))) + + assertEqual(s"$baseSql distribute by a, b", + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions)) + assertEqual(s"$baseSql distribute by a sort by b", + Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + assertEqual(s"$baseSql cluster by a, b", + Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + } } From 4b5cf7679c04800277c8eb4bcc096801971ccdee Mon Sep 17 00:00:00 2001 From: "pj.fanning" Date: Wed, 22 Feb 2017 18:03:25 -0800 Subject: [PATCH 26/61] [SPARK-15615][SQL] Add an API to load DataFrame from Dataset[String] storing JSON ## What changes were proposed in this pull request? SPARK-15615 proposes replacing the sqlContext.read.json(rdd) with a dataset equivalent. SPARK-15463 adds a CSV API for reading from Dataset[String] so this keeps the API consistent. I am deprecating the existing RDD based APIs. ## How was this patch tested? There are existing tests. I left most tests to use the existing APIs as they delegate to the new json API. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: pj.fanning Author: PJ Fanning Closes #16895 from pjfanning/SPARK-15615. --- .../apache/spark/sql/DataFrameReader.scala | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index cb9493a575643..4c1341ed5da60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -323,6 +323,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param jsonRDD input RDD with one JSON object per record * @since 1.4.0 */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) /** @@ -335,7 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param jsonRDD input RDD with one JSON object per record * @since 1.4.0 */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: RDD[String]): DataFrame = { + json(sparkSession.createDataset(jsonRDD)(Encoders.STRING)) + } + + /** + * Loads a `Dataset[String]` storing JSON objects (JSON Lines + * text format or newline-delimited JSON) and returns the result as a `DataFrame`. + * + * Unless the schema is specified using `schema` function, this function goes through the + * input once to determine the input schema. + * + * @param jsonDataset input Dataset with one JSON object per record + * @since 2.2.0 + */ + def json(jsonDataset: Dataset[String]): DataFrame = { val parsedOptions = new JSONOptions( extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, @@ -344,12 +360,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val schema = userSpecifiedSchema.getOrElse { JsonInferSchema.infer( - jsonRDD, + jsonDataset.rdd, parsedOptions, createParser) } - val parsed = jsonRDD.mapPartitions { iter => + val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) } From 951323f6e0a0802e22d845144a4adfcd31765145 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 22 Feb 2017 20:03:01 -0800 Subject: [PATCH 27/61] [SPARK-16122][CORE] Add rest api for job environment ## What changes were proposed in this pull request? add rest api for job environment. ## How was this patch tested? existing ut. Author: uncleGen Closes #16949 from uncleGen/SPARK-16122. --- .../spark/status/api/v1/ApiRootResource.scala | 15 +++++++ .../v1/ApplicationEnvironmentResource.scala | 45 +++++++++++++++++++ .../org/apache/spark/status/api/v1/api.scala | 11 +++++ 3 files changed, 71 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 67ccf43afa44a..00f918c09c66b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -199,6 +199,21 @@ private[v1] class ApiRootResource extends ApiRequestContext { new VersionResource(uiRoot) } + @Path("applications/{appId}/environment") + def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, None) { ui => + new ApplicationEnvironmentResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/environment") + def getEnvironment( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, Some(attemptId)) { ui => + new ApplicationEnvironmentResource(ui) + } + } } private[spark] object ApiRootResource { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala new file mode 100644 index 0000000000000..739a8aceae861 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +import org.apache.spark.ui.SparkUI + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { + + @GET + def getEnvironmentInfo(): ApplicationEnvironmentInfo = { + val listener = ui.environmentListener + listener.synchronized { + val jvmInfo = Map(listener.jvmInformation: _*) + val runtime = new RuntimeInfo( + jvmInfo("Java Version"), + jvmInfo("Java Home"), + jvmInfo("Scala Version")) + + new ApplicationEnvironmentInfo( + runtime, + listener.sparkProperties, + listener.systemProperties, + listener.classpathEntries) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index c509398db1ecf..5b9227350edaa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -252,3 +252,14 @@ class AccumulableInfo private[spark]( class VersionInfo private[spark]( val spark: String) + +class ApplicationEnvironmentInfo private[spark] ( + val runtime: RuntimeInfo, + val sparkProperties: Seq[(String, String)], + val systemProperties: Seq[(String, String)], + val classpathEntries: Seq[(String, String)]) + +class RuntimeInfo private[spark]( + val javaVersion: String, + val javaHome: String, + val scalaVersion: String) From 46f4a19d0f383740e2114723f8b9b3879952e103 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 22 Feb 2017 21:39:20 -0800 Subject: [PATCH 28/61] [SPARK-19695][SQL] Throw an exception if a `columnNameOfCorruptRecord` field violates requirements in json formats ## What changes were proposed in this pull request? This pr comes from #16928 and fixed a json behaviour along with the CSV one. ## How was this patch tested? Added tests in `JsonSuite`. Author: Takeshi Yamamuro Closes #17023 from maropu/SPARK-19695. --- .../sql/catalyst/json/JacksonParser.scala | 5 ++- .../apache/spark/sql/DataFrameReader.scala | 11 ++++++- .../datasources/json/JsonFileFormat.scala | 13 ++++++-- .../datasources/json/JsonSuite.scala | 31 +++++++++++++++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 995095969d7af..9b80c0fc87c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -58,7 +58,10 @@ class JacksonParser( private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType)) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } @transient private[this] var isWarningPrinted: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4c1341ed5da60..2be22761e8dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String /** @@ -365,6 +365,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } + // Check a field requirement for corrupt records here to throw an exception in a driver side + schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 2cbf4ea7beaca..902fee5a7e3f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -102,6 +102,15 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + (file: PartitionedFile) => { val parser = new JacksonParser(requiredSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 05aa2ab2ce2d0..0e01be2410409 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1944,4 +1944,35 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) } } + + test("Throw an exception if a `columnNameOfCorruptRecord` field violates requirements") { + val columnNameOfCorruptRecord = "_unparsed" + val schema = StructType( + StructField(columnNameOfCorruptRecord, IntegerType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(corruptRecords) + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + + withTempPath { dir => + val path = dir.getCanonicalPath + corruptRecords.toDF("value").write.text(path) + val errMsg = intercept[AnalysisException] { + spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema) + .json(path) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } } From a0ce01ec28ccbe02e845eb9add43eac55f30c730 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 16:28:36 +0100 Subject: [PATCH 29/61] [SPARK-19691][SQL] Fix ClassCastException when calculating percentile of decimal column ## What changes were proposed in this pull request? This pr fixed a class-cast exception below; ``` scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect() java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141) at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58) at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109) at ``` This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`. ## How was this patch tested? Added a test in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #17028 from maropu/SPARK-19691. --- .../expressions/aggregate/Percentile.scala | 48 ++++++++++--------- .../aggregate/PercentileSuite.scala | 29 ++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6b7cf7991d39d..8433a93ea3032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +61,7 @@ case class Percentile( frequencyExpression : Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, Literal(1L), 0, 0) @@ -130,15 +130,20 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } override def update( - buffer: OpenHashMap[Number, Long], - input: InternalRow): OpenHashMap[Number, Long] = { - val key = child.eval(input).asInstanceOf[Number] + buffer: OpenHashMap[AnyRef, Long], + input: InternalRow): OpenHashMap[AnyRef, Long] = { + val key = child.eval(input).asInstanceOf[AnyRef] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. @@ -155,32 +160,32 @@ case class Percentile( } override def merge( - buffer: OpenHashMap[Number, Long], - other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = { + buffer: OpenHashMap[AnyRef, Long], + other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } buffer } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -200,7 +205,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -213,18 +218,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -238,7 +242,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -261,11 +265,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -274,7 +278,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 1533fe5f90ee2..2420ba513f287 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(childExpression, percentageExpression) // Test with rows without frequency - val rows = (1 to count).map( x => Seq(x)) - runTest( agg, rows, expectedPercentiles) + val rows = (1 to count).map(x => Seq(x)) + runTest(agg, rows, expectedPercentiles) // Test with row with frequency. Second and third columns are frequency in Int and Long val countForFrequencyTest = 1000 - val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong) + val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong) val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0) val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false) val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt) - runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) val frequencyExpressionLong = BoundReference(2, LongType, nullable = false) val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong) - runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) // Run test with Flatten data - val flattenRows = (1 to countForFrequencyTest).flatMap( current => - (1 to current).map( y => current )).map( Seq(_)) + val flattenRows = (1 to countForFrequencyTest).flatMap(current => + (1 to current).map(y => current )).map(Seq(_)) runTest(agg, flattenRows, expectedPercentilesWithFrquency) } @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite { } val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType) - for ( dataType <- validDataTypes; + for (dataType <- validDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite { StringType, DateType, TimestampType, CalendarIntervalType, NullType) - for( dataType <- invalidDataTypes; + for(dataType <- invalidDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite { s"'`a`' is of ${dataType.simpleString} type.")) } - for( dataType <- validDataTypes; + for(dataType <- validDataTypes; frequencyType <- invalidFrequencyDataTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite { agg.update(buffer, InternalRow(1, -5)) agg.eval(buffer) } - assert( caught.getMessage.startsWith("Negative values found in ")) + assert(caught.getMessage.startsWith("Negative values found in ")) } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e6338ab7cd800..5e65436079db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j") checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } } From ae4a6971e55837ef1a5a7ef7cdc2a0086695b46b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 23 Feb 2017 10:25:18 -0800 Subject: [PATCH 30/61] [SPARK-19459] Support for nested char/varchar fields in ORC ## What changes were proposed in this pull request? This PR is a small follow-up on https://github.com/apache/spark/pull/16804. This PR also adds support for nested char/varchar fields in orc. ## How was this patch tested? I have added a regression test to the OrcSourceSuite. Author: Herman van Hovell Closes #17030 from hvanhovell/SPARK-19459-follow-up. --- .../sql/catalyst/parser/AstBuilder.scala | 34 +++++---- .../spark/sql/types/HiveStringType.scala | 73 +++++++++++++++++++ .../spark/sql/hive/orc/OrcSourceSuite.scala | 12 ++- 3 files changed, 100 insertions(+), 19 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 926a37b363f1b..d2e091f4dda69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visit(ctx.dataType).asInstanceOf[DataType] + visitSparkDataType(ctx.dataType) } /* ******************************************************************************************** @@ -1006,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[Cast]] expression. */ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } /** @@ -1424,6 +1424,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /* ******************************************************************************************** * DataType parsing * ******************************************************************************************** */ + /** + * Create a Spark DataType. + */ + private def visitSparkDataType(ctx: DataTypeContext): DataType = { + HiveStringType.replaceCharType(typedVisit(ctx)) + } + /** * Resolve/create a primitive type. */ @@ -1438,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("double", Nil) => DoubleType case ("date", Nil) => DateType case ("timestamp", Nil) => TimestampType - case ("char" | "varchar" | "string", Nil) => StringType - case ("char" | "varchar", _ :: Nil) => StringType + case ("string", Nil) => StringType + case ("char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) case ("binary", Nil) => BinaryType case ("decimal", Nil) => DecimalType.USER_DEFAULT case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) @@ -1461,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.MAP => MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) case SqlBaseParser.STRUCT => - createStructType(ctx.complexColTypeList()) + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) } } @@ -1480,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a [[StructField]] from a column definition. + * Create a top level [[StructField]] from a column definition. */ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { import ctx._ @@ -1491,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { builder.putString("comment", string(STRING)) } // Add Hive type string to metadata. - dataType match { - case p: PrimitiveDataTypeContext => - p.identifier.getText.toLowerCase match { - case "varchar" | "char" => - builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase) - case _ => - } - case _ => + val rawDataType = typedVisit[DataType](ctx.dataType) + val cleanedDataType = HiveStringType.replaceCharType(rawDataType) + if (rawDataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) } StructField( identifier.getText, - typedVisit(dataType), + cleanedDataType, nullable = true, builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala new file mode 100644 index 0000000000000..b319eb70bc13c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.unsafe.types.UTF8String + +/** + * A hive string type for compatibility. These datatypes should only used for parsing, + * and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. + */ +sealed abstract class HiveStringType extends AtomicType { + private[sql] type InternalType = UTF8String + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { + typeTag[InternalType] + } + + override def defaultSize: Int = length + + private[spark] override def asNullable: HiveStringType = this + + def length: Int +} + +object HiveStringType { + def replaceCharType(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharType(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharType(kt), replaceCharType(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharType(field.dataType)) + }) + case _: HiveStringType => StringType + case _ => dt + } +} + +/** + * Hive char type. + */ +case class CharType(length: Int) extends HiveStringType { + override def simpleString: String = s"char($length)" +} + +/** + * Hive varchar type. + */ +case class VarcharType(length: Int) extends HiveStringType { + override def simpleString: String = s"varchar($length)" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 59ea8916efae9..11dda5425cf94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -162,13 +162,16 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE hive_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - "INSERT INTO TABLE hive_orc SELECT 'a', 'b', 'c' FROM (SELECT 1) t") + """INSERT INTO TABLE hive_orc + |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) + |FROM (SELECT 1) t""".stripMargin) // We create a different table in Spark using the same schema which points to // the same location. @@ -177,10 +180,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE spark_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc |LOCATION '$uri'""".stripMargin) - val result = Row("a", "b ", "c") + val result = Row("a", "b ", "c", Seq("d ")) checkAnswer(spark.table("hive_orc"), result) checkAnswer(spark.table("spark_orc"), result) } finally { From 027cd9c41134358ae21356ae65217a978791fafb Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 23 Feb 2017 11:12:02 -0800 Subject: [PATCH 31/61] [SPARK-19682][SPARKR] Issue warning (or error) when subset method "[[" takes vector index ## What changes were proposed in this pull request? The `[[` method is supposed to take a single index and return a column. This is different from base R which takes a vector index. We should check for this and issue warning or error when vector index is supplied (which is very likely given the behavior in base R). Currently I'm issuing a warning message and just take the first element of the vector index. We could change this to an error it that's better. ## How was this patch tested? new tests Author: actuaryzhang Closes #17017 from actuaryzhang/sparkRSubsetter. --- R/pkg/R/DataFrame.R | 8 ++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cf331bab47c66..cc4cfa3423ced 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1804,6 +1804,10 @@ setClassUnion("numericOrcharacter", c("numeric", "character")) #' @note [[ since 1.4.0 setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] @@ -1817,6 +1821,10 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), #' @note [[<- since 2.1.1 setMethod("[[<-", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i, value) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index a7259f362ebeb..ce0f5a198a259 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1015,6 +1015,18 @@ test_that("select operators", { expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") + expect_warning(df[[1:2]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[1:2]]), "Column") + expect_warning(df[[c("name", "age")]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[c("name", "age")]]), "Column") + + expect_warning(df[[1:2]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_warning(df[[c("name", "age")]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_is(df[, 1, drop = F], "SparkDataFrame") expect_equal(columns(df[, 1, drop = F]), c("name")) expect_equal(columns(df[, "age", drop = F]), c("age")) From ee4366b5bebfa91437ad8186c08b417910a6c281 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 23 Feb 2017 11:25:39 -0800 Subject: [PATCH 32/61] [SPARK-19497][SS] Implement streaming deduplication ## What changes were proposed in this pull request? This PR adds a special streaming deduplication operator to support `dropDuplicates` with `aggregation` and watermark. It reuses the `dropDuplicates` API but creates new logical plan `Deduplication` and new physical plan `DeduplicationExec`. The following cases are supported: - one or multiple `dropDuplicates()` without aggregation (with or without watermark) - `dropDuplicates` before aggregation Not supported cases: - `dropDuplicates` after aggregation Breaking changes: - `dropDuplicates` without aggregation doesn't work with `complete` or `update` mode. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #16970 from zsxwing/dedup. --- python/pyspark/sql/dataframe.py | 6 + .../UnsupportedOperationChecker.scala | 6 +- .../sql/catalyst/optimizer/Optimizer.scala | 21 +- .../plans/logical/basicLogicalOperators.scala | 9 + .../analysis/UnsupportedOperationsSuite.scala | 56 +++- .../optimizer/ReplaceOperatorSuite.scala | 33 ++- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++- .../spark/sql/execution/SparkStrategies.scala | 15 +- .../streaming/IncrementalExecution.scala | 10 + .../streaming/statefulOperators.scala | 140 +++++++--- .../sql/streaming/DeduplicateSuite.scala | 252 ++++++++++++++++++ .../streaming/MapGroupsWithStateSuite.scala | 9 +- .../sql/streaming/StateStoreMetricsTest.scala | 36 +++ .../spark/sql/streaming/StreamSuite.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 2 +- 15 files changed, 578 insertions(+), 58 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 70efeaf0160c9..bb6df22682095 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1158,6 +1158,12 @@ def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. + For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming + :class:`DataFrame`, it will keep all data across triggers as intermediate state to drop + duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can + be and system will accordingly limit the state. In addition, too late data older than + watermark will be dropped to avoid any possibility of duplicates. + :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 07b3558ee2f56..397f5cfe2a540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -75,7 +75,7 @@ object UnsupportedOperationChecker { if (watermarkAttributes.isEmpty) { throwError( s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + s"streaming DataFrames/DataSets without watermark")(plan) } case InternalOutputModes.Complete if aggregates.isEmpty => @@ -120,6 +120,10 @@ object UnsupportedOperationChecker { throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + "streaming DataFrame/Dataset") + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => + throwError("dropDuplicates is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index af846a09a8d89..036da3ad2062f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), - RewriteDistinctAggregates) :: + RewriteDistinctAggregates, + ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -1142,6 +1143,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. + */ +object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Deduplicate(keys, child, streaming) if !streaming => + val keyExprIds = keys.map(_.exprId) + val aggCols = child.output.map { attr => + if (keyExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) + } + } + Aggregate(keys, aggCols, child) + } +} + /** * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. * {{{ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d17d12cd83091..ce1c55dc089e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -864,3 +864,12 @@ case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) } + +/** A logical plan for `dropDuplicates`. */ +case class Deduplicate( + keys: Seq[Attribute], + child: LogicalPlan, + streaming: Boolean) extends UnaryNode { + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3b756e89d9036..82be69a0f7d7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -36,6 +37,11 @@ case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { val attribute = AttributeReference("a", IntegerType, nullable = true)() + val watermarkMetadata = new MetadataBuilder() + .withMetadata(attribute.metadata) + .putLong(EventTimeWatermark.delayKey, 1000L) + .build() + val attributeWithWatermark = attribute.withMetadata(watermarkMetadata) val batchRelation = LocalRelation(attribute) val streamRelation = new TestStreamingRelation(attribute) @@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in update mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Update) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in complete mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Complete) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations with watermark in append mode", + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "aggregate - streaming aggregations without watermark in append mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + // Aggregation: Distinct aggregates not supported on streaming relation val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) assertSupportedInStreamingPlan( @@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), + outputMode = Append + ) + + // Deduplicate + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation before aggregation", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + Deduplicate(Seq(att), streamRelation, streaming = true)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation after aggregation", + Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true), + outputMode = Complete, + expectedMsgs = Seq("dropDuplicates")) + + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on batch relation inside a streaming query", + Deduplicate(Seq(att), batchRelation, streaming = false), + outputMode = Append + ) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index f23e262f286b8..e68423f85c92e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest { Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, - ReplaceIntersectWithSemiJoin) :: Nil + ReplaceIntersectWithSemiJoin, + ReplaceDeduplicateWithAggregate) :: Nil } test("replace Intersect with Left-semi Join") { @@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("replace batch Deduplicate with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val attrB = input.output(1) + val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate( + Seq(attrA), + Seq( + attrA, + Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) + ), + input) + + comparePlans(optimized, correctAnswer) + } + + test("don't replace streaming Deduplicate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + comparePlans(optimized, query) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1ebc53d2bb84e..3c212d656e371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -557,7 +557,8 @@ class Dataset[T] private[sql]( * Spark will use this watermark for several purposes: * - To know when a given time window aggregation can be finalized and thus can be emitted when * using output modes that do not allow updates. - * - To minimize the amount of state that we need to keep for on-going aggregations. + * - To minimize the amount of state that we need to keep for on-going aggregations, + * `mapGroupsWithState` and `dropDuplicates` operators. * * The current watermark is computed by looking at the `MAX(eventTime)` seen across * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost @@ -1981,6 +1982,12 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `distinct`. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -1990,13 +1997,19 @@ class Dataset[T] private[sql]( * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.flatMap { colName => + val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) => // It is possibly there are more than one columns with the same name, // so we call filter instead of find. val cols = allColumns.filter(col => resolver(col.name, colName)) @@ -2006,21 +2019,19 @@ class Dataset[T] private[sql]( } cols } - val groupColExprIds = groupCols.map(_.exprId) - val aggCols = logicalPlan.output.map { attr => - if (groupColExprIds.contains(attr.exprId)) { - attr - } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() - } - } - Aggregate(groupCols, aggCols, logicalPlan) + Deduplicate(groupCols, logicalPlan, isStreaming) } /** * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -2030,6 +2041,12 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0e3d5595df94b..027b1481af96b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.{SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -244,6 +245,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming deduplicate operator. + */ + object StreamingDeduplicationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Deduplicate(keys, child, true) => + StreamingDeduplicateExec(keys, planLater(child)) :: Nil + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a3e108b29eda6..ffdcd9b19d058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -45,6 +45,7 @@ class IncrementalExecution( sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: + sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. @@ -93,6 +94,15 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + + StreamingDeduplicateExec( + keys, + child, + Some(stateId), + Some(currentEventTimeWatermark)) case MapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1292452574594..d92529748b6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, NullType, StructType} import org.apache.spark.util.CompletionIterator @@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator { "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) } +/** An operator that supports watermark. */ +trait WatermarkSupport extends SparkPlan { + + /** The keys that may have a watermark attribute. */ + def keyExpressions: Seq[Attribute] + + /** The watermark value. */ + def eventTimeWatermark: Option[Long] + + /** Generate a predicate that matches data older than the watermark */ + lazy val watermarkPredicate: Option[Predicate] = { + val optionalWatermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + + optionalWatermarkAttribute.map { watermarkAttribute => + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + newPredicate(evictionExpression, keyExpressions) + } + } +} + /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -76,7 +109,7 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StateStoreReader { + extends UnaryExecNode with StateStoreReader { override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -113,31 +146,7 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StateStoreWriter { - - /** Generate a predicate that matches data older than the watermark */ - private lazy val watermarkPredicate: Option[Predicate] = { - val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) - - optionalWatermarkAttribute.map { watermarkAttribute => - // If we are evicting based on a window, use the end of the window. Otherwise just - // use the attribute itself. - val evictionExpression = - if (watermarkAttribute.dataType.isInstanceOf[StructType]) { - LessThanOrEqual( - GetStructField(watermarkAttribute, 1), - Literal(eventTimeWatermark.get * 1000)) - } else { - LessThanOrEqual( - watermarkAttribute, - Literal(eventTimeWatermark.get * 1000)) - } - - logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) - } - } + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -146,8 +155,8 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -262,8 +271,8 @@ case class MapGroupsWithStateExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, groupingAttributes.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -321,3 +330,70 @@ case class MapGroupsWithStateExec( } } } + + +/** Physical operator for executing streaming Deduplicate. */ +case class StreamingDeduplicateExec( + keyExpressions: Seq[Attribute], + child: SparkPlan, + stateId: Option[OperatorStateId] = None, + eventTimeWatermark: Option[Long] = None) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(keyExpressions) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + val result = baseIterator.filter { r => + val row = r.asInstanceOf[UnsafeRow] + val key = getKey(row) + val value = store.get(key) + if (value.isEmpty) { + store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + numUpdatedStateRows += 1 + numOutputRows += 1 + true + } else { + // Drop duplicated rows + false + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + watermarkPredicate.foreach(f => store.remove(f.eval _)) + store.commit() + numTotalStateRows += store.numKeys() + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +object StreamingDeduplicateExec { + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala new file mode 100644 index 0000000000000..7ea716231e5dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ + +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("deduplicate with all columns") { + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch("a"), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a"), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b"), + CheckLastBatch("b"), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("deduplicate with some columns") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a" -> 2), // Dropped + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("multiple deduplicates") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + + AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates` + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)), + + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with watermark") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(10 to 15: _*), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(25), + assertNumStateRows(total = 7, updated = 1), + + AddData(inputData, 25), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 45), // Advance watermark to 35 seconds + CheckLastBatch(45), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, 45), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0) + ) + } + + test("deduplicate with aggregate - append mode") { + val inputData = MemoryStream[Int] + val windowedaggregate = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedaggregate)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) (2 windows) + // states in deduplicate is 10 to 15 + assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) + // states in deduplicate is 10 to 15 and 25 + assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), + + AddData(inputData, 25), // Emit items less than watermark and drop their state + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate + // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of + // window to evict items, so [15, 20) is still in the state store) + // states in deduplicate is 25 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 40), // Advance watermark to 30 seconds + CheckLastBatch(), + // states in aggregate in [15, 20), [25, 30) and [40, 45) + // states in deduplicate is 25 and 40, + assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), + + AddData(inputData, 40), // Emit items less than watermark and drop their state + CheckLastBatch((15 -> 1), (25 -> 1)), + // states in aggregate in [40, 45) + // states in deduplicate is 40, + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + ) + } + + test("deduplicate with aggregate - update mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Update)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with aggregate - complete mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Complete)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("a" -> 3L, "b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with file sink") { + withTempDir { output => + withTempDir { checkpointDir => + val outputPath = output.getAbsolutePath + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + val q = result.writeStream + .format("parquet") + .outputMode(Append) + .option("checkpointLocation", checkpointDir.getPath) + .start(outputPath) + try { + inputData.addData("a") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("a") // Dropped + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("b") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a", "b") + } finally { + q.stop() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 0524898b15ead..6cf4d51f99333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { +class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } - - private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") - assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") - true - } } object MapGroupsWithStateSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala new file mode 100644 index 0000000000000..894786c50e238 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +trait StateStoreMetricsTest extends StreamTest { + + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = + AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert( + progressWithData.stateOperators.map(_.numRowsTotal) === total, + "incorrect total rows") + assert( + progressWithData.stateOperators.map(_.numRowsUpdated) === updated, + "incorrect updates rows") + true + } + + def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = + assertNumStateRows(Seq(total), Seq(updated)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 0296a2ade3459..f44cfada29e2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -338,7 +338,7 @@ class StreamSuite extends StreamTest { .writeStream .format("memory") .queryName("testquery") - .outputMode("complete") + .outputMode("append") .start() try { query.processAllAvailable() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index eca2647dea52b..0c8015672bab4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -35,7 +35,7 @@ object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { +class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() From f2317ce1ffcc987c04df5fa343fe946c25ca9d0d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 12:09:36 -0800 Subject: [PATCH 33/61] [SPARK-18699][SQL] Put malformed tokens into a new field when parsing CSV data ## What changes were proposed in this pull request? This pr added a logic to put malformed tokens into a new field when parsing CSV data in case of permissive modes. In the current master, if the CSV parser hits these malformed ones, it throws an exception below (and then a job fails); ``` Caused by: java.lang.IllegalArgumentException at java.sql.Date.valueOf(Date.java:143) at org.apache.spark.sql.catalyst.util.DateTimeUtils$.stringToTime(DateTimeUtils.scala:137) at org.apache.spark.sql.execution.datasources.csv.CSVTypeCast$$anonfun$castTo$6.apply$mcJ$sp(CSVInferSchema.scala:272) at org.apache.spark.sql.execution.datasources.csv.CSVTypeCast$$anonfun$castTo$6.apply(CSVInferSchema.scala:272) at org.apache.spark.sql.execution.datasources.csv.CSVTypeCast$$anonfun$castTo$6.apply(CSVInferSchema.scala:272) at scala.util.Try.getOrElse(Try.scala:79) at org.apache.spark.sql.execution.datasources.csv.CSVTypeCast$.castTo(CSVInferSchema.scala:269) at ``` In case that users load large CSV-formatted data, the job failure makes users get some confused. So, this fix set NULL for original columns and put malformed tokens in a new field. ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro Closes #16928 from maropu/SPARK-18699-2. --- python/pyspark/sql/readwriter.py | 32 +++++++--- python/pyspark/sql/streaming.py | 32 +++++++--- .../apache/spark/sql/DataFrameReader.scala | 18 ++++-- .../datasources/csv/CSVFileFormat.scala | 31 ++++++--- .../datasources/csv/CSVOptions.scala | 18 +++++- .../datasources/csv/UnivocityParser.scala | 62 ++++++++++++++---- .../sql/streaming/DataStreamReader.scala | 18 ++++-- .../resources/test-data/value-malformed.csv | 2 + .../execution/datasources/csv/CSVSuite.scala | 63 +++++++++++++++++-- 9 files changed, 223 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/value-malformed.csv diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bed390e60c96..b5e5b18bcbefa 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -191,10 +191,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -304,7 +307,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -366,11 +370,22 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] @@ -382,7 +397,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 965c8f6b269e9..bd19fd4e385b4 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -463,10 +463,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record and puts the malformed string into a new field configured by \ - ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ - ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ + schema. If a schema does not have the field, it drops corrupt records during \ + parsing. When inferring a schema, it implicitly adds a \ + ``columnNameOfCorruptRecord`` field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -558,7 +561,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -618,11 +622,22 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record, and puts the malformed string into a field configured by \ + ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ + a string type field named ``columnNameOfCorruptRecord`` in an \ + user-defined schema. If a schema does not have the field, it drops corrupt \ + records during parsing. When a length of parsed CSV tokens is shorter than \ + an expected length of a schema, it sets `null` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -636,7 +651,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2be22761e8dbc..59baf6e567721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -286,8 +286,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. *
    *
  • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
  • + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema. *
  • `DROPMALFORMED` : ignores the whole corrupted records.
  • *
  • `FAILFAST` : throws an exception when it meets corrupted records.
  • *
@@ -447,12 +450,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    *
  • + *
  • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 566f40f454393..59f2919edfe2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,9 +27,9 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ @@ -96,31 +96,44 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - + CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val parsedOptions = new CSVOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + // Check a field requirement for corrupt records here to throw an exception in a driver side + dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = dataSchema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + (file: PartitionedFile) => { val lines = { val conf = broadcastedHadoopConf.value.value val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, csvOptions.charset) + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) } } - val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) { + val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) { // Note that if there are only comments in the first block, the header would probably // be not dropped. - CSVUtils.dropHeaderLine(lines, csvOptions) + CSVUtils.dropHeaderLine(lines, parsedOptions) } else { lines } - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions) - val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) + val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions) + val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) filteredLines.flatMap(parser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index b7fbaa4f44a62..1caeec7c63945 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,11 +27,20 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} private[csv] class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String) + @transient private val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { - def this(parameters: Map[String, String], defaultTimeZoneId: String) = - this(CaseInsensitiveMap(parameters), defaultTimeZoneId) + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) @@ -95,6 +104,9 @@ private[csv] class CSVOptions( val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + val nullValue = parameters.getOrElse("nullValue", "") val nanValue = parameters.getOrElse("nanValue", "NaN") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 2e409b3f5fbfc..eb471651db2e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,8 +45,16 @@ private[csv] class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) + corruptFieldIndex.foreach { corrFieldIndex => + require(schema(corrFieldIndex).dataType == StringType) + require(schema(corrFieldIndex).nullable) + } + + private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) + private val valueConverters = - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val parser = new CsvParser(options.asParserSettings) @@ -54,7 +62,9 @@ private[csv] class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) - private val indexArr: Array[Int] = { + // This parser loads an `indexArr._1`-th position value in input tokens, + // then put the value in `row(indexArr._2)`. + private val indexArr: Array[(Int, Int)] = { val fields = if (options.dropMalformed) { // If `dropMalformed` is enabled, then it needs to parse all the values // so that we can decide which row is malformed. @@ -62,7 +72,17 @@ private[csv] class UnivocityParser( } else { requiredSchema } - fields.map(schema.indexOf(_: StructField)).toArray + // TODO: Revisit this; we need to clean up code here for readability. + // See an URL below for related discussions: + // https://github.com/apache/spark/pull/16928#discussion_r102636720 + val fieldsWithIndexes = fields.zipWithIndex + corruptFieldIndex.map { case corrFieldIndex => + fieldsWithIndexes.filter { case (_, i) => i != corrFieldIndex } + }.getOrElse { + fieldsWithIndexes + }.map { case (f, i) => + (dataSchema.indexOf(f), i) + }.toArray } /** @@ -148,6 +168,7 @@ private[csv] class UnivocityParser( case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) + // We don't actually hit this exception though, we keep it for understandability case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") } @@ -172,16 +193,16 @@ private[csv] class UnivocityParser( * the record is malformed). */ def parse(input: String): Option[InternalRow] = { - convertWithParseMode(parser.parseLine(input)) { tokens => + convertWithParseMode(input) { tokens => var i: Int = 0 while (i < indexArr.length) { - val pos = indexArr(i) + val (pos, rowIdx) = indexArr(i) // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. val value = valueConverters(pos).apply(tokens(pos)) if (i < requiredSchema.length) { - row(i) = value + row(rowIdx) = value } i += 1 } @@ -190,8 +211,9 @@ private[csv] class UnivocityParser( } private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && schema.length != tokens.length) { + input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { + val tokens = parser.parseLine(input) + if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") } @@ -202,14 +224,24 @@ private[csv] class UnivocityParser( } numMalformedRecords += 1 None - } else if (options.failFast && schema.length != tokens.length) { + } else if (options.failFast && dataSchema.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + s"${tokens.mkString(options.delimiter.toString)}") } else { - val checkedTokens = if (options.permissive && schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) - } else if (options.permissive && schema.length < tokens.length) { - tokens.take(schema.length) + // If a length of parsed tokens is not equal to expected one, it makes the length the same + // with the expected. If the length is shorter, it adds extra tokens in the tail. + // If longer, it drops extra tokens. + // + // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, + // we probably need to treat it as a malformed record. + // See an URL below for related discussions: + // https://github.com/apache/spark/pull/16928#discussion_r102657214 + val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { + if (dataSchema.length > tokens.length) { + tokens ++ new Array[String](dataSchema.length - tokens.length) + } else { + tokens.take(dataSchema.length) + } } else { tokens } @@ -217,6 +249,10 @@ private[csv] class UnivocityParser( try { Some(convert(checkedTokens)) } catch { + case NonFatal(e) if options.permissive => + val row = new GenericInternalRow(requiredSchema.length) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input)) + Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning("Parse exception. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 99943944f3c6d..f78e73f319de7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -168,8 +168,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * during parsing. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
    • + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` + * field in an output schema. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -245,12 +248,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` + * in an user-defined schema. If a schema does not have the field, it drops corrupt records + * during parsing. When a length of parsed CSV tokens is shorter than an expected length + * of a schema, it sets `null` for extra fields.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    *
  • + *
  • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv new file mode 100644 index 0000000000000..8945ed73d2e83 --- /dev/null +++ b/sql/core/src/test/resources/test-data/value-malformed.csv @@ -0,0 +1,2 @@ +0,2013-111-11 12:13:14 +1,1983-08-04 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 0c9a7298c3fa0..371d4311baa3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val numbersFile = "test-data/numbers.csv" private val datesFile = "test-data/dates.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" + private val valueMalformedFile = "test-data/value-malformed.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -700,12 +701,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[SparkException] { + msg = intercept[UnsupportedOperationException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() - }.getCause.getMessage - assert(msg.contains("Unsupported type: array")) + }.getMessage + assert(msg.contains("CSV data source does not support array data type.")) } } @@ -958,4 +959,58 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, Row(1, null)) } } + + test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schemaWithCorrField1) + .csv(testFile(valueMalformedFile)) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schemaWithCorrField2) + .csv(testFile(valueMalformedFile)) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } } From abe0198b7e93f41caa6c71a29c918e9f160f6d05 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Feb 2017 13:22:39 -0800 Subject: [PATCH 34/61] [SPARK-19706][PYSPARK] add Column.contains in pyspark ## What changes were proposed in this pull request? to be consistent with the scala API, we should also add `contains` to `Column` in pyspark. ## How was this patch tested? updated unit test Author: Wenchen Fan Closes #17036 from cloud-fan/pyspark. --- python/pyspark/sql/column.py | 1 + python/pyspark/sql/tests.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0df187a9d3c3d..c10ab9638a21f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -248,6 +248,7 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + contains = _bin_op("contains") rlike = _bin_op("rlike") like = _bin_op("like") startswith = _bin_op("startsWith") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9058443285aca..abd68bfd391a0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -944,7 +944,8 @@ def test_column_operators(self): self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ + cs.startswith('a'), cs.endswith('a') self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) From 589734afc96f4134724da8d809d4786161ba6e2e Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 23 Feb 2017 13:27:47 -0800 Subject: [PATCH 35/61] [SPARK-19684][DOCS] Remove developer info from docs. This commit moves developer-specific information from the release- specific documentation in this repo to the developer tools page on the main Spark website. This commit relies on this PR on the Spark website: https://github.com/apache/spark-website/pull/33. srowen Author: Kay Ousterhout Closes #17018 from kayousterhout/SPARK-19684. --- docs/building-spark.md | 43 +++++++++++------------------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 56b892696ee2c..8353b7a520b8e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -132,20 +132,6 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m $ cd core $ ../build/mvn scala:cc -## Speeding up Compilation with Zinc - -[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental -compiler. When run locally as a background process, it speeds up builds of Scala-based projects -like Spark. Developers who regularly recompile Spark with Maven will be the most interested in -Zinc. The project site gives instructions for building and running `zinc`; OS X users can -install it using `brew install zinc`. - -If using the `build/mvn` package `zinc` will automatically be downloaded and leveraged for all -builds. This process will auto-start after the first time `build/mvn` is called and bind to port -3030 unless the `ZINC_PORT` environment variable is set. The `zinc` process can subsequently be -shut down at any time by running `build/zinc-/bin/zinc -shutdown` and will automatically -restart whenever `build/mvn` is called. - ## Building with SBT Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. @@ -159,8 +145,14 @@ can be set to control the SBT build. For example: To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command -prompt. For more recommendations on reducing build time, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). +prompt. + +## Speeding up Compilation + +Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc +(for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for +developers who build with SBT). For more information about how to do this, refer to the +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems @@ -190,29 +182,16 @@ The following is an example of a command to run the tests: ./build/mvn test -The ScalaTest plugin also supports running only a specific Scala test suite as follows: - - ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.ReplSuite test - ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.* test - -or a Java test: - - ./build/mvn test -P... -DwildcardSuites=none -Dtest=org.apache.spark.streaming.JavaAPISuite - ## Testing with SBT The following is an example of a command to run the tests: ./build/sbt test -To run only a specific test suite as follows: - - ./build/sbt "test-only org.apache.spark.repl.ReplSuite" - ./build/sbt "test-only org.apache.spark.repl.*" - -To run test suites of a specific sub project as follows: +## Running Individual Tests - ./build/sbt core/test +For information about how to run individual tests, refer to the +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#running-individual-tests). ## PySpark pip installable From 0ff97a5faa93b2dab2df17f1c42df0144832098f Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 23 Feb 2017 14:31:16 -0800 Subject: [PATCH 36/61] [SPARK-19674][SQL] Ignore driver accumulator updates don't belong to the execution when merging all accumulator updates ## What changes were proposed in this pull request? In SQLListener.getExecutionMetrics, driver accumulator updates don't belong to the execution should be ignored when merging all accumulator updates to prevent NoSuchElementException. ## How was this patch tested? Updated unit test. Author: Carson Wang Closes #17009 from carsonwang/FixSQLMetrics. --- .../org/apache/spark/sql/execution/ui/SQLListener.scala | 7 +++++-- .../apache/spark/sql/execution/ui/SQLListenerSuite.scala | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5daf21595d8a2..12d3bc9281f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -343,10 +343,13 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging { accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { (accumulatorUpdate._1, accumulatorUpdate._2) } - }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + } val driverUpdates = executionUIData.driverAccumUpdates.toSeq - mergeAccumulatorUpdates(accumulatorUpdates ++ driverUpdates, accumulatorId => + val totalUpdates = (accumulatorUpdates ++ driverUpdates).filter { + case (id, _) => executionUIData.accumulatorMetrics.contains(id) + } + mergeAccumulatorUpdates(totalUpdates, accumulatorId => executionUIData.accumulatorMetrics(accumulatorId).metricType) case None => // This execution has been dropped diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 8aea112897fb3..e41c00ecec271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -147,6 +147,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + // Driver accumulator updates don't belong to this execution should be filtered and no + // exception will be thrown. + listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), From ef6e5d5f32e15867496d5cc58c9bf3bdbd4a4be3 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Thu, 23 Feb 2017 17:06:14 -0800 Subject: [PATCH 37/61] [SPARK-16122][DOCS] application environment rest api ## What changes were proposed in this pull request? follow up pr of #16949. ## How was this patch tested? jenkins Author: uncleGen Closes #17033 from uncleGen/doc-restapi-environment. --- docs/monitoring.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/monitoring.md b/docs/monitoring.md index 7ba4824d463fc..80519525af0c3 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -381,6 +381,10 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/streaming/batches/[batch-id]/operations/[outputOp-id] Details of the given operation and given batch. + + + /applications/[app-id]/environment + Environment details of the given application. From 22771f24be10765ba66ec94b494d44d2939d039e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 23 Feb 2017 18:05:58 -0800 Subject: [PATCH 38/61] [SPARK-14772][PYTHON][ML] Fixed Params.copy method to match Scala implementation ## What changes were proposed in this pull request? Fixed the PySpark Params.copy method to behave like the Scala implementation. The main issue was that it did not account for the _defaultParamMap and merged it into the explicitly created param map. ## How was this patch tested? Added new unit test to verify the copy method behaves correctly for copying uid, explicitly created params, and default params. Author: Bryan Cutler Closes #16772 from BryanCutler/pyspark-ml-param_copy-Scala_sync-SPARK-14772. --- python/pyspark/ml/param/__init__.py | 17 +++++++++++------ python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index dc3d23ff1661d..99d8fa3a5b73e 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -372,6 +372,7 @@ def copy(self, extra=None): extra = dict() that = copy.copy(self) that._paramMap = {} + that._defaultParamMap = {} return self._copyValues(that, extra) def _shouldOwn(self, param): @@ -452,12 +453,16 @@ def _copyValues(self, to, extra=None): :param extra: extra params to be copied :return: the target instance with param values copied """ - if extra is None: - extra = dict() - paramMap = self.extractParamMap(extra) - for p in self.params: - if p in paramMap and to.hasParam(p.name): - to._set(**{p.name: paramMap[p]}) + paramMap = self._paramMap.copy() + if extra is not None: + paramMap.update(extra) + for param in self.params: + # copy default params + if param in self._defaultParamMap and to.hasParam(param.name): + to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param] + # copy explicitly set params + if param in paramMap and to.hasParam(param.name): + to._set(**{param.name: paramMap[param]}) return to def _resetUid(self, newUid): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 53204cde29b74..293c6c0b0f36a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -389,6 +389,22 @@ def test_word2vec_param(self): # Check windowSize is set properly self.assertEqual(model.getWindowSize(), 6) + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + class EvaluatorTests(SparkSessionTestCase): From 1aad87afe10a7d868049839afa33b473ea696991 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 23 Feb 2017 20:18:21 -0800 Subject: [PATCH 39/61] [SPARK-17075][SQL] implemented filter estimation ## What changes were proposed in this pull request? We traverse predicate and evaluate the logical expressions to compute the selectivity of a FILTER operator. ## How was this patch tested? We add a new test suite to test various logical operators. Author: Ron Hu Closes #16395 from ron8hu/filterSelectivity. --- .../plans/logical/basicLogicalOperators.scala | 10 +- .../statsEstimation/FilterEstimation.scala | 511 ++++++++++++++++++ .../plans/logical/statsEstimation/Range.scala | 16 + .../FilterEstimationSuite.scala | 403 ++++++++++++++ 4 files changed, 939 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ce1c55dc089e6..ccebae3cc2701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -129,6 +129,14 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } + + override def computeStats(conf: CatalystConf): Statistics = { + if (conf.cboEnabled) { + FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + } else { + super.computeStats(conf) + } + } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala new file mode 100644 index 0000000000000..fcc607a610fcc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.{HashSet, Map} +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + + /** + * We use a mutable colStats because we need to update the corresponding ColumnStat + * for a column after we apply a predicate condition. For example, column c has + * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), + * we need to set the column's [min, max] value to [40, 100] after we evaluate the + * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. + */ + private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ + def estimate: Option[Statistics] = { + // We first copy child node's statistics and then modify it based on filter selectivity. + val stats: Statistics = plan.child.stats(catalystConf) + if (stats.rowCount.isEmpty) return None + + // save a mutable copy of colStats so that we can later change it recursively + mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + + // estimate selectivity of this filter predicate + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { + case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => 1.0 + } + + // attributeStats has mapping Attribute-to-ColumnStat. + // mutableColStats has mapping ExprId-to-ColumnStat. + // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) + // copy mutableColStats contents to an immutable AttributeMap. + val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = + mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) + val newColStats = AttributeMap(mutableAttributeStats.toSeq) + + val filteredRowCount: BigInt = + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes = + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) + + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + attributeStats = newColStats)) + } + + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + + condition match { + case And(cond1, cond2) => + (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + match { + case (Some(p1), Some(p2)) => Some(p1 * p2) + case (Some(p1), None) => Some(p1) + case (None, Some(p2)) => Some(p2) + case (None, None) => None + } + + case Or(cond1, cond2) => + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update = false) + val percent2 = calculateFilterSelectivity(cond2, update = false) + (percent1, percent2) match { + case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) + case (Some(p1), None) => Some(1.0) + case (None, Some(p2)) => Some(1.0) + case (None, None) => None + } + + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => None + } + + case _ => + calculateSingleCondition(condition, update) + } + } + + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param condition a single logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return Option[Double] value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + condition match { + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + + // EqualTo does not care about the order + case op @ EqualTo(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ EqualTo(l: Literal, ar: AttributeReference) => + evaluateBinary(op, ar, l, update) + + case op @ LessThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThan(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op @ GreaterThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThan(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + case In(ar: AttributeReference, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ar: AttributeReference, set) => + evaluateInSet(ar, set, update) + + case IsNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = true, update) + + case IsNotNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = false, update) + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + None + } + } + + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param attrRef an AttributeReference (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics collected for a given column. + */ + def evaluateIsNull( + attrRef: AttributeReference, + isNull: Boolean, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + val aColStat = mutableColStats(attrRef.exprId) + val rowCountValue = plan.child.stats(catalystConf).rowCount.get + val nullPercent: BigDecimal = + if (rowCountValue == 0) 0.0 + else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + + if (update) { + val newStats = + if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) + else aColStat.copy(nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + + val percent = + if (isNull) { + nullPercent.toDouble + } + else { + /** ISNOTNULL(column) */ + 1.0 - nullPercent.toDouble + } + + Some(percent) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column or wrong value. + */ + def evaluateBinary( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + op match { + case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + case _ => + attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attrRef, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) + None + } + } + } + + /** + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ + def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { + attrDataType match { + case DateType => + Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) + case TimestampType => + Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case StringType | BinaryType => + None + case _ => + Some(litValue) + } + } + + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateEqualTo( + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) + val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) + + if (inBoundary) { + + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + val newStats = aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + Some(1.0 / ndv.toDouble) + } else { + Some(0.0) + } + + } + + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column. + */ + + def evaluateInSet( + attrRef: AttributeReference, + hSet: Set[Any], + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val aType = attrRef.dataType + var newNdv: Long = 0 + + // use [min, max] to filter the original hSet + aType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] + + // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. + // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. + val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) + val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) + // We use hSetBigdecToAnyMap to help us find the original hSet value. + val hSetBigdecToAnyMap: Map[BigDecimal, Any] = + hSet.map(e => BigDecimal(e.toString) -> e).toMap + + if (validQuerySet.isEmpty) { + return Some(0.0) + } + + // Need to save new min/max using the external type value of the literal + val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) + val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => + newNdv = math.min(hSet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + } + + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. + Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForNumeric( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + var percent = 1.0 + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + // determine the overlapping degree between predicate range and column's range + val literalValueBD = BigDecimal(literal.value.toString) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case _: LessThan => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case _: LessThanOrEqual => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case _: GreaterThan => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case _: GreaterThanOrEqual => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // this is partial overlap case + var newMax = aColStat.max + var newMin = aColStat.min + var newNdv = ndv + val literalToDouble = literalValueBD.toDouble + val maxToDouble = BigDecimal(statsRange.max).toDouble + val minToDouble = BigDecimal(statsRange.min).toDouble + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + // For ease of computation, we convert all relevant numeric values to Double. + percent = op match { + case _: LessThan => + (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: LessThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble + else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: GreaterThan => + (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + case _: GreaterThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble + else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + } + + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + + if (update) { + op match { + case _: GreaterThan => newMin = newValue + case _: GreaterThanOrEqual => newMin = newValue + case _: LessThan => newMax = newValue + case _: LessThanOrEqual => newMax = newValue + } + + newNdv = math.max(math.round(ndv.toDouble * percent), 1) + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + } + + Some(percent) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 5aa6b9353bc4c..455711453272d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} @@ -57,6 +58,20 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } + def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { + case _: DefaultRange => true + case _: NullRange => false + case n: NumericRange => + val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { + if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + } else { + assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || + lit.dataType.isInstanceOf[TimestampType]) + new JDecimal(lit.value.toString) + } + n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 + } + /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. @@ -113,4 +128,5 @@ object Range { DateTimeUtils.toJavaTimestamp(n.max.longValue())) } } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala new file mode 100644 index 0000000000000..f139c9e28c6c5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * In this test suite, we test predicates containing the following operators: + * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN + */ +class FilterEstimationSuite extends StatsEstimationTestBase { + + // Suppose our test table has 10 rows and 6 columns. + // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val arInt = AttributeReference("cint", IntegerType)() + val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + val dMin = Date.valueOf("2017-01-01") + val dMax = Date.valueOf("2017-01-10") + val arDate = AttributeReference("cdate", DateType)() + val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through + // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). + val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") + val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") + val arTimestamp = AttributeReference("ctimestamp", TimestampType)() + val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + val decMin = new java.math.BigDecimal("0.200000000000000000") + val decMax = new java.math.BigDecimal("0.800000000000000000") + val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arDouble = AttributeReference("cdouble", DoubleType)() + val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Sixth column cstring has 10 String values: + // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" + val arString = AttributeReference("cstring", StringType)() + val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2) + + test("cint = 2") { + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + + test("cint = 0") { + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint < 3") { + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint < 0") { + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint <= 3") { + validateEstimatedStats( + arInt, + Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint > 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) + } + + test("cint > 10") { + // This is a corner case since max value is 10. + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint >= 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) + } + + test("cint IS NULL") { + validateEstimatedStats( + arInt, + Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 0, min = None, max = None, + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint IS NOT NULL") { + validateEstimatedStats( + arInt, + Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(10L) + ) + } + + test("cint > 3 AND cint <= 6") { + val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(4L) + ) + } + + test("cint = 3 OR cint = 6") { + val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) + } + + test("cint IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint NOT IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(7L) + ) + } + + test("cdate = cast('2017-01-02' AS DATE)") { + val d20170102 = Date.valueOf("2017-01-02") + validateEstimatedStats( + arDate, + Filter(EqualTo(arDate, Literal(d20170102)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + + test("cdate < cast('2017-01-03' AS DATE)") { + val d20170103 = Date.valueOf("2017-01-03") + validateEstimatedStats( + arDate, + Filter(LessThan(arDate, Literal(d20170103)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("""cdate IN ( cast('2017-01-03' AS DATE), + cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { + val d20170103 = Date.valueOf("2017-01-03") + val d20170104 = Date.valueOf("2017-01-04") + val d20170105 = Date.valueOf("2017-01-05") + validateEstimatedStats( + arDate, + Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { + val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") + validateEstimatedStats( + arTimestamp, + Filter(EqualTo(arTimestamp, Literal(ts2017010102)), + childStatsTestPlan(Seq(arTimestamp), 10L)), + ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { + val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") + validateEstimatedStats( + arTimestamp, + Filter(LessThan(arTimestamp, Literal(ts2017010103)), + childStatsTestPlan(Seq(arTimestamp), 10L)), + ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cdecimal = 0.400000000000000000") { + val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + validateEstimatedStats( + arDecimal, + Filter(EqualTo(arDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("cdecimal < 0.60 ") { + val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + validateEstimatedStats( + arDecimal, + Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cdouble < 3.0") { + validateEstimatedStats( + arDouble, + Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), + ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cstring = 'A2'") { + validateEstimatedStats( + arString, + Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(1L) + ) + } + + // There is no min/max statistics for String type. We estimate 10 rows returned. + test("cstring < 'A2'") { + validateEstimatedStats( + arString, + Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(10L) + ) + } + + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). + test("cint IN (1, 2, 3, 4, 5)") { + val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildStatsTestplan = StatsTestPlan( + outputList = Seq(arInt), + rowCount = 2L, + attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + ) + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) + } + + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { + StatsTestPlan( + outputList = outList, + rowCount = tableRowCount, + attributeStats = AttributeMap(Seq( + arInt -> childColStatInt, + arDate -> childColStatDate, + arTimestamp -> childColStatTimestamp, + arDecimal -> childColStatDecimal, + arDouble -> childColStatDouble, + arString -> childColStatString + )) + ) + } + + private def validateEstimatedStats( + ar: AttributeReference, + filterNode: Filter, + expectedColStats: ColumnStat, + rowCount: Option[BigInt] = None) + : Unit = { + + val expectedRowCount: BigInt = rowCount.getOrElse(0L) + val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) + val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) + + val filteredStats = filterNode.stats(conf) + assert(filteredStats.sizeInBytes == expectedSizeInBytes) + assert(filteredStats.rowCount == rowCount) + ar.dataType match { + case DecimalType() => + // Due to the internal transformation for DecimalType within engine, the new min/max + // in ColumnStat may have a different structure even it contains the right values. + // We convert them to Java BigDecimal values so that we can compare the entire object. + val generatedColumnStats = filteredStats.attributeStats(ar) + val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) + val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) + val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) + assert(outputColStats == expectedColStats) + case _ => + // For all other SQL types, we compare the entire object directly. + assert(filteredStats.attributeStats(ar) == expectedColStats) + } + } + +} From 9820db0cbfd5dced6e2d09185c13e0c008b7ea60 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 23 Feb 2017 22:57:23 -0800 Subject: [PATCH 40/61] [SPARK-19664][SQL] put hive.metastore.warehouse.dir in hadoopconf to overwrite its original value ## What changes were proposed in this pull request? In [SPARK-15959](https://issues.apache.org/jira/browse/SPARK-15959), we bring back the `hive.metastore.warehouse.dir` , while in the logic, when use the value of `spark.sql.warehouse.dir` to overwrite `hive.metastore.warehouse.dir` , it set it to `sparkContext.conf` which does not overwrite the value is hadoopConf, I think it should put in `sparkContext.hadoopConfiguration` and overwrite the original value of hadoopConf https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala#L64 ## How was this patch tested? N/A Author: windpiger Closes #16996 from windpiger/hivemetawarehouseConf. --- .../scala/org/apache/spark/sql/internal/SharedState.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 7ce9938f0d075..bce84de45c3d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -42,9 +42,12 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { val warehousePath: String = { val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") if (configFile != null) { + logInfo(s"loading hive config file: $configFile") sparkContext.hadoopConfiguration.addResource(configFile) } + // hive.metastore.warehouse.dir only stay in hadoopConf + sparkContext.conf.remove("hive.metastore.warehouse.dir") // Set the Hive metastore warehouse path to the one we use val hiveWarehouseDir = sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") if (hiveWarehouseDir != null && !sparkContext.conf.contains(WAREHOUSE_PATH.key)) { @@ -61,10 +64,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // When neither spark.sql.warehouse.dir nor hive.metastore.warehouse.dir is set, // we will set hive.metastore.warehouse.dir to the default value of spark.sql.warehouse.dir. val sparkWarehouseDir = sparkContext.conf.get(WAREHOUSE_PATH) - sparkContext.conf.set("hive.metastore.warehouse.dir", sparkWarehouseDir) + logInfo(s"Setting hive.metastore.warehouse.dir ('$hiveWarehouseDir') to the value of " + + s"${WAREHOUSE_PATH.key} ('$sparkWarehouseDir').") + sparkContext.hadoopConfiguration.set("hive.metastore.warehouse.dir", sparkWarehouseDir) sparkWarehouseDir } - } logInfo(s"Warehouse path is '$warehousePath'.") From a9a196d9bd9a75601f326ad6e2db6e465d5baef0 Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 24 Feb 2017 08:22:30 -0800 Subject: [PATCH 41/61] [SPARK-19161][PYTHON][SQL] Improving UDF Docstrings ## What changes were proposed in this pull request? Replaces `UserDefinedFunction` object returned from `udf` with a function wrapper providing docstring and arguments information as proposed in [SPARK-19161](https://issues.apache.org/jira/browse/SPARK-19161). ### Backward incompatible changes: - `pyspark.sql.functions.udf` will return a `function` instead of `UserDefinedFunction`. To ensure backward compatible public API we use function attributes to mimic `UserDefinedFunction` API (`func` and `returnType` attributes). This should have a minimal impact on the user code. An alternative implementation could use dynamical sub-classing. This would ensure full backward compatibility but is more fragile in practice. ### Limitations: Full functionality (retained docstring and argument list) is achieved only in the recent Python version. Legacy Python version will preserve only docstrings, but not argument list. This should be an acceptable trade-off between achieved improvements and overall complexity. ### Possible impact on other tickets: This can affect [SPARK-18777](https://issues.apache.org/jira/browse/SPARK-18777). ## How was this patch tested? Existing unit tests to ensure backward compatibility, additional tests targeting proposed changes. Author: zero323 Closes #16534 from zero323/SPARK-19161. --- python/pyspark/sql/functions.py | 11 ++++++++++- python/pyspark/sql/tests.py | 25 +++++++++++++++---------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d2617203140fa..426a4a8c93a67 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()): +----------+--------------+------------+ """ def _udf(f, returnType=StringType()): - return UserDefinedFunction(f, returnType) + udf_obj = UserDefinedFunction(f, returnType) + + @functools.wraps(f) + def wrapper(*args): + return udf_obj(*args) + + wrapper.func = udf_obj.func + wrapper.returnType = udf_obj.returnType + + return wrapper # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index abd68bfd391a0..fd083e4868cd6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,9 +266,6 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") - with self.assertRaises(ValueError): - data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() - def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -578,6 +575,21 @@ def as_double(x): [2, 3.0, "FOO", "foo", "foo", 3, 1.0] ) + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -963,13 +975,6 @@ def test_column_select(self): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - def test_column_alias_metadata(self): - df = self.df - df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) - self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') - with self.assertRaises(AssertionError): - df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) - def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() From 6b696a4bffde35111170ceced136d3f0bcfd406b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 24 Feb 2017 09:28:59 -0800 Subject: [PATCH 42/61] [SPARK-19707][CORE] Improve the invalid path check for sc.addJar ## What changes were proposed in this pull request? Currently in Spark there're two issues when we add jars with invalid path: * If the jar path is a empty string {--jar ",dummy.jar"}, then Spark will resolve it to the current directory path and add to classpath / file server, which is unwanted. This is happened in our programatic way to submit Spark application. From my understanding Spark should defensively filter out such empty path. * If the jar path is a invalid path (file doesn't exist), `addJar` doesn't check it and will still add to file server, the exception will be delayed until job running. Actually this local path could be checked beforehand, no need to wait until task running. We have similar check in `addFile`, but lacks similar similar mechanism in `addJar`. ## How was this patch tested? Add unit test and local manual verification. Author: jerryshao Closes #17038 from jerryshao/SPARK-19707. --- .../scala/org/apache/spark/SparkContext.scala | 12 ++++++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 2 +- .../org/apache/spark/SparkContextSuite.scala | 16 ++++++++++++++++ .../scala/org/apache/spark/util/UtilsSuite.scala | 1 + 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 17194b9f06d35..0e36a30c933d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1815,10 +1815,18 @@ class SparkContext(config: SparkConf) extends Logging { // A JAR file which exists only on the driver node case null | "file" => try { + val file = new File(uri.getPath) + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") + } + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") + } env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) null } // A JAR file which exists locally on every worker node diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 55382899a34d7..480240a93d4e5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1989,7 +1989,7 @@ private[spark] object Utils extends Logging { if (paths == null || paths.trim.isEmpty) { "" } else { - paths.split(",").map { p => Utils.resolveURI(p) }.mkString(",") + paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",") } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5a41e1c61908e..f97a112ec1276 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -292,6 +292,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("add jar with invalid path") { + val tmpDir = Utils.createTempDir() + val tmpJar = File.createTempFile("test", ".jar", tmpDir) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(tmpJar.getAbsolutePath) + + // Invaid jar path will only print the error log, will not add to file server. + sc.addJar("dummy.jar") + sc.addJar("") + sc.addJar(tmpDir.getAbsolutePath) + + sc.listJars().size should be (1) + sc.listJars().head should include (tmpJar.getName) + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 43f77e68c153c..c9cf651ecf759 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -507,6 +507,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4""", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") } + assertResolves(",jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2") } test("nonLocalPaths") { From 09f5f996063bb6611c4d1c4c28d78dd8c0ddad9d Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 24 Feb 2017 09:31:52 -0800 Subject: [PATCH 43/61] [SPARK-19038][YARN] Avoid overwriting keytab configuration in yarn-client ## What changes were proposed in this pull request? Because yarn#client will reset the `spark.yarn.keytab` configuration to point to the location in distributed file, so if user still uses the old `SparkConf` to create `SparkSession` with Hive enabled, it will read keytab from the path in distributed cached. This is OK for yarn cluster mode, but in yarn client mode where driver is running out of container, it will be failed to fetch the keytab. So here we should avoid reseting this configuration in the `yarn#client` and only overwriting it for AM, so using `spark.yarn.keytab` could get correct keytab path no matter running in client (keytab in local fs) or cluster (keytab in distributed cache) mode. ## How was this patch tested? Verified in security cluster. Author: jerryshao Closes #16923 from jerryshao/SPARK-19038. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 9 ++++++--- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 4 ---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a00234c2b416c..fa99cd3b64a4d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -100,6 +100,7 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null private var credentials: Credentials = null + private var amKeytabFileName: String = null private val launcherBackend = new LauncherBackend() { override def onStopRequest(): Unit = { @@ -471,7 +472,7 @@ private[spark] class Client( logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") val (_, localizedPath) = distribute(keytab, - destName = sparkConf.get(KEYTAB), + destName = Some(amKeytabFileName), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") } @@ -708,6 +709,9 @@ private[spark] class Client( // Save Spark configuration to a file in the archive. val props = new Properties() sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + // Override spark.yarn.key to point to the location in distributed cache which will be used + // by AM. + Option(amKeytabFileName).foreach { k => props.setProperty(KEYTAB.key, k) } confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8) props.store(writer, "Spark configuration.") @@ -995,8 +999,7 @@ private[spark] class Client( val f = new File(keytab) // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. - val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(KEYTAB.key, keytabFileName) + amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString sparkConf.set(PRINCIPAL.key, principal) } // Defensive copy of the credentials diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index dc9c3ff33542d..24dfd33bc3682 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -106,10 +106,6 @@ private[hive] class HiveClientImpl( // Set up kerberos credentials for UserGroupInformation.loginUser within // current class loader - // Instead of using the spark conf of the current spark context, a new - // instance of SparkConf is needed for the original value of spark.yarn.keytab - // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the - // keytab configuration for the link name in distributed cache if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { val principalName = sparkConf.get("spark.yarn.principal") val keytabFileName = sparkConf.get("spark.yarn.keytab") From 65152362fedca52b88e16b97e611f64a4ea091df Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 24 Feb 2017 09:46:42 -0800 Subject: [PATCH 44/61] [SPARK-17495][SQL] Add more tests for hive hash ## What changes were proposed in this pull request? This PR adds tests hive-hash by comparing the outputs generated against Hive 1.2.1. Following datatypes are covered by this PR: - null - boolean - byte - short - int - long - float - double - string - array - map - struct Datatypes that I have _NOT_ covered but I will work on separately are: - Decimal (handled separately in https://github.com/apache/spark/pull/17056) - TimestampType - DateType - CalendarIntervalType ## How was this patch tested? NA Author: Tejas Patil Closes #17049 from tejasapatil/SPARK-17495_remaining_types. --- .../sql/catalyst/expressions/HiveHasher.java | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 11 +- .../expressions/HashExpressionsSuite.scala | 247 +++++++++++++++++- 3 files changed, 252 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c7ea9085eba66..73577437ac506 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; /** - * Simulates Hive's hashing function at + * Simulates Hive's hashing function from Hive v1.2.1 * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() */ public class HiveHasher { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index e14f0544c2b81..2d9c2e42064b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction { } } - /** - * Simulates Hive's hashing function at - * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * Simulates Hive's hashing function from Hive v1.2.1 at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() * * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution @@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def hasherClassName: String = classOf[HiveHasher].getName override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, seed).toInt + HiveHashFunction.hash(value, dataType, this.seed).toInt } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = struct.numFields while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt i += 1 } result - case _ => super.hash(value, dataType, seed) + case _ => super.hash(value, dataType, 0) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 032629265269a..0cb3a79eee67d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + val random = new scala.util.Random test("md5") { checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), @@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { + // Note : All expected hashes need to be computed using Hive 1.2.1 + val actual = HiveHashFunction.hash(input, dataType, seed = 0) + + withClue(s"hash mismatch for input = `$input` of type `$dataType`.") { + assert(actual == expected) + } + } + + def checkHiveHashForIntegralType(dataType: DataType): Unit = { + // corner cases + checkHiveHash(null, dataType, 0) + checkHiveHash(1, dataType, 1) + checkHiveHash(0, dataType, 0) + checkHiveHash(-1, dataType, -1) + checkHiveHash(Int.MaxValue, dataType, Int.MaxValue) + checkHiveHash(Int.MinValue, dataType, Int.MinValue) + + // random values + for (_ <- 0 until 10) { + val input = random.nextInt() + checkHiveHash(input, dataType, input) + } + } + + test("hive-hash for null") { + checkHiveHash(null, NullType, 0) + } + + test("hive-hash for boolean") { + checkHiveHash(true, BooleanType, 1) + checkHiveHash(false, BooleanType, 0) + } + + test("hive-hash for byte") { + checkHiveHashForIntegralType(ByteType) + } + + test("hive-hash for short") { + checkHiveHashForIntegralType(ShortType) + } + + test("hive-hash for int") { + checkHiveHashForIntegralType(IntegerType) + } + + test("hive-hash for long") { + checkHiveHash(1L, LongType, 1L) + checkHiveHash(0L, LongType, 0L) + checkHiveHash(-1L, LongType, 0L) + checkHiveHash(Long.MaxValue, LongType, -2147483648) + // Hive's fails to parse this.. but the hashing function itself can handle this input + checkHiveHash(Long.MinValue, LongType, -2147483648) + + for (_ <- 0 until 10) { + val input = random.nextLong() + checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt) + } + } + + test("hive-hash for float") { + checkHiveHash(0F, FloatType, 0) + checkHiveHash(0.0F, FloatType, 0) + checkHiveHash(1.1F, FloatType, 1066192077L) + checkHiveHash(-1.1F, FloatType, -1081291571) + checkHiveHash(99999999.99999999999F, FloatType, 1287568416L) + checkHiveHash(Float.MaxValue, FloatType, 2139095039) + checkHiveHash(Float.MinValue, FloatType, -8388609) + } + + test("hive-hash for double") { + checkHiveHash(0, DoubleType, 0) + checkHiveHash(0.0, DoubleType, 0) + checkHiveHash(1.1, DoubleType, -1503133693) + checkHiveHash(-1.1, DoubleType, 644349955) + checkHiveHash(1000000000.000001, DoubleType, 1104006509) + checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501) + checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676) + checkHiveHash(Double.MaxValue, DoubleType, -2146435072) + checkHiveHash(Double.MinValue, DoubleType, 1048576) + } + + test("hive-hash for string") { + checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L) + checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L) + checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L) + checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L) + // scalastyle:off nonascii + checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L) + checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L) + // scalastyle:on nonascii + } + + test("hive-hash for array") { + // empty array + checkHiveHash( + input = new GenericArrayData(Array[Int]()), + dataType = ArrayType(IntegerType, containsNull = false), + expected = 0) + + // basic case + checkHiveHash( + input = new GenericArrayData(Array(1, 10000, Int.MaxValue)), + dataType = ArrayType(IntegerType, containsNull = false), + expected = -2147172688L) + + // with negative values + checkHiveHash( + input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)), + dataType = ArrayType(LongType, containsNull = false), + expected = -2147452680L) + + // with nulls only + val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true) + checkHiveHash( + input = new GenericArrayData(Array(null, null)), + dataType = arrayTypeWithNull, + expected = 0) + + // mix with null + checkHiveHash( + input = new GenericArrayData(Array(-12221, 89, null, 767)), + dataType = arrayTypeWithNull, + expected = -363989515) + + // nested with array + checkHiveHash( + input = new GenericArrayData( + Array( + new GenericArrayData(Array(1234L, -9L, 67L)), + new GenericArrayData(Array(null, null)), + new GenericArrayData(Array(55L, -100L, -2147452680L)) + )), + dataType = ArrayType(ArrayType(LongType)), + expected = -1007531064) + + // nested with map + checkHiveHash( + input = new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + )), + dataType = ArrayType(MapType(IntegerType, StringType)), + expected = 1139205955) + } + + test("hive-hash for map") { + val mapType = MapType(IntegerType, StringType) + + // empty map + checkHiveHash( + input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())), + dataType = mapType, + expected = 0) + + // basic case + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))), + dataType = mapType, + expected = 198872) + + // with null value + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(55, -99)), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))), + dataType = mapType, + expected = 1142704473) + + // nesting (only values can be nested as keys have to be primitive datatype) + val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType)) + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, -100)), + new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + ))), + dataType = nestedMapType, + expected = -1142817416) + } + + test("hive-hash for struct") { + // basic + val row = new GenericInternalRow(Array[Any](1, 2, 3)) + checkHiveHash( + input = row, + dataType = + new StructType() + .add("col1", IntegerType) + .add("col2", IntegerType) + .add("col3", IntegerType), + expected = 1026) + + // mix of several datatypes + val structType = new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("arrayOfString", arrayOfString) + .add("mapOfString", mapOfString) + + val rowValues = new ArrayBuffer[Any]() + rowValues += null + rowValues += true + rowValues += 1 + rowValues += 2 + rowValues += Int.MaxValue + rowValues += Long.MinValue + rowValues += new GenericArrayData(Array( + UTF8String.fromString("apache spark"), + UTF8String.fromString("hello world") + )) + rowValues += new ArrayBasedMapData( + new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null)) + ) + + val row2 = new GenericInternalRow(rowValues.toArray) + checkHiveHash( + input = row2, + dataType = structType, + expected = -2119012447) + } + private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) From 3375721948414ef8d0efa4a6e3a8eb05e0009d3e Mon Sep 17 00:00:00 2001 From: Shuai Lin Date: Fri, 24 Feb 2017 10:24:01 -0800 Subject: [PATCH 45/61] [SPARK-17075][SQL] Follow up: fix file line ending and improve the tests ## What changes were proposed in this pull request? Fixed the line ending of `FilterEstimation.scala` (It's still using `\n\r`). Also improved the tests to cover the cases where the literals are on the left side of a binary operator. ## How was this patch tested? Existing unit tests. Author: Shuai Lin Closes #17051 from lins05/fix-cbo-filter-file-encoding. --- .../statsEstimation/FilterEstimation.scala | 1022 ++++++++--------- .../FilterEstimationSuite.scala | 23 +- 2 files changed, 533 insertions(+), 512 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index fcc607a610fcc..37f29ba68a206 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -1,511 +1,511 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.statsEstimation - -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.{HashSet, Map} -import scala.collection.mutable - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ - -case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { - - /** - * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, column c has - * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), - * we need to set the column's [min, max] value to [40, 100] after we evaluate the - * first condition c > 40. We need to set the column's [min, max] value to [40, 50] - * after we evaluate the second condition c <= 50. - */ - private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty - - /** - * Returns an option of Statistics for a Filter logical plan node. - * For a given compound expression condition, this method computes filter selectivity - * (or the percentage of rows meeting the filter condition), which - * is used to compute row count, size in bytes, and the updated statistics after a given - * predicated is applied. - * - * @return Option[Statistics] When there is no statistics collected, it returns None. - */ - def estimate: Option[Statistics] = { - // We first copy child node's statistics and then modify it based on filter selectivity. - val stats: Statistics = plan.child.stats(catalystConf) - if (stats.rowCount.isEmpty) return None - - // save a mutable copy of colStats so that we can later change it recursively - mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) - - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } - - // attributeStats has mapping Attribute-to-ColumnStat. - // mutableColStats has mapping ExprId-to-ColumnStat. - // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - // copy mutableColStats contents to an immutable AttributeMap. - val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) - val newColStats = AttributeMap(mutableAttributeStats.toSeq) - - val filteredRowCount: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes = - EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) - - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), - attributeStats = newColStats)) - } - - /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. - * For logical AND conditions, we need to update stats after a condition estimation - * so that the stats will be more accurate for subsequent estimation. This is needed for - * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. - * - * @param condition the compound logical expression - * @param update a boolean flag to specify if we need to update ColumnStat of a column - * for subsequent conditions - * @return a double value to show the percentage of rows meeting a given condition. - * It returns None if the condition is not supported. - */ - def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { - - condition match { - case And(cond1, cond2) => - (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) - match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } - - case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None - } - - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } - - case _ => - calculateSingleCondition(condition, update) - } - } - - /** - * Returns a percentage of rows meeting a single condition in Filter node. - * Currently we only support binary predicates where one side is a column, - * and the other is a literal. - * - * @param condition a single logical expression - * @param update a boolean flag to specify if we need to update ColumnStat of a column - * for subsequent conditions - * @return Option[Double] value to show the percentage of rows meeting a given condition. - * It returns None if the condition is not supported. - */ - def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { - condition match { - // For evaluateBinary method, we assume the literal on the right side of an operator. - // So we will change the order if not. - - // EqualTo does not care about the order - case op @ EqualTo(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ EqualTo(l: Literal, ar: AttributeReference) => - evaluateBinary(op, ar, l, update) - - case op @ LessThan(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ LessThan(l: Literal, ar: AttributeReference) => - evaluateBinary(GreaterThan(ar, l), ar, l, update) - - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => - evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - - case op @ GreaterThan(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ GreaterThan(l: Literal, ar: AttributeReference) => - evaluateBinary(LessThan(ar, l), ar, l, update) - - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => - evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - - case In(ar: AttributeReference, expList) - if expList.forall(e => e.isInstanceOf[Literal]) => - // Expression [In (value, seq[Literal])] will be replaced with optimized version - // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. - // Here we convert In into InSet anyway, because they share the same processing logic. - val hSet = expList.map(e => e.eval()) - evaluateInSet(ar, HashSet() ++ hSet, update) - - case InSet(ar: AttributeReference, set) => - evaluateInSet(ar, set, update) - - case IsNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = true, update) - - case IsNotNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = false, update) - - case _ => - // TODO: it's difficult to support string operators without advanced statistics. - // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) - // | EndsWith(_, _) are not supported yet - logDebug("[CBO] Unsupported filter condition: " + condition) - None - } - } - - /** - * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. - * - * @param attrRef an AttributeReference (or a column) - * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return an optional double value to show the percentage of rows meeting a given condition - * It returns None if no statistics collected for a given column. - */ - def evaluateIsNull( - attrRef: AttributeReference, - isNull: Boolean, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.stats(catalystConf).rowCount.get - val nullPercent: BigDecimal = - if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) - - if (update) { - val newStats = - if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) - else aColStat.copy(nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) - } - - val percent = - if (isNull) { - nullPercent.toDouble - } - else { - /** ISNOTNULL(column) */ - 1.0 - nullPercent.toDouble - } - - Some(percent) - } - - /** - * Returns a percentage of rows meeting a binary comparison expression. - * - * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return an optional double value to show the percentage of rows meeting a given condition - * It returns None if no statistics exists for a given column or wrong value. - */ - def evaluateBinary( - op: BinaryComparison, - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - - op match { - case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) - case _ => - attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, attrRef, literal, update) - case StringType | BinaryType => - // TODO: It is difficult to support other binary comparisons for String/Binary - // type without min/max and advanced statistics like histogram. - logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) - None - } - } - } - - /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case StringType | BinaryType => - None - case _ => - Some(litValue) - } - } - - /** - * Returns a percentage of rows meeting an equality (=) expression. - * This method evaluates the equality predicate for all data types. - * - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return an optional double value to show the percentage of rows meeting a given condition - */ - def evaluateEqualTo( - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Option[Double] = { - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - - // decide if the value is in [min, max] of the column. - // We currently don't store min/max for binary/string type. - // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) - val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) - - if (inBoundary) { - - if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - val newStats = aColStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) - } - - Some(1.0 / ndv.toDouble) - } else { - Some(0.0) - } - - } - - /** - * Returns a percentage of rows meeting "IN" operator expression. - * This method evaluates the equality predicate for all data types. - * - * @param attrRef an AttributeReference (or a column) - * @param hSet a set of literal values - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return an optional double value to show the percentage of rows meeting a given condition - * It returns None if no statistics exists for a given column. - */ - - def evaluateInSet( - attrRef: AttributeReference, - hSet: Set[Any], - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val aType = attrRef.dataType - var newNdv: Long = 0 - - // use [min, max] to filter the original hSet - aType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - - // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. - // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. - val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) - val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) - // We use hSetBigdecToAnyMap to help us find the original hSet value. - val hSetBigdecToAnyMap: Map[BigDecimal, Any] = - hSet.map(e => BigDecimal(e.toString) -> e).toMap - - if (validQuerySet.isEmpty) { - return Some(0.0) - } - - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) - val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) - - // newNdv should not be greater than the old ndv. For example, column has only 2 values - // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) - if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) - } - - // We assume the whole set since there is no min/max information for String/Binary type - case StringType | BinaryType => - newNdv = math.min(hSet.size.toLong, ndv.longValue()) - if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) - } - } - - // return the filter selectivity. Without advanced statistics such as histograms, - // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) - } - - /** - * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. - * - * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return an optional double value to show the percentage of rows meeting a given condition - */ - def evaluateBinaryForNumeric( - op: BinaryComparison, - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Option[Double] = { - - var percent = 1.0 - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - - // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) - val (noOverlap: Boolean, completeOverlap: Boolean) = op match { - case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) - case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) - case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) - case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) - } - - if (noOverlap) { - percent = 0.0 - } else if (completeOverlap) { - percent = 1.0 - } else { - // this is partial overlap case - var newMax = aColStat.max - var newMin = aColStat.min - var newNdv = ndv - val literalToDouble = literalValueBD.toDouble - val maxToDouble = BigDecimal(statsRange.max).toDouble - val minToDouble = BigDecimal(statsRange.min).toDouble - - // Without advanced statistics like histogram, we assume uniform data distribution. - // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. - percent = op match { - case _: LessThan => - (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble - else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case _: GreaterThan => - (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) - case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble - else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) - } - - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - - if (update) { - op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue - } - - newNdv = math.max(math.round(ndv.toDouble * percent), 1) - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) - } - } - - Some(percent) - } - -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.{HashSet, Map} +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + + /** + * We use a mutable colStats because we need to update the corresponding ColumnStat + * for a column after we apply a predicate condition. For example, column c has + * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), + * we need to set the column's [min, max] value to [40, 100] after we evaluate the + * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. + */ + private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ + def estimate: Option[Statistics] = { + // We first copy child node's statistics and then modify it based on filter selectivity. + val stats: Statistics = plan.child.stats(catalystConf) + if (stats.rowCount.isEmpty) return None + + // save a mutable copy of colStats so that we can later change it recursively + mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + + // estimate selectivity of this filter predicate + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { + case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => 1.0 + } + + // attributeStats has mapping Attribute-to-ColumnStat. + // mutableColStats has mapping ExprId-to-ColumnStat. + // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) + // copy mutableColStats contents to an immutable AttributeMap. + val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = + mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) + val newColStats = AttributeMap(mutableAttributeStats.toSeq) + + val filteredRowCount: BigInt = + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes = + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) + + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + attributeStats = newColStats)) + } + + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + + condition match { + case And(cond1, cond2) => + (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + match { + case (Some(p1), Some(p2)) => Some(p1 * p2) + case (Some(p1), None) => Some(p1) + case (None, Some(p2)) => Some(p2) + case (None, None) => None + } + + case Or(cond1, cond2) => + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update = false) + val percent2 = calculateFilterSelectivity(cond2, update = false) + (percent1, percent2) match { + case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) + case (Some(p1), None) => Some(1.0) + case (None, Some(p2)) => Some(1.0) + case (None, None) => None + } + + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => None + } + + case _ => + calculateSingleCondition(condition, update) + } + } + + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param condition a single logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return Option[Double] value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + condition match { + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + + // EqualTo does not care about the order + case op @ EqualTo(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ EqualTo(l: Literal, ar: AttributeReference) => + evaluateBinary(op, ar, l, update) + + case op @ LessThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThan(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op @ GreaterThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThan(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + case In(ar: AttributeReference, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ar: AttributeReference, set) => + evaluateInSet(ar, set, update) + + case IsNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = true, update) + + case IsNotNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = false, update) + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + None + } + } + + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param attrRef an AttributeReference (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics collected for a given column. + */ + def evaluateIsNull( + attrRef: AttributeReference, + isNull: Boolean, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + val aColStat = mutableColStats(attrRef.exprId) + val rowCountValue = plan.child.stats(catalystConf).rowCount.get + val nullPercent: BigDecimal = + if (rowCountValue == 0) 0.0 + else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + + if (update) { + val newStats = + if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) + else aColStat.copy(nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + + val percent = + if (isNull) { + nullPercent.toDouble + } + else { + /** ISNOTNULL(column) */ + 1.0 - nullPercent.toDouble + } + + Some(percent) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column or wrong value. + */ + def evaluateBinary( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + op match { + case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + case _ => + attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attrRef, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) + None + } + } + } + + /** + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ + def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { + attrDataType match { + case DateType => + Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) + case TimestampType => + Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case StringType | BinaryType => + None + case _ => + Some(litValue) + } + } + + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateEqualTo( + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) + val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) + + if (inBoundary) { + + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + val newStats = aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + Some(1.0 / ndv.toDouble) + } else { + Some(0.0) + } + + } + + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column. + */ + + def evaluateInSet( + attrRef: AttributeReference, + hSet: Set[Any], + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val aType = attrRef.dataType + var newNdv: Long = 0 + + // use [min, max] to filter the original hSet + aType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] + + // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. + // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. + val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) + val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) + // We use hSetBigdecToAnyMap to help us find the original hSet value. + val hSetBigdecToAnyMap: Map[BigDecimal, Any] = + hSet.map(e => BigDecimal(e.toString) -> e).toMap + + if (validQuerySet.isEmpty) { + return Some(0.0) + } + + // Need to save new min/max using the external type value of the literal + val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) + val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => + newNdv = math.min(hSet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + } + + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. + Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForNumeric( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + var percent = 1.0 + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + // determine the overlapping degree between predicate range and column's range + val literalValueBD = BigDecimal(literal.value.toString) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case _: LessThan => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case _: LessThanOrEqual => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case _: GreaterThan => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case _: GreaterThanOrEqual => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // this is partial overlap case + var newMax = aColStat.max + var newMin = aColStat.min + var newNdv = ndv + val literalToDouble = literalValueBD.toDouble + val maxToDouble = BigDecimal(statsRange.max).toDouble + val minToDouble = BigDecimal(statsRange.min).toDouble + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + // For ease of computation, we convert all relevant numeric values to Double. + percent = op match { + case _: LessThan => + (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: LessThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble + else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: GreaterThan => + (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + case _: GreaterThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble + else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + } + + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + + if (update) { + op match { + case _: GreaterThan => newMin = newValue + case _: GreaterThanOrEqual => newMin = newValue + case _: LessThan => newMax = newValue + case _: LessThanOrEqual => newMax = newValue + } + + newNdv = math.max(math.round(ndv.toDouble * percent), 1) + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + } + + Some(percent) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index f139c9e28c6c5..f5e306f9e504d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -398,6 +398,27 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // For all other SQL types, we compare the entire object directly. assert(filteredStats.attributeStats(ar) == expectedColStats) } - } + // If the filter has a binary operator (including those nested inside + // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the + // operator, and then check again. + val rewrittenFilter = filterNode transformExpressionsDown { + case op @ EqualTo(ar: AttributeReference, l: Literal) => + EqualTo(l, ar) + + case op @ LessThan(ar: AttributeReference, l: Literal) => + GreaterThan(l, ar) + case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + GreaterThanOrEqual(l, ar) + + case op @ GreaterThan(ar: AttributeReference, l: Literal) => + LessThan(l, ar) + case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + LessThanOrEqual(l, ar) + } + + if (rewrittenFilter != filterNode) { + validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + } + } } From 9baadec43e44639998f91972f0955a0612de053e Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 24 Feb 2017 10:24:59 -0800 Subject: [PATCH 46/61] [SPARK-17078][SQL] Show stats when explain ## What changes were proposed in this pull request? Currently we can only check the estimated stats in logical plans by debugging. We need to provide an easier and more efficient way for developers/users. In this pr, we add EXPLAIN COST command to show stats in the optimized logical plan. E.g. ``` spark-sql> EXPLAIN COST select count(1) from store_returns; ... == Optimized Logical Plan == Aggregate [count(1) AS count(1)#24L], Statistics(sizeInBytes=16.0 B, rowCount=1, isBroadcastable=false) +- Project, Statistics(sizeInBytes=4.3 GB, rowCount=5.76E+8, isBroadcastable=false) +- Relation[sr_returned_date_sk#3,sr_return_time_sk#4,sr_item_sk#5,sr_customer_sk#6,sr_cdemo_sk#7,sr_hdemo_sk#8,sr_addr_sk#9,sr_store_sk#10,sr_reason_sk#11,sr_ticket_number#12,sr_return_quantity#13,sr_return_amt#14,sr_return_tax#15,sr_return_amt_inc_tax#16,sr_fee#17,sr_return_ship_cost#18,sr_refunded_cash#19,sr_reversed_charge#20,sr_store_credit#21,sr_net_loss#22] parquet, Statistics(sizeInBytes=28.6 GB, rowCount=5.76E+8, isBroadcastable=false) ... ``` ## How was this patch tested? Add test cases. Author: wangzhenhua Author: Zhenhua Wang Closes #16594 from wzhfy/showStats. --- .../scala/org/apache/spark/util/Utils.scala | 40 +++++++++++++------ .../org/apache/spark/util/UtilsSuite.scala | 5 ++- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 ++- .../catalyst/plans/logical/LogicalPlan.scala | 4 ++ .../catalyst/plans/logical/Statistics.scala | 12 +++++- .../spark/sql/catalyst/trees/TreeNode.scala | 27 +++++++++---- .../parser/TableIdentifierParserSuite.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 16 +++++++- .../spark/sql/execution/SparkSqlParser.scala | 6 ++- .../sql/execution/WholeStageCodegenExec.scala | 6 ++- .../sql/execution/command/commands.scala | 6 ++- .../spark/sql/StatisticsCollectionSuite.scala | 21 ++++++++++ .../sql/hive/execution/HiveExplainSuite.scala | 13 ++++++ 13 files changed, 132 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 480240a93d4e5..10e5233679562 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -1109,26 +1110,39 @@ private[spark] object Utils extends Logging { /** * Convert a quantity in bytes to a human-readable string such as "4.0 MB". */ - def bytesToString(size: Long): String = { + def bytesToString(size: Long): String = bytesToString(BigInt(size)) + + def bytesToString(size: BigInt): String = { + val EB = 1L << 60 + val PB = 1L << 50 val TB = 1L << 40 val GB = 1L << 30 val MB = 1L << 20 val KB = 1L << 10 - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") + if (size >= BigInt(1L << 11) * EB) { + // The number is too large, show it in scientific notation. + BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B" + } else { + val (value, unit) = { + if (size >= 2 * EB) { + (BigDecimal(size) / EB, "EB") + } else if (size >= 2 * PB) { + (BigDecimal(size) / PB, "PB") + } else if (size >= 2 * TB) { + (BigDecimal(size) / TB, "TB") + } else if (size >= 2 * GB) { + (BigDecimal(size) / GB, "GB") + } else if (size >= 2 * MB) { + (BigDecimal(size) / MB, "MB") + } else if (size >= 2 * KB) { + (BigDecimal(size) / KB, "KB") + } else { + (BigDecimal(size), "B") + } } + "%.1f %s".formatLocal(Locale.US, value, unit) } - "%.1f %s".formatLocal(Locale.US, value, unit) } /** diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c9cf651ecf759..8ed09749ffd54 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -200,7 +200,10 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.bytesToString(2097152) === "2.0 MB") assert(Utils.bytesToString(2306867) === "2.2 MB") assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") + assert(Utils.bytesToString(5L * (1L << 40)) === "5.0 TB") + assert(Utils.bytesToString(5L * (1L << 50)) === "5.0 PB") + assert(Utils.bytesToString(5L * (1L << 60)) === "5.0 EB") + assert(Utils.bytesToString(BigInt(1L << 11) * (1L << 60)) === "2.36E+21 B") } test("copyStream") { diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d8cd68e2d9e90..59f93b3c469d5 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -123,7 +123,8 @@ statement | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction - | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? @@ -693,7 +694,7 @@ nonReserved | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN + | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | COST | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF | SET | RESET | VIEW | REPLACE @@ -794,6 +795,7 @@ EXPLAIN: 'EXPLAIN'; FORMAT: 'FORMAT'; LOGICAL: 'LOGICAL'; CODEGEN: 'CODEGEN'; +COST: 'COST'; CAST: 'CAST'; SHOW: 'SHOW'; TABLES: 'TABLES'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0937825e273a2..e22b429aec68b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -115,6 +115,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) } + override def verboseStringWithSuffix: String = { + super.verboseString + statsCache.map(", " + _.toString).getOrElse("") + } + /** * Returns the maximum number of rows that this plan may compute. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 91404d4bb81b8..f24b240956a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.math.{MathContext, RoundingMode} + import scala.util.control.NonFatal import org.apache.spark.internal.Logging @@ -24,6 +26,7 @@ import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -54,8 +57,13 @@ case class Statistics( /** Readable string representation for the Statistics. */ def simpleString: String = { - Seq(s"sizeInBytes=$sizeInBytes", - if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", + Seq(s"sizeInBytes=${Utils.bytesToString(sizeInBytes)}", + if (rowCount.isDefined) { + // Show row count in scientific notation. + s"rowCount=${BigDecimal(rowCount.get, new MathContext(3, RoundingMode.HALF_UP)).toString()}" + } else { + "" + }, s"isBroadcastable=$isBroadcastable" ).filter(_.nonEmpty).mkString(", ") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f37661c315849..cc4c0835954ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -453,13 +453,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** ONE line description of this node with more information */ def verboseString: String + /** ONE line description of this node with some suffix information */ + def verboseStringWithSuffix: String = verboseString + override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ def treeString: String = treeString(verbose = true) - def treeString(verbose: Boolean): String = { - generateTreeString(0, Nil, new StringBuilder, verbose).toString + def treeString(verbose: Boolean, addSuffix: Boolean = false): String = { + generateTreeString(0, Nil, new StringBuilder, verbose = verbose, addSuffix = addSuffix).toString } /** @@ -524,7 +527,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { if (depth > 0) { lastChildren.init.foreach { isLast => @@ -533,22 +537,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(if (lastChildren.last) "+- " else ":- ") } + val str = if (verbose) { + if (addSuffix) verboseStringWithSuffix else verboseString + } else { + simpleString + } builder.append(prefix) - builder.append(if (verbose) verboseString else simpleString) + builder.append(str) builder.append("\n") if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose)) + depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose, + addSuffix = addSuffix)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose) + depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose, + addSuffix = addSuffix) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, builder, verbose, prefix)) + depth + 1, lastChildren :+ false, builder, verbose, prefix, addSuffix)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, builder, verbose, prefix) + depth + 1, lastChildren :+ true, builder, verbose, prefix, addSuffix) } builder diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 7d46011b410e2..170c469197e73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -25,8 +25,8 @@ class TableIdentifierParserSuite extends SparkFunSuite { // Add "$elem$", "$value$" & "$key$" val hiveNonReservedKeyword = Array("add", "admin", "after", "analyze", "archive", "asc", "before", "bucket", "buckets", "cascade", "change", "cluster", "clustered", "clusterstatus", "collection", - "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "data", - "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", + "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "cost", + "data", "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", "dependency", "desc", "directories", "directory", "disable", "distribute", "enable", "escaped", "exclusive", "explain", "export", "fields", "file", "fileformat", "first", "format", "formatted", "functions", "hold_ddltime", "hour", "idxproperties", "ignore", "index", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9d046c0766aa5..137f7ba04d572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -197,7 +197,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } - override def toString: String = { + override def toString: String = completeString(appendStats = false) + + def toStringWithStats: String = completeString(appendStats = true) + + private def completeString(appendStats: Boolean): String = { def output = Utils.truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") val analyzedPlan = Seq( @@ -205,12 +209,20 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { stringOrError(analyzed.treeString(verbose = true)) ).filter(_.nonEmpty).mkString("\n") + val optimizedPlanString = if (appendStats) { + // trigger to compute stats for logical plans + optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.treeString(verbose = true, addSuffix = true) + } else { + optimizedPlan.treeString(verbose = true) + } + s"""== Parsed Logical Plan == |${stringOrError(logical.treeString(verbose = true))} |== Analyzed Logical Plan == |$analyzedPlan |== Optimized Logical Plan == - |${stringOrError(optimizedPlan.treeString(verbose = true))} + |${stringOrError(optimizedPlanString)} |== Physical Plan == |${stringOrError(executedPlan.treeString(verbose = true))} """.stripMargin.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 1340aebc1ddd5..65df688689397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -283,7 +283,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (statement == null) { null // This is enough since ParseException will raise later. } else if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) + ExplainCommand( + logicalPlan = statement, + extended = ctx.EXTENDED != null, + codegen = ctx.CODEGEN != null, + cost = ctx.COST != null) } else { ExplainCommand(OneRowRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 2ead8f6baae6b..c58474eba05d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -254,7 +254,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "") } } @@ -428,7 +429,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co lastChildren: Seq[Boolean], builder: StringBuilder, verbose: Boolean, - prefix: String = ""): StringBuilder = { + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "*") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 58f507119325d..5de45b159684c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -88,11 +88,13 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { * @param logicalPlan plan to explain * @param extended whether to do extended explain or not * @param codegen whether to output generated code from whole-stage codegen or not + * @param cost whether to show cost information for operators. */ case class ExplainCommand( logicalPlan: LogicalPlan, extended: Boolean = false, - codegen: Boolean = false) + codegen: Boolean = false, + cost: Boolean = false) extends RunnableCommand { override val output: Seq[Attribute] = @@ -113,6 +115,8 @@ case class ExplainCommand( codegenString(queryExecution.executedPlan) } else if (extended) { queryExecution.toString + } else if (cost) { + queryExecution.toStringWithStats } else { queryExecution.simpleString } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bd1ce8aa3eb13..b38bbd8e7eef2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -170,6 +170,27 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) } + + test("number format in statistics") { + val numbers = Seq( + BigInt(0) -> ("0.0 B", "0"), + BigInt(100) -> ("100.0 B", "100"), + BigInt(2047) -> ("2047.0 B", "2.05E+3"), + BigInt(2048) -> ("2.0 KB", "2.05E+3"), + BigInt(3333333) -> ("3.2 MB", "3.33E+6"), + BigInt(4444444444L) -> ("4.1 GB", "4.44E+9"), + BigInt(5555555555555L) -> ("5.1 TB", "5.56E+12"), + BigInt(6666666666666666L) -> ("5.9 PB", "6.67E+15"), + BigInt(1L << 10 ) * (1L << 60) -> ("1024.0 EB", "1.18E+21"), + BigInt(1L << 11) * (1L << 60) -> ("2.36E+21 B", "2.36E+21") + ) + numbers.foreach { case (input, (expectedSize, expectedRows)) => + val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) + val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + + s" isBroadcastable=${stats.isBroadcastable}" + assert(stats.simpleString == expectedString) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f9751e3d5f2eb..cfca1d79836b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -27,6 +27,19 @@ import org.apache.spark.sql.test.SQLTestUtils */ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("show cost in explain command") { + // Only has sizeInBytes before ANALYZE command + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes") + checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount") + + // Has both sizeInBytes and rowCount after ANALYZE command + sql("ANALYZE TABLE src COMPUTE STATISTICS") + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes", "rowCount") + + // No cost information + checkKeywordsNotExist(sql("EXPLAIN SELECT * FROM src "), "sizeInBytes", "rowCount") + } + test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), "== Physical Plan ==") From c9ac47581878a40a05c4de522003ca781c5e84a9 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 24 Feb 2017 11:42:45 -0800 Subject: [PATCH 47/61] [SPARK-19560] Improve DAGScheduler tests. This commit improves the tests that check the case when a ShuffleMapTask completes successfully on an executor that has failed. This commit improves the commenting around the existing test for this, and adds some additional checks to make it more clear what went wrong if the tests fail (the fact that these tests are hard to understand came up in the context of markhamstra's proposed fix for #16620). This commit also removes a test that I realized tested exactly the same functionality. markhamstra, I verified that the new version of the test still fails (and in a more helpful way) for your proposed change for #16620. Author: Kay Ousterhout Closes #16892 from kayousterhout/SPARK-19560. --- .../spark/scheduler/DAGSchedulerSuite.scala | 58 ++++++++++++++++--- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c735220da2e15..8eaf9dfcf49b1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1569,24 +1569,45 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } - test("run trivial shuffle with out-of-band failure and retry") { + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from a task that ran on that executor. We want to make sure the + * stage is resubmitted so that the task that ran on the failed executor is re-executed, and + * that the stage is only marked as finished once that task completes. + */ + test("run trivial shuffle with out-of-band executor failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) - // blockManagerMaster.removeExecutor("exec-hostA") - // pretend we were told hostA went away + // Tell the DAGScheduler that hostA was lost. runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) - // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks - // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers the + // stage complete), but the tasks that ran on HostA need to be re-run, so the DAGScheduler + // should re-submit the stage with one task (the task that originally ran on HostA). + assert(taskSets.size === 2) + assert(taskSets(1).tasks.size === 1) + + // Make sure that the stage that was re-submitted was the ShuffleMapStage (not the reduce + // stage, which shouldn't be run until all of the tasks in the ShuffleMapStage complete on + // alive executors). + assert(taskSets(1).tasks(0).isInstanceOf[ShuffleMapTask]) + // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // Make sure that the reduce stage was now submitted. + assert(taskSets.size === 3) + assert(taskSets(2).tasks(0).isInstanceOf[ResultTask[_, _]]) + + // Complete the reduce stage. complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -2031,6 +2052,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported * as done until all tasks have completed. + * + * Most of the functionality in this test is tested in "run trivial shuffle with out-of-band + * executor failure and retry". However, that test uses ShuffleMapStages that are followed by + * a ResultStage, whereas in this test, the ShuffleMapStage is tested in isolation, without a + * ResultStage after it. */ test("map stage submission with executor failure late map task completions") { val shuffleMapRdd = new MyRDD(sc, 3, Nil) @@ -2042,7 +2068,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet - // Pretend host A was lost + // Pretend host A was lost. This will cause the TaskSetManager to resubmit task 0, because it + // completed on hostA. val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch @@ -2054,13 +2081,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // A completion from another task should work because it's a non-failed host runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers + // the stage complete), but the task that ran on hostA needs to be re-run, so the map stage + // shouldn't be marked as complete, and the DAGScheduler should re-submit the stage. + assert(results.size === 0) + assert(taskSets.size === 2) // Now complete tasks in the second task set val newTaskSet = taskSets(1) - assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on hostA + // 2 tasks should have been re-submitted, for tasks 0 and 1 (which ran on hostA). + assert(newTaskSet.tasks.size === 2) + // Complete task 0 from the original task set (i.e., not hte one that's currently active). + // This should still be counted towards the job being complete (but there's still one + // outstanding task). runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + assert(results.size === 0) + + // Complete the final task, from the currently active task set. There's still one + // running task, task 0 in the currently active stage attempt, but the success of task 0 means + // the DAGScheduler can mark the stage as finished. runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() From d33c915c1bbf7d8606c7ab427399adf9dc19b63f Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 24 Feb 2017 13:03:37 -0800 Subject: [PATCH 48/61] [SPARK-19597][CORE] test case for task deserialization errors Adds a test case that ensures that Executors gracefully handle a task that fails to deserialize, by sending back a reasonable failure message. This does not change any behavior (the prior behavior was already correct), it just adds a test case to prevent regression. Author: Imran Rashid Closes #16930 from squito/executor_task_deserialization. --- .../org/apache/spark/executor/Executor.scala | 2 + .../apache/spark/executor/ExecutorSuite.scala | 139 +++++++++++++----- 2 files changed, 108 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d762f11125516..975a6e4eeb33a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -148,6 +148,8 @@ private[spark] class Executor( startDriverHeartbeater() + private[executor] def numRunningTasks: Int = runningTasks.size() + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f94baaa30d18d..b743ff5376c49 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -17,16 +17,21 @@ package org.apache.spark.executor +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map +import scala.concurrent.duration._ -import org.mockito.Matchers._ -import org.mockito.Mockito.{mock, when} +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{inOrder, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.{FakeTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer -class ExecutorSuite extends SparkFunSuite { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock(classOf[SparkEnv]) - val mockRpcEnv = mock(classOf[RpcEnv]) - val mockMetricsSystem = mock(classOf[MetricsSystem]) - val mockMemoryManager = mock(classOf[MemoryManager]) - when(mockEnv.conf).thenReturn(conf) - when(mockEnv.serializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) - when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) - when(mockEnv.closureSerializer).thenReturn(serializer) - val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() - val serializedTask = serializer.newInstance().serialize( - new FakeTask(0, 0, Nil, fakeTaskMetrics)) - val taskDescription = new TaskDescription( - taskId = 0, - attemptNumber = 0, - executorId = "", - name = "", - index = 0, - addedFiles = Map[String, Long](), - addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite { val executorSuiteHelper = new ExecutorSuiteHelper - val mockExecutorBackend = mock(classOf[ExecutorBackend]) + val mockExecutorBackend = mock[ExecutorBackend] when(mockExecutorBackend.statusUpdate(any(), any(), any())) .thenAnswer(new Answer[Unit] { var firstTime = true @@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite { val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] - executorSuiteHelper.testFailedReason - = serializer.newInstance().deserialize(taskEndReason) + executorSuiteHelper.testFailedReason = + serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` executorSuiteHelper.latch3.countDown() } @@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite { var executor: Executor = null try { - executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread executor.launchTask(mockExecutorBackend, taskDescription) - executorSuiteHelper.latch1.await() + if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { + fail("executor did not send first status update in time") + } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. executor.killAllTasks(true) executorSuiteHelper.latch2.countDown() - executorSuiteHelper.latch3.await() + if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { + fail("executor did not send second status update in time") + } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` assert(executorSuiteHelper.testFailedReason === TaskKilled) @@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = createFakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg) + case _ => + fail(s"unexpected failure type: $failReason") + } + } + + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + + private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + val mockBackend = mock[ExecutorBackend] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + eventually(timeout(5 seconds), interval(10 milliseconds)) { + assert(executor.numRunningTasks === 0) + } + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + } } // Helps to test("SPARK-15963") @@ -145,3 +207,14 @@ private class ExecutorSuiteHelper { @volatile var taskState: TaskState = _ @volatile var testFailedReason: TaskFailedReason = _ } + +private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { + def writeExternal(out: ObjectOutput): Unit = {} + def readExternal(in: ObjectInput): Unit = { + throw new RuntimeException(NonDeserializableTask.errorMsg) + } +} + +private object NonDeserializableTask { + val errorMsg = "failure in deserialization" +} From 5bfab2e340b34ade364a1a95fabe57e109f06ca6 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 24 Feb 2017 15:04:42 -0800 Subject: [PATCH 49/61] [SPARK-13330][PYSPARK] PYTHONHASHSEED is not propgated to python worker ## What changes were proposed in this pull request? self.environment will be propagated to executor. Should set PYTHONHASHSEED as long as the python version is greater than 3.3 ## How was this patch tested? Manually tested it. Author: Jeff Zhang Closes #11211 from zjffdu/SPARK-13330. --- .../main/scala/org/apache/spark/deploy/PythonRunner.scala | 1 + python/pyspark/context.py | 6 ++---- python/pyspark/rdd.py | 3 ++- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0b1cec2df8303..a8f732b11f6cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -85,6 +85,7 @@ object PythonRunner { // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) + sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize try { val process = builder.start() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ac4b2b035f5c1..2961cda553d6a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,10 +173,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v - if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: - # disable randomness of hash of string in worker, if this is not - # launched by spark-submit - self.environment["PYTHONHASHSEED"] = "0" + + self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", "0") # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b384b2b507332..a5e6e2b054963 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -68,7 +68,8 @@ def portable_hash(x): >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ - if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + + if sys.version_info >= (3, 2, 3) and 'PYTHONHASHSEED' not in os.environ: raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") if x is None: diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index fa99cd3b64a4d..e86bd5459311d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -817,6 +817,7 @@ private[spark] class Client( sys.env.get(envname).foreach(env(envname) = _) } } + sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) } sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp => From d21df6163aefbed8801410308c7f58109013b1b3 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Fri, 24 Feb 2017 15:40:01 -0800 Subject: [PATCH 50/61] [SPARK-15355][CORE] Proactive block replication ## What changes were proposed in this pull request? We are proposing addition of pro-active block replication in case of executor failures. BlockManagerMasterEndpoint does all the book-keeping to keep a track of all the executors and the blocks they hold. It also keeps a track of which executors are alive through heartbeats. When an executor is removed, all this book-keeping state is updated to reflect the lost executor. This step can be used to identify executors that are still in possession of a copy of the cached data and a message could be sent to them to use the existing "replicate" function to find and place new replicas on other suitable hosts. Blocks replicated this way will let the master know of their existence. This can happen when an executor is lost, and would that way be pro-active as opposed be being done at query time. ## How was this patch tested? This patch was tested with existing unit tests along with new unit tests added to test the functionality. Author: Shubham Chopra Closes #14412 from shubhamchopra/ProactiveBlockReplication. --- .../apache/spark/storage/BlockManager.scala | 43 +++++++-- .../storage/BlockManagerMasterEndpoint.scala | 24 +++++ .../spark/storage/BlockManagerMessages.scala | 4 + .../storage/BlockManagerSlaveEndpoint.scala | 4 + .../BlockManagerReplicationSuite.scala | 95 +++++++++++++++---- docs/configuration.md | 9 ++ 6 files changed, 154 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6946a98cdda68..45b73380806dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1159,6 +1159,34 @@ private[spark] class BlockManager( } } + /** + * Called for pro-active replenishment of blocks lost due to executor failures + * + * @param blockId blockId being replicate + * @param existingReplicas existing block managers that have a replica + * @param maxReplicas maximum replicas needed + */ + def replicateBlock( + blockId: BlockId, + existingReplicas: Set[BlockManagerId], + maxReplicas: Int): Unit = { + logInfo(s"Pro-actively replicating $blockId") + blockInfoManager.lockForReading(blockId).foreach { info => + val data = doGetLocalBytes(blockId, info) + val storageLevel = StorageLevel( + useDisk = info.level.useDisk, + useMemory = info.level.useMemory, + useOffHeap = info.level.useOffHeap, + deserialized = info.level.deserialized, + replication = maxReplicas) + try { + replicate(blockId, data, storageLevel, info.classTag, existingReplicas) + } finally { + releaseLock(blockId) + } + } + } + /** * Replicate block to another node. Note that this is a blocking call that returns after * the block has been replicated. @@ -1167,7 +1195,8 @@ private[spark] class BlockManager( blockId: BlockId, data: ChunkedByteBuffer, level: StorageLevel, - classTag: ClassTag[_]): Unit = { + classTag: ClassTag[_], + existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) val tLevel = StorageLevel( @@ -1181,20 +1210,22 @@ private[spark] class BlockManager( val startTime = System.nanoTime - var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId] + var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] var numFailures = 0 + val initialPeers = getPeers(false).filterNot(existingReplicas.contains(_)) + var peersForReplication = blockReplicationPolicy.prioritize( blockManagerId, - getPeers(false), - mutable.HashSet.empty, + initialPeers, + peersReplicatedTo, blockId, numPeersToReplicateTo) while(numFailures <= maxReplicationFailures && - !peersForReplication.isEmpty && - peersReplicatedTo.size != numPeersToReplicateTo) { + !peersForReplication.isEmpty && + peersReplicatedTo.size < numPeersToReplicateTo) { val peer = peersForReplication.head try { val onePeerStartTime = System.nanoTime diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 145c434a4f0cf..84c04d22600ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -22,6 +22,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi @@ -65,6 +66,8 @@ class BlockManagerMasterEndpoint( mapper } + val proactivelyReplicate = conf.get("spark.storage.replication.proactive", "false").toBoolean + logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -195,17 +198,38 @@ class BlockManagerMasterEndpoint( // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId) locations -= blockManagerId + // De-register the block if none of the block managers have it. Otherwise, if pro-active + // replication is enabled, and a block is either an RDD or a test block (the latter is used + // for unit testing), we send a message to a randomly chosen executor location to replicate + // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks + // etc.) as replication doesn't make much sense in that context. if (locations.size == 0) { blockLocations.remove(blockId) + logWarning(s"No more replicas available for $blockId !") + } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) { + // As a heursitic, assume single executor failure to find out the number of replicas that + // existed before failure + val maxReplicas = locations.size + 1 + val i = (new Random(blockId.hashCode)).nextInt(locations.size) + val blockLocations = locations.toSeq + val candidateBMId = blockLocations(i) + blockManagerInfo.get(candidateBMId).foreach { bm => + val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) + val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + bm.slaveEndpoint.ask[Boolean](replicateMsg) + } } } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) logInfo(s"Removing block manager $blockManagerId") + } private def removeExecutor(execId: String) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index d71acbb4cf771..0aea438e7f473 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -32,6 +32,10 @@ private[spark] object BlockManagerMessages { // blocks that the master knows about. case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + // Replicate blocks that were lost due to executor failure + case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) + extends ToBlockManagerSlave + // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index d17ddbc162579..1aaa42459df69 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -74,6 +74,10 @@ class BlockManagerSlaveEndpoint( case TriggerThreadDump => context.reply(Utils.getThreadDump()) + + case ReplicateBlock(blockId, replicas, maxReplicas) => + context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas)) + } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f4bfdc2fd69a9..ccede34b8cb4d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -37,32 +37,31 @@ import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ -/** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends SparkFunSuite - with Matchers - with BeforeAndAfter - with LocalSparkContext { - - private val conf = new SparkConf(false).set("spark.app.id", "test") - private var rpcEnv: RpcEnv = null - private var master: BlockManagerMaster = null - private val securityMgr = new SecurityManager(conf) - private val bcastManager = new BroadcastManager(true, conf, securityMgr) - private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) - private val shuffleManager = new SortShuffleManager(conf) +trait BlockManagerReplicationBehavior extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + val conf: SparkConf + protected var rpcEnv: RpcEnv = null + protected var master: BlockManagerMaster = null + protected lazy val securityMgr = new SecurityManager(conf) + protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) + protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. - private val allStores = new ArrayBuffer[BlockManager] + protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer", "1m") - private val serializer = new KryoSerializer(conf) + + protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - private def makeBlockManager( + protected def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) @@ -355,7 +354,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite * is correct. Then it also drops the block from memory of each store (using LRU) and * again checks whether the master's knowledge gets updated. */ - private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { + protected def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { import org.apache.spark.storage.StorageLevel._ assert(maxReplication > 1, @@ -448,3 +447,61 @@ class BlockManagerReplicationSuite extends SparkFunSuite } } } + +class BlockManagerReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") +} + +class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.storage.replication.proactive", "true") + conf.set("spark.storage.exceptionOnPinLeak", "true") + + (2 to 5).foreach{ i => + test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { + testProactiveReplication(i) + } + } + + def testProactiveReplication(replicationFactor: Int) { + val blockSize = 1000 + val storeSize = 10000 + val initialStores = (1 to 10).map { i => makeBlockManager(storeSize, s"store$i") } + + val blockId = "a1" + + val storageLevel = StorageLevel(true, true, false, true, replicationFactor) + initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + + val blockLocations = master.getLocations(blockId) + logInfo(s"Initial locations : $blockLocations") + + assert(blockLocations.size === replicationFactor) + + // remove a random blockManager + val executorsToRemove = blockLocations.take(replicationFactor - 1) + logInfo(s"Removing $executorsToRemove") + executorsToRemove.foreach{exec => + master.removeExecutor(exec.executorId) + // giving enough time for replication to happen and new block be reported to master + Thread.sleep(200) + } + + // giving enough time for replication complete and locks released + Thread.sleep(500) + + val newLocations = master.getLocations(blockId).toSet + logInfo(s"New locations : $newLocations") + assert(newLocations.size === replicationFactor) + // there should only be one common block manager between initial and new locations + assert(newLocations.intersect(blockLocations.toSet).size === 1) + + // check if all the read locks have been released + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) + assert(locks.size === 0, "Read locks unreleased!") + } + } +} diff --git a/docs/configuration.md b/docs/configuration.md index 2fcb3a096aea5..63392a741a1f0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1000,6 +1000,15 @@ Apart from these, the following properties are also available, and may be useful storage space to unroll the new block in its entirety. + + spark.storage.replication.proactive + false + + Enables proactive block replication for RDD blocks. Cached RDD block replicas lost due to + executor failures are replenished if there are any existing available replicas. This tries + to get the replication level of the block to the initial number. + + ### Execution Behavior From 4a58e76c93d22dff8ea6004b92dbb06817b9893a Mon Sep 17 00:00:00 2001 From: Ramkumar Venkataraman Date: Sat, 25 Feb 2017 02:18:22 +0000 Subject: [PATCH 51/61] [MINOR][DOCS] Fix few typos in structured streaming doc ## What changes were proposed in this pull request? Minor typo in `even-time`, which is changed to `event-time` and a couple of grammatical errors fix. ## How was this patch tested? N/A - since this is a doc fix. I did a jekyll build locally though. Author: Ramkumar Venkataraman Closes #17037 from ramkumarvenkat/doc-fix. --- docs/structured-streaming-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index ad3b2fb26dd6e..6af47b6efba2c 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -392,7 +392,7 @@ data, thus relieving the users from reasoning about it. As an example, let’s see how this model handles event-time based processing and late arriving data. ## Handling Event-time and Late Data -Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of events every minute) to be just a special type of grouping and aggregation on the even-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier. +Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of events every minute) to be just a special type of grouping and aggregation on the event-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier. Furthermore, this model naturally handles data that has arrived later than expected based on its event-time. Since Spark is updating the Result Table, @@ -401,7 +401,7 @@ as well as cleaning up old aggregates to limit the size of intermediate state data. Since Spark 2.1, we have support for watermarking which allows the user to specify the threshold of late data, and allows the engine to accordingly clean up old state. These are explained later in more -details in the [Window Operations](#window-operations-on-event-time) section. +detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) @@ -647,7 +647,7 @@ df.groupBy("deviceType").count() ### Window Operations on Event Time -Aggregations over a sliding event-time window are straightforward with Structured Streaming. The key idea to understand about window-based aggregations are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration. +Aggregations over a sliding event-time window are straightforward with Structured Streaming and are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration. Imagine our [quick example](#quick-example) is modified and the stream now contains lines along with the time when the line was generated. Instead of running word counts, we want to count words within 10 minute windows, updating every 5 minutes. That is, word counts in words received between 10 minute windows 12:00 - 12:10, 12:05 - 12:15, 12:10 - 12:20, etc. Note that 12:00 - 12:10 means data that arrived after 12:00 but before 12:10. Now, consider a word that was received at 12:07. This word should increment the counts corresponding to two windows 12:00 - 12:10 and 12:05 - 12:15. So the counts will be indexed by both, the grouping key (i.e. the word) and the window (can be calculated from the event-time). @@ -713,7 +713,7 @@ old windows correctly, as illustrated below. ![Handling Late Data](img/structured-streaming-late-data.png) -However, to run this query for days, its necessary for the system to bound the amount of +However, to run this query for days, it's necessary for the system to bound the amount of intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced From 44b5215d0bc45854523a0a91cc6dc446c03f8510 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 24 Feb 2017 23:03:59 -0800 Subject: [PATCH 52/61] [SPARK-19735][SQL] Remove HOLD_DDLTIME from Catalog APIs ### What changes were proposed in this pull request? As explained in Hive JIRA https://issues.apache.org/jira/browse/HIVE-12224, HOLD_DDLTIME was broken as soon as it landed. Hive 2.0 removes HOLD_DDLTIME from the API. In Spark SQL, we always set it to FALSE. Like Hive, we should also remove it from our Catalog APIs. ### How was this patch tested? N/A Author: Xiao Li Closes #17063 from gatorsmile/removalHoldDDLTime. --- .../catalyst/catalog/ExternalCatalog.scala | 5 +--- .../catalyst/catalog/InMemoryCatalog.scala | 5 +--- .../sql/catalyst/catalog/SessionCatalog.scala | 6 ++--- .../spark/sql/execution/command/tables.scala | 2 -- .../spark/sql/hive/HiveExternalCatalog.scala | 10 ++------ .../spark/sql/hive/client/HiveClient.scala | 5 +--- .../sql/hive/client/HiveClientImpl.scala | 8 +------ .../spark/sql/hive/client/HiveShim.scala | 24 ++++++------------- .../hive/execution/InsertIntoHiveTable.scala | 9 +------ .../spark/sql/hive/client/VersionsSuite.scala | 5 +--- 10 files changed, 17 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 5233699facae0..a3a4ab37ea714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -125,7 +125,6 @@ abstract class ExternalCatalog { table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit /** @@ -140,7 +139,6 @@ abstract class ExternalCatalog { loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit @@ -150,8 +148,7 @@ abstract class ExternalCatalog { loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit + numDP: Int): Unit // -------------------------------------------------------------------------- // Partitions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 15aed5f9b1bdf..6bb2b2d4ff72e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -325,7 +325,6 @@ class InMemoryCatalog( table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadTable is not implemented") } @@ -336,7 +335,6 @@ class InMemoryCatalog( loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadPartition is not implemented.") @@ -348,8 +346,7 @@ class InMemoryCatalog( loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = { + numDP: Int): Unit = { throw new UnsupportedOperationException("loadDynamicPartitions is not implemented.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 73ef0e6a1869e..0230626a6644e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -322,13 +322,12 @@ class SessionCatalog( name: TableIdentifier, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime, isSrcLocal) + externalCatalog.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) } /** @@ -341,7 +340,6 @@ class SessionCatalog( loadPath: String, spec: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) @@ -350,7 +348,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(table, Some(db))) requireNonEmptyValueInPartitionSpec(Seq(spec)) externalCatalog.loadPartition( - db, table, loadPath, spec, isOverwrite, holdDDLTime, inheritTableSpecs, isSrcLocal) + db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } def defaultTablePath(tableIdent: TableIdentifier): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index d646a215c38c4..49407b44d7b8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -312,7 +312,6 @@ case class LoadDataCommand( loadPath.toString, partition.get, isOverwrite, - holdDDLTime = false, inheritTableSpecs = true, isSrcLocal = isLocal) } else { @@ -320,7 +319,6 @@ case class LoadDataCommand( targetTable.identifier, loadPath.toString, isOverwrite, - holdDDLTime = false, isSrcLocal = isLocal) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ea48256147857..50bb44f7d4e6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -736,14 +736,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = withClient { requireTableExists(db, table) client.loadTable( loadPath, s"$db.$table", isOverwrite, - holdDDLTime, isSrcLocal) } @@ -753,7 +751,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat loadPath: String, partition: TablePartitionSpec, isOverwrite: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = withClient { requireTableExists(db, table) @@ -773,7 +770,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table, orderedPartitionSpec, isOverwrite, - holdDDLTime, inheritTableSpecs, isSrcLocal) } @@ -784,8 +780,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat loadPath: String, partition: TablePartitionSpec, replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = withClient { + numDP: Int): Unit = withClient { requireTableExists(db, table) val orderedPartitionSpec = new util.LinkedHashMap[String, String]() @@ -803,8 +798,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table, orderedPartitionSpec, replace, - numDP, - holdDDLTime) + numDP) } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 8bdcf3111d8e1..16a80f9fff452 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -208,7 +208,6 @@ private[hive] trait HiveClient { tableName: String, partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit @@ -217,7 +216,6 @@ private[hive] trait HiveClient { loadPath: String, // TODO URI tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit /** Loads new dynamic partitions into an existing table. */ @@ -227,8 +225,7 @@ private[hive] trait HiveClient { tableName: String, partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit + numDP: Int): Unit /** Create a function in an existing database. */ def createFunction(db: String, func: CatalogFunction): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 24dfd33bc3682..c326ac4cc1a53 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -664,7 +664,6 @@ private[hive] class HiveClientImpl( tableName: String, partSpec: java.util.LinkedHashMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSrcLocal: Boolean): Unit = withHiveState { val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) @@ -674,7 +673,6 @@ private[hive] class HiveClientImpl( s"$dbName.$tableName", partSpec, replace, - holdDDLTime, inheritTableSpecs, isSkewedStoreAsSubdir = hiveTable.isStoredAsSubDirectories, isSrcLocal = isSrcLocal) @@ -684,14 +682,12 @@ private[hive] class HiveClientImpl( loadPath: String, // TODO URI tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = withHiveState { shim.loadTable( client, new Path(loadPath), tableName, replace, - holdDDLTime, isSrcLocal) } @@ -701,8 +697,7 @@ private[hive] class HiveClientImpl( tableName: String, partSpec: java.util.LinkedHashMap[String, String], replace: Boolean, - numDP: Int, - holdDDLTime: Boolean): Unit = withHiveState { + numDP: Int): Unit = withHiveState { val hiveTable = client.getTable(dbName, tableName, true /* throw exception */) shim.loadDynamicPartitions( client, @@ -711,7 +706,6 @@ private[hive] class HiveClientImpl( partSpec, replace, numDP, - holdDDLTime, listBucketingEnabled = hiveTable.isStoredAsSubDirectories) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index b052f1e7e43f5..9fe1c76d3325d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -96,7 +96,6 @@ private[client] sealed abstract class Shim { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit @@ -106,7 +105,6 @@ private[client] sealed abstract class Shim { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit def loadDynamicPartitions( @@ -116,7 +114,6 @@ private[client] sealed abstract class Shim { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit def createFunction(hive: Hive, db: String, func: CatalogFunction): Unit @@ -332,12 +329,11 @@ private[client] class Shim_v0_12 extends Shim with Logging { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) + JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) } override def loadTable( @@ -345,9 +341,8 @@ private[client] class Shim_v0_12 extends Shim with Logging { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -357,10 +352,9 @@ private[client] class Shim_v0_12 extends Shim with Logging { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean) } override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { @@ -703,12 +697,11 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { tableName: String, partSpec: JMap[String, String], replace: Boolean, - holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, isSrcLocal: JBoolean, JBoolean.FALSE) } @@ -717,9 +710,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { loadPath: Path, tableName: String, replace: Boolean, - holdDDLTime: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE, isSrcLocal: JBoolean, JBoolean.FALSE, JBoolean.FALSE) } @@ -730,10 +722,9 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE) } override def dropTable( @@ -818,10 +809,9 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { partSpec: JMap[String, String], replace: Boolean, numDP: Int, - holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3e654d8eeb355..5d5688ecb36b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -301,10 +301,6 @@ case class InsertIntoHiveTable( refreshFunction = _ => (), options = Map.empty) - // TODO: Correctly set holdDDLTime. - // In most of the time, we should have holdDDLTime = false. - // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. - val holdDDLTime = false if (partition.nonEmpty) { if (numDynamicPartitions > 0) { externalCatalog.loadDynamicPartitions( @@ -313,8 +309,7 @@ case class InsertIntoHiveTable( tmpLocation.toString, partitionSpec, overwrite, - numDynamicPartitions, - holdDDLTime = holdDDLTime) + numDynamicPartitions) } else { // scalastyle:off // ifNotExists is only valid with static partition, refer to @@ -357,7 +352,6 @@ case class InsertIntoHiveTable( tmpLocation.toString, partitionSpec, isOverwrite = doHiveOverwrite, - holdDDLTime = holdDDLTime, inheritTableSpecs = inheritTableSpecs, isSrcLocal = false) } @@ -368,7 +362,6 @@ case class InsertIntoHiveTable( table.catalogTable.identifier.table, tmpLocation.toString, // TODO: URI overwrite, - holdDDLTime, isSrcLocal = false) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index fe14824cf0967..6feb277ca88e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -175,7 +175,6 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w emptyDir, tableName = "src", replace = false, - holdDDLTime = false, isSrcLocal = false) } @@ -313,7 +312,6 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w "src_part", partSpec, replace = false, - holdDDLTime = false, inheritTableSpecs = false, isSrcLocal = false) } @@ -329,8 +327,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w "src_part", partSpec, replace = false, - numDP = 1, - holdDDLTime = false) + numDP = 1) } test(s"$version: renamePartitions") { From 0d899c6e312be8e97b5c02f790ef99c751bdc179 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 24 Feb 2017 23:05:36 -0800 Subject: [PATCH 53/61] [SPARK-19650] Commands should not trigger a Spark job Spark executes SQL commands eagerly. It does this by creating an RDD which contains the command's results. The downside to this is that any action on this RDD triggers a Spark job which is expensive and is unnecessary. This PR fixes this by avoiding the materialization of an `RDD` for `Command`s; it just materializes the result and puts them in a `LocalRelation`. Added a regression test to `SQLQuerySuite`. Author: Herman van Hovell Closes #17027 from hvanhovell/no-job-command. --- .../scala/org/apache/spark/sql/Dataset.scala | 20 ++++++--------- .../spark/sql/execution/QueryExecution.scala | 2 -- .../spark/sql/execution/SparkStrategies.scala | 3 +-- .../sql-tests/results/change-column.sql.out | 4 +-- .../results/group-by-ordinal.sql.out | 2 +- .../results/order-by-ordinal.sql.out | 2 +- .../sql-tests/results/outer-join.sql.out | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 25 +++++++++++++++++++ 8 files changed, 39 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c212d656e371..1b04623596073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter @@ -175,19 +175,13 @@ class Dataset[T] private[sql]( } @transient private[sql] val logicalPlan: LogicalPlan = { - def hasSideEffects(plan: LogicalPlan): Boolean = plan match { - case _: Command | - _: InsertIntoTable => true - case _ => false - } - + // For various commands (like DDL) and queries with side effects, we force query execution + // to happen right away to let these side effects take place eagerly. queryExecution.analyzed match { - // For various commands (like DDL) and queries with side effects, we force query execution - // to happen right away to let these side effects take place eagerly. - case p if hasSideEffects(p) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) - case Union(children) if children.forall(hasSideEffects) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) + case c: Command => + LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) + case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => + LocalRelation(u.output, queryExecution.executedPlan.executeCollect()) case _ => queryExecution.analyzed } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 137f7ba04d572..6ec2f4d840862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -125,8 +125,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => command.executeCollect().map(_.getString(1)) - case command: ExecutedCommandExec => - command.executeCollect().map(_.getString(0)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq // We need the types so we can output struct field names diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 027b1481af96b..20bf4925dbec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SaveMode, Strategy} +import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming._ diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 59eb56920cdcf..ba8bc936f0c79 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -196,7 +196,7 @@ SET spark.sql.caseSensitive=false -- !query 19 schema struct -- !query 19 output -spark.sql.caseSensitive +spark.sql.caseSensitive false -- !query 20 @@ -212,7 +212,7 @@ SET spark.sql.caseSensitive=true -- !query 21 schema struct -- !query 21 output -spark.sql.caseSensitive +spark.sql.caseSensitive true -- !query 22 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c64520ff93c83..c0930bbde69a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -177,7 +177,7 @@ set spark.sql.groupByOrdinal=false -- !query 17 schema struct -- !query 17 output -spark.sql.groupByOrdinal +spark.sql.groupByOrdinal false -- !query 18 diff --git a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out index 03a4e72d0fa3e..cc47cc67c87c8 100644 --- a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out @@ -114,7 +114,7 @@ set spark.sql.orderByOrdinal=false -- !query 9 schema struct -- !query 9 output -spark.sql.orderByOrdinal +spark.sql.orderByOrdinal false -- !query 10 diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out index cc50b9444bb4b..5db3bae5d0379 100644 --- a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out @@ -63,7 +63,7 @@ set spark.sql.crossJoin.enabled = true -- !query 5 schema struct -- !query 5 output -spark.sql.crossJoin.enabled +spark.sql.crossJoin.enabled true -- !query 6 @@ -85,4 +85,4 @@ set spark.sql.crossJoin.enabled = false -- !query 7 schema struct -- !query 7 output -spark.sql.crossJoin.enabled +spark.sql.crossJoin.enabled false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 40d0ce0992170..03cdfccdda555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.sql.Timestamp +import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -2564,4 +2566,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql(badQuery), Row(1) :: Nil) } + test("SPARK-19650: An action on a Command should not trigger a Spark job") { + // Create a listener that checks if new jobs have started. + val jobStarted = new AtomicBoolean(false) + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobStarted.set(true) + } + } + + // Make sure no spurious job starts are pending in the listener bus. + sparkContext.listenerBus.waitUntilEmpty(500) + sparkContext.addSparkListener(listener) + try { + // Execute the command. + sql("show databases").head() + + // Make sure we have seen all events triggered by DataFrame.show() + sparkContext.listenerBus.waitUntilEmpty(500) + } finally { + sparkContext.removeSparkListener(listener) + } + assert(!jobStarted.get(), "Command should not trigger a Spark job.") + } } From 310f34265441eb04571c20300f71efcaa0892164 Mon Sep 17 00:00:00 2001 From: Boaz Mohar Date: Sat, 25 Feb 2017 11:32:09 -0800 Subject: [PATCH 54/61] [MINOR][DOCS] Fixes two problems in the SQL programing guide page ## What changes were proposed in this pull request? Removed duplicated lines in sql python example and found a typo. ## How was this patch tested? Searched for other typo's in the page to minimize PR's. Author: Boaz Mohar Closes #17066 from boazmohar/doc-fix. --- docs/sql-programming-guide.md | 2 +- examples/src/main/python/sql/basic.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 235f5ecc40c9f..2dd1ab6ef3de1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1410,7 +1410,7 @@ Thrift JDBC server also supports sending thrift RPC messages over HTTP transport Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: hive.server2.transport.mode - Set this to value: http - hive.server2.thrift.http.port - HTTP port number fo listen on; default is 10001 + hive.server2.thrift.http.port - HTTP port number to listen on; default is 10001 hive.server2.http.endpoint - HTTP endpoint; default is cliservice To test, use beeline to connect to the JDBC/ODBC server in http mode with: diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index ebcf66995b477..c07fa8f2752b3 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -187,9 +187,6 @@ def programmatic_schema_example(spark): # Creates a temporary view using the DataFrame schemaPeople.createOrReplaceTempView("people") - # Creates a temporary view using the DataFrame - schemaPeople.createOrReplaceTempView("people") - # SQL can be run over DataFrames that have been registered as a table. results = spark.sql("SELECT name FROM people") From 4df0f94989df5df8c2697398cb17773f40123afa Mon Sep 17 00:00:00 2001 From: lvdongr Date: Sat, 25 Feb 2017 21:47:02 +0000 Subject: [PATCH 55/61] [SPARK-19673][SQL] "ThriftServer default app name is changed wrong" ## What changes were proposed in this pull request? In spark 1.x ,the name of ThriftServer is SparkSQL:localHostName. While the ThriftServer default name is changed to the className of HiveThfift2 , which is not appropriate. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: lvdongr Closes #17010 from lvdongr/ThriftserverName. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 78a309497ab57..c0b299411e94a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -40,6 +40,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeAppName = sparkConf .getOption("spark.app.name") .filterNot(_ == classOf[SparkSQLCLIDriver].getName) + .filterNot(_ == classOf[HiveThriftServer2].getName) sparkConf .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) From de0300fc690a17b22245962b611b8d78d62a7ee7 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Sat, 25 Feb 2017 21:48:41 +0000 Subject: [PATCH 56/61] [SPARK-15288][MESOS] Mesos dispatcher should handle gracefully when any thread gets UncaughtException ## What changes were proposed in this pull request? Adding the default UncaughtExceptionHandler to the MesosClusterDispatcher. ## How was this patch tested? I verified it manually, when any of the dispatcher thread gets uncaught exceptions then the default UncaughtExceptionHandler will handle those exceptions. Author: Devaraj K Closes #13072 from devaraj-kavali/SPARK-15288. --- .../org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 792ade8f0bdbd..38b082ac01197 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, SparkUncaughtExceptionHandler, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -97,6 +97,7 @@ private[mesos] object MesosClusterDispatcher with CommandLineUtils { override def main(args: Array[String]) { + Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) From 418d1a9bd8e74da7bbea89e90521b8dc4838dd68 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 25 Feb 2017 22:24:08 -0800 Subject: [PATCH 57/61] [MINOR][ML][DOC] Document default value for GeneralizedLinearRegression.linkPower Add Scaladoc for GeneralizedLinearRegression.linkPower default value Follow-up to https://github.com/apache/spark/pull/16344 Author: Joseph K. Bradley Closes #17069 from jkbradley/tweedie-comment. --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fdeadaf274971..110764dc074f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -109,6 +109,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the index in the power link function. Only applicable for the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. + * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" + * package. * * @group param */ From 71ef12fd33ff2df55b975cedb41e0abf8a884173 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 25 Feb 2017 23:01:44 -0800 Subject: [PATCH 58/61] [SPARK-17075][SQL][FOLLOWUP] fix some minor issues and clean up the code ## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/16395. It fixes some code style issues, naming issues, some missing cases in pattern match, etc. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #17065 from cloud-fan/follow-up. --- .../catalyst/expressions/AttributeMap.scala | 2 +- .../statsEstimation/FilterEstimation.scala | 330 +++++++++--------- .../statsEstimation/JoinEstimation.scala | 91 ++--- .../plans/logical/statsEstimation/Range.scala | 36 +- .../FilterEstimationSuite.scala | 174 ++++----- 5 files changed, 296 insertions(+), 337 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 1504a522798b0..9f4a0f2b7017a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -28,7 +28,7 @@ object AttributeMap { } } -class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) +class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 37f29ba68a206..0c928832d7d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.{HashSet, Map} +import scala.collection.immutable.HashSet import scala.collection.mutable import org.apache.spark.internal.Logging @@ -31,15 +29,16 @@ import org.apache.spark.sql.types._ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + private val childStats = plan.child.stats(catalystConf) + /** - * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, column c has - * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), - * we need to set the column's [min, max] value to [40, 100] after we evaluate the - * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * We will update the corresponding ColumnStats for a column after we apply a predicate condition. + * For example, column c has [min, max] value as [0, 100]. In a range condition such as + * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we + * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50] * after we evaluate the second condition c <= 50. */ - private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + private val colStatsMap = new ColumnStatsMap /** * Returns an option of Statistics for a Filter logical plan node. @@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return Option[Statistics] When there is no statistics collected, it returns None. */ def estimate: Option[Statistics] = { - // We first copy child node's statistics and then modify it based on filter selectivity. - val stats: Statistics = plan.child.stats(catalystConf) - if (stats.rowCount.isEmpty) return None + if (childStats.rowCount.isEmpty) return None // save a mutable copy of colStats so that we can later change it recursively - mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + colStatsMap.setInitValues(childStats.attributeStats) // estimate selectivity of this filter predicate val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { @@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => 1.0 } - // attributeStats has mapping Attribute-to-ColumnStat. - // mutableColStats has mapping ExprId-to-ColumnStat. - // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - // copy mutableColStats contents to an immutable AttributeMap. - val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) - val newColStats = AttributeMap(mutableAttributeStats.toSeq) + val newColStats = colStatsMap.toColumnStats val filteredRowCount: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes = + EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes: BigInt = EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) } @@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a double value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { - condition match { case And(cond1, cond2) => - (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) - match { + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update) + val percent2 = calculateFilterSelectivity(cond2, update) + (percent1, percent2) match { case (Some(p1), Some(p2)) => Some(p1 * p2) case (Some(p1), None) => Some(p1) case (None, Some(p2)) => Some(p2) @@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => None } - case _ => - calculateSingleCondition(condition, update) + case _ => calculateSingleCondition(condition, update) } } @@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition a single logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return Option[Double] value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { @@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. - // EqualTo does not care about the order - case op @ EqualTo(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ EqualTo(l: Literal, ar: AttributeReference) => - evaluateBinary(op, ar, l, update) + // EqualTo/EqualNullSafe does not care about the order + case op @ Equality(ar: Attribute, l: Literal) => + evaluateEquality(ar, l, update) + case op @ Equality(l: Literal, ar: Attribute) => + evaluateEquality(ar, l, update) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case op @ LessThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThan(l: Literal, ar: AttributeReference) => + case op @ LessThan(l: Literal, ar: Attribute) => evaluateBinary(GreaterThan(ar, l), ar, l, update) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ LessThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ LessThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case op @ GreaterThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThan(l: Literal, ar: AttributeReference) => + case op @ GreaterThan(l: Literal, ar: Attribute) => evaluateBinary(LessThan(ar, l), ar, l, update) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: AttributeReference, expList) + case In(ar: Attribute, expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. @@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: AttributeReference, set) => + case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = true, update) + case IsNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = false, update) + case IsNotNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = false, update) case _ => // TODO: it's difficult to support string operators without advanced statistics. @@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics collected for a given column. */ - def evaluateIsNull( - attrRef: AttributeReference, + def evaluateNullCheck( + attr: Attribute, isNull: Boolean, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.stats(catalystConf).rowCount.get - val nullPercent: BigDecimal = - if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + val colStat = colStatsMap(attr) + val rowCountValue = childStats.rowCount.get + val nullPercent: BigDecimal = if (rowCountValue == 0) { + 0 + } else { + BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + } if (update) { - val newStats = - if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) - else aColStat.copy(nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) + val newStats = if (isNull) { + colStat.copy(distinctCount = 0, min = None, max = None) + } else { + colStat.copy(nullCount = 0) + } + colStatsMap(attr) = newStats } - val percent = - if (isNull) { - nullPercent.toDouble - } - else { - /** ISNOTNULL(column) */ - 1.0 - nullPercent.toDouble - } + val percent = if (isNull) { + nullPercent.toDouble + } else { + 1.0 - nullPercent.toDouble + } Some(percent) } @@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinary( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - - op match { - case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + update: Boolean): Option[Double] = { + attr.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attr, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) + None case _ => - attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, attrRef, literal, update) - case StringType | BinaryType => - // TODO: It is difficult to support other binary comparisons for String/Binary - // type without min/max and advanced statistics like histogram. - logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) - None - } + // TODO: support boolean type. + None } } @@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) case TimestampType => Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case _: DecimalType => + Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) case StringType | BinaryType => None case _ => @@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateEqualTo( - attrRef: AttributeReference, + def evaluateEquality( + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) - val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) - - if (inBoundary) { - + val statsRange = Range(colStat.min, colStat.max, attr.dataType) + if (statsRange.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate. // Set distinctCount to 1. Set nullCount to 0. // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - val newStats = aColStat.copy(distinctCount = 1, min = newValue, + val newValue = convertBoundValue(attr.dataType, literal.value) + val newStats = colStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } Some(1.0 / ndv.toDouble) @@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param hSet a set of literal values * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateInSet( - attrRef: AttributeReference, + attr: Attribute, hSet: Set[Any], - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val aType = attrRef.dataType - var newNdv: Long = 0 + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + val dataType = attr.dataType + var newNdv = ndv // use [min, max] to filter the original hSet - aType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - - // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. - // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. - val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) - val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) - // We use hSetBigdecToAnyMap to help us find the original hSet value. - val hSetBigdecToAnyMap: Map[BigDecimal, Any] = - hSet.map(e => BigDecimal(e.toString) -> e).toMap + dataType match { + case _: NumericType | BooleanType | DateType | TimestampType => + val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val validQuerySet = hSet.filter { v => + v != null && statsRange.contains(Literal(v, dataType)) + } if (validQuerySet.isEmpty) { return Some(0.0) } // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) - val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + val newMax = convertBoundValue( + attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) + val newMin = convertBoundValue( + attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => - newNdv = math.min(hSet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + colStatsMap(attr) = newStats } } @@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * This method evaluate expression for Numeric columns only. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinaryForNumeric( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { + update: Boolean): Option[Double] = { var percent = 1.0 - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + val colStat = colStatsMap(attr) val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] // determine the overlapping degree between predicate range and column's range val literalValueBD = BigDecimal(literal.value.toString) @@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo percent = 1.0 } else { // this is partial overlap case - var newMax = aColStat.max - var newMin = aColStat.min - var newNdv = ndv - val literalToDouble = literalValueBD.toDouble - val maxToDouble = BigDecimal(statsRange.max).toDouble - val minToDouble = BigDecimal(statsRange.min).toDouble + val literalDouble = literalValueBD.toDouble + val maxDouble = BigDecimal(statsRange.max).toDouble + val minDouble = BigDecimal(statsRange.min).toDouble // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. // For ease of computation, we convert all relevant numeric values to Double. percent = op match { case _: LessThan => - (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + (literalDouble - minDouble) / (maxDouble - minDouble) case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble - else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.min)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (literalDouble - minDouble) / (maxDouble - minDouble) + } case _: GreaterThan => - (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + (maxDouble - literalDouble) / (maxDouble - minDouble) case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble - else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.max)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (maxDouble - literalDouble) / (maxDouble - minDouble) + } } - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - if (update) { + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attr.dataType, literal.value) + var newMax = colStat.max + var newMin = colStat.min op match { case _: GreaterThan => newMin = newValue case _: GreaterThanOrEqual => newMin = newValue @@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: LessThanOrEqual => newMax = newValue } - newNdv = math.max(math.round(ndv.toDouble * percent), 1) - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } } @@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } } + +class ColumnStatsMap { + private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty + + def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = { + baseMap.clear() + baseMap ++= colStats.baseMap + } + + def contains(a: Attribute): Boolean = baseMap.contains(a.exprId) + + def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2 + + def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats) + + def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 982a5a8bb89be..9782c0bb0a939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging case _ if !rowCountsExist(conf, join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val selectivity = joinSelectivity(joinKeyPairs) @@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (innerJoinedRows == 0) { + } else if (selectivity == 0) { joinType match { - // For outer joins, if the inner join part is empty, the number of output rows is the + // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side // keep unchanged, while column stats of join keys from the other side should be updated // based on added null values. @@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } case _ => Nil } + } else if (selectivity == 1) { + // Cartesian product, just propagate the original column stats + inputAttrStats.toSeq } else { val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { @@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - isBroadcastable = false)) + attributeStats = outputAttrStats)) case _ => // When there is no equi-join condition, we do estimation like cartesian product. @@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), rowCount = Some(outputRows), - attributeStats = inputAttrStats, - isBroadcastable = false)) + attributeStats = inputAttrStats)) } // scalastyle:off @@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } if (ndvDenom < 0) { - // There isn't join keys or column stats for any of the join key pairs, we do estimation like - // cartesian product. + // We can't find any join key pairs with column stats, estimate it as cartesian join. 1 } else if (ndvDenom == 0) { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. @@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging /** * Propagate or update column stats for output attributes. - * 1. For cartesian product, all values are preserved, so there's no need to change column stats. - * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update - * distinct count of other attributes based on output rows after join. */ private def updateAttrStats( outputRows: BigInt, @@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - if (outputRows == leftRows * rightRows) { - // Cartesian product, just propagate the original column stats - attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) - } else { - val leftRatio = - if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) - val rightRatio = - if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) - attributes.foreach { a => - // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats += a -> joinKeyStats(a) + } else { + val leftRatio = if (leftRows != 0) { + BigDecimal(outputRows) / BigDecimal(leftRows) + } else { + BigDecimal(0) + } + val rightRatio = if (rightRows != 0) { + BigDecimal(outputRows) / BigDecimal(rightRows) } else { - val oldColStat = oldAttrStats(a) - val oldNdv = oldColStat.distinctCount - // We only change (scale down) the number of distinct values if the number of rows - // decreases after join, because join won't produce new values even if the number of - // rows increases. - val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { - ceil(BigDecimal(oldNdv) * leftRatio) - } else if (join.right.outputSet.contains(a) && rightRatio < 1) { - ceil(BigDecimal(oldNdv) * rightRatio) - } else { - oldNdv - } - // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + BigDecimal(0) } + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) } + } outputAttrStats } @@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) - intersectedStats.put(leftKey, - leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) - intersectedStats.put(rightKey, - rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) + val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + intersectedStats.put(leftKey, newStats) + intersectedStats.put(rightKey, newStats) } AttributeMap(intersectedStats.toSeq) } @@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), rowCount = Some(outputRows), - attributeStats = leftStats.attributeStats, - isBroadcastable = false)) + attributeStats = leftStats.attributeStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 455711453272d..3d13967cb62a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} /** Value range of a column. */ -trait Range +trait Range { + def contains(l: Literal): Boolean +} /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range +case class NumericRange(min: JDecimal, max: JDecimal) extends Range { + override def contains(l: Literal): Boolean = { + val decimal = l.dataType match { + case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + case _ => new JDecimal(l.value.toString) + } + min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + } +} /** * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range +class DefaultRange extends Range { + override def contains(l: Literal): Boolean = true +} /** This is for columns with only null values. */ -class NullRange extends Range +class NullRange extends Range { + override def contains(l: Literal): Boolean = false +} object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { @@ -58,20 +72,6 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } - def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { - case _: DefaultRange => true - case _: NullRange => false - case n: NumericRange => - val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { - if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - } else { - assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || - lit.dataType.isInstanceOf[TimestampType]) - new JDecimal(lit.value.toString) - } - n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 - } - /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index f5e306f9e504d..8be74ced7bb71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.statsEstimation -import java.sql.{Date, Timestamp} +import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) + // only 2 values + val arBool = AttributeReference("cbool", BooleanType)() + val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1) + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") @@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through - // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). - val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") - val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") - val arTimestamp = AttributeReference("ctimestamp", TimestampType)() - val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), - nullCount = 0, avgLen = 8, maxLen = 8) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") @@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) + } + + test("cint <=> 2") { + validateEstimatedStats( + arInt, + Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) } test("cint = 0") { @@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint < 3") { @@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint < 0") { @@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint <= 3") { @@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint > 6") { @@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint > 10") { @@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint >= 6") { @@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint IS NULL") { @@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint IS NOT NULL") { @@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(10L) - ) + 10) } test("cint > 3 AND cint <= 6") { @@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), - Some(4L) - ) + 4) } test("cint = 3 OR cint = 6") { @@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } test("cint IN (3, 4, 5)") { @@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint NOT IN (3, 4, 5)") { @@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(7L) - ) + 7) + } + + test("cbool = true") { + validateEstimatedStats( + arBool, + Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 5) + } + + test("cbool > false") { + // bool comparison is not supported yet, so stats remain same. + validateEstimatedStats( + arBool, + Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 10) } test("cdate = cast('2017-01-02' AS DATE)") { @@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) } test("cdate < cast('2017-01-03' AS DATE)") { @@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) - } - - test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { - val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") - validateEstimatedStats( - arTimestamp, - Filter(EqualTo(arTimestamp, Literal(ts2017010102)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) - } - - test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { - val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") - validateEstimatedStats( - arTimestamp, - Filter(LessThan(arTimestamp, Literal(ts2017010103)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdecimal = 0.400000000000000000") { @@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) + 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + Filter(LessThan(arDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdouble < 3.0") { @@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cstring = 'A2'") { @@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(1L) - ) + 1) } // There is no min/max statistics for String type. We estimate 10 rows returned. @@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(10L) - ) + 10) } // This is a corner test case. We want to test if we can handle the case when the number of @@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { @@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = tableRowCount, attributeStats = AttributeMap(Seq( arInt -> childColStatInt, + arBool -> childColStatBool, arDate -> childColStatDate, - arTimestamp -> childColStatTimestamp, arDecimal -> childColStatDecimal, arDouble -> childColStatDouble, arString -> childColStatString @@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ar: AttributeReference, filterNode: Filter, expectedColStats: ColumnStat, - rowCount: Option[BigInt] = None) - : Unit = { + rowCount: Int): Unit = { - val expectedRowCount: BigInt = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) + val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) val filteredStats = filterNode.stats(conf) assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount == rowCount) - ar.dataType match { - case DecimalType() => - // Due to the internal transformation for DecimalType within engine, the new min/max - // in ColumnStat may have a different structure even it contains the right values. - // We convert them to Java BigDecimal values so that we can compare the entire object. - val generatedColumnStats = filteredStats.attributeStats(ar) - val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) - val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) - val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) - assert(outputColStats == expectedColStats) - case _ => - // For all other SQL types, we compare the entire object directly. - assert(filteredStats.attributeStats(ar) == expectedColStats) - } + assert(filteredStats.rowCount.get == rowCount) + assert(filteredStats.attributeStats(ar) == expectedColStats) // If the filter has a binary operator (including those nested inside // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the // operator, and then check again. val rewrittenFilter = filterNode transformExpressionsDown { - case op @ EqualTo(ar: AttributeReference, l: Literal) => + case EqualTo(ar: AttributeReference, l: Literal) => EqualTo(l, ar) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case LessThan(ar: AttributeReference, l: Literal) => GreaterThan(l, ar) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case LessThanOrEqual(ar: AttributeReference, l: Literal) => GreaterThanOrEqual(l, ar) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case GreaterThan(ar: AttributeReference, l: Literal) => LessThan(l, ar) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => LessThanOrEqual(l, ar) } From ecf8a1b3eda1571bcabeaef10326c48469fe4af2 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 25 Feb 2017 23:56:57 -0800 Subject: [PATCH 59/61] [SQL] Duplicate test exception in SQLQueryTestSuite due to meta files(.DS_Store) on Mac ## What changes were proposed in this pull request? After adding the tests for subquery, we now have multiple level of directories under "sql-tests/inputs". Some times on Mac while using Finder application it creates the meta data files called ".DS_Store". When these files are present at different levels in directory hierarchy, we get duplicate test exception while running the tests as we just use the file name as the test case name. In this PR, we use the relative file path from the base directory along with the test file as the test name. Also after this change, we can have the same test file name under different directory like exists/basic.sql , in/basic.sql. Here is the truncated output of the test run after the change. ```SQL info] SQLQueryTestSuite: [info] - arithmetic.sql (5 seconds, 235 milliseconds) [info] - array.sql (536 milliseconds) [info] - blacklist.sql !!! IGNORED !!! [info] - cast.sql (550 milliseconds) .... .... .... [info] - union.sql (315 milliseconds) [info] - subquery/.DS_Store !!! IGNORED !!! [info] - subquery/exists-subquery/.DS_Store !!! IGNORED !!! [info] - subquery/exists-subquery/exists-aggregate.sql (2 seconds, 451 milliseconds) .... .... [info] - subquery/in-subquery/in-group-by.sql (12 seconds, 264 milliseconds) .... .... [info] - subquery/scalar-subquery/scalar-subquery-predicate.sql (7 seconds, 769 milliseconds) [info] - subquery/scalar-subquery/scalar-subquery-select.sql (4 seconds, 119 milliseconds) ``` Since this is a simple change, i haven't created a JIRA for it. ## How was this patch tested? Manually verified. This is change to test infrastructure Author: Dilip Biswal Closes #17060 from dilipbiswal/sqlquerytestsuite. --- .../scala/org/apache/spark/sql/SQLQueryTestSuite.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 91aecca537fb2..0b3da9aa8fbee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -98,7 +98,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** List of test cases to ignore, in lower cases. */ private val blackList = Set( - "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + "blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality. + ".DS_Store" // A meta-file that may be created on Mac by Finder App. + // We should ignore this file from processing. ) // Create all the test cases. @@ -121,7 +123,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.contains(testCase.name.toLowerCase)) { + if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { @@ -241,7 +243,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def listTestCases(): Seq[TestCase] = { listFilesRecursively(new File(inputFilePath)).map { file => val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" - TestCase(file.getName, file.getAbsolutePath, resultFile) + val absPath = file.getAbsolutePath + val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) + TestCase(testCaseName, absPath, resultFile) } } From 389118b1fe4d6bcaed5c6860c0014224137b3810 Mon Sep 17 00:00:00 2001 From: Eyal Zituny Date: Sun, 26 Feb 2017 15:57:32 -0800 Subject: [PATCH 60/61] [SPARK-19594][STRUCTURED STREAMING] StreamingQueryListener fails to handle QueryTerminatedEvent if more then one listeners exists ## What changes were proposed in this pull request? currently if multiple streaming queries listeners exists, when a QueryTerminatedEvent is triggered, only one of the listeners will be invoked while the rest of the listeners will ignore the event. this is caused since the the streaming queries listeners bus holds a set of running queries ids and when a termination event is triggered, after the first listeners is handling the event, the terminated query id is being removed from the set. in this PR, the query id will be removed from the set only after all the listeners handles the event ## How was this patch tested? a test with multiple listeners has been added to StreamingQueryListenerSuite Author: Eyal Zituny Closes #16991 from eyalzit/master. --- .../org/apache/spark/util/ListenerBus.scala | 2 +- .../streaming/StreamingQueryListenerBus.scala | 14 ++++++++++- .../StreamingQueryListenerSuite.scala | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 79fc2e94599c7..fa5ad4e8d81e1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -52,7 +52,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. */ - final def postToAll(event: E): Unit = { + def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index a2153d27e9fef..4207013c3f75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -75,6 +75,19 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Override the parent `postToAll` to remove the query id from `activeQueryRunIds` after all + * the listeners process `QueryTerminatedEvent`. (SPARK-19594) + */ + override def postToAll(event: Event): Unit = { + super.postToAll(event) + event match { + case t: QueryTerminatedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds -= t.runId } + case _ => + } + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => @@ -112,7 +125,6 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) case queryTerminated: QueryTerminatedEvent => if (shouldReport(queryTerminated.runId)) { listener.onQueryTerminated(queryTerminated) - activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } } case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 4596aa1d348e3..eb09b9ffcfc5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -133,6 +133,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("SPARK-19594: all of listeners should receive QueryTerminatedEvent") { + val df = MemoryStream[Int].toDS().as[Long] + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append)( + StartStream(), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() From 0b4646199cf061d1f358a78122ef8bdf164ac839 Mon Sep 17 00:00:00 2001 From: Yunni Date: Sun, 26 Feb 2017 23:04:37 -0500 Subject: [PATCH 61/61] Fix typos in unit tests --- .../spark/ml/feature/BucketedRandomProjectionLSHSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 9c5929fce156e..2497e8f4f6c62 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -187,7 +187,7 @@ class BucketedRandomProjectionLSHSuite val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") val brp = new BucketedRandomProjectionLSH() - .setNumHashTables(8) + .setNumHashFunctions(4) .setNumHashTables(2) .setInputCol("keys") .setOutputCol("values") @@ -206,14 +206,14 @@ class BucketedRandomProjectionLSHSuite val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") val brp = new BucketedRandomProjectionLSH() - .setNumHashTables(8) + .setNumHashFunctions(4) .setNumHashTables(2) .setInputCol("keys") .setOutputCol("values") .setBucketLength(4.0) .setSeed(12345) - val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 3.0) + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 2.0) assert(precision == 1.0) assert(recall >= 0.7) }