Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CONNECT][ML][WIP] Spark connect ML for scala client #40479

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
684a9e3
update
WeichenXu123 Mar 3, 2023
97ab924
update
WeichenXu123 Mar 6, 2023
86f2fad
update
WeichenXu123 Mar 6, 2023
4f17d6c
update
WeichenXu123 Mar 6, 2023
6582037
merge master & fix
WeichenXu123 Mar 7, 2023
ba4f580
update
WeichenXu123 Mar 8, 2023
941550e
update
WeichenXu123 Mar 8, 2023
c4473a6
update
WeichenXu123 Mar 9, 2023
606168d
update
WeichenXu123 Mar 9, 2023
1003787
update
WeichenXu123 Mar 9, 2023
f9f3542
update
WeichenXu123 Mar 9, 2023
ed24307
fix
WeichenXu123 Mar 10, 2023
c1f9162
merge master
WeichenXu123 Mar 10, 2023
e178de3
update
WeichenXu123 Mar 10, 2023
130bd1e
update
WeichenXu123 Mar 10, 2023
eee1013
fix
WeichenXu123 Mar 11, 2023
d72fba0
update
WeichenXu123 Mar 11, 2023
36bc69b
update
WeichenXu123 Mar 13, 2023
870c994
merge master
WeichenXu123 Mar 13, 2023
e500be8
update
WeichenXu123 Mar 13, 2023
7c44e5c
update
WeichenXu123 Mar 13, 2023
e87aa53
merge master
WeichenXu123 Mar 13, 2023
18876c2
update
WeichenXu123 Mar 13, 2023
23744b8
model gc
WeichenXu123 Mar 13, 2023
33be464
try_remote_ml_class
WeichenXu123 Mar 13, 2023
bb43f01
format
WeichenXu123 Mar 14, 2023
ec89d40
format
WeichenXu123 Mar 14, 2023
2ebcf45
fix tests
WeichenXu123 Mar 14, 2023
6e12a22
update
WeichenXu123 Mar 14, 2023
9ae327b
doctests
WeichenXu123 Mar 14, 2023
36e6d33
merge master
WeichenXu123 Mar 14, 2023
e5278cb
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 Mar 14, 2023
c80414d
add proto comments
WeichenXu123 Mar 15, 2023
66c472b
address comments
WeichenXu123 Mar 15, 2023
39b2f24
model_ref message
WeichenXu123 Mar 16, 2023
2d4377d
move to ml.connect
WeichenXu123 Mar 16, 2023
e316a82
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 Mar 16, 2023
caebf75
fix
WeichenXu123 Mar 16, 2023
1059624
update
WeichenXu123 Mar 19, 2023
5650fd0
update
WeichenXu123 Mar 19, 2023
d3dc34b
update
WeichenXu123 Mar 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,208 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import org.apache.spark.connect.proto
import org.apache.spark.ml.linalg.{Matrix, Vector, Matrices, Vectors}
import org.apache.spark.ml.param.Params
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.connect.common.{InvalidPlanInput, LiteralValueProtoConverter}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval

import scala.collection.mutable
import scala.reflect.ClassTag

