Skip to content

Commit

Permalink
[KYUUBI #3336] Use StageAttempt instead of StageId in SQLOperationLis…
Browse files Browse the repository at this point in the history
…tener

### _Why are the changes needed?_
Now `activeStages` only has the stageid, and does not record and output the number of stage retries, which may cause errors when the stage fails to retry.

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #3336 from cxzl25/spark_listener_stage_attempt.

Closes #3336

a6da472 [sychen] use StageAttempt instead of StageId

Authored-by: sychen <sychen@ctrip.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
(cherry picked from commit 53bbff8)
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
cxzl25 authored and pan3793 committed Aug 26, 2022
1 parent 99753de commit d1297bd
Showing 1 changed file with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ class SQLOperationListener(

private val operationId: String = operation.getHandle.identifier.toString
private lazy val activeJobs = new java.util.HashSet[Int]()
private lazy val activeStages = new java.util.HashSet[Int]()
private lazy val activeStages = new ConcurrentHashMap[StageAttempt, StageInfo]()
private var executionId: Option[Long] = None
private lazy val liveStages = new ConcurrentHashMap[StageAttempt, StageInfo]()

private val conf: KyuubiConf = operation.getSession.sessionManager.getConf
private lazy val consoleProgressBar =
if (conf.get(ENGINE_SPARK_SHOW_PROGRESS)) {
Some(new SparkConsoleProgressBar(
operation,
liveStages,
activeStages,
conf.get(ENGINE_SPARK_SHOW_PROGRESS_UPDATE_INTERVAL),
conf.get(ENGINE_SPARK_SHOW_PROGRESS_TIME_FORMAT)))
} else {
Expand Down Expand Up @@ -120,41 +119,43 @@ class SQLOperationListener(
if (sameGroupId(stageSubmitted.properties)) {
val stageInfo = stageSubmitted.stageInfo
val stageId = stageInfo.stageId
activeStages.add(stageId)
liveStages.put(
StageAttempt(stageId, stageInfo.attemptNumber()),
val attemptNumber = stageInfo.attemptNumber()
val stageAttempt = StageAttempt(stageId, attemptNumber)
activeStages.put(
stageAttempt,
new StageInfo(stageId, stageInfo.numTasks))
withOperationLog {
info(s"Query [$operationId]: Stage $stageId started with ${stageInfo.numTasks} tasks," +
s" ${activeStages.size()} active stages running")
info(s"Query [$operationId]: Stage $stageId.$attemptNumber started " +
s"with ${stageInfo.numTasks} tasks, ${activeStages.size()} active stages running")
}
}
}
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val stageInfo = stageCompleted.stageInfo
val stageId = stageInfo.stageId
val stageAttempt = StageAttempt(stageInfo.stageId, stageInfo.attemptNumber())
activeStages.synchronized {
if (activeStages.remove(stageId)) {
liveStages.remove(StageAttempt(stageId, stageInfo.attemptNumber()))
if (activeStages.remove(stageAttempt) != null) {
withOperationLog(super.onStageCompleted(stageCompleted))
}
}
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = activeStages.synchronized {
if (activeStages.contains(taskStart.stageId)) {
liveStages.get(StageAttempt(taskStart.stageId, taskStart.stageAttemptId)).numActiveTasks += 1
val stageAttempt = StageAttempt(taskStart.stageId, taskStart.stageAttemptId)
if (activeStages.contains(stageAttempt)) {
activeStages.get(stageAttempt).numActiveTasks += 1
super.onTaskStart(taskStart)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = activeStages.synchronized {
if (activeStages.contains(taskEnd.stageId)) {
liveStages.get(StageAttempt(taskEnd.stageId, taskEnd.stageAttemptId)).numActiveTasks -= 1
val stageAttempt = StageAttempt(taskEnd.stageId, taskEnd.stageAttemptId)
if (activeStages.contains(stageAttempt)) {
activeStages.get(stageAttempt).numActiveTasks -= 1
if (taskEnd.reason == org.apache.spark.Success) {
liveStages.get(StageAttempt(taskEnd.stageId, taskEnd.stageAttemptId)).numCompleteTasks += 1
activeStages.get(stageAttempt).numCompleteTasks += 1
}
super.onTaskEnd(taskEnd)
}
Expand Down

0 comments on commit d1297bd

Please sign in to comment.