diff --git a/core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala b/core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala index 27d1d14e96..c30057964b 100644 --- a/core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala +++ b/core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala @@ -24,6 +24,9 @@ import org.apache.predictionio.workflow.PersistentModelManifest import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, blocking} +import scala.language.postfixOps import scala.reflect._ /** Base class of a local algorithm. @@ -72,11 +75,11 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P] val glomQs: RDD[Array[(Long, Q)]] = qs.glom() val cartesian: RDD[(M, Array[(Long, Q)])] = mRDD.cartesian(glomQs) cartesian.flatMap { case (m, qArray) => - qArray.map { case (qx, q) => (qx, predict(m, q)) } + qArray.map { case (qx, q) => (qx, blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) }) } } } - def predictBase(localBaseModel: Any, q: Q): P = { + def predictBase(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = { predict(localBaseModel.asInstanceOf[M], q) } @@ -87,7 +90,7 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P] * @param q An input query. * @return A prediction. */ - def predict(m: M, q: Q): P + def predict(m: M, q: Q)(implicit ec: ExecutionContext): Future[P] /** :: DeveloperApi :: * Engine developers should not use this directly (read on to see how local diff --git a/core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala b/core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala index c617d2c50a..a22bf5e1c3 100644 --- a/core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala +++ b/core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala @@ -25,6 +25,9 @@ import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, blocking} +import scala.language.postfixOps import scala.reflect._ /** Base class of a parallel-to-local algorithm. @@ -67,10 +70,10 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P] * @return Batch of predicted results */ def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = { - qs.mapValues { q => predict(m, q) } + qs.mapValues { q => blocking { Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) } } } - def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q) + def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = predict(bm.asInstanceOf[M], q) /** Implement this method to produce a prediction from a query and trained * model. @@ -79,7 +82,7 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P] * @param query An input query. * @return A prediction. */ - def predict(model: M, query: Q): P + def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] /** :: DeveloperApi :: * Engine developers should not use this directly (read on to see how diff --git a/core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala b/core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala index 55f8363fdb..5c6b591319 100644 --- a/core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala +++ b/core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala @@ -24,6 +24,8 @@ import org.apache.predictionio.workflow.PersistentModelManifest import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import scala.concurrent.{ExecutionContext, Future} + /** Base class of a parallel algorithm. * * A parallel algorithm can be run in parallel on a cluster and produces a @@ -72,7 +74,7 @@ abstract class PAlgorithm[PD, M, Q, P] def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = throw new NotImplementedError("batchPredict not implemented") - def predictBase(baseModel: Any, query: Q): P = { + def predictBase(baseModel: Any, query: Q)(implicit ec: ExecutionContext): Future[P] = { predict(baseModel.asInstanceOf[M], query) } @@ -83,7 +85,7 @@ abstract class PAlgorithm[PD, M, Q, P] * @param query An input query. * @return A prediction. */ - def predict(model: M, query: Q): P + def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] /** :: DeveloperApi :: * Engine developers should not use this directly (read on to see how parallel diff --git a/core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala b/core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala index 8b9edc147b..f6703c1a64 100644 --- a/core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala +++ b/core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala @@ -26,6 +26,8 @@ import net.jodah.typetools.TypeResolver import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import scala.concurrent.{ExecutionContext, Future} + /** :: DeveloperApi :: * Base trait with default custom query serializer, exposed to engine developer * via [[org.apache.predictionio.controller.CustomQuerySerializer]] @@ -90,7 +92,7 @@ abstract class BaseAlgorithm[PD, M, Q, P] * @return Predicted result */ @DeveloperApi - def predictBase(bm: Any, q: Q): P + def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] /** :: DeveloperApi :: * Engine developers should not use this directly. Prepare a model for diff --git a/core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala b/core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala index 69525b11cf..2e80f292b0 100644 --- a/core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala +++ b/core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala @@ -32,7 +32,12 @@ import org.apache.predictionio.workflow.CleanupFunctions import org.apache.spark.rdd.RDD import org.json4s._ import org.json4s.native.JsonMethods._ +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.blocking +import scala.concurrent.{Await, Future} import scala.language.existentials +import scala.concurrent.ExecutionContext.Implicits.global case class BatchPredictConfig( inputFilePath: String = "batchpredict-input.json", @@ -207,23 +212,26 @@ object BatchPredict extends Logging { // Deploy logic. First call Serving.supplement, then Algo.predict, // finally Serving.serve. val supplementedQuery = serving.supplementBase(query) - // TODO: Parallelize the following. - val predictions = algorithms.zip(models).map { case (a, m) => + val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) => a.predictBase(m, supplementedQuery) - } + }) // Notice that it is by design to call Serving.serve with the // *original* query. - val prediction = serving.serveBase(query, predictions) - // Combine query with prediction, so the batch results are - // self-descriptive. - val predictionJValue = JsonExtractor.toJValue( - jsonExtractorOption, - Map("query" -> query, - "prediction" -> prediction), - algorithms.head.querySerializer, - algorithms.head.gsonTypeAdapterFactories) - // Return JSON string - compact(render(predictionJValue)) + val predFutureRdds = predictionsFuture.map { + predictions => + val prediction = serving.serveBase(query, predictions) + // Combine query with prediction, so the batch results are + // self-descriptive. + val predictionJValue = JsonExtractor.toJValue( + jsonExtractorOption, + Map("query" -> query, + "prediction" -> prediction), + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + // Return JSON string + compact(render(predictionJValue)) + } + blocking { Await.result(predFutureRdds, 60 minutes) } } predictionsRDD.saveAsTextFile(config.outputFilePath) diff --git a/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala b/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala index 5642114f8b..12ea70a7a7 100644 --- a/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala +++ b/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala @@ -51,6 +51,7 @@ import akka.http.scaladsl.server.Directives._ import akka.stream.ActorMaterializer import org.apache.predictionio.akkahttpjson4s.Json4sSupport._ import org.apache.predictionio.configuration.SSLConfiguration +import org.json4s import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Await, Future} @@ -487,7 +488,6 @@ class PredictionServer[Q, P]( try { val servingStartTime = DateTime.now val jsonExtractorOption = args.jsonExtractor - val queryTime = DateTime.now // Extract Query from Json val query = JsonExtractor.extract( jsonExtractorOption, @@ -504,107 +504,64 @@ class PredictionServer[Q, P]( // Deploy logic. First call Serving.supplement, then Algo.predict, // finally Serving.serve. val supplementedQuery = serving.supplementBase(query) - // TODO: Parallelize the following. - val predictions = algorithms.zip(models).map { case (a, m) => + + val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) => a.predictBase(m, supplementedQuery) - } + }) // Notice that it is by design to call Serving.serve with the // *original* query. - val prediction = serving.serveBase(query, predictions) - val predictionJValue = JsonExtractor.toJValue( - jsonExtractorOption, - prediction, - algorithms.head.querySerializer, - algorithms.head.gsonTypeAdapterFactories) - /** Handle feedback to Event Server - * Send the following back to the Event Server - * - appId - * - engineInstanceId - * - query - * - prediction - * - prId - */ - val result = if (feedbackEnabled) { - implicit val formats = - algorithms.headOption map { alg => - alg.querySerializer - } getOrElse { - Utils.json4sDefaultFormats - } - // val genPrId = Random.alphanumeric.take(64).mkString - def genPrId: String = Random.alphanumeric.take(64).mkString - val newPrId = prediction match { - case id: WithPrId => - val org = id.prId - if (org.isEmpty) genPrId else org - case _ => genPrId - } - - // also save Query's prId as prId of this pio_pr predict events - val queryPrId = - query match { - case id: WithPrId => - Map("prId" -> id.prId) - case _ => - Map.empty - } - val data = Map( - // "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId, - "event" -> "predict", - "eventTime" -> queryTime.toString(), - "entityType" -> "pio_pr", // prediction result - "entityId" -> newPrId, - "properties" -> Map( - "engineInstanceId" -> engineInstance.id, - "query" -> query, - "prediction" -> prediction)) ++ queryPrId - // At this point args.accessKey should be Some(String). - val accessKey = args.accessKey.getOrElse("") - val f: Future[Int] = Future { - scalaj.http.Http( - s"http://${args.eventServerIp}:${args.eventServerPort}/" + - s"events.json?accessKey=$accessKey").postData( - write(data)).header( - "content-type", "application/json").asString.code - } - f onComplete { - case Success(code) => { - if (code != 201) { - log.error(s"Feedback event failed. Status code: $code." - + s"Data: ${write(data)}.") + val pluginResultFuture = predictionsFuture.map { + predictions => + val prediction = serving.serveBase(query, predictions) + val predictionJValue = JsonExtractor.toJValue( + jsonExtractorOption, + prediction, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + /** Handle feedback to Event Server + * Send the following back to the Event Server + * - appId + * - engineInstanceId + * - query + * - prediction + * - prId + */ + val result: json4s.JValue = if (feedbackEnabled) { + sendFeedback(prediction, query, predictionJValue) + } else predictionJValue + + val pluginResult = + pluginContext.outputBlockers.values.foldLeft(result) { case (r, p) => + p.process(engineInstance, queryJValue, r, pluginContext) } - } - case Failure(t) => { - log.error(s"Feedback event failed: ${t.getMessage}") } - } - // overwrite prId in predictedResult - // - if it is WithPrId, - // then overwrite with new prId - // - if it is not WithPrId, no prId injection - if (prediction.isInstanceOf[WithPrId]) { - predictionJValue merge parse(s"""{"prId" : "$newPrId"}""") - } else { - predictionJValue - } - } else predictionJValue - - val pluginResult = - pluginContext.outputBlockers.values.foldLeft(result) { case (r, p) => - p.process(engineInstance, queryJValue, r, pluginContext) - } - pluginsActorRef ! (engineInstance, queryJValue, result) - - // Bookkeeping - val servingEndTime = DateTime.now - lastServingSec = - (servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0 - avgServingSec = - ((avgServingSec * requestCount) + lastServingSec) / - (requestCount + 1) - requestCount += 1 - - complete(compact(render(pluginResult))) + pluginsActorRef ! (engineInstance, queryJValue, result) + + // Bookkeeping + val servingEndTime = DateTime.now + lastServingSec = + (servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0 + avgServingSec = + ((avgServingSec * requestCount) + lastServingSec) / + (requestCount + 1) + requestCount += 1 + + pluginResult + } + onComplete(pluginResultFuture) { + case Success(pluginResult) => complete(compact(render(pluginResult))) + case Failure(t) => + val msg = s"Query:\n$queryString\n\nStack Trace:\n" + + s"${ExceptionUtils.getStackTrace(t)}\n\n" + log.error(msg) + args.logUrl map { url => + remoteLog( + url, + args.logPrefix.getOrElse(""), + msg) + } + complete(StatusCodes.InternalServerError, msg) + } } catch { case e: MappingException => val msg = s"Query:\n$queryString\n\nStack Trace:\n" + @@ -703,4 +660,69 @@ class PredictionServer[Q, P]( myRoute } + + def sendFeedback(prediction: P, query: Q, predictionJValue: JsonAST.JValue): json4s.JValue = { + val queryTime = DateTime.now + implicit val formats = + algorithms.headOption map { alg => + alg.querySerializer + } getOrElse { + Utils.json4sDefaultFormats + } + // val genPrId = Random.alphanumeric.take(64).mkString + def genPrId: String = Random.alphanumeric.take(64).mkString + val newPrId = prediction match { + case id: WithPrId => + val org = id.prId + if (org.isEmpty) genPrId else org + case _ => genPrId + } + + // also save Query's prId as prId of this pio_pr predict events + val queryPrId = + query match { + case id: WithPrId => + Map("prId" -> id.prId) + case _ => + Map.empty + } + val data = Map( + // "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId, + "event" -> "predict", + "eventTime" -> queryTime.toString(), + "entityType" -> "pio_pr", // prediction result + "entityId" -> newPrId, + "properties" -> Map( + "engineInstanceId" -> engineInstance.id, + "query" -> query, + "prediction" -> prediction)) ++ queryPrId + // At this point args.accessKey should be Some(String). + val accessKey = args.accessKey.getOrElse("") + val f: Future[Int] = Future { + scalaj.http.Http( + s"http://${args.eventServerIp}:${args.eventServerPort}/" + + s"events.json?accessKey=$accessKey").postData( + write(data)).header( + "content-type", "application/json").asString.code + } + f onComplete { + case Success(code) => { + if (code != 201) { + log.error(s"Feedback event failed. Status code: $code." + + s"Data: ${write(data)}.") + } + } + case Failure(t) => { + log.error(s"Feedback event failed: ${t.getMessage}") } + } + // overwrite prId in predictedResult + // - if it is WithPrId, + // then overwrite with new prId + // - if it is not WithPrId, no prId injection + if (prediction.isInstanceOf[WithPrId]) { + predictionJValue merge parse(s"""{"prId" : "$newPrId"}""") + } else { + predictionJValue + } + } } diff --git a/core/src/test/scala/org/apache/predictionio/controller/SampleEngine.scala b/core/src/test/scala/org/apache/predictionio/controller/SampleEngine.scala index c53e98e827..8b17fe959f 100644 --- a/core/src/test/scala/org/apache/predictionio/controller/SampleEngine.scala +++ b/core/src/test/scala/org/apache/predictionio/controller/SampleEngine.scala @@ -19,13 +19,14 @@ package org.apache.predictionio.controller import org.apache.predictionio.controller.{Params => PIOParams} import org.apache.predictionio.core._ - import grizzled.slf4j.Logger import org.apache.predictionio.workflow.WorkflowParams import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import scala.concurrent.{ExecutionContext, Future} + object Engine0 { @transient lazy val logger = Logger[this.type] @@ -204,8 +205,8 @@ object Engine0 { qs.mapValues(q => Prediction(id, q, Some(m))) } - def predict(m: PAlgo0.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: PAlgo0.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -224,8 +225,8 @@ object Engine0 { qs.mapValues(q => Prediction(id, q, Some(m))) } - def predict(m: PAlgo1.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: PAlgo1.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -247,8 +248,8 @@ object Engine0 { qs.mapValues(q => Prediction(id, q, Some(m))) } - def predict(m: PAlgo2.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: PAlgo2.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -274,8 +275,8 @@ object Engine0 { qs.mapValues(q => Prediction(id, q, Some(m))) } - def predict(m: PAlgo3.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: PAlgo3.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -287,8 +288,8 @@ object Engine0 { extends LAlgorithm[ProcessedData, LAlgo0.Model, Query, Prediction] { def train(pd: ProcessedData): LAlgo0.Model = LAlgo0.Model(id, pd) - def predict(m: LAlgo0.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: LAlgo0.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -300,8 +301,8 @@ object Engine0 { extends LAlgorithm[ProcessedData, LAlgo1.Model, Query, Prediction] { def train(pd: ProcessedData): LAlgo1.Model = LAlgo1.Model(id, pd) - def predict(m: LAlgo1.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: LAlgo1.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -318,8 +319,8 @@ object Engine0 { extends LAlgorithm[ProcessedData, LAlgo2.Model, Query, Prediction] { def train(pd: ProcessedData): LAlgo2.Model = LAlgo2.Model(params.id, pd) - def predict(m: LAlgo2.Model, q: Query): Prediction = { - Prediction(params.id, q, Some(m)) + def predict(m: LAlgo2.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(params.id, q, Some(m))) } } @@ -333,8 +334,8 @@ object Engine0 { extends LAlgorithm[ProcessedData, LAlgo3.Model, Query, Prediction] { def train(pd: ProcessedData): LAlgo3.Model = LAlgo3.Model(params.id, pd) - def predict(m: LAlgo3.Model, q: Query): Prediction = { - Prediction(params.id, q, Some(m)) + def predict(m: LAlgo3.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(params.id, q, Some(m))) } } @@ -348,8 +349,8 @@ object Engine0 { def train(sc: SparkContext, pd: ProcessedData) : NAlgo0.Model = NAlgo0.Model(id, pd) - def predict(m: NAlgo0.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: NAlgo0.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -362,8 +363,8 @@ object Engine0 { def train(sc: SparkContext, pd: ProcessedData) : NAlgo1.Model = NAlgo1.Model(id, pd) - def predict(m: NAlgo1.Model, q: Query): Prediction = { - Prediction(id, q, Some(m)) + def predict(m: NAlgo1.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(id, q, Some(m))) } } @@ -381,8 +382,8 @@ object Engine0 { def train(sc: SparkContext, pd: ProcessedData) : NAlgo2.Model = NAlgo2.Model(params.id, pd) - def predict(m: NAlgo2.Model, q: Query): Prediction = { - Prediction(params.id, q, Some(m)) + def predict(m: NAlgo2.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(params.id, q, Some(m))) } } @@ -397,8 +398,8 @@ object Engine0 { def train(sc: SparkContext, pd: ProcessedData) : NAlgo3.Model = NAlgo3.Model(params.id, pd) - def predict(m: NAlgo3.Model, q: Query): Prediction = { - Prediction(params.id, q, Some(m)) + def predict(m: NAlgo3.Model, q: Query)(implicit ec: ExecutionContext): Future[Prediction] = { + Future.successful(Prediction(params.id, q, Some(m))) } } diff --git a/data/src/main/scala/org/apache/predictionio/data/store/LEventStore.scala b/data/src/main/scala/org/apache/predictionio/data/store/LEventStore.scala index bcb6acd591..678d81385e 100644 --- a/data/src/main/scala/org/apache/predictionio/data/store/LEventStore.scala +++ b/data/src/main/scala/org/apache/predictionio/data/store/LEventStore.scala @@ -118,59 +118,6 @@ object LEventStore { timeout) } - /** Reads events of the specified entity. May use this in Algorithm's predict() - * or Serving logic to have fast event store access. - * - * @param appName return events of this app - * @param entityType return events of this entityType - * @param entityId return events of this entityId - * @param channelName return events of this channel (default channel if it's None) - * @param eventNames return events with any of these event names. - * @param targetEntityType return events of this targetEntityType: - * - None means no restriction on targetEntityType - * - Some(None) means no targetEntityType for this event - * - Some(Some(x)) means targetEntityType should match x. - * @param targetEntityId return events of this targetEntityId - * - None means no restriction on targetEntityId - * - Some(None) means no targetEntityId for this event - * - Some(Some(x)) means targetEntityId should match x. - * @param startTime return events with eventTime >= startTime - * @param untilTime return events with eventTime < untilTime - * @param limit Limit number of events. Get all events if None or Some(-1) - * @param latest Return latest event first (default true) - * @return Future[Iterator[Event]] - */ - def findByEntityAsync( - appName: String, - entityType: String, - entityId: String, - channelName: Option[String] = None, - eventNames: Option[Seq[String]] = None, - targetEntityType: Option[Option[String]] = None, - targetEntityId: Option[Option[String]] = None, - startTime: Option[DateTime] = None, - untilTime: Option[DateTime] = None, - limit: Option[Int] = None, - latest: Boolean = true)(implicit ec: ExecutionContext): Future[Iterator[Event]] = { - - val (appId, channelId) = Common.appNameToId(appName, channelName) - - eventsDb.futureFind( - appId = appId, - channelId = channelId, - startTime = startTime, - untilTime = untilTime, - entityType = Some(entityType), - entityId = Some(entityId), - eventNames = eventNames, - targetEntityType = targetEntityType, - targetEntityId = targetEntityId, - startTime = startTime, - untilTime = untilTime, - limit = limit, - reversed = Some(latest)) - } - /** Reads events of the specified entity. May use this in Algorithm's predict() * or Serving logic to have fast event store access. * @@ -276,59 +223,6 @@ object LEventStore { limit = limit), timeout) } - /** Reads events generically. If entityType or entityId is not specified, it - * results in table scan. - * - * @param appName return events of this app - * @param entityType return events of this entityType - * - None means no restriction on entityType - * - Some(x) means entityType should match x. - * @param entityId return events of this entityId - * - None means no restriction on entityId - * - Some(x) means entityId should match x. - * @param channelName return events of this channel (default channel if it's None) - * @param eventNames return events with any of these event names. - * @param targetEntityType return events of this targetEntityType: - * - None means no restriction on targetEntityType - * - Some(None) means no targetEntityType for this event - * - Some(Some(x)) means targetEntityType should match x. - * @param targetEntityId return events of this targetEntityId - * - None means no restriction on targetEntityId - * - Some(None) means no targetEntityId for this event - * - Some(Some(x)) means targetEntityId should match x. - * @param startTime return events with eventTime >= startTime - * @param untilTime return events with eventTime < untilTime - * @param limit Limit number of events. Get all events if None or Some(-1) - * @return Future[Iterator[Event]] - */ - def findAsync( - appName: String, - entityType: Option[String] = None, - entityId: Option[String] = None, - channelName: Option[String] = None, - eventNames: Option[Seq[String]] = None, - targetEntityType: Option[Option[String]] = None, - targetEntityId: Option[Option[String]] = None, - startTime: Option[DateTime] = None, - untilTime: Option[DateTime] = None, - limit: Option[Int] = None)(implicit ec: ExecutionContext): Future[Iterator[Event]] = { - - val (appId, channelId) = Common.appNameToId(appName, channelName) - - eventsDb.futureFind( - appId = appId, - channelId = channelId, - startTime = startTime, - untilTime = untilTime, - entityType = entityType, - entityId = entityId, - channelName = channelName, - eventNames = eventNames, - targetEntityType = targetEntityType, - targetEntityId = targetEntityId, - limit = limit) - } - /** Reads events generically. If entityType or entityId is not specified, it * results in table scan. * diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESAccessKeys.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESAccessKeys.scala index 15f223f81a..488d4b0505 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESAccessKeys.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESAccessKeys.scala @@ -20,7 +20,6 @@ package org.apache.predictionio.data.storage.elasticsearch import java.io.IOException import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.http.entity.ContentType import org.apache.http.nio.entity.NStringEntity import org.apache.http.util.EntityUtils @@ -32,8 +31,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write - import grizzled.slf4j.Logging +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.{Await, ExecutionContext, Future} /** Elasticsearch implementation of AccessKeys. */ class ESAccessKeys(client: RestClient, config: StorageClientConfig, index: String) @@ -90,30 +91,32 @@ class ESAccessKeys(client: RestClient, config: StorageClientConfig, index: Strin } def getAll(): Seq[AccessKey] = { - try { - val json = - ("query" -> - ("match_all" -> List.empty)) - ESUtils.getAll[AccessKey](client, internalIndex, estype, compact(render(json))) - } catch { - case e: IOException => - error("Failed to access to /$internalIndex/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("match_all" -> List.empty)) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[AccessKey](client, internalIndex, estype, compact(render(json))) + .recover { + case e: IOException => + error("Failed to access to /$internalIndex/$estype/_search", e) + Nil + }, 1 minute) } def getByAppid(appid: Int): Seq[AccessKey] = { - try { - val json = - ("query" -> - ("term" -> - ("appid" -> appid))) - ESUtils.getAll[AccessKey](client, internalIndex, estype, compact(render(json))) - } catch { + val json = + ("query" -> + ("term" -> + ("appid" -> appid))) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[AccessKey](client, internalIndex, estype, compact(render(json))) + .recover { case e: IOException => error("Failed to access to /$internalIndex/$estype/_search", e) Nil - } + }, 1 minute) } def update(accessKey: AccessKey): Unit = { diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESApps.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESApps.scala index cb17af8ebc..b34a0c5793 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESApps.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESApps.scala @@ -20,7 +20,6 @@ package org.apache.predictionio.data.storage.elasticsearch import java.io.IOException import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.http.entity.ContentType import org.apache.http.nio.entity.NStringEntity import org.apache.http.util.EntityUtils @@ -32,8 +31,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write - import grizzled.slf4j.Logging +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.Await /** Elasticsearch implementation of Items. */ class ESApps(client: RestClient, config: StorageClientConfig, index: String) @@ -127,16 +128,18 @@ class ESApps(client: RestClient, config: StorageClientConfig, index: String) } def getAll(): Seq[App] = { - try { - val json = - ("query" -> - ("match_all" -> Nil)) - ESUtils.getAll[App](client, internalIndex, estype, compact(render(json))) - } catch { - case e: IOException => - error("Failed to access to /$internalIndex/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("match_all" -> Nil)) + import scala.concurrent.ExecutionContext.Implicits.global + + Await.result(ESUtils + .getAll[App](client, internalIndex, estype, compact(render(json))) + .recover { + case e: IOException => + error("Failed to access to /$internalIndex/$estype/_search", e) + Nil + }, 1 minute) } def update(app: App): Unit = { diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESChannels.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESChannels.scala index 63b108f107..7d36dad548 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESChannels.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESChannels.scala @@ -20,7 +20,6 @@ package org.apache.predictionio.data.storage.elasticsearch import java.io.IOException import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.http.entity.ContentType import org.apache.http.nio.entity.NStringEntity import org.apache.http.util.EntityUtils @@ -32,8 +31,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write - import grizzled.slf4j.Logging +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.Await class ESChannels(client: RestClient, config: StorageClientConfig, index: String) extends Channels with Logging { @@ -97,17 +98,18 @@ class ESChannels(client: RestClient, config: StorageClientConfig, index: String) } def getByAppid(appid: Int): Seq[Channel] = { - try { - val json = - ("query" -> - ("term" -> - ("appid" -> appid))) - ESUtils.getAll[Channel](client, internalIndex, estype, compact(render(json))) - } catch { - case e: IOException => - error(s"Failed to access to /$internalIndex/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("term" -> + ("appid" -> appid))) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[Channel](client, internalIndex, estype, compact(render(json))) + .recover { + case e: IOException => + error(s"Failed to access to /$internalIndex/$estype/_search", e) + Nil + }, 1 minute) } def update(channel: Channel): Boolean = { diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEngineInstances.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEngineInstances.scala index 02f7b98248..8db59f9107 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEngineInstances.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEngineInstances.scala @@ -20,7 +20,6 @@ package org.apache.predictionio.data.storage.elasticsearch import java.io.IOException import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.http.entity.ContentType import org.apache.http.nio.entity.NStringEntity import org.apache.http.util.EntityUtils @@ -33,8 +32,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write - import grizzled.slf4j.Logging +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.Await class ESEngineInstances(client: RestClient, config: StorageClientConfig, index: String) extends EngineInstances with Logging { @@ -133,44 +134,47 @@ class ESEngineInstances(client: RestClient, config: StorageClientConfig, index: } def getAll(): Seq[EngineInstance] = { - try { - val json = - ("query" -> - ("match_all" -> List.empty)) - ESUtils.getAll[EngineInstance](client, index, estype, compact(render(json))) - } catch { - case e: IOException => - error("Failed to access to /$index/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("match_all" -> List.empty)) + import scala.concurrent.ExecutionContext.Implicits.global + + Await.result(ESUtils + .getAll[EngineInstance](client, index, estype, compact(render(json))) + .recover { + case e: IOException => + error("Failed to access to /$index/$estype/_search", e) + Nil + }, 1 minute) } def getCompleted( engineId: String, engineVersion: String, engineVariant: String): Seq[EngineInstance] = { - try { - val json = - ("query" -> - ("bool" -> - ("must" -> List( - ("term" -> - ("status" -> "COMPLETED")), - ("term" -> - ("engineId" -> engineId)), - ("term" -> - ("engineVersion" -> engineVersion)), - ("term" -> - ("engineVariant" -> engineVariant)))))) ~ - ("sort" -> List( - ("startTime" -> - ("order" -> "desc")))) - ESUtils.getAll[EngineInstance](client, index, estype, compact(render(json))) - } catch { - case e: IOException => - error(s"Failed to access to /$index/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("bool" -> + ("must" -> List( + ("term" -> + ("status" -> "COMPLETED")), + ("term" -> + ("engineId" -> engineId)), + ("term" -> + ("engineVersion" -> engineVersion)), + ("term" -> + ("engineVariant" -> engineVariant)))))) ~ + ("sort" -> List( + ("startTime" -> + ("order" -> "desc")))) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[EngineInstance](client, index, estype, compact(render(json))) + .recover { + case e: IOException => + error(s"Failed to access to /$index/$estype/_search", e) + Nil + }, 1 minute) } def getLatestCompleted( diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEvaluationInstances.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEvaluationInstances.scala index 03b851d496..7baa0c60f5 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEvaluationInstances.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESEvaluationInstances.scala @@ -20,7 +20,6 @@ package org.apache.predictionio.data.storage.elasticsearch import java.io.IOException import scala.collection.JavaConverters._ - import org.apache.http.entity.ContentType import org.apache.http.nio.entity.NStringEntity import org.apache.http.util.EntityUtils @@ -34,8 +33,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write - import grizzled.slf4j.Logging +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.concurrent.Await class ESEvaluationInstances(client: RestClient, config: StorageClientConfig, index: String) extends EvaluationInstances with Logging { @@ -107,33 +108,35 @@ class ESEvaluationInstances(client: RestClient, config: StorageClientConfig, ind } def getAll(): Seq[EvaluationInstance] = { - try { - val json = - ("query" -> - ("match_all" -> List.empty)) - ESUtils.getAll[EvaluationInstance](client, internalIndex, estype, compact(render(json))) - } catch { - case e: IOException => - error("Failed to access to /$internalIndex/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("match_all" -> List.empty)) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[EvaluationInstance](client, internalIndex, estype, compact(render(json))) + .recover { + case e: IOException => + error("Failed to access to /$internalIndex/$estype/_search", e) + Nil + }, 1 minute) } def getCompleted(): Seq[EvaluationInstance] = { - try { - val json = - ("query" -> - ("term" -> - ("status" -> "EVALCOMPLETED"))) ~ - ("sort" -> - ("startTime" -> - ("order" -> "desc"))) - ESUtils.getAll[EvaluationInstance](client, internalIndex, estype, compact(render(json))) - } catch { - case e: IOException => - error("Failed to access to /$internalIndex/$estype/_search", e) - Nil - } + val json = + ("query" -> + ("term" -> + ("status" -> "EVALCOMPLETED"))) ~ + ("sort" -> + ("startTime" -> + ("order" -> "desc"))) + import scala.concurrent.ExecutionContext.Implicits.global + Await.result(ESUtils + .getAll[EvaluationInstance](client, internalIndex, estype, compact(render(json))) + .recover { + case e: IOException => + error("Failed to access to /$internalIndex/$estype/_search", e) + Nil + }, 1 minute) } def update(i: EvaluationInstance): Unit = { diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESLEvents.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESLEvents.scala index f275ec9210..eedbd417a2 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESLEvents.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESLEvents.scala @@ -37,6 +37,7 @@ import org.json4s.native.Serialization.write import org.json4s.ext.JodaTimeSerializers import grizzled.slf4j.Logging import org.apache.http.message.BasicHeader +import org.apache.predictionio.data.storage.elasticsearch.ScalaRestClient.ExtendedScalaRestClient class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseIndex: String) extends LEvents with Logging { @@ -107,7 +108,6 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd event: Event, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[String] = { - Future { val estype = getEsType(appId, channelId) val index = baseIndex + "_" + estype try { @@ -127,33 +127,39 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd ("creationTime" -> ESUtils.formatUTCDateTime(event.creationTime)) ~ ("properties" -> write(event.properties.toJObject)) val entity = new NStringEntity(compact(render(json)), ContentType.APPLICATION_JSON) - val response = client.performRequest( + val futureResponse = client.performRequestFuture( "POST", s"/$index/$estype/$id", - Map("refresh" -> ESUtils.getEventDataRefresh(config)).asJava, + Map("refresh" -> ESUtils.getEventDataRefresh(config)), entity) - val jsonResponse = parse(EntityUtils.toString(response.getEntity)) - val result = (jsonResponse \ "result").extract[String] - result match { - case "created" => id - case "updated" => id - case _ => - error(s"[$result] Failed to update $index/$estype/$id") + + futureResponse.map { + response => + val jsonResponse = parse(EntityUtils.toString(response.getEntity)) + val result = (jsonResponse \ "result").extract[String] + result match { + case "created" => id + case "updated" => id + case _ => + error(s"[$result] Failed to update $index/$estype/$id") + "" + } + }.recover { + case t => + error(s"Failed to update $index/$estype/", t) "" } } catch { case e: IOException => error(s"Failed to update $index/$estype/", e) - "" + Future.successful("") } - } } override def futureInsertBatch( events: Seq[Event], appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Seq[String]] = { - Future { val estype = getEsType(appId, channelId) val index = baseIndex + "_" + estype try { @@ -187,34 +193,40 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd }.mkString("", "\n", "\n") val entity = new StringEntity(json) - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", "/_bulk", - Map("refresh" -> ESUtils.getEventDataRefresh(config)).asJava, + Map("refresh" -> ESUtils.getEventDataRefresh(config)), entity, new BasicHeader("Content-Type", "application/x-ndjson")) - val responseJson = parse(EntityUtils.toString(response.getEntity)) - val items = (responseJson \ "items").asInstanceOf[JArray] + responseFuture.map { + response => + val responseJson = parse(EntityUtils.toString(response.getEntity)) + val items = (responseJson \ "items").asInstanceOf[JArray] - items.arr.map { case value: JObject => - val result = (value \ "index" \ "result").extract[String] - val id = (value \ "index" \ "_id").extract[String] + items.arr.map { case value: JObject => + val result = (value \ "index" \ "result").extract[String] + val id = (value \ "index" \ "_id").extract[String] - result match { - case "created" => id - case "updated" => id - case _ => - error(s"[$result] Failed to update $index/$estype/$id") - "" - } + result match { + case "created" => id + case "updated" => id + case _ => + error(s"[$result] Failed to update $index/$estype/$id") + "" + } + } + }.recover { + case t => + error(s"Failed to update $index/$estype/", t) + Nil } } catch { case e: IOException => error(s"Failed to update $index/$estype/", e) - Nil + Future.successful(Nil) } - } } private def exists(client: RestClient, estype: String, id: Int): Boolean = { @@ -245,7 +257,6 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd eventId: String, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Option[Event]] = { - Future { val estype = getEsType(appId, channelId) val index = baseIndex + "_" + estype try { @@ -254,32 +265,37 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd ("term" -> ("eventId" -> eventId))) val entity = new NStringEntity(compact(render(json)), ContentType.APPLICATION_JSON) - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", s"/$index/$estype/_search", - Map.empty[String, String].asJava, + Map.empty[String, String], entity) - val jsonResponse = parse(EntityUtils.toString(response.getEntity)) - (jsonResponse \ "hits" \ "total").extract[Long] match { - case 0 => None - case _ => - val results = (jsonResponse \ "hits" \ "hits").extract[Seq[JValue]] - val result = (results.head \ "_source").extract[Event] - Some(result) + responseFuture.map { + response => + val jsonResponse = parse(EntityUtils.toString(response.getEntity)) + (jsonResponse \ "hits" \ "total").extract[Long] match { + case 0 => None + case _ => + val results = (jsonResponse \ "hits" \ "hits").extract[Seq[JValue]] + val result = (results.head \ "_source").extract[Event] + Some(result) + } + }.recover { + case t => + error("Failed to access to /$index/$estype/_search", t) + None } } catch { case e: IOException => error("Failed to access to /$index/$estype/_search", e) - None + Future.successful(None) } - } } override def futureDelete( eventId: String, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Boolean] = { - Future { val estype = getEsType(appId, channelId) val index = baseIndex + "_" + estype try { @@ -288,19 +304,25 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd ("term" -> ("eventId" -> eventId))) val entity = new NStringEntity(compact(render(json)), ContentType.APPLICATION_JSON) - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", s"/$index/$estype/_delete_by_query", - Map("refresh" -> ESUtils.getEventDataRefresh(config)).asJava, + Map("refresh" -> ESUtils.getEventDataRefresh(config)), entity) - val jsonResponse = parse(EntityUtils.toString(response.getEntity)) - (jsonResponse \ "deleted").extract[Int] > 0 + responseFuture.map { + response => + val jsonResponse = parse(EntityUtils.toString(response.getEntity)) + (jsonResponse \ "deleted").extract[Int] > 0 + }.recover { + case t => + error(s"Failed to delete $index/$estype:$eventId", t) + false + } } catch { case e: IOException => error(s"Failed to delete $index/$estype:$eventId", e) - false + Future.successful(false) } - } } override def futureFind( @@ -316,23 +338,26 @@ class ESLEvents(val client: RestClient, config: StorageClientConfig, val baseInd limit: Option[Int] = None, reversed: Option[Boolean] = None) (implicit ec: ExecutionContext): Future[Iterator[Event]] = { - Future { val estype = getEsType(appId, channelId) val index = baseIndex + "_" + estype try { val query = ESUtils.createEventQuery( startTime, untilTime, entityType, entityId, eventNames, targetEntityType, targetEntityId, reversed) - limit.getOrElse(20) match { - case -1 => ESUtils.getEventAll(client, index, estype, query).toIterator - case size => ESUtils.getEvents(client, index, estype, query, size).toIterator + val eventsFuture = limit.getOrElse(20) match { + case -1 => ESUtils.getEventAll(client, index, estype, query).map(_.toIterator) + case size => ESUtils.getEvents(client, index, estype, query, size).map(_.toIterator) + } + eventsFuture.recover { + case t => + error(t.getMessage) + Iterator.empty } } catch { case e: IOException => error(e.getMessage) - Iterator.empty + Future.successful(Iterator.empty) } - } } } diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESUtils.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESUtils.scala index cd9aa53a7c..8f439582ce 100644 --- a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESUtils.scala +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ESUtils.scala @@ -19,7 +19,6 @@ package org.apache.predictionio.data.storage.elasticsearch import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ - import org.apache.http.entity.ContentType import org.apache.http.entity.StringEntity import org.apache.http.nio.entity.NStringEntity @@ -36,6 +35,10 @@ import org.apache.predictionio.data.storage.StorageClientConfig import org.apache.http.HttpHost import org.apache.predictionio.data.storage.Event import org.apache.predictionio.data.storage.DataMap +import org.apache.predictionio.data.storage.elasticsearch.ScalaRestClient.ExtendedScalaRestClient + +import scala.concurrent.{ExecutionContext, Future} + object ESUtils { val scrollLife = "1m" @@ -86,8 +89,8 @@ object ESUtils { estype: String, query: String, size: Int)( - implicit formats: Formats): Seq[Event] = { - getDocList(client, index, estype, query, size).map(x => toEvent(x)) + implicit formats: Formats, ec: ExecutionContext): Future[Seq[Event]] = { + getDocList(client, index, estype, query, size).map(docs => docs.map(x => toEvent(x))) } def getDocList( @@ -96,16 +99,19 @@ object ESUtils { estype: String, query: String, size: Int)( - implicit formats: Formats): Seq[JValue] = { + implicit formats: Formats, ec: ExecutionContext): Future[Seq[JValue]] = { val entity = new NStringEntity(query, ContentType.APPLICATION_JSON) - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", s"/$index/$estype/_search", Map("size" -> s"${size}"), entity) - val responseJValue = parse(EntityUtils.toString(response.getEntity)) - val hits = (responseJValue \ "hits" \ "hits").extract[Seq[JValue]] - hits.map(h => (h \ "_source")) + responseFuture.map { + response => + val responseJValue = parse(EntityUtils.toString(response.getEntity)) + val hits = (responseJValue \ "hits" \ "hits").extract[Seq[JValue]] + hits.map(h => (h \ "_source")) + } } def getAll[T: Manifest]( @@ -113,8 +119,9 @@ object ESUtils { index: String, estype: String, query: String)( - implicit formats: Formats): Seq[T] = { - getDocAll(client, index, estype, query).map(x => x.extract[T]) + implicit formats: Formats, ec: ExecutionContext): Future[Seq[T]] = { + getDocAll(client, index, estype, query) + .map(docs => docs.map(x => x.extract[T])) } def getEventAll( @@ -122,8 +129,8 @@ object ESUtils { index: String, estype: String, query: String)( - implicit formats: Formats): Seq[Event] = { - getDocAll(client, index, estype, query).map(x => toEvent(x)) + implicit formats: Formats, ec: ExecutionContext): Future[Seq[Event]] = { + getDocAll(client, index, estype, query).map(docs => docs.map(x => toEvent(x))) } def getDocAll( @@ -131,7 +138,7 @@ object ESUtils { index: String, estype: String, query: String)( - implicit formats: Formats): Seq[JValue] = { + implicit formats: Formats, ec: ExecutionContext): Future[Seq[JValue]] = { @scala.annotation.tailrec def scroll(scrollId: String, hits: Seq[JValue], results: Seq[JValue]): Seq[JValue] = { @@ -152,15 +159,18 @@ object ESUtils { } val entity = new NStringEntity(query, ContentType.APPLICATION_JSON) - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", s"/$index/$estype/_search", Map("scroll" -> scrollLife), entity) - val responseJValue = parse(EntityUtils.toString(response.getEntity)) - scroll((responseJValue \ "_scroll_id").extract[String], - (responseJValue \ "hits" \ "hits").extract[Seq[JValue]], - Nil) + responseFuture.map { + response => + val responseJValue = parse(EntityUtils.toString(response.getEntity)) + scroll((responseJValue \ "_scroll_id").extract[String], + (responseJValue \ "hits" \ "hits").extract[Seq[JValue]], + Nil) + } } def createIndex( diff --git a/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ScalaRestClient.scala b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ScalaRestClient.scala new file mode 100644 index 0000000000..08a23f3f71 --- /dev/null +++ b/storage/elasticsearch/src/main/scala/org/apache/predictionio/data/storage/elasticsearch/ScalaRestClient.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.storage.elasticsearch + +import org.apache.http.{Header, HttpEntity} +import org.elasticsearch.client.{Response, ResponseListener, RestClient} + +import scala.collection.JavaConverters._ +import scala.concurrent.{Future, Promise} + +object ScalaRestClient { + + implicit class ExtendedScalaRestClient(restClient: RestClient) { + + def performRequestFuture(method: String, endpoint: String, params: Map[String, String], + entity: HttpEntity, headers: Header*): Future[Response] = { + val promise: Promise[Response] = Promise() + val responseListener = new ResponseListener { + override def onSuccess(response: Response): Unit = promise.success(response) + override def onFailure(exception: Exception): Unit = promise.failure(exception) + } + restClient.performRequestAsync( + method, endpoint, params.asJava, entity, responseListener, headers: _*) + promise.future + } + } +} diff --git a/storage/hbase/src/main/scala/org/apache/predictionio/data/storage/hbase/HBLEvents.scala b/storage/hbase/src/main/scala/org/apache/predictionio/data/storage/hbase/HBLEvents.scala index e95e7e82b1..608a269ce3 100644 --- a/storage/hbase/src/main/scala/org/apache/predictionio/data/storage/hbase/HBLEvents.scala +++ b/storage/hbase/src/main/scala/org/apache/predictionio/data/storage/hbase/HBLEvents.scala @@ -33,6 +33,7 @@ import org.joda.time.DateTime import scala.collection.JavaConversions._ import scala.concurrent.ExecutionContext import scala.concurrent.Future +import scala.concurrent.blocking class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace: String) extends LEvents with Logging { @@ -100,12 +101,14 @@ class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace event: Event, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[String] = { Future { - val table = getTable(appId, channelId) - val (put, rowKey) = HBEventsUtil.eventToPut(event, appId) - table.put(put) - table.flushCommits() - table.close() - rowKey.toString + blocking { + val table = getTable(appId, channelId) + val (put, rowKey) = HBEventsUtil.eventToPut(event, appId) + table.put(put) + table.flushCommits() + table.close() + rowKey.toString + } } } @@ -114,12 +117,14 @@ class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace events: Seq[Event], appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Seq[String]] = { Future { - val table = getTable(appId, channelId) - val (puts, rowKeys) = events.map { event => HBEventsUtil.eventToPut(event, appId) }.unzip - table.put(puts) - table.flushCommits() - table.close() - rowKeys.map(_.toString) + blocking { + val table = getTable(appId, channelId) + val (puts, rowKeys) = events.map { event => HBEventsUtil.eventToPut(event, appId) }.unzip + table.put(puts) + table.flushCommits() + table.close() + rowKeys.map(_.toString) + } } } @@ -128,18 +133,20 @@ class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace eventId: String, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Option[Event]] = { Future { - val table = getTable(appId, channelId) - val rowKey = RowKey(eventId) - val get = new Get(rowKey.toBytes) - - val result = table.get(get) - table.close() - - if (!result.isEmpty()) { - val event = resultToEvent(result, appId) - Some(event) - } else { - None + blocking { + val table = getTable(appId, channelId) + val rowKey = RowKey(eventId) + val get = new Get(rowKey.toBytes) + + val result = table.get(get) + table.close() + + if (!result.isEmpty()) { + val event = resultToEvent(result, appId) + Some(event) + } else { + None + } } } } @@ -149,12 +156,14 @@ class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace eventId: String, appId: Int, channelId: Option[Int])(implicit ec: ExecutionContext): Future[Boolean] = { Future { - val table = getTable(appId, channelId) - val rowKey = RowKey(eventId) - val exists = table.exists(new Get(rowKey.toBytes)) - table.delete(new Delete(rowKey.toBytes)) - table.close() - exists + blocking { + val table = getTable(appId, channelId) + val rowKey = RowKey(eventId) + val exists = table.exists(new Get(rowKey.toBytes)) + table.delete(new Delete(rowKey.toBytes)) + table.close() + exists + } } } @@ -173,36 +182,39 @@ class HBLEvents(val client: HBClient, config: StorageClientConfig, val namespace reversed: Option[Boolean] = None)(implicit ec: ExecutionContext): Future[Iterator[Event]] = { Future { - - require(!((reversed == Some(true)) && (entityType.isEmpty || entityId.isEmpty)), - "the parameter reversed can only be used with both entityType and entityId specified.") - - val table = getTable(appId, channelId) - - val scan = HBEventsUtil.createScan( - startTime = startTime, - untilTime = untilTime, - entityType = entityType, - entityId = entityId, - eventNames = eventNames, - targetEntityType = targetEntityType, - targetEntityId = targetEntityId, - reversed = reversed) - val scanner = table.getScanner(scan) - table.close() - - val eventsIter = scanner.iterator() - - // Get all events if None or Some(-1) - val results: Iterator[Result] = limit match { - case Some(-1) => eventsIter - case None => eventsIter - case Some(x) => eventsIter.take(x) + blocking { + require(!((reversed == Some(true)) && (entityType.isEmpty || entityId.isEmpty)), + "the parameter reversed can only be used with both entityType and entityId specified.") + + val table = getTable(appId, channelId) + + val scan = HBEventsUtil.createScan( + startTime = startTime, + untilTime = untilTime, + entityType = entityType, + entityId = entityId, + eventNames = eventNames, + targetEntityType = targetEntityType, + targetEntityId = targetEntityId, + reversed = reversed) + val scanner = table.getScanner(scan) + table.close() + + val eventsIter = scanner.iterator() + + // Get all events if None or Some(-1) + val results: Iterator[Result] = limit match { + case Some(-1) => eventsIter + case None => eventsIter + case Some(x) => eventsIter.take(x) + } + + val eventsIt = results.map { + resultToEvent(_, appId) + } + + eventsIt } - - val eventsIt = results.map { resultToEvent(_, appId) } - - eventsIt } } diff --git a/storage/jdbc/src/main/scala/org/apache/predictionio/data/storage/jdbc/JDBCLEvents.scala b/storage/jdbc/src/main/scala/org/apache/predictionio/data/storage/jdbc/JDBCLEvents.scala index b4230ccd11..0ffa18786d 100644 --- a/storage/jdbc/src/main/scala/org/apache/predictionio/data/storage/jdbc/JDBCLEvents.scala +++ b/storage/jdbc/src/main/scala/org/apache/predictionio/data/storage/jdbc/JDBCLEvents.scala @@ -32,6 +32,7 @@ import scalikejdbc._ import scala.concurrent.ExecutionContext import scala.concurrent.Future +import scala.concurrent.blocking /** JDBC implementation of [[LEvents]] */ class JDBCLEvents( @@ -103,10 +104,11 @@ class JDBCLEvents( override def futureInsert(event: Event, appId: Int, channelId: Option[Int])( implicit ec: ExecutionContext): Future[String] = Future { - DB localTx { implicit session => - val id = event.eventId.getOrElse(JDBCUtils.generateId) - val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) - sql""" + blocking { + DB localTx { implicit session => + val id = event.eventId.getOrElse(JDBCUtils.generateId) + val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) + sql""" insert into $tableName values( $id, ${event.event}, @@ -123,34 +125,36 @@ class JDBCLEvents( ${event.creationTime.getZone.getID} ) """.update().apply() - id + id + } } } override def futureInsertBatch(events: Seq[Event], appId: Int, channelId: Option[Int])( implicit ec: ExecutionContext): Future[Seq[String]] = Future { - DB localTx { implicit session => - val ids = events.map(_.eventId.getOrElse(JDBCUtils.generateId)) - val params = events.zip(ids).map { case (event, id) => - Seq( - 'id -> id, - 'event -> event.event, - 'entityType -> event.entityType, - 'entityId -> event.entityId, - 'targetEntityType -> event.targetEntityType, - 'targetEntityId -> event.targetEntityId, - 'properties -> write(event.properties.toJObject), - 'eventTime -> event.eventTime, - 'eventTimeZone -> event.eventTime.getZone.getID, - 'tags -> (if(event.tags.nonEmpty) Some(event.tags.mkString(",")) else None), - 'prId -> event.prId, - 'creationTime -> event.creationTime, - 'creationTimeZone -> event.creationTime.getZone.getID - ) - } + blocking { + DB localTx { implicit session => + val ids = events.map(_.eventId.getOrElse(JDBCUtils.generateId)) + val params = events.zip(ids).map { case (event, id) => + Seq( + 'id -> id, + 'event -> event.event, + 'entityType -> event.entityType, + 'entityId -> event.entityId, + 'targetEntityType -> event.targetEntityType, + 'targetEntityId -> event.targetEntityId, + 'properties -> write(event.properties.toJObject), + 'eventTime -> event.eventTime, + 'eventTimeZone -> event.eventTime.getZone.getID, + 'tags -> (if (event.tags.nonEmpty) Some(event.tags.mkString(",")) else None), + 'prId -> event.prId, + 'creationTime -> event.creationTime, + 'creationTimeZone -> event.creationTime.getZone.getID + ) + } - val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) - sql""" + val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) + sql""" insert into $tableName values( {id}, {event}, @@ -168,15 +172,17 @@ class JDBCLEvents( ) """.batchByName(params: _*).apply() - ids + ids + } } } override def futureGet(eventId: String, appId: Int, channelId: Option[Int])( implicit ec: ExecutionContext): Future[Option[Event]] = Future { - DB readOnly { implicit session => - val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) - sql""" + blocking { + DB readOnly { implicit session => + val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) + sql""" select id, event, @@ -194,17 +200,20 @@ class JDBCLEvents( from $tableName where id = $eventId """.map(resultToEvent).single().apply() + } } } override def futureDelete(eventId: String, appId: Int, channelId: Option[Int])( implicit ec: ExecutionContext): Future[Boolean] = Future { - DB localTx { implicit session => - val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) - sql""" + blocking { + DB localTx { implicit session => + val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) + sql""" delete from $tableName where id = $eventId """.update().apply() - true + true + } } } @@ -221,30 +230,32 @@ class JDBCLEvents( limit: Option[Int] = None, reversed: Option[Boolean] = None )(implicit ec: ExecutionContext): Future[Iterator[Event]] = Future { - DB readOnly { implicit session => - val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) - val whereClause = sqls.toAndConditionOpt( - startTime.map(x => sqls"eventTime >= $x"), - untilTime.map(x => sqls"eventTime < $x"), - entityType.map(x => sqls"entityType = $x"), - entityId.map(x => sqls"entityId = $x"), - eventNames.map(x => - sqls.toOrConditionOpt(x.map(y => - Some(sqls"event = $y") - ): _*) - ).getOrElse(None), - targetEntityType.map(x => x.map(y => sqls"targetEntityType = $y") + blocking { + DB readOnly { implicit session => + val tableName = sqls.createUnsafely(JDBCUtils.eventTableName(namespace, appId, channelId)) + val whereClause = sqls.toAndConditionOpt( + startTime.map(x => sqls"eventTime >= $x"), + untilTime.map(x => sqls"eventTime < $x"), + entityType.map(x => sqls"entityType = $x"), + entityId.map(x => sqls"entityId = $x"), + eventNames.map(x => + sqls.toOrConditionOpt(x.map(y => + Some(sqls"event = $y") + ): _*) + ).getOrElse(None), + targetEntityType.map(x => x.map(y => sqls"targetEntityType = $y") .getOrElse(sqls"targetEntityType IS NULL")), - targetEntityId.map(x => x.map(y => sqls"targetEntityId = $y") + targetEntityId.map(x => x.map(y => sqls"targetEntityId = $y") .getOrElse(sqls"targetEntityId IS NULL")) - ).map(sqls.where(_)).getOrElse(sqls"") - val orderByClause = reversed.map(x => - if (x) sqls"eventTime desc" else sqls"eventTime asc" - ).getOrElse(sqls"eventTime asc") - val limitClause = limit.map(x => - if (x < 0) sqls"" else sqls.limit(x) - ).getOrElse(sqls"") - val q = sql""" + ).map(sqls.where(_)).getOrElse(sqls"") + val orderByClause = reversed.map(x => + if (x) sqls"eventTime desc" else sqls"eventTime asc" + ).getOrElse(sqls"eventTime asc") + val limitClause = limit.map(x => + if (x < 0) sqls"" else sqls.limit(x) + ).getOrElse(sqls"") + val q = + sql""" select id, event, @@ -264,7 +275,8 @@ class JDBCLEvents( order by $orderByClause $limitClause """ - q.map(resultToEvent).list().apply().toIterator + q.map(resultToEvent).list().apply().toIterator + } } }