Skip to content

Commit

Permalink
[FLINK-2371] improve AccumulatorLiveITCase
Browse files Browse the repository at this point in the history
Instead of using Thread.sleep() to synchronize the checks of the
accumulator values, we rely on message passing here to synchronize the
task process.

Therefore, we let the task process signal to the task manager that it
has updated its accumulator values. The task manager lets the job
manager know and sends out the heartbeat which contains the
accumulators. When the job manager receives the accumulators and has
been notified previously, it sends a message to the subscribed test case
with the current accumulators.

This assures that all processes are always synchronized correctly and we
can verify the live accumulator results correctly.

In the course of rewriting the test, I had to change two things in the
implementation:

a) User accumulators are now immediately serialized as well. Otherwise,
Akka does not serialize in local one VM setups and passes the live
accumulator map through.

b) The asynchronous update of the accumulators is disabled for
tests (via the dispatcher flag of the TestingCluster). This was
necessary because we cannot guarantee when the Future for updating the
accumulators is executed. In real setups this is neglectable.

This closes #925.
  • Loading branch information
mxm committed Jul 21, 2015
1 parent 0f45f2b commit 0f589aa
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 156 deletions.
Expand Up @@ -40,9 +40,9 @@ public class AccumulatorSnapshot implements Serializable {
private final ExecutionAttemptID executionAttemptID; private final ExecutionAttemptID executionAttemptID;


/** /**
* Flink internal accumulators which can be serialized using the system class loader. * Flink internal accumulators which can be deserialized using the system class loader.
*/ */
private final Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators; private final SerializedValue<Map<AccumulatorRegistry.Metric, Accumulator<?, ?>>> flinkAccumulators;


/** /**
* Serialized user accumulators which may require the custom user class loader. * Serialized user accumulators which may require the custom user class loader.
Expand All @@ -54,7 +54,7 @@ public AccumulatorSnapshot(JobID jobID, ExecutionAttemptID executionAttemptID,
Map<String, Accumulator<?, ?>> userAccumulators) throws IOException { Map<String, Accumulator<?, ?>> userAccumulators) throws IOException {
this.jobID = jobID; this.jobID = jobID;
this.executionAttemptID = executionAttemptID; this.executionAttemptID = executionAttemptID;
this.flinkAccumulators = flinkAccumulators; this.flinkAccumulators = new SerializedValue<Map<AccumulatorRegistry.Metric, Accumulator<?, ?>>>(flinkAccumulators);
this.userAccumulators = new SerializedValue<Map<String, Accumulator<?, ?>>>(userAccumulators); this.userAccumulators = new SerializedValue<Map<String, Accumulator<?, ?>>>(userAccumulators);
} }


Expand All @@ -70,8 +70,8 @@ public ExecutionAttemptID getExecutionAttemptID() {
* Gets the Flink (internal) accumulators values. * Gets the Flink (internal) accumulators values.
* @return the serialized map * @return the serialized map
*/ */
public Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> getFlinkAccumulators() { public Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> deserializeFlinkAccumulators() throws IOException, ClassNotFoundException {
return flinkAccumulators; return flinkAccumulators.deserializeValue(ClassLoader.getSystemClassLoader());
} }


/** /**
Expand Down
Expand Up @@ -141,9 +141,10 @@ public class ExecutionGraph implements Serializable {
* @param accumulatorSnapshot The serialized flink and user-defined accumulators * @param accumulatorSnapshot The serialized flink and user-defined accumulators
*/ */
public void updateAccumulators(AccumulatorSnapshot accumulatorSnapshot) { public void updateAccumulators(AccumulatorSnapshot accumulatorSnapshot) {
Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators = accumulatorSnapshot.getFlinkAccumulators(); Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators;
Map<String, Accumulator<?, ?>> userAccumulators; Map<String, Accumulator<?, ?>> userAccumulators;
try { try {
flinkAccumulators = accumulatorSnapshot.deserializeFlinkAccumulators();
userAccumulators = accumulatorSnapshot.deserializeUserAccumulators(userClassLoader); userAccumulators = accumulatorSnapshot.deserializeUserAccumulators(userClassLoader);


ExecutionAttemptID execID = accumulatorSnapshot.getExecutionAttemptID(); ExecutionAttemptID execID = accumulatorSnapshot.getExecutionAttemptID();
Expand Down Expand Up @@ -889,7 +890,7 @@ public boolean updateState(TaskExecutionState state) {
Map<String, Accumulator<?, ?>> userAccumulators = null; Map<String, Accumulator<?, ?>> userAccumulators = null;
try { try {
AccumulatorSnapshot accumulators = state.getAccumulators(); AccumulatorSnapshot accumulators = state.getAccumulators();
flinkAccumulators = accumulators.getFlinkAccumulators(); flinkAccumulators = accumulators.deserializeFlinkAccumulators();
userAccumulators = accumulators.deserializeUserAccumulators(userClassLoader); userAccumulators = accumulators.deserializeUserAccumulators(userClassLoader);
} catch (Exception e) { } catch (Exception e) {
// Exceptions would be thrown in the future here // Exceptions would be thrown in the future here
Expand Down
Expand Up @@ -50,12 +50,6 @@ public class RecordWriter<T extends IOReadableWritable> {


private final int numChannels; private final int numChannels;


/**
* Counter for the number of records emitted and for the number of bytes written.
* @param counter
*/
private AccumulatorRegistry.Reporter reporter;

/** {@link RecordSerializer} per outgoing channel */ /** {@link RecordSerializer} per outgoing channel */
private final RecordSerializer<T>[] serializers; private final RecordSerializer<T>[] serializers;


Expand Down Expand Up @@ -88,7 +82,6 @@ public void emit(T record) throws IOException, InterruptedException {


synchronized (serializer) { synchronized (serializer) {
SerializationResult result = serializer.addRecord(record); SerializationResult result = serializer.addRecord(record);

while (result.isFullBuffer()) { while (result.isFullBuffer()) {
Buffer buffer = serializer.getCurrentBuffer(); Buffer buffer = serializer.getCurrentBuffer();


Expand All @@ -98,18 +91,8 @@ public void emit(T record) throws IOException, InterruptedException {
} }


buffer = writer.getBufferProvider().requestBufferBlocking(); buffer = writer.getBufferProvider().requestBufferBlocking();
if (reporter != null) {
// increase the number of written bytes by the memory segment's size
reporter.reportNumBytesOut(buffer.getSize());
}

result = serializer.setNextBuffer(buffer); result = serializer.setNextBuffer(buffer);
} }

if(reporter != null) {
// count number of emitted records
reporter.reportNumRecordsOut(1);
}
} }
} }
} }
Expand Down
Expand Up @@ -28,7 +28,7 @@ import grizzled.slf4j.Logger
import org.apache.flink.api.common.{ExecutionConfig, JobID} import org.apache.flink.api.common.{ExecutionConfig, JobID}
import org.apache.flink.configuration.{ConfigConstants, Configuration, GlobalConfiguration} import org.apache.flink.configuration.{ConfigConstants, Configuration, GlobalConfiguration}
import org.apache.flink.core.io.InputSplitAssigner import org.apache.flink.core.io.InputSplitAssigner
import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult import org.apache.flink.runtime.accumulators.{AccumulatorSnapshot, StringifiedAccumulatorResult}
import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.blob.BlobServer
import org.apache.flink.runtime.client._ import org.apache.flink.runtime.client._
import org.apache.flink.runtime.executiongraph.{ExecutionGraph, ExecutionJobVertex} import org.apache.flink.runtime.executiongraph.{ExecutionGraph, ExecutionJobVertex}
Expand Down Expand Up @@ -404,15 +404,7 @@ class JobManager(
log.debug(s"Received hearbeat message from $instanceID.") log.debug(s"Received hearbeat message from $instanceID.")


Future { Future {
accumulators foreach { updateAccumulators(accumulators)
case accumulators =>
currentJobs.get(accumulators.getJobID) match {
case Some((jobGraph, jobInfo)) =>
jobGraph.updateAccumulators(accumulators)
case None =>
// ignore accumulator values for old job
}
}
}(context.dispatcher) }(context.dispatcher)


instanceManager.reportHeartBeat(instanceID, metricsReport) instanceManager.reportHeartBeat(instanceID, metricsReport)
Expand Down Expand Up @@ -770,6 +762,22 @@ class JobManager(
log.error(s"Could not properly unregister job $jobID form the library cache.", t) log.error(s"Could not properly unregister job $jobID form the library cache.", t)
} }
} }

/**
* Updates the accumulators reported from a task manager via the Heartbeat message.
* @param accumulators list of accumulator snapshots
*/
private def updateAccumulators(accumulators : Seq[AccumulatorSnapshot]) = {
accumulators foreach {
case accumulatorEvent =>
currentJobs.get(accumulatorEvent.getJobID) match {
case Some((jobGraph, jobInfo)) =>
jobGraph.updateAccumulators(accumulatorEvent)
case None =>
// ignore accumulator values for old job
}
}
}
} }


/** /**
Expand Down
Expand Up @@ -155,7 +155,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {


private var blobService: Option[BlobService] = None private var blobService: Option[BlobService] = None
private var libraryCacheManager: Option[LibraryCacheManager] = None private var libraryCacheManager: Option[LibraryCacheManager] = None
private var currentJobManager: Option[ActorRef] = None protected var currentJobManager: Option[ActorRef] = None


private var instanceID: InstanceID = null private var instanceID: InstanceID = null


Expand Down Expand Up @@ -936,7 +936,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
* Sends a heartbeat message to the JobManager (if connected) with the current * Sends a heartbeat message to the JobManager (if connected) with the current
* metrics report. * metrics report.
*/ */
private def sendHeartbeatToJobManager(): Unit = { protected def sendHeartbeatToJobManager(): Unit = {
try { try {
log.debug("Sending heartbeat to JobManager") log.debug("Sending heartbeat to JobManager")
val metricsReport: Array[Byte] = metricRegistryMapper.writeValueAsBytes(metricRegistry) val metricsReport: Array[Byte] = metricRegistryMapper.writeValueAsBytes(metricRegistry)
Expand Down
Expand Up @@ -27,8 +27,10 @@ import org.apache.flink.runtime.jobgraph.JobStatus
import org.apache.flink.runtime.jobmanager.{JobManager, MemoryArchivist} import org.apache.flink.runtime.jobmanager.{JobManager, MemoryArchivist}
import org.apache.flink.runtime.messages.ExecutionGraphMessages.JobStatusChanged import org.apache.flink.runtime.messages.ExecutionGraphMessages.JobStatusChanged
import org.apache.flink.runtime.messages.Messages.Disconnect import org.apache.flink.runtime.messages.Messages.Disconnect
import org.apache.flink.runtime.messages.TaskManagerMessages.Heartbeat
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages._ import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages._
import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect
import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.AccumulatorsChanged


import scala.collection.convert.WrapAsScala import scala.collection.convert.WrapAsScala
import scala.concurrent.Future import scala.concurrent.Future
Expand All @@ -55,6 +57,8 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
val waitForJobStatus = scala.collection.mutable.HashMap[JobID, val waitForJobStatus = scala.collection.mutable.HashMap[JobID,
collection.mutable.HashMap[JobStatus, Set[ActorRef]]]() collection.mutable.HashMap[JobStatus, Set[ActorRef]]]()


val waitForAccumulatorUpdate = scala.collection.mutable.HashMap[JobID, (Boolean, Set[ActorRef])]()

var disconnectDisabled = false var disconnectDisabled = false


abstract override def receiveWithLogMessages: Receive = { abstract override def receiveWithLogMessages: Receive = {
Expand Down Expand Up @@ -130,6 +134,46 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
} }
} }


case NotifyWhenAccumulatorChange(jobID) =>

val (updated, registered) = waitForAccumulatorUpdate.
getOrElse(jobID, (false, Set[ActorRef]()))
waitForAccumulatorUpdate += jobID -> (updated, registered + sender)
sender ! true

/**
* Notification from the task manager that changed accumulator are transferred on next
* Hearbeat. We need to keep this state to notify the listeners on next Heartbeat report.
*/
case AccumulatorsChanged(jobID: JobID) =>
waitForAccumulatorUpdate.get(jobID) match {
case Some((updated, registered)) =>
waitForAccumulatorUpdate.put(jobID, (true, registered))
case None =>
}

/**
* Disabled async processing of accumulator values and send accumulators to the listeners if
* we previously received an [[AccumulatorsChanged]] message.
*/
case msg : Heartbeat =>
super.receiveWithLogMessages(msg)

waitForAccumulatorUpdate foreach {
case (jobID, (updated, actors)) if updated =>
currentJobs.get(jobID) match {
case Some((graph, jobInfo)) =>
val flinkAccumulators = graph.getFlinkAccumulators
val userAccumulators = graph.aggregateUserAccumulators
actors foreach {
actor => actor ! UpdatedAccumulators(jobID, flinkAccumulators, userAccumulators)
}
case None =>
}
waitForAccumulatorUpdate.put(jobID, (false, actors))
case _ =>
}

case RequestWorkingTaskManager(jobID) => case RequestWorkingTaskManager(jobID) =>
currentJobs.get(jobID) match { currentJobs.get(jobID) match {
case Some((eg, _)) => case Some((eg, _)) =>
Expand All @@ -147,15 +191,6 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
case None => sender ! WorkingTaskManager(None) case None => sender ! WorkingTaskManager(None)
} }


case RequestAccumulatorValues(jobID) =>

val (flinkAccumulators, userAccumulators) = currentJobs.get(jobID) match {
case Some((graph, jobInfo)) =>
(graph.getFlinkAccumulators, graph.aggregateUserAccumulators)
case None => null
}

sender ! RequestAccumulatorValuesResponse(jobID, flinkAccumulators, userAccumulators)


case NotifyWhenJobStatus(jobID, state) => case NotifyWhenJobStatus(jobID, state) =>
val jobStatusListener = waitForJobStatus.getOrElseUpdate(jobID, val jobStatusListener = waitForJobStatus.getOrElseUpdate(jobID,
Expand Down
Expand Up @@ -20,12 +20,12 @@ package org.apache.flink.runtime.testingUtils


import akka.actor.ActorRef import akka.actor.ActorRef
import org.apache.flink.api.common.JobID import org.apache.flink.api.common.JobID
import org.apache.flink.api.common.accumulators.Accumulator
import org.apache.flink.runtime.accumulators.AccumulatorRegistry import org.apache.flink.runtime.accumulators.AccumulatorRegistry
import org.apache.flink.runtime.executiongraph.{ExecutionAttemptID, ExecutionGraph} import org.apache.flink.runtime.executiongraph.{ExecutionAttemptID, ExecutionGraph}
import org.apache.flink.runtime.instance.InstanceGateway import org.apache.flink.runtime.instance.InstanceGateway
import org.apache.flink.runtime.jobgraph.JobStatus import org.apache.flink.runtime.jobgraph.JobStatus
import java.util.Map import java.util.Map
import org.apache.flink.api.common.accumulators.Accumulator


object TestingJobManagerMessages { object TestingJobManagerMessages {


Expand Down Expand Up @@ -57,8 +57,18 @@ object TestingJobManagerMessages {
case class NotifyWhenTaskManagerTerminated(taskManager: ActorRef) case class NotifyWhenTaskManagerTerminated(taskManager: ActorRef)
case class TaskManagerTerminated(taskManager: ActorRef) case class TaskManagerTerminated(taskManager: ActorRef)


case class RequestAccumulatorValues(jobID: JobID) /* Registers a listener to receive a message when accumulators changed.
case class RequestAccumulatorValuesResponse(jobID: JobID, * The change must be explicitly triggered by the TestingTaskManager which can receive an
* [[AccumulatorChanged]] message by a task that changed the accumulators. This message is then
* forwarded to the JobManager which will send the accumulators in the [[UpdatedAccumulators]]
* message when the next Heartbeat occurs.
* */
case class NotifyWhenAccumulatorChange(jobID: JobID)

/**
* Reports updated accumulators back to the listener.
*/
case class UpdatedAccumulators(jobID: JobID,
flinkAccumulators: Map[ExecutionAttemptID, Map[AccumulatorRegistry.Metric, Accumulator[_,_]]], flinkAccumulators: Map[ExecutionAttemptID, Map[AccumulatorRegistry.Metric, Accumulator[_,_]]],
userAccumulators: Map[String, Accumulator[_,_]]) userAccumulators: Map[String, Accumulator[_,_]])
} }
Expand Up @@ -94,7 +94,7 @@ class TestingTaskManager(config: TaskManagerConfiguration,
waitForRemoval += (executionID -> (set + sender)) waitForRemoval += (executionID -> (set + sender))
} }
} }

case TaskInFinalState(executionID) => case TaskInFinalState(executionID) =>
super.receiveWithLogMessages(TaskInFinalState(executionID)) super.receiveWithLogMessages(TaskInFinalState(executionID))
waitForRemoval.remove(executionID) match { waitForRemoval.remove(executionID) match {
Expand Down Expand Up @@ -144,6 +144,21 @@ class TestingTaskManager(config: TaskManagerConfiguration,
val waiting = waitForJobManagerToBeTerminated.getOrElse(jobManager.path.name, Set()) val waiting = waitForJobManagerToBeTerminated.getOrElse(jobManager.path.name, Set())
waitForJobManagerToBeTerminated += jobManager.path.name -> (waiting + sender) waitForJobManagerToBeTerminated += jobManager.path.name -> (waiting + sender)


/**
* Message from task manager that accumulator values changed and need to be reported immediately
* instead of lazily through the
* [[org.apache.flink.runtime.messages.TaskManagerMessages.Heartbeat]] message. We forward this
* message to the job manager that it knows it should report to the listeners.
*/
case msg: AccumulatorsChanged =>
currentJobManager match {
case Some(jobManager) =>
jobManager.forward(msg)
sendHeartbeatToJobManager()
sender ! true
case None =>
}

case msg@Terminated(jobManager) => case msg@Terminated(jobManager) =>
super.receiveWithLogMessages(msg) super.receiveWithLogMessages(msg)


Expand Down
Expand Up @@ -51,7 +51,14 @@ object TestingTaskManagerMessages {
case class NotifyWhenJobManagerTerminated(jobManager: ActorRef) case class NotifyWhenJobManagerTerminated(jobManager: ActorRef)


case class JobManagerTerminated(jobManager: ActorRef) case class JobManagerTerminated(jobManager: ActorRef)


/**
* Message to give a hint to the task manager that accumulator values were updated in the task.
* This message is forwarded to the job manager which knows that it needs to notify listeners
* of accumulator updates.
*/
case class AccumulatorsChanged(jobID: JobID)

// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// Utility methods to allow simpler case object access from Java // Utility methods to allow simpler case object access from Java
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
Expand Down

0 comments on commit 0f589aa

Please sign in to comment.