Skip to content
Permalink
Browse files

[SPARK-23033][SS] Don't use task level retry for continuous processing

## What changes were proposed in this pull request?

Continuous processing tasks will fail on any attempt number greater than 0. ContinuousExecution will catch these failures and restart globally from the last recorded checkpoints.
## How was this patch tested?
unit test

Author: Jose Torres <jose@databricks.com>

Closes #20225 from jose-torres/no-retry.

(cherry picked from commit 86a8450)
Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
  • Loading branch information...
jose-torres authored and tdas committed Jan 17, 2018
1 parent 1a6dfaf commit dbd2a5566d8924ab340c3c840d31e83e5af92242
@@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest {
val query = kafka
.writeStream
.format("memory")
.outputMode("append")
.queryName("kafkaColumnTypes")
.trigger(defaultTrigger)
.start()
var rows: Array[Row] = Array()
eventually(timeout(streamingTimeout)) {
rows = spark.table("kafkaColumnTypes").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
assert(spark.table("kafkaColumnTypes").count == 1,
s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}")
}
val row = rows(0)
val row = spark.table("kafkaColumnTypes").head()
assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row")
assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row")
assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row")
@@ -52,6 +52,11 @@ class ContinuousDataSourceRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
}

val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()

val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
@@ -24,7 +24,7 @@ import java.util.function.UnaryOperator
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}

import org.apache.spark.SparkEnv
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark.SparkException

/**
* An exception thrown when a continuous processing task runs with a nonzero attempt ID.
*/
class ContinuousTaskRetryException
extends SparkException("Continuous execution does not support task retry", null)
@@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
currentStream.awaitInitialization(streamingTimeout.toMillis)
currentStream match {
case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
s.lastExecution.executedPlan // will fail if lastExecution is null
}
s.lastExecution.executedPlan // will fail if lastExecution is null
}
case _ =>
}
} catch {
@@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}

case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
val sparkAnswer = currentStream match {
case null => fetchStreamAnswer(lastStream, lastOnly)
case s => fetchStreamAnswer(s, lastOnly)
}
QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}
@@ -17,36 +17,18 @@

package org.apache.spark.sql.streaming.continuous

import java.io.{File, InterruptedIOException, IOException, UncheckedIOException}
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit}
import java.util.UUID

import scala.reflect.ClassTag
import scala.util.control.ControlThrowable

import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.{SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.TestSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class ContinuousSuiteBase extends StreamTest {
// We need more than the default local[2] to be able to schedule all partitions simultaneously.
@@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase {
StopStream)
}

test("task failure kills the query") {
val df = spark.readStream
.format("rate")
.option("numPartitions", "5")
.option("rowsPerSecond", "5")
.load()
.select('value)

// Get an arbitrary task from this query to kill. It doesn't matter which one.
var taskId: Long = -1
val listener = new SparkListener() {
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
taskId = start.taskInfo.taskId
}
}
spark.sparkContext.addSparkListener(listener)
try {
testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(100)),
Execute(waitForRateSourceTriggers(_, 2)),
Execute { _ =>
// Wait until a task is started, then kill its first attempt.
eventually(timeout(streamingTimeout)) {
assert(taskId != -1)
}
spark.sparkContext.killTaskAttempt(taskId)
},
ExpectFailure[SparkException] { e =>
e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException]
})
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}

test("query without test harness") {
val df = spark.readStream
.format("rate")
@@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 201)),
IncrementEpoch(),
Execute { query =>
val data = query.sink.asInstanceOf[MemorySinkV2].allData
val vals = data.map(_.getLong(0)).toSet
assert(scala.Range(0, 25000).forall { i =>
vals.contains(i)
})
})
StopStream,
CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))
)
}

test("automatic epoch advancement") {
@@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 201)),
IncrementEpoch(),
StopStream,
CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
}

@@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
StopStream,
StartStream(Trigger.Continuous(2012)),
AwaitEpoch(50),
StopStream,
CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
}
}

0 comments on commit dbd2a55

Please sign in to comment.
You can’t perform that action at this time.