Skip to content

Commit

Permalink
Add Trigger and ProcessingTime to control how to execute a batch
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Mar 30, 2016
1 parent 6f5c6ed commit 92d204c
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 32 deletions.
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution, Trigger}
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener

Expand Down Expand Up @@ -172,7 +172,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
checkpointLocation: String,
df: DataFrame,
sink: Sink,
triggerIntervalMs: Long): ContinuousQuery = {
trigger: Trigger): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
Expand All @@ -184,7 +184,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
checkpointLocation,
df.logicalPlan,
sink,
triggerIntervalMs)
trigger)
query.start()
activeQueries.put(name, query)
query
Expand Down
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.{ProcessingTime, StreamExecution, Trigger}
import org.apache.spark.sql.sources.HadoopFsRelation

/**
Expand Down Expand Up @@ -84,7 +84,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def trigger(interval: Duration): DataFrameWriter = {
this.extraOptions += ("triggerInterval" -> interval.toMillis.toString)
trigger = ProcessingTime(interval.toMillis)
this
}

Expand All @@ -94,7 +94,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def trigger(interval: Long, unit: TimeUnit): DataFrameWriter = {
this.extraOptions += ("triggerInterval" -> unit.toMillis(interval).toString)
trigger = ProcessingTime(unit.toMillis(interval))
this
}

Expand Down Expand Up @@ -278,14 +278,12 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
})
val triggerIntervalMs = extraOptions.getOrElse("triggerInterval", "0").toLong
require(triggerIntervalMs >= 0, "the interval of trigger should not be negative")
df.sqlContext.sessionState.continuousQueryManager.startQuery(
queryName,
checkpointLocation,
df,
dataSource.createSink(),
triggerIntervalMs)
trigger)
}

/**
Expand Down Expand Up @@ -576,6 +574,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private var trigger: Trigger = ProcessingTime(0L)

private var extraOptions = new scala.collection.mutable.HashMap[String, String]

private var partitioningColumns: Option[Seq[String]] = None
Expand Down
Expand Up @@ -47,7 +47,7 @@ class StreamExecution(
val checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink,
triggerIntervalMs: Long) extends ContinuousQuery with Logging {
val trigger: Trigger) extends ContinuousQuery with Logging {

/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
Expand Down Expand Up @@ -209,20 +209,15 @@ class StreamExecution(
SQLContext.setActive(sqlContext)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
while (isActive) {
val batchStartTimeMs = System.currentTimeMillis()
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
if (triggerIntervalMs > 0) {
val batchElapsedTime = System.currentTimeMillis() - batchStartTimeMs
if (batchElapsedTime > triggerIntervalMs) {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${triggerIntervalMs} milliseconds, but spent ${batchElapsedTime} milliseconds")
} else {
Thread.sleep(triggerIntervalMs - batchElapsedTime)
}
trigger.execute(() => {
if (isActive) {
if (dataAvailable) runBatch()
commitAndConstructNextBatch()
true
} else {
false
}
}
})
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
Expand Down
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging

/**
* A interface that indicates how to run a batch.
*/
trait Trigger {

/**
* Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
*/
def execute(batchRunner: () => Boolean): Unit
}

/**
* A trigger that runs a batch every `intervalMs` milliseconds.
*/
case class ProcessingTime(intervalMs: Long) extends Trigger with Logging {

require(intervalMs >= 0, "the interval of trigger should not be negative")

override def execute(batchRunner: () => Boolean): Unit = {
while (true) {
val batchStartTimeMs = System.currentTimeMillis()
if (!batchRunner()) {
return
}
if (intervalMs > 0) {
val batchEndTimeMs = System.currentTimeMillis()
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
if (batchElapsedTimeMs > intervalMs) {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${intervalMs} milliseconds, but spent ${batchElapsedTimeMs} milliseconds")
}
waitUntil(nextBatchTime(batchEndTimeMs))
}
}
}

private def waitUntil(time: Long): Unit = {
var now = System.currentTimeMillis()
while (now < time) {
Thread.sleep(time - now)
now = System.currentTimeMillis()
}
}

/** Return the next multiple of intervalMs */
def nextBatchTime(now: Long): Long = {
(now - 1) / intervalMs * intervalMs + intervalMs
}
}
Expand Up @@ -276,7 +276,12 @@ trait StreamTest extends QueryTest with Timeouts {
currentStream =
sqlContext
.streams
.startQuery(StreamExecution.nextName, metadataRoot, stream, sink, 10L)
.startQuery(
StreamExecution.nextName,
metadataRoot,
stream,
sink,
ProcessingTime(0L))
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
Expand Down
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.SparkFunSuite

class ProcessingTimeSuite extends SparkFunSuite {

test("nextBatchTime") {
val processingTime = ProcessingTime(100)
assert(processingTime.nextBatchTime(1) === 100)
assert(processingTime.nextBatchTime(99) === 100)
assert(processingTime.nextBatchTime(100) === 100)
assert(processingTime.nextBatchTime(101) === 200)
assert(processingTime.nextBatchTime(150) === 200)
}
}
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -244,7 +244,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
metadataRoot,
df,
new MemorySink(df.schema),
10L)
ProcessingTime(0))
.asInstanceOf[StreamExecution]
} catch {
case NonFatal(e) =>
Expand Down
Expand Up @@ -284,22 +284,22 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
.format("org.apache.spark.sql.streaming.test")
.stream("/test")

df.write
var q = df.write
.format("org.apache.spark.sql.streaming.test")
.option("checkpointLocation", newMetadataDir)
.trigger(10.seconds)
.startStream()
.stop()
q.stop()

assert(LastOptions.parameters("triggerInterval") == "10000")
assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000))

df.write
q = df.write
.format("org.apache.spark.sql.streaming.test")
.option("checkpointLocation", newMetadataDir)
.trigger(100, TimeUnit.SECONDS)
.startStream()
.stop()
q.stop()

assert(LastOptions.parameters("triggerInterval") == "100000")
assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000))
}
}

0 comments on commit 92d204c

Please sign in to comment.