object ConnectUtils {

def getInstanceParamsProto(instance: Params): proto.MlParams = {
val builder = proto.MlParams.newBuilder()

for (param <- instance.params) {
instance.get(param).map { value =>
builder.putParams(
param.name,
LiteralValueProtoConverter.toLiteralProto(value)
)
}
instance.getDefault(param).map { value =>
builder.putParams(
param.name,
LiteralValueProtoConverter.toLiteralProto(value)
)
}
}
builder.build()
}

def serializeResponseValue(data: Any): proto.MlCommandResponse = {
data match {
case v: Vector => serializeVector(v)
case v: Matrix => serializeMatrix(v)
case _: Byte | _: Short | _: Int | _: Long | _: Float | _: Double | _: Boolean | _: String |
_: Array[_] =>
proto.MlCommandResponse
.newBuilder()
.setLiteral(LiteralValueProtoConverter.toLiteralProto(data))
.build()
case _ =>
throw new IllegalArgumentException()
}
}

def serializeVector(data: Vector): proto.MlCommandResponse = {
// TODO: Support sparse
val values = data.toArray
val denseBuilder = proto.Vector.Dense.newBuilder()
for (i <- 0 until values.length) {
denseBuilder.addValue(values(i))
}

proto.MlCommandResponse
.newBuilder()
.setVector(proto.Vector.newBuilder().setDense(denseBuilder))
.build()
}

def deserializeVector(protoValue: proto.Vector): Vector = {
// TODO: Support sparse
Vectors.dense(
protoValue.getDense.getValueList.stream().mapToDouble(_.doubleValue()).toArray
)
}

def deserializeMatrix(protoValue: proto.Matrix): Matrix = {
// TODO: Support sparse
val denseProto = protoValue.getDense
Matrices.dense(
denseProto.getNumRows,
denseProto.getNumCols,
denseProto.getValueList.stream().mapToDouble(_.doubleValue()).toArray
)
}

def serializeMatrix(data: Matrix): proto.MlCommandResponse = {
// TODO: Support sparse
// TODO: optimize transposed case
val denseBuilder = proto.Matrix.Dense.newBuilder()
val values = data.toArray
for (i <- 0 until values.length) {
denseBuilder.addValue(values(i))
}
denseBuilder.setNumCols(data.numCols)
denseBuilder.setNumRows(data.numRows)
denseBuilder.setIsTransposed(false)
proto.MlCommandResponse
.newBuilder()
.setMatrix(proto.Matrix.newBuilder().setDense(denseBuilder))
.build()
}

def deserializeResponseValue(protoValue: proto.MlCommandResponse): Any = {
protoValue.getMlCommandResponseTypeCase match {
case proto.MlCommandResponse.MlCommandResponseTypeCase.LITERAL =>
deserializeLiteral(protoValue.getLiteral)
case proto.MlCommandResponse.MlCommandResponseTypeCase.VECTOR =>
deserializeVector(protoValue.getVector)
case proto.MlCommandResponse.MlCommandResponseTypeCase.MATRIX =>
deserializeMatrix(protoValue.getMatrix)
case proto.MlCommandResponse.MlCommandResponseTypeCase.MODEL_REF =>
ModelRef.fromProto(protoValue.getModelRef)
case _ =>
throw new IllegalArgumentException()
}
}

def deserializeLiteral(protoValue: proto.Expression.Literal): Any = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengruifeng Could you help move this utility function to "common" project ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a util function convert proto.Expression.Literal to Any value?
let me take a look

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protoValue.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
protoValue.getInteger
case proto.Expression.Literal.LiteralTypeCase.LONG =>
protoValue.getLong
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
protoValue.getFloat
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
protoValue.getDouble
case proto.Expression.Literal.LiteralTypeCase.STRING =>
protoValue.getString
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
protoValue.getInteger
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
toArrayData(protoValue.getArray)
case _ =>
throw new IllegalArgumentException()
}
}

private def toArrayData(array: proto.Expression.Literal.Array): Any = {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
val builder = mutable.ArrayBuilder.make[T]
val elementList = array.getElementsList
builder.sizeHint(elementList.size())
val iter = elementList.iterator()
while (iter.hasNext) {
builder += converter(iter.next())
}
builder.result()
}

val elementType = array.getElementType
if (elementType.hasShort) {
makeArrayData(v => v.getShort.toShort)
} else if (elementType.hasInteger) {
makeArrayData(v => v.getInteger)
} else if (elementType.hasLong) {
makeArrayData(v => v.getLong)
} else if (elementType.hasDouble) {
makeArrayData(v => v.getDouble)
} else if (elementType.hasByte) {
makeArrayData(v => v.getByte.toByte)
} else if (elementType.hasFloat) {
makeArrayData(v => v.getFloat)
} else if (elementType.hasBoolean) {
makeArrayData(v => v.getBoolean)
} else if (elementType.hasString) {
makeArrayData(v => v.getString)
} else if (elementType.hasBinary) {
makeArrayData(v => v.getBinary.toByteArray)
} else if (elementType.hasDate) {
makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
} else if (elementType.hasTimestamp) {
makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
} else if (elementType.hasTimestampNtz) {
makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
} else if (elementType.hasDayTimeInterval) {
makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
} else if (elementType.hasYearMonthInterval) {
makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
} else if (elementType.hasDecimal) {
makeArrayData(v => Decimal(v.getDecimal.getValue))
} else if (elementType.hasCalendarInterval) {
makeArrayData(v => {
val interval = v.getCalendarInterval
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
})
} else if (elementType.hasArray) {
makeArrayData(v => toArrayData(v.getArray))
} else {
throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)")
}
}

}
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.ml
import scala.annotation.varargs

import org.apache.spark.annotation.Since
import org.apache.spark.connect.proto
import org.apache.spark.ml.param.{ParamMap, ParamPair}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, SparkSession}

