Skip to content

Commit

Permalink
Refactor the cursor folder to avoid GC overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Oct 3, 2016
1 parent 57d67f3 commit 0624760
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 120 deletions.
4 changes: 3 additions & 1 deletion bson/src/main/scala/buffer.scala
Expand Up @@ -202,7 +202,9 @@ object ArrayReadableBuffer {
}

/** An array-backed writable buffer. */
class ArrayBSONBuffer protected[buffer] (protected val buffer: ArrayBuffer[Byte]) extends WritableBuffer {
class ArrayBSONBuffer protected[buffer] (
protected val buffer: ArrayBuffer[Byte]
) extends WritableBuffer {
def index = buffer.length // useless

def bytebuffer(size: Int) = {
Expand Down
Expand Up @@ -581,7 +581,7 @@ trait GenericCollection[P <: SerializationPack with Singleton] extends Collectio
* @param reader $readerParam
* @param cf $cursorFlattenerParam
*/
def aggregatingWith[T](explain: Boolean = false, allowDiskUse: Boolean = false, bypassDocumentValidation: Boolean = false, readConcern: Option[ReadConcern] = None, readPreference: ReadPreference = ReadPreference.primary, batchSize: Option[Int] = None)(f: AggregationFramework => (PipelineOperator, List[PipelineOperator]))(implicit ec: ExecutionContext, reader: pack.Reader[T], cf: CursorFlattener[Cursor]): Cursor[T] = {
def aggregateWith[T](explain: Boolean = false, allowDiskUse: Boolean = false, bypassDocumentValidation: Boolean = false, readConcern: Option[ReadConcern] = None, readPreference: ReadPreference = ReadPreference.primary, batchSize: Option[Int] = None)(f: AggregationFramework => (PipelineOperator, List[PipelineOperator]))(implicit ec: ExecutionContext, reader: pack.Reader[T], cf: CursorFlattener[Cursor]): Cursor[T] = {
val (firstOp, otherOps) = f(BatchCommands.AggregationFramework)
val aggCursor = BatchCommands.AggregationFramework.
Cursor(batchSize.getOrElse(101))
Expand Down
283 changes: 200 additions & 83 deletions driver/src/main/scala/api/cursor.scala
Expand Up @@ -20,6 +20,8 @@ import scala.collection.generic.CanBuildFrom
import scala.collection.mutable.Builder
import scala.concurrent.{ ExecutionContext, Future }

import akka.actor.ActorSystem

import play.api.libs.iteratee.{ Enumerator, Enumeratee, Error, Input, Iteratee }

import reactivemongo.core.actors.Exceptions
Expand Down Expand Up @@ -466,7 +468,7 @@ object Cursor {
}

object DefaultCursor {
import Cursor.{ ErrorHandler, State, Cont, Done, Fail, logger }
import Cursor.{ ErrorHandler, State, Cont, Fail, logger }
import reactivemongo.api.commands.ResultCursor
import CursorOps.Unrecoverable

Expand Down Expand Up @@ -654,15 +656,15 @@ object DefaultCursor {

private def killCursors(cursorID: Long, logCat: String): Unit = {
if (cursorID != 0) {
logger.debug(s"[$logCat] Clean up ${cursorID}, sending KillCursors")
logger.debug(s"[$logCat] Clean up $cursorID, sending KillCursors")

val killReq = RequestMaker(
KillCursors(Set(cursorID)),
readPreference = preference
)

connection.send(killReq)
} else logger.trace(s"[$logCat] Cursor exhausted (${cursorID})")
} else logger.trace(s"[$logCat] Cursor exhausted ($cursorID)")
}

def head(implicit ctx: ExecutionContext): Future[A] =
Expand All @@ -686,9 +688,15 @@ object DefaultCursor {

@inline private def syncSuccess[A, B](f: (A, B) => State[A])(implicit ec: ExecutionContext): (A, B) => Future[State[A]] = { (a: A, b: B) => Future(f(a, b)) }

def foldResponses[T](z: => T, maxDocs: Int = -1)(suc: (T, Response) => State[T], err: (T, Throwable) => State[T])(implicit ctx: ExecutionContext): Future[T] = new FoldResponses(z, maxDocs, syncSuccess(suc), err)(ctx)()
def foldResponses[T](z: => T, maxDocs: Int = -1)(suc: (T, Response) => State[T], err: (T, Throwable) => State[T])(implicit ctx: ExecutionContext): Future[T] = FoldResponses(z, makeRequest(maxDocs)(_: ExecutionContext),
nextResponse(maxDocs), killCursors _, syncSuccess(suc), err, maxDocs)(
connection.actorSystem, ctx
)

def foldResponsesM[T](z: => T, maxDocs: Int = -1)(suc: (T, Response) => Future[State[T]], err: (T, Throwable) => State[T])(implicit ctx: ExecutionContext): Future[T] = new FoldResponses(z, maxDocs, suc, err)(ctx)()
def foldResponsesM[T](z: => T, maxDocs: Int = -1)(suc: (T, Response) => Future[State[T]], err: (T, Throwable) => State[T])(implicit ctx: ExecutionContext): Future[T] = FoldResponses(z, makeRequest(maxDocs)(_: ExecutionContext),
nextResponse(maxDocs), killCursors _, suc, err, maxDocs)(
connection.actorSystem, ctx
)

def foldBulks[T](z: => T, maxDocs: Int = -1)(suc: (T, Iterator[A]) => State[T], err: (T, Throwable) => State[T])(implicit ctx: ExecutionContext): Future[T] = foldBulksM[T](z, maxDocs)(syncSuccess[T, Iterator[A]](suc), err)

Expand Down Expand Up @@ -830,95 +838,204 @@ object DefaultCursor {
tailResponse(r, maxDocs)(ec)
}
}
}
}

private class FoldResponses[T](
z: => T, maxDocs: Int,
suc: (T, Response) => Future[State[T]],
err: ErrorHandler[T]
)(implicit ctx: ExecutionContext) {
/**
* @define curParam the current value
* @define cParam the value index (first = 0)
* @define lastParam the last [[reactivemongo.core.protocol.Response]]
*/
private[api] final class FoldResponses[T](
nextResponse: (ExecutionContext, Response) => Future[Option[Response]],
killCursors: (Long, String) => Unit,
maxDocs: Int,
suc: (T, Response) => Future[Cursor.State[T]],
err: Cursor.ErrorHandler[T]
)(implicit actorSys: ActorSystem, ec: ExecutionContext) { self =>
import Cursor.{ Cont, Done, Fail, State, logger }
import CursorOps.Unrecoverable

private val nextResp: Response => Future[Option[Response]] =
nextResponse(maxDocs)(ctx, _: Response)
private val promise = scala.concurrent.Promise[T]()
lazy val result: Future[T] = promise.future

@inline def ok(r: Response, v: T) = {
// Releases cursor before ending
killCursors(r.reply.cursorID, "FoldResponses")
Future.successful(v)
}
private val handle: Any => Unit = {
case ProcResponses(makeReq, cur, c, id) =>
procResponses(makeReq(), cur, c, id)

@inline def kill(r: Response, f: Throwable) = {
killCursors(r.reply.cursorID, "FoldResponses")
Future.failed[T](f)
}
case HandleResponse(last, cur, c) =>
handleResponse(last, cur, c)

def procResp(resp: Response, cur: T, c: Int): Future[T] = {
logger.trace(s"Process response: $resp")

suc(cur, resp).transform(resp -> _, { error =>
killCursors(resp.reply.cursorID, "FoldResponses")
error
}).flatMap {
case (r, next) =>
val nc = c + resp.reply.numberReturned

next match {
case Done(d) => ok(r, d)
case Cont(v) if (
maxDocs > -1 && nc >= maxDocs
) => ok(r, v)

case Fail(u @ Unrecoverable(_)) =>
/* already marked recoverable */ kill(r, u)

case Fail(f) => kill(r, Unrecoverable(f))
case Cont(v) =>
nextResp(r).flatMap(_.fold(Future successful v) { x =>
procResponses(Future.successful(x), v, nc)
})
}
}.recoverWith {
case Unrecoverable(e) => Future.failed(e)
case e =>
val nc = c + 1 // resp.reply.numberReturned

err(cur, e) match {
case Done(d) => Future.successful(d)
case Cont(v) if (
maxDocs > -1 && nc >= maxDocs
) => Future.successful(v)

case Fail(f) => Future.failed[T](f)
case Cont(v) =>
procResp(resp, cur, nc) // retry
}
case ProcNext(last, cur, next, c) =>
procNext(last, cur, next, c)

case OnError(last, cur, error, c) =>
onError(last, cur, error, c)
}

@inline private def kill(cursorID: Long): Unit = try {
killCursors(cursorID, "FoldResponses")
} catch {
case cause: Throwable =>
logger.warn(s"fails to kill cursor: $cursorID", cause)
}

@inline private def ok(r: Response, v: T): Unit = {
kill(r.reply.cursorID) // Releases cursor before ending
promise.success(v)
}

@inline private def ko(r: Response, f: Throwable): Unit = {
kill(r.reply.cursorID) // Releases cursor before ending
promise.failure(f)
}

@inline private def handleResponse(last: Response, cur: T, c: Int): Unit = {
logger.trace(s"Process response: $last")

suc(cur, last).onComplete({
case Success(next) => self ! ProcNext(last, cur, next, c)
case Failure(error) => self ! OnError(last, cur, error, c)
})(ec)
}

@inline
private def onError(last: Response, cur: T, error: Throwable, c: Int): Unit =
error match {
case Unrecoverable(e) => ko(last, e) // already marked recoverable

case _ => {
val nc = c + 1 // resp.reply.numberReturned

err(cur, error) match {
case Done(d) => ok(last, d)

case Cont(v) if (
maxDocs > -1 && nc >= maxDocs
) => ok(last, v)

case Fail(f) => ko(last, f)

case Cont(v) => self ! HandleResponse(last, cur, nc)
}
}
}

def procResponses(done: Future[Response], cur: T, c: Int): Future[T] =
done.map[Try[Response]](Success(_)).
recover { case err => Failure(err) }.flatMap {
case Success(r) => procResp(r, cur, c)
case Failure(error) => {
logger.error("fails to send request", error)

err(cur, error) match {
case Done(v) => Future.successful(v)
case Fail(e) => Future.failed(e)
case Cont(v) => {
logger.warn(
"cannot continue after fatal request error", error
)

Future.successful(v)
}
}
}
}
@inline private def procNext(last: Response, cur: T, next: State[T], c: Int): Unit = {
val nc = c + last.reply.numberReturned

next match {
case Done(d) => ok(last, d)

case Cont(v) if (
maxDocs > -1 && nc >= maxDocs
) => ok(last, v)

def apply(): Future[T] =
Future(z).flatMap(v => procResponses(makeRequest(maxDocs), v, 0))
case Fail(f) => self ! OnError(last, cur, f, c)

case Cont(v) => nextResponse(ec, last).onSuccess({
case Some(r) => self ! ProcResponses(
() => Future.successful(r), v, nc, r.reply.cursorID
)

case _ => ok(last, v)
})(ec)
}
}

@inline private def procResponses(last: Future[Response], cur: T, c: Int, lastID: Long): Unit = last.onComplete({
case Success(r) => self ! HandleResponse(r, cur, c)

case Failure(error) => {
logger.error("fails to send request", error)

err(cur, error) match {
case Done(v) => {
if (lastID > 0) kill(lastID)
promise.success(v)
}

case Fail(e) => {
if (lastID > 0) kill(lastID)
promise.failure(e)
}

case Cont(v) => {
logger.warn("cannot continue after fatal request error", error)

promise.success(v)
}
}
}
})(ec)

/**
* Enqueues a `message` to be processed while fold the cursor results.
*/
def !(message: Any): Unit = actorSys.scheduler.scheduleOnce(
// TODO: on retry, add some delay according FailoverStrategy
scala.concurrent.duration.Duration.Zero
)(handle(message))(ec)

// Messages

/**
* @param requester the function the perform the next request
* @param cur $curParam
* @param c $cParam
* @param lastID the last ID for the cursor (or `-1` if unknown)
*/
private[api] case class ProcResponses(
requester: () => Future[Response],
cur: T,
c: Int,
lastID: Long
)

/**
* @param last $lastParam
* @param cur $curParam
* @param c $cParam
*/
private case class HandleResponse(last: Response, cur: T, c: Int)

/**
* @param last $lastParam
* @param cur $curParam
* @param next the next state
* @param c $cParam
*/
private case class ProcNext(last: Response, cur: T, next: State[T], c: Int)

/**
* @param last $lastParam
* @param cur $curParam
* @param error the error details
* @param c $cParam
*/
private case class OnError(last: Response, cur: T, error: Throwable, c: Int)
}

private[api] object FoldResponses {
def apply[T](
z: => T,
makeRequest: ExecutionContext => Future[Response],
nextResponse: (ExecutionContext, Response) => Future[Option[Response]],
killCursors: (Long, String) => Unit,
suc: (T, Response) => Future[Cursor.State[T]],
err: Cursor.ErrorHandler[T],
maxDocs: Int
)(implicit actorSys: ActorSystem, ec: ExecutionContext): Future[T] = {
Future(z)(ec).flatMap({ v =>
val f = new FoldResponses[T](
nextResponse, killCursors, maxDocs, suc, err
)(actorSys, ec)

f ! f.ProcResponses(() => makeRequest(ec), v, 0, -1L)

f.result
})(ec)
}
}

/** Allows to enrich a base cursor. */
Expand Down
5 changes: 4 additions & 1 deletion driver/src/main/scala/core/netty.scala
Expand Up @@ -133,7 +133,10 @@ object ChannelBufferWritableBuffer {
}
}

case class BufferSequence(private val head: ChannelBuffer, private val tail: ChannelBuffer*) {
case class BufferSequence(
private val head: ChannelBuffer,
private val tail: ChannelBuffer*
) {
def merged: ChannelBuffer = mergedBuffer.duplicate()

private lazy val mergedBuffer =
Expand Down
13 changes: 11 additions & 2 deletions driver/src/main/scala/core/protocol/operations.scala
Expand Up @@ -39,8 +39,17 @@ private[protocol] object BufferAccessors {
def apply(buffer: ChannelBuffer, l: Long) = buffer writeLong l
}

implicit object StringChannelInteroperable extends BufferInteroperable[String] {
def apply(buffer: ChannelBuffer, s: String) = buffer writeCString s
implicit object StringChannelInteroperable
extends BufferInteroperable[String] {

private def writeCString(buffer: ChannelBuffer, s: String): ChannelBuffer = {
val bytes = s.getBytes("utf-8")
buffer writeBytes bytes
buffer writeByte 0
buffer
}

def apply(buffer: ChannelBuffer, s: String) = writeCString(buffer, s)
}

/**
Expand Down

0 comments on commit 0624760

Please sign in to comment.