Permalink
Browse files

Update to spark 2.4.0

  • Loading branch information...
mhamilton723 committed Nov 9, 2018
1 parent e5017db commit d87516ea4b0e4cf823bf8b8f5f69a09e83da432c
@@ -107,14 +107,20 @@ private[ml] object ComplexParamsWriter {
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq
.filter(!_.param.isInstanceOf[ComplexParam[_]]).asInstanceOf[Seq[ParamPair[Any]]]
val defaultParams = instance.defaultParamMap.toSeq
.filter(!_.param.isInstanceOf[ComplexParam[_]]).asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList)
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
("paramMap" -> jsonParams) ~
("defaultParamMap" -> jsonDefaultParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
@@ -141,7 +147,7 @@ private[ml] class ComplexParamsReader[T] extends MLReader[T] {
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultParamsReader.getAndSetParams(instance, metadata)
metadata.getAndSetParams(instance)
ComplexParamsReader.getAndSetComplexParams(instance, metadata, path)
instance.asInstanceOf[T]
}
@@ -20,15 +20,17 @@ import org.apache.http.impl.client.HttpClientBuilder
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.HTTPServerUtils
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.{DataSourceRegister, v2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.json4s.DefaultFormats
import org.json4s.jackson.Serialization
@@ -39,7 +41,7 @@ import scala.util.Try
object HTTPSourceStateHolder {
val factories: mutable.Map[(String, Int), HTTPContinuousDataReader] = mutable.Map()
val factories: mutable.Map[(String, Int), HTTPContinuousInputPartitionReader] = mutable.Map()
val serviceInformation: mutable.Map[String, ParHashSet[ServiceInfo]] = mutable.Map()
@@ -199,7 +201,7 @@ class HTTPContinuousReader(options: DataSourceOptions)
override def getStartOffset: Offset = offset
override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = {
val partitionStartMap = offset match {
case off: HTTPOffset => off.partitionToValue
case off =>
@@ -215,11 +217,10 @@ class HTTPContinuousReader(options: DataSourceOptions)
Range(0, numPartitions).map { i =>
val start = partitionStartMap(i)
HTTPContinuousDataReaderFactory(
HTTPContinuousInputPartition(
host, port, name, start, i, forwardingOptions,
DriverServiceUtils.getDriverHost, driverService.getAddress.getPort
)
.asInstanceOf[DataReaderFactory[Row]]
).asInstanceOf[InputPartition[InternalRow]]
}.asJava
}
@@ -239,7 +240,7 @@ case class HTTPReaderInfo(host: String,
startValue: Long,
partitionIndex: Int)
case class HTTPContinuousDataReaderFactory(host: String,
case class HTTPContinuousInputPartition(host: String,
port: Int,
name: String,
startValue: Long,
@@ -249,15 +250,27 @@ case class HTTPContinuousDataReaderFactory(host: String,
driverServicePort: Int
)
extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] =
new HTTPContinuousDataReader(
extends ContinuousInputPartition[InternalRow] {
override def createContinuousReader(
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
//require(HTTPOffset.partition == partitionIndex,
// s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
new HTTPContinuousInputPartitionReader(
host, port, name, startValue, partitionIndex, forwardingOptions,
driverServiceHost, driverServicePort
)
}
override def createPartitionReader(): InputPartitionReader[InternalRow] =
new HTTPContinuousInputPartitionReader(
host, port, name, startValue, partitionIndex, forwardingOptions,
driverServiceHost, driverServicePort
)
}
class HTTPContinuousDataReader(host: String,
class HTTPContinuousInputPartitionReader(host: String,
port: Int,
name: String,
startValue: Long,
@@ -266,7 +279,7 @@ class HTTPContinuousDataReader(host: String,
driverServiceHost: String,
driverServicePort: Int)
extends ContinuousDataReader[Row] {
extends ContinuousInputPartitionReader[InternalRow] {
HTTPSourceStateHolder.factories.update((name, partitionIndex), this)
@@ -362,20 +375,21 @@ class HTTPContinuousDataReader(host: String,
private val routingTable: ParHashMap[String, HttpExchange] = ParHashMap()
private var currentValue = startValue
private var currentRow: Row = _
private var currentRow: InternalRow = _
private val requestDataToRow = HTTPRequestData.makeToRowConverter
private val requestDataToRow = HTTPRequestData.makeToInternalRowConverter
override def next(): Boolean = {
currentValue += 1
val request = requests.take()
val id = UUID.randomUUID().toString
routingTable.put(id, request)
currentRow = Row(Row(id, partitionIndex), requestDataToRow(HTTPRequestData.fromHTTPExchange(request)))
currentRow = InternalRow(InternalRow(
UTF8String.fromString(id), partitionIndex), requestDataToRow(HTTPRequestData.fromHTTPExchange(request)))
true
}
override def get: Row = currentRow
override def get: InternalRow = currentRow
override def close(): Unit = {
server.stop(0)
@@ -415,33 +429,34 @@ class HTTPWriter(schema: StructType, options: DataSourceOptions)
val replyColIndex: Int = schema.fieldIndex(replyCol)
assert(SparkSession.getActiveSession.isDefined)
def createWriterFactory(): DataWriterFactory[Row] = HTTPWriterFactory(idColIndex, replyColIndex, name)
def createWriterFactory(): DataWriterFactory[InternalRow] = HTTPWriterFactory(idColIndex, replyColIndex, name)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}
case class HTTPWriterFactory(idColIndex: Int, replyColIndex: Int, name: String) extends DataWriterFactory[Row] {
def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
case class HTTPWriterFactory(idColIndex: Int, replyColIndex: Int, name: String)
extends DataWriterFactory[InternalRow] {
def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
new HTTPDataWriter(partitionId, idColIndex, replyColIndex, name)
}
}
class HTTPDataWriter(partitionId: Int, val idColIndex: Int,
val replyColIndex: Int, val name: String)
extends DataWriter[Row] with Logging {
extends DataWriter[InternalRow] with Logging {
var ids: mutable.ListBuffer[(String, Int)] = new mutable.ListBuffer[(String, Int)]()
val fromRow = HTTPResponseData.makeFromRowConverter
val fromRow = HTTPResponseData.makeFromInternalRowConverter
override def write(row: Row): Unit = {
val id = row.getStruct(idColIndex)
override def write(row: InternalRow): Unit = {
val id = row.getStruct(idColIndex,2)
val rid = id.getString(0)
val pid = id.getInt(1)
val reply = fromRow(row.getStruct(replyColIndex))
val reply = fromRow(row.getStruct(replyColIndex, 4))
HTTPSourceStateHolder.factories((name, pid)).replyTo(rid, reply)
ids.append((rid, pid))
}
@@ -101,7 +101,7 @@ class LightGBMRegressionModel(override val uid: String, model: LightGBMBooster,
set(featuresCol, featuresColName)
set(predictionCol, predictionColName)
override protected def predict(features: Vector): Double = {
override def predict(features: Vector): Double = {
model.score(features, raw = false)
}
@@ -7,7 +7,6 @@ import org.apache.spark.SparkContext
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ParamMap, UDFParam}
import org.apache.spark.ml.util.{ComplexParamsReadable, ComplexParamsWritable, Identifiable}
import org.apache.spark.sql.execution.python.{PythonUDF, UserDefinedPythonFunction}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -35,6 +35,8 @@ object Extras {
"com.microsoft.cntk" % "cntk" % cntkVer,
"org.openpnp" % "opencv" % "3.2.0-1",
"com.jcraft" % "jsch" % "0.1.54",
"com.jcraft" % "jsch" % "0.1.54",
"org.apache.httpcomponents" % "httpclient" % "4.5.6",
"com.microsoft.ml.lightgbm" % "lightgbmlib" % "2.1.250"
// needed for wasb access, but it collides with the version that comes with Spark,
// so it gets installed manually for now (see "tools/config.sh")
Oops, something went wrong.

0 comments on commit d87516e

Please sign in to comment.