/**
* Abstract class for estimators that fit models to data.
Expand Down Expand Up @@ -72,7 +73,26 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* Fits a model to the input data.
*/
@Since("3.5.0")
def fit(dataset: Dataset[_]): M
def fit(dataset: Dataset[_]): M = {
val fitProto = proto.MlCommand.Fit.newBuilder()
.setEstimator(
proto.MlStage.newBuilder()
.setName("LogisticRegression")
.setType(proto.MlStage.StageType.ESTIMATOR)
.setUid(uid)
.setParams(ConnectUtils.getInstanceParamsProto(this))
).build()

val resp = SparkSession.active.executeMl(
proto.MlCommand.newBuilder().setFit(fitProto).build()
)

val modelRefProto = resp.getModelInfo.getModelRef
val model = createModel(ModelRef.fromProto(modelRefProto))
copyValues(model)
}

def createModel(modelRef: ModelRef): M

/**
* Fits multiple models to the input data with multiple sets of parameters. The default
Expand Down
Expand Up @@ -18,7 +18,24 @@
package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.connect.proto
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}


// TODO: Using Cleaner interface to clean server side object.
case class ModelRef(refId: String)

object ModelRef {
def fromProto(protoValue: proto.ModelRef): ModelRef = {
ModelRef(protoValue.getId)
}

def toProto(modelRef: ModelRef): proto.ModelRef = {
proto.ModelRef.newBuilder().setId(modelRef.refId).build()
}
}


/**
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
Expand All @@ -28,6 +45,8 @@ import org.apache.spark.ml.param.ParamMap
*/
abstract class Model[M <: Model[M]] extends Transformer {

@transient var modelRef: ModelRef = _

/**
* The parent estimator that produced this model.
* @note
Expand All @@ -49,5 +68,74 @@ abstract class Model[M <: Model[M]] extends Transformer {
def hasParent: Boolean = parent != null

@Since("3.5.0")
override def copy(extra: ParamMap): M
override def copy(extra: ParamMap): M = {
val cmdBuilder = proto.MlCommand.newBuilder()

cmdBuilder.getCopyModelBuilder
.setModelRef(ModelRef.toProto(modelRef))

val resp = SparkSession.active.executeMl(cmdBuilder.build())
val newRef = ConnectUtils.deserializeResponseValue(resp).asInstanceOf[ModelRef]
val newModel = defaultCopy(extra).asInstanceOf[M]
newModel.modelRef = newRef
newModel
}

def transform(dataset: Dataset[_]): DataFrame = {
dataset.sparkSession.newDataFrame { builder =>
builder.getMlRelationBuilder.getModelTransformBuilder
.setInput(dataset.plan.getRoot)
.setModelRef(ModelRef.toProto(modelRef))
.setParams(ConnectUtils.getInstanceParamsProto(this))
}
}

protected def getModelAttr(name: String): Any = {
val cmdBuilder = proto.MlCommand.newBuilder()

cmdBuilder.getFetchModelAttrBuilder
.setModelRef(ModelRef.toProto(modelRef))
.setName(name)

val resp = SparkSession.active.executeMl(cmdBuilder.build())
ConnectUtils.deserializeResponseValue(resp)
}

}

trait ModelSummary {

protected def model: Model[_]

protected def datasetOpt: Option[Dataset[_]]

protected def getModelSummaryAttr(name: String): Any = {
val cmdBuilder = proto.MlCommand.newBuilder()

val fetchCmdBuilder = cmdBuilder.getFetchModelSummaryAttrBuilder

fetchCmdBuilder
.setModelRef(ModelRef.toProto(model.modelRef))
.setName(name)
.setParams(ConnectUtils.getInstanceParamsProto(model))

datasetOpt.map(x => fetchCmdBuilder.setEvaluationDataset(x.plan.getRoot))

val resp = SparkSession.active.executeMl(cmdBuilder.build())
ConnectUtils.deserializeResponseValue(resp)
}

protected def getModelSummaryAttrDataFrame(name: String): DataFrame = {
SparkSession.active.newDataFrame { builder =>
builder.getMlRelationBuilder.getModelSummaryAttrBuilder
.setName(name)
.setModelRef(ModelRef.toProto(model.modelRef))
.setParams(ConnectUtils.getInstanceParamsProto(model))
datasetOpt.map { x =>
builder.getMlRelationBuilder.getModelSummaryAttrBuilder.setEvaluationDataset(
x.plan.getRoot
)
}
}
}
}