From 2267d9d210b9a1f612ca31c92ab24e0c23a735de Mon Sep 17 00:00:00 2001 From: Chris Wewerka Date: Wed, 24 Oct 2018 11:52:40 +0200 Subject: [PATCH] support for Future-based predict-method, see https://github.com/apache/predictionio/pull/495 --- build.sbt | 2 +- src/main/scala/EsClient.scala | 30 ++-- src/main/scala/ScalaRestClient.scala | 23 +++ src/main/scala/URAlgorithm.scala | 238 ++++++++++++++------------- 4 files changed, 167 insertions(+), 126 deletions(-) create mode 100644 src/main/scala/ScalaRestClient.scala diff --git a/build.sbt b/build.sbt index d03457a..9ae4e51 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ scalaVersion in ThisBuild := "2.11.11" val mahoutVersion = "0.13.0" -val pioVersion = "0.12.0-incubating" +val pioVersion = "0.14.0-SNAPSHOT" val elasticsearchVersion = "5.5.2" diff --git a/src/main/scala/EsClient.scala b/src/main/scala/EsClient.scala index 8040153..bd4c3c1 100644 --- a/src/main/scala/EsClient.scala +++ b/src/main/scala/EsClient.scala @@ -22,13 +22,13 @@ import java.util import grizzled.slf4j.Logger import org.apache.http.util.EntityUtils -import org.apache.predictionio.data.storage.{ DataMap, Storage, StorageClientConfig } +import org.apache.predictionio.data.storage.{DataMap, Storage, StorageClientConfig} import org.apache.predictionio.workflow.CleanupFunctions import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.elasticsearch.client.RestClient import org.apache.http.HttpHost -import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials } +import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials} import org.apache.http.entity.ContentType import org.apache.http.entity.StringEntity import org.apache.http.impl.client.BasicCredentialsProvider @@ -42,6 +42,9 @@ import org.elasticsearch.spark._ import org.json4s.JValue import org.json4s.DefaultFormats import org.json4s.JsonAST.JString +import ScalaRestClient.ExtendedScalaRestClient + +import scala.concurrent.{ExecutionContext, Future} // import org.json4s.native.Serialization.writePretty import com.actionml.helpers.{ ItemID, ItemProps } @@ -367,20 +370,23 @@ object EsClient { * @param indexName the index to search * @return a [PredictedResults] collection */ - def search(query: String, indexName: String): Option[JValue] = { + def search(query: String, indexName: String)(implicit ec: ExecutionContext): Future[Option[JValue]] = { logger.info(s"Query:\n${query}") - val response = client.performRequest( + val responseFuture = client.performRequestFuture( "POST", s"/$indexName/_search", - Map.empty[String, String].asJava, + Map.empty[String, String], new StringEntity(query, ContentType.APPLICATION_JSON)) - response.getStatusLine.getStatusCode match { - case 200 => - logger.info(s"Got source from query: ${query}") - Some(parse(EntityUtils.toString(response.getEntity))) - case _ => - logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}") - None + responseFuture.map { + response => + response.getStatusLine.getStatusCode match { + case 200 => + logger.info(s"Got source from query: ${query}") + Some(parse(EntityUtils.toString(response.getEntity))) + case _ => + logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}") + None + } } } diff --git a/src/main/scala/ScalaRestClient.scala b/src/main/scala/ScalaRestClient.scala new file mode 100644 index 0000000..cad97bb --- /dev/null +++ b/src/main/scala/ScalaRestClient.scala @@ -0,0 +1,23 @@ +package com.actionml + +import org.apache.http.{Header, HttpEntity} +import org.elasticsearch.client.{Response, ResponseListener, RestClient} +import scala.collection.JavaConverters._ +import scala.concurrent.{Future, Promise} + +object ScalaRestClient { + + implicit class ExtendedScalaRestClient(restClient: RestClient) { + + def performRequestFuture(method: String, endpoint: String, params: Map[String, String], + entity: HttpEntity, headers: Header*): Future[Response] = { + val promise: Promise[Response] = Promise() + val responseListener = new ResponseListener { + override def onSuccess(response: Response): Unit = promise.success(response) + override def onFailure(exception: Exception): Unit = promise.failure(exception) + } + restClient.performRequestAsync(method, endpoint, params.asJava, entity, responseListener, headers: _*) + promise.future + } + } +} diff --git a/src/main/scala/URAlgorithm.scala b/src/main/scala/URAlgorithm.scala index cd58cd7..c92c4fc 100644 --- a/src/main/scala/URAlgorithm.scala +++ b/src/main/scala/URAlgorithm.scala @@ -20,10 +20,10 @@ package com.actionml import java.util import grizzled.slf4j.Logger -import org.apache.predictionio.controller.{ P2LAlgorithm, Params } -import org.apache.predictionio.data.storage.{ DataMap, Event, NullModel, PropertyMap } +import org.apache.predictionio.controller.{P2LAlgorithm, Params} +import org.apache.predictionio.data.storage.{DataMap, Event, NullModel, PropertyMap} import org.apache.predictionio.data.store.LEventStore -import org.apache.mahout.math.cf.{ DownsamplableCrossOccurrenceDataset, SimilarityAnalysis } +import org.apache.mahout.math.cf.{DownsamplableCrossOccurrenceDataset, SimilarityAnalysis} import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -34,10 +34,12 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import com.actionml.helpers._ - +import scala.concurrent.ExecutionContext import scala.collection.JavaConverters._ +import scala.concurrent.Future import scala.concurrent.duration.Duration -import scala.language.{ implicitConversions, postfixOps } +import scala.language.{implicitConversions, postfixOps} +import ScalaRestClient.ExtendedScalaRestClient /** Available value for algorithm param "RecsModel" */ object RecsModels { // todo: replace this with rankings @@ -481,51 +483,58 @@ class URAlgorithm(val ap: URAlgorithmParams) * @todo Need to prune that query to minimum required for data include, for instance no need for the popularity * ranking if no PopModel is being used, same for "must" clause and dates. */ - def predict(model: NullModel, query: Query): PredictedResult = { + override def predictAsync(model: NullModel, query: Query)(implicit ec: ExecutionContext): Future[PredictedResult] = { queryEventNames = query.eventNames.getOrElse(modelEventNames) // eventNames in query take precedence - val (queryStr, blacklist) = buildQuery(ap, query, rankingFieldNames) - // old es1 query - // val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames) - val searchHitsOpt = EsClient.search(queryStr, esIndex) - - val withRanks = query.withRanks.getOrElse(false) - val predictedResults = searchHitsOpt match { - case Some(searchHits) => - val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]] - val recs = hits.map { hit => - if (withRanks) { - val source = hit \ "source" - val ranks: Map[String, Double] = rankingsParams map { backfillParams => - val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType) - val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType)) - backfillFieldName -> (source \ backfillFieldName).extract[Double] - } toMap - - ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double], - ranks = if (ranks.nonEmpty) Some(ranks) else None) - } else { - ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double]) - } - }.toArray - logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}") - PredictedResult(recs) - - case _ => - logger.info(s"No results for query ${parse(queryStr)}") - PredictedResult(Array.empty[ItemScore]) + val queryStrBlacklistFuture = buildQuery(ap, query, rankingFieldNames) + + queryStrBlacklistFuture.flatMap { + case (queryStr, blacklist) => + // old es1 query + // val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames) + val searchHitsOptFuture = EsClient.search(queryStr, esIndex) + + val withRanks = query.withRanks.getOrElse(false) + searchHitsOptFuture.map { + searchHitsOpt => + val predictedResults = searchHitsOpt match { + case Some(searchHits) => + val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]] + val recs = hits.map { hit => + if (withRanks) { + val source = hit \ "source" + val ranks: Map[String, Double] = rankingsParams map { backfillParams => + val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType) + val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType)) + backfillFieldName -> (source \ backfillFieldName).extract[Double] + } toMap + + ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double], + ranks = if (ranks.nonEmpty) Some(ranks) else None) + } else { + ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double]) + } + }.toArray + logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}") + PredictedResult(recs) + + case _ => + logger.info(s"No results for query ${parse(queryStr)}") + PredictedResult(Array.empty[ItemScore]) + } + + // todo: is this needed to remove ranked items from recs? + //if (recsModel == RecsModels.CF) { + // PredictedResult(predictedResults.filter(_.score != 0.0)) + //} else PredictedResult(predictedResults) + + // should have all blacklisted items excluded + // todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time + // period so the recs ordering stays fixed for that time period. + predictedResults + } } - - // todo: is this needed to remove ranked items from recs? - //if (recsModel == RecsModels.CF) { - // PredictedResult(predictedResults.filter(_.score != 0.0)) - //} else PredictedResult(predictedResults) - - // should have all blacklisted items excluded - // todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time - // period so the recs ordering stays fixed for that time period. - predictedResults } /** Calculate all fields and items needed for ranking. @@ -563,56 +572,60 @@ class URAlgorithm(val ap: URAlgorithmParams) def buildQuery( ap: URAlgorithmParams, query: Query, - backfillFieldNames: Seq[String] = Seq.empty): (String, Seq[Event]) = { + backfillFieldNames: Seq[String] = Seq.empty)(implicit ec: ExecutionContext): Future[(String, Seq[Event])] = { logger.info(s"Got query: \n${query}") val startPos = query.from.getOrElse(0) logger.info(s"from: ${startPos}") - try { - // create a list of all query correlators that can have a bias (boost or filter) attached - val (boostable, events) = getBiasedRecentUserActions(query) - logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}") - - // since users have action history and items have correlators and both correspond to the same "actions" like - // purchase or view, we'll pass both to the query if the user history or items correlators are empty - // then metadata or backfill must be relied on to return results. - val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config - logger.info(s"UR query num = ${query.num}") - logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}") - - val should = buildQueryShould(query, boostable) - logger.info(s"buildQueryShould returned should: ${should}") - val must = buildQueryMust(query, boostable) - logger.info(s"buildQueryMust returned must: ${must}") - val mustNot = buildQueryMustNot(query, events) - logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}") - val sort = buildQuerySort() - logger.info(s"buildQuerySort returned sort: ${sort}") - - val json = - ("from" -> startPos) ~ - ("size" -> numRecs) ~ - ("query" -> - ("bool" -> - ("should" -> should) ~ - ("must" -> must) ~ - ("must_not" -> mustNot) ~ - ("minimum_should_match" -> 1))) ~ - ("sort" -> sort) - - logger.info(s"json is: ${json}") - val compactJson = compact(render(json)) - logger.info(s"compact json is: ${compactJson}") - - //logger.info(s"Query:\n$compactJson") - (compactJson, events) - } catch { - case e: IllegalArgumentException => { - logger.warn("whoops, IllegalArgumentException for something in buildQuery.") - ("", Seq.empty[Event]) - } + // create a list of all query correlators that can have a bias (boost or filter) attached + val biasedRecentUserActionsFuture = getBiasedRecentUserActions(query) + + biasedRecentUserActionsFuture.map { + case (boostable, events) => + try { + logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}") + + // since users have action history and items have correlators and both correspond to the same "actions" like + // purchase or view, we'll pass both to the query if the user history or items correlators are empty + // then metadata or backfill must be relied on to return results. + val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config + logger.info(s"UR query num = ${query.num}") + logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}") + + val should = buildQueryShould(query, boostable) + logger.info(s"buildQueryShould returned should: ${should}") + val must = buildQueryMust(query, boostable) + logger.info(s"buildQueryMust returned must: ${must}") + val mustNot = buildQueryMustNot(query, events) + logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}") + val sort = buildQuerySort() + logger.info(s"buildQuerySort returned sort: ${sort}") + + val json = + ("from" -> startPos) ~ + ("size" -> numRecs) ~ + ("query" -> + ("bool" -> + ("should" -> should) ~ + ("must" -> must) ~ + ("must_not" -> mustNot) ~ + ("minimum_should_match" -> 1))) ~ + ("sort" -> sort) + + logger.info(s"json is: ${json}") + val compactJson = compact(render(json)) + logger.info(s"compact json is: ${compactJson}") + + //logger.info(s"Query:\n$compactJson") + (compactJson, events) + } catch { + case e: IllegalArgumentException => { + logger.warn("whoops, IllegalArgumentException for something in buildQuery.") + ("", Seq.empty[Event]) + } + } } } @@ -792,10 +805,10 @@ class URAlgorithm(val ap: URAlgorithmParams) } /** Get recent events of the user on items to create the recommendations query from */ - def getBiasedRecentUserActions(query: Query): (Seq[BoostableCorrelators], Seq[Event]) = { + def getBiasedRecentUserActions(query: Query)(implicit ec: ExecutionContext): Future[(Seq[BoostableCorrelators], Seq[Event])] = { - val recentEvents = try { - LEventStore.findByEntity( + val recentEventsFuture = + LEventStore.findByEntityAsync( appName = appName, // entityType and entityId is specified for fast lookup entityType = "user", @@ -806,13 +819,9 @@ class URAlgorithm(val ap: URAlgorithmParams) // targetEntityType = None, // limit = Some(maxQueryEvents), // this will get all history then each action can be limited before using in // the query - latest = true, - // set time limit to avoid super long DB access - timeout = Duration(200, "millis")).toSeq - } catch { - case e: scala.concurrent.TimeoutException => - logger.error(s"Timeout when reading recent events. Empty list is used. $e") - Seq.empty[Event] + latest = true).map(_.toSeq) + + val recoveredRecentEventsFuture = recentEventsFuture.recover { case e: NoSuchElementException => logger.info("No user id for recs, returning item-based recs if an item is specified in the query.") Seq.empty[Event] @@ -821,21 +830,24 @@ class URAlgorithm(val ap: URAlgorithmParams) Seq.empty[Event] } - val userEventBias = query.userBias.getOrElse(userBias) - val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None - val rActions = queryEventNames.map { action => - var items = Seq.empty[String] - - for (event <- recentEvents) { // todo: use indidatorParams for each indicator type - if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) { - items = event.targetEntityId.get +: items - // todo: may throw exception and we should ignore the event instead of crashing + recoveredRecentEventsFuture.map { + recentEvents => + val userEventBias = query.userBias.getOrElse(userBias) + val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None + val rActions = queryEventNames.map { action => + var items = Seq.empty[String] + + for (event <- recentEvents) { // todo: use indidatorParams for each indicator type + if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) { + items = event.targetEntityId.get +: items + // todo: may throw exception and we should ignore the event instead of crashing + } + // userBias may be None, which will cause no JSON output for this + } + BoostableCorrelators(action, items.distinct, userEventsBoost) } - // userBias may be None, which will cause no JSON output for this - } - BoostableCorrelators(action, items.distinct, userEventsBoost) + (rActions, recentEvents) } - (rActions, recentEvents) } /** get all metadata fields that potentially have boosts (not filters) */