Skip to content

Commit

Permalink
Merge pull request #1051 from akka/wip-agent-rework-√
Browse files Browse the repository at this point in the history
 Migrating Agents to greener pastures
  • Loading branch information
viktorklang committed Jan 28, 2013
2 parents feb413d + a70db6a commit d6addd9
Show file tree
Hide file tree
Showing 9 changed files with 463 additions and 436 deletions.
Expand Up @@ -4,6 +4,7 @@ import java.util.concurrent.{ ExecutorService, Executor, Executors }
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent._
import akka.testkit.{ TestLatch, AkkaSpec, DefaultTimeout }
import akka.util.SerializedSuspendableExecutionContext

@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class ExecutionContextSpec extends AkkaSpec with DefaultTimeout {
Expand Down Expand Up @@ -81,4 +82,82 @@ class ExecutionContextSpec extends AkkaSpec with DefaultTimeout {
Await.ready(latch, timeout.duration)
}
}

"A SerializedSuspendableExecutionContext" must {
"be suspendable and resumable" in {
val sec = SerializedSuspendableExecutionContext(1)(ExecutionContext.global)
val counter = new AtomicInteger(0)
def perform(f: Int Int) = sec execute new Runnable { def run = counter.set(f(counter.get)) }
perform(_ + 1)
perform(x { sec.suspend(); x * 2 })
awaitCond(counter.get == 2)
perform(_ + 4)
perform(_ * 2)
sec.size must be === 2
Thread.sleep(500)
sec.size must be === 2
counter.get must be === 2
sec.resume()
awaitCond(counter.get == 12)
perform(_ * 2)
awaitCond(counter.get == 24)
sec.isEmpty must be === true
}

"execute 'throughput' number of tasks per sweep" in {
val submissions = new AtomicInteger(0)
val counter = new AtomicInteger(0)
val underlying = new ExecutionContext {
override def execute(r: Runnable) { submissions.incrementAndGet(); ExecutionContext.global.execute(r) }
override def reportFailure(t: Throwable) { ExecutionContext.global.reportFailure(t) }
}
val throughput = 25
val sec = SerializedSuspendableExecutionContext(throughput)(underlying)
sec.suspend()
def perform(f: Int Int) = sec execute new Runnable { def run = counter.set(f(counter.get)) }

val total = 1000
1 to total foreach { _ perform(_ + 1) }
sec.size() must be === total
sec.resume()
awaitCond(counter.get == total)
submissions.get must be === (total / throughput)
sec.isEmpty must be === true
}

"execute tasks in serial" in {
val sec = SerializedSuspendableExecutionContext(1)(ExecutionContext.global)
val total = 10000
val counter = new AtomicInteger(0)
def perform(f: Int Int) = sec execute new Runnable { def run = counter.set(f(counter.get)) }

1 to total foreach { i perform(c if (c == (i - 1)) c + 1 else c) }
awaitCond(counter.get == total)
sec.isEmpty must be === true
}

"should relinquish thread when suspended" in {
val submissions = new AtomicInteger(0)
val counter = new AtomicInteger(0)
val underlying = new ExecutionContext {
override def execute(r: Runnable) { submissions.incrementAndGet(); ExecutionContext.global.execute(r) }
override def reportFailure(t: Throwable) { ExecutionContext.global.reportFailure(t) }
}
val throughput = 25
val sec = SerializedSuspendableExecutionContext(throughput)(underlying)
sec.suspend()
def perform(f: Int Int) = sec execute new Runnable { def run = counter.set(f(counter.get)) }
perform(_ + 1)
1 to 10 foreach { _ perform(identity) }
perform(x { sec.suspend(); x * 2 })
perform(_ + 8)
sec.size must be === 13
sec.resume()
awaitCond(counter.get == 2)
sec.resume()
awaitCond(counter.get == 10)
sec.isEmpty must be === true
submissions.get must be === 2
}
}
}
@@ -0,0 +1,81 @@
package akka.util

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
import scala.annotation.{ tailrec, switch }

private[akka] object SerializedSuspendableExecutionContext {
final val Off = 0
final val On = 1
final val Suspended = 2

def apply(batchSize: Int)(implicit context: ExecutionContext): SerializedSuspendableExecutionContext =
new SerializedSuspendableExecutionContext(batchSize)(context match {
case s: SerializedSuspendableExecutionContext s.context
case other other
})
}

/**
* This `ExecutionContext` allows to wrap an underlying `ExecutionContext` and provide guaranteed serial execution
* of tasks submitted to it. On top of that it also allows for *suspending* and *resuming* processing of tasks.
*
* WARNING: This type must never leak into User code as anything but `ExecutionContext`
*
* @param throughput maximum number of tasks to be executed in serial before relinquishing the executing thread.
* @param context the underlying context which will be used to actually execute the submitted tasks
*/
private[akka] final class SerializedSuspendableExecutionContext(throughput: Int)(val context: ExecutionContext)
extends ConcurrentLinkedQueue[Runnable] with Runnable with ExecutionContext {
import SerializedSuspendableExecutionContext._
require(throughput > 0, s"SerializedSuspendableExecutionContext.throughput must be greater than 0 but was $throughput")

private final val state = new AtomicInteger(Off)
@tailrec private final def addState(newState: Int): Boolean = {
val c = state.get
state.compareAndSet(c, c | newState) || addState(newState)
}
@tailrec private final def remState(oldState: Int) {
val c = state.get
if (state.compareAndSet(c, c & ~oldState)) attach() else remState(oldState)
}

/**
* Resumes execution of tasks until `suspend` is called,
* if it isn't currently suspended, it is a no-op.
* This operation is idempotent.
*/
final def resume(): Unit = remState(Suspended)

/**
* Suspends execution of tasks until `resume` is called,
* this operation is idempotent.
*/
final def suspend(): Unit = addState(Suspended)

final def run(): Unit = {
@tailrec def run(done: Int): Unit =
if (done < throughput && state.get == On) {
poll() match {
case null ()
case some
try some.run() catch { case NonFatal(t) context reportFailure t }
run(done + 1)
}
}
try run(0) finally remState(On)
}

final def attach(): Unit = if (!isEmpty && state.compareAndSet(Off, On)) context execute this
override final def execute(task: Runnable): Unit = try add(task) finally attach()
override final def reportFailure(t: Throwable): Unit = context reportFailure t

override final def toString: String = (state.get: @switch) match {
case 0 "Off"
case 1 "On"
case 2 "Off & Suspended"
case 3 "On & Suspended"
}
}

0 comments on commit d6addd9

Please sign in to comment.