Permalink
Browse files

Add Azure Search Sink

  • Loading branch information...
caseyhong authored and mhamilton723 committed Dec 3, 2018
1 parent a3ab017 commit 833024e90826b8334c3f40e8297d66a369f04d8a
@@ -122,8 +122,7 @@ trait HasMiniBatcher extends Params {

object FixedMiniBatchTransformer extends DefaultParamsReadable[FixedMiniBatchTransformer]

class FixedMiniBatchTransformer(val uid: String)
extends MiniBatchBase {
trait HasBatchSize extends Params {

val batchSize: Param[Int] = new IntParam(
this, "batchSize", "The max size of the buffer")
@@ -134,6 +133,11 @@ class FixedMiniBatchTransformer(val uid: String)
/** @group setParam */
def setBatchSize(value: Int): this.type = set(batchSize, value)

}

class FixedMiniBatchTransformer(val uid: String)
extends MiniBatchBase with HasBatchSize {

val maxBufferSize: Param[Int] = new IntParam(
this, "maxBufferSize", "The max size of the buffer")

@@ -0,0 +1,54 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.cognitive

import com.microsoft.ml.spark.schema.SparkBindings
import spray.json.DefaultJsonProtocol._

object ASResponses extends SparkBindings[ASResponses]

case class ASResponses(value: Seq[ASResponse])

case class ASResponse(key: String, status: Boolean, errorMessage: Option[String], statusCode: Int)

case class IndexInfo(
name: Option[String],
fields: Array[IndexField],
suggesters: Option[Array[String]],
scoringProfiles: Option[Array[String]],
analyzers: Option[Array[String]],
charFilters: Option[Array[String]],
tokenizers: Option[Array[String]],
tokenFilters: Option[Array[String]],
defaultScoringProfile: Option[Array[String]],
corsOptions: Option[Array[String]]
)

case class IndexField(
name: String,
`type`: String,
searchable: Option[Boolean],
filterable: Option[Boolean],
sortable: Option[Boolean],
facetable: Option[Boolean],
key: Option[Boolean],
retrievable: Option[Boolean],
analyzer: Option[String],
searchAnalyzer: Option[String],
indexAnalyzer: Option[String],
synonymMap: Option[String]
)

case class IndexStats(documentCount: Int, storageSize: Int)

case class IndexList(`@odata.context`: String, value: Seq[IndexName])
case class IndexName(name: String)

object AzureSearchProtocol {
implicit val ifEnc = jsonFormat12(IndexField.apply)
implicit val iiEnc = jsonFormat10(IndexInfo.apply)
implicit val isEnc = jsonFormat2(IndexStats.apply)
implicit val inEnc = jsonFormat1(IndexName.apply)
implicit val ilEnc = jsonFormat2(IndexList.apply)
}
@@ -0,0 +1,26 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import spray.json._

case class IndexSchema(name: String, fields: Seq[Field])

case class Field(name: String,
`type`: String,
searchable: Boolean,
filterable: Boolean,
sortable: Boolean,
facetable: Boolean,
key: Boolean,
retrievable: Boolean,
analyzer: Option[String],
searchAnalyzer: Option[String],
indexAnalyzer: Option[String],
synonymMaps: Option[String])

object IndexJsonProtocol extends DefaultJsonProtocol {
implicit val fieldFormat = jsonFormat12(Field)
implicit val indexFormat = jsonFormat2(IndexSchema)
}
@@ -0,0 +1,220 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import com.microsoft.ml.spark.cognitive._
import org.apache.http.entity.{AbstractHttpEntity, StringEntity}
import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{NamespaceInjections, PipelineModel}
import org.apache.spark.sql.functions.{array, col, struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}

import scala.collection.JavaConverters._
import scala.util.{Try, Success, Failure}

object AddDocuments extends ComplexParamsReadable[AddDocuments] with Serializable

trait HasActionCol extends HasServiceParams {

val actionCol = new Param[String](this, "actionCol",
s"""
|You can combine actions, such as an upload and a delete, in the same batch.
|
|upload: An upload action is similar to an 'upsert'
|where the document will be inserted if it is new and updated/replaced
|if it exists. Note that all fields are replaced in the update case.
|
|merge: Merge updates an existing document with the specified fields.
|If the document doesn't exist, the merge will fail. Any field
|you specify in a merge will replace the existing field in the document.
|This includes fields of type Collection(Edm.String). For example, if
|the document contains a field 'tags' with value ['budget'] and you execute
|a merge with value ['economy', 'pool'] for 'tags', the final value
|of the 'tags' field will be ['economy', 'pool'].
| It will not be ['budget', 'economy', 'pool'].
|
|mergeOrUpload: This action behaves like merge if a document
| with the given key already exists in the index.
| If the document does not exist, it behaves like upload with a new document.
|
|delete: Delete removes the specified document from the index.
| Note that any field you specify in a delete operation,
| other than the key field, will be ignored. If you want to
| remove an individual field from a document, use merge
| instead and simply set the field explicitly to null.
""".stripMargin.replace("\n",""))

def setActionCol(v: String): this.type = set(actionCol, v)

def getActionCol: String = $(actionCol)

}

trait HasIndexName extends HasServiceParams {

val indexName = new Param[String](this, "indexName", "")

def setIndexName(v: String): this.type = set(indexName, v)

def getIndexName: String = $(indexName)

}

trait HasServiceName extends HasServiceParams {

val serviceName = new Param[String](this, "serviceName", "")

def setServiceName(v: String): this.type = set(serviceName, v)

def getServiceName: String = $(serviceName)

}

class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser
with HasActionCol with HasServiceName with HasIndexName with HasBatchSize {

def this() = this(Identifiable.randomUID("AddDocuments"))

setDefault(actionCol -> "@search.action")

override val subscriptionKeyHeaderName = "api-key"

setDefault(batchSize->100)

override protected def getInternalTransformer(schema: StructType): PipelineModel = {
val stages = Array(
Lambda(df =>
df.withColumnRenamed(getActionCol, "@search.action")
.select(struct("*").alias("arr"))
),
new FixedMiniBatchTransformer().setBuffered(false).setBatchSize(getBatchSize),
Lambda(df =>
df.select(struct(
to_json(struct(col("arr").alias("value")))
).alias("input"))
),
new SimpleHTTPTransformer()
.setInputCol("input")
.setOutputCol(getOutputCol)
.setInputParser(getInternalInputParser(schema))
.setOutputParser(getInternalOutputParser(schema))
.setHandler(handlingFunc)
.setConcurrency(getConcurrency)
.setConcurrentTimeout(getConcurrentTimeout)
.setErrorCol(getErrorCol)
)

NamespaceInjections.pipelineModel(stages)
}

override def transform(dataset: Dataset[_]): DataFrame = {
if (get(url).isEmpty) {
setUrl(s"https://$getServiceName.search.windows.net" +
s"/indexes/$getIndexName/docs/index?api-version=2017-11-11")
}
super.transform(dataset)
}

override def prepareEntity: Row => Option[AbstractHttpEntity] = { row =>
Some(new StringEntity(row.getString(0)))
}

override def responseDataType: DataType = ASResponses.schema
}

private[ml] class StreamMaterializer2 extends ForeachWriter[Row] {

override def open(partitionId: Long, version: Long): Boolean = true

override def process(value: Row): Unit = ()

override def close(errorOrNull: Throwable): Unit = ()

}

object AzureSearchWriter extends IndexParser {

val logger: Logger = LogManager.getRootLogger

private def prepareDF(df: DataFrame, options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize"
)

options.keys.foreach(k =>
assert(applicableOptions(k), s"$k not an applicable option ${applicableOptions.toList}"))

val subscriptionKey = options("subscriptionKey")
val actionCol = options.getOrElse("actionCol", "@search.action")
val serviceName = options("serviceName")
val indexJson = options("indexJson")
val apiVersion = options.getOrElse("apiVersion", "2017-11-11")
val indexName = parseIndexJson(indexJson).name.get
val batchSize = options.getOrElse("batchSize", "100").toInt

SearchIndex.createIfNoneExists(subscriptionKey,serviceName, indexJson, apiVersion)

checkSchemaParity(df, indexJson) match {
case Success(_) => ()
case Failure(e) =>
println("Exception: Schema mismatch found in dataframe and json")
throw e
}

new AddDocuments()
.setSubscriptionKey(subscriptionKey)
.setServiceName(serviceName)
.setIndexName(indexName)
.setActionCol(actionCol)
.setBatchSize(batchSize)
.transform(df)
}

private def checkSchemaParity(df: DataFrame, indexJson: String): Try[Boolean] = {
val edmTypes = Map("Edm.String" -> "string",
"Collection(Edm.String)" -> "array<string>",
"Edm.Boolean" -> "boolean",
"Edm.Int64" -> "bigint",
"Edm.Int32" -> "int",
"Edm.Double" -> "double",
"Edm.DateTimeOffset" -> "string",
"Edm.GeographyPoint" -> "string")
val fieldNames = parseIndexJson(indexJson).fields.map(f => f.name).toList.map(n => n.toString)
val fieldTypes = parseIndexJson(indexJson).fields.map(f => f.`type`).toList.map(t => edmTypes(t.toString))

// drop the first comparison element because the first field in the dataframe corresponds to the search action
val isValid = df.schema.toList.drop(1).map(field =>
fieldNames.contains(field.name) && fieldTypes.contains(field.dataType.simpleString)
)

val result = isValid.forall(x => x)

if (result) {Success(result)}
else {Failure(new IllegalArgumentException)}
}

def stream(df: DataFrame, options: Map[String, String] = Map()): DataStreamWriter[Row] = {
prepareDF(df, options).writeStream.foreach(new StreamMaterializer2)
}

def write(df: DataFrame, options: Map[String, String] = Map()): Unit = {
prepareDF(df, options).foreachPartition(it => it.foreach(_ => ()))
}

def stream(df: DataFrame, options: java.util.HashMap[String, String]): DataStreamWriter[Row] = {
stream(df, options.asScala.toMap)
}

def write(df: DataFrame,
options: java.util.HashMap[String, String]): Unit = {
write(df, options.asScala.toMap)
}

}
Oops, something went wrong.

0 comments on commit 833024e

Please sign in to comment.