-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CONNECT][ML][WIP] Spark connect ML for scala client #40479
Closed
WeichenXu123
wants to merge
41
commits into
apache:master
from
WeichenXu123:spark-connect-ml-scala-1
Closed
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
684a9e3
update
WeichenXu123 97ab924
update
WeichenXu123 86f2fad
update
WeichenXu123 4f17d6c
update
WeichenXu123 6582037
merge master & fix
WeichenXu123 ba4f580
update
WeichenXu123 941550e
update
WeichenXu123 c4473a6
update
WeichenXu123 606168d
update
WeichenXu123 1003787
update
WeichenXu123 f9f3542
update
WeichenXu123 ed24307
fix
WeichenXu123 c1f9162
merge master
WeichenXu123 e178de3
update
WeichenXu123 130bd1e
update
WeichenXu123 eee1013
fix
WeichenXu123 d72fba0
update
WeichenXu123 36bc69b
update
WeichenXu123 870c994
merge master
WeichenXu123 e500be8
update
WeichenXu123 7c44e5c
update
WeichenXu123 e87aa53
merge master
WeichenXu123 18876c2
update
WeichenXu123 23744b8
model gc
WeichenXu123 33be464
try_remote_ml_class
WeichenXu123 bb43f01
format
WeichenXu123 ec89d40
format
WeichenXu123 2ebcf45
fix tests
WeichenXu123 6e12a22
update
WeichenXu123 9ae327b
doctests
WeichenXu123 36e6d33
merge master
WeichenXu123 e5278cb
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 c80414d
add proto comments
WeichenXu123 66c472b
address comments
WeichenXu123 39b2f24
model_ref message
WeichenXu123 2d4377d
move to ml.connect
WeichenXu123 e316a82
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 caebf75
fix
WeichenXu123 1059624
update
WeichenXu123 5650fd0
update
WeichenXu123 d3dc34b
update
WeichenXu123 File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
208 changes: 208 additions & 0 deletions
208
connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/ConnectUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.connect.proto | ||
import org.apache.spark.ml.linalg.{Matrix, Vector, Matrices, Vectors} | ||
import org.apache.spark.ml.param.Params | ||
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} | ||
import org.apache.spark.sql.connect.common.{InvalidPlanInput, LiteralValueProtoConverter} | ||
import org.apache.spark.sql.types.Decimal | ||
import org.apache.spark.unsafe.types.CalendarInterval | ||
|
||
import scala.collection.mutable | ||
import scala.reflect.ClassTag | ||
|
||
object ConnectUtils { | ||
|
||
def getInstanceParamsProto(instance: Params): proto.MlParams = { | ||
val builder = proto.MlParams.newBuilder() | ||
|
||
for (param <- instance.params) { | ||
instance.get(param).map { value => | ||
builder.putParams( | ||
param.name, | ||
LiteralValueProtoConverter.toLiteralProto(value) | ||
) | ||
} | ||
instance.getDefault(param).map { value => | ||
builder.putParams( | ||
param.name, | ||
LiteralValueProtoConverter.toLiteralProto(value) | ||
) | ||
} | ||
} | ||
builder.build() | ||
} | ||
|
||
def serializeResponseValue(data: Any): proto.MlCommandResponse = { | ||
data match { | ||
case v: Vector => serializeVector(v) | ||
case v: Matrix => serializeMatrix(v) | ||
case _: Byte | _: Short | _: Int | _: Long | _: Float | _: Double | _: Boolean | _: String | | ||
_: Array[_] => | ||
proto.MlCommandResponse | ||
.newBuilder() | ||
.setLiteral(LiteralValueProtoConverter.toLiteralProto(data)) | ||
.build() | ||
case _ => | ||
throw new IllegalArgumentException() | ||
} | ||
} | ||
|
||
def serializeVector(data: Vector): proto.MlCommandResponse = { | ||
// TODO: Support sparse | ||
val values = data.toArray | ||
val denseBuilder = proto.Vector.Dense.newBuilder() | ||
for (i <- 0 until values.length) { | ||
denseBuilder.addValue(values(i)) | ||
} | ||
|
||
proto.MlCommandResponse | ||
.newBuilder() | ||
.setVector(proto.Vector.newBuilder().setDense(denseBuilder)) | ||
.build() | ||
} | ||
|
||
def deserializeVector(protoValue: proto.Vector): Vector = { | ||
// TODO: Support sparse | ||
Vectors.dense( | ||
protoValue.getDense.getValueList.stream().mapToDouble(_.doubleValue()).toArray | ||
) | ||
} | ||
|
||
def deserializeMatrix(protoValue: proto.Matrix): Matrix = { | ||
// TODO: Support sparse | ||
val denseProto = protoValue.getDense | ||
Matrices.dense( | ||
denseProto.getNumRows, | ||
denseProto.getNumCols, | ||
denseProto.getValueList.stream().mapToDouble(_.doubleValue()).toArray | ||
) | ||
} | ||
|
||
def serializeMatrix(data: Matrix): proto.MlCommandResponse = { | ||
// TODO: Support sparse | ||
// TODO: optimize transposed case | ||
val denseBuilder = proto.Matrix.Dense.newBuilder() | ||
val values = data.toArray | ||
for (i <- 0 until values.length) { | ||
denseBuilder.addValue(values(i)) | ||
} | ||
denseBuilder.setNumCols(data.numCols) | ||
denseBuilder.setNumRows(data.numRows) | ||
denseBuilder.setIsTransposed(false) | ||
proto.MlCommandResponse | ||
.newBuilder() | ||
.setMatrix(proto.Matrix.newBuilder().setDense(denseBuilder)) | ||
.build() | ||
} | ||
|
||
def deserializeResponseValue(protoValue: proto.MlCommandResponse): Any = { | ||
protoValue.getMlCommandResponseTypeCase match { | ||
case proto.MlCommandResponse.MlCommandResponseTypeCase.LITERAL => | ||
deserializeLiteral(protoValue.getLiteral) | ||
case proto.MlCommandResponse.MlCommandResponseTypeCase.VECTOR => | ||
deserializeVector(protoValue.getVector) | ||
case proto.MlCommandResponse.MlCommandResponseTypeCase.MATRIX => | ||
deserializeMatrix(protoValue.getMatrix) | ||
case proto.MlCommandResponse.MlCommandResponseTypeCase.MODEL_REF => | ||
ModelRef.fromProto(protoValue.getModelRef) | ||
case _ => | ||
throw new IllegalArgumentException() | ||
} | ||
} | ||
|
||
def deserializeLiteral(protoValue: proto.Expression.Literal): Any = { | ||
protoValue.getLiteralTypeCase match { | ||
case proto.Expression.Literal.LiteralTypeCase.INTEGER => | ||
protoValue.getInteger | ||
case proto.Expression.Literal.LiteralTypeCase.LONG => | ||
protoValue.getLong | ||
case proto.Expression.Literal.LiteralTypeCase.FLOAT => | ||
protoValue.getFloat | ||
case proto.Expression.Literal.LiteralTypeCase.DOUBLE => | ||
protoValue.getDouble | ||
case proto.Expression.Literal.LiteralTypeCase.STRING => | ||
protoValue.getString | ||
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => | ||
protoValue.getInteger | ||
case proto.Expression.Literal.LiteralTypeCase.ARRAY => | ||
toArrayData(protoValue.getArray) | ||
case _ => | ||
throw new IllegalArgumentException() | ||
} | ||
} | ||
|
||
private def toArrayData(array: proto.Expression.Literal.Array): Any = { | ||
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit | ||
tag: ClassTag[T]): Array[T] = { | ||
val builder = mutable.ArrayBuilder.make[T] | ||
val elementList = array.getElementsList | ||
builder.sizeHint(elementList.size()) | ||
val iter = elementList.iterator() | ||
while (iter.hasNext) { | ||
builder += converter(iter.next()) | ||
} | ||
builder.result() | ||
} | ||
|
||
val elementType = array.getElementType | ||
if (elementType.hasShort) { | ||
makeArrayData(v => v.getShort.toShort) | ||
} else if (elementType.hasInteger) { | ||
makeArrayData(v => v.getInteger) | ||
} else if (elementType.hasLong) { | ||
makeArrayData(v => v.getLong) | ||
} else if (elementType.hasDouble) { | ||
makeArrayData(v => v.getDouble) | ||
} else if (elementType.hasByte) { | ||
makeArrayData(v => v.getByte.toByte) | ||
} else if (elementType.hasFloat) { | ||
makeArrayData(v => v.getFloat) | ||
} else if (elementType.hasBoolean) { | ||
makeArrayData(v => v.getBoolean) | ||
} else if (elementType.hasString) { | ||
makeArrayData(v => v.getString) | ||
} else if (elementType.hasBinary) { | ||
makeArrayData(v => v.getBinary.toByteArray) | ||
} else if (elementType.hasDate) { | ||
makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate)) | ||
} else if (elementType.hasTimestamp) { | ||
makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp)) | ||
} else if (elementType.hasTimestampNtz) { | ||
makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz)) | ||
} else if (elementType.hasDayTimeInterval) { | ||
makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval)) | ||
} else if (elementType.hasYearMonthInterval) { | ||
makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval)) | ||
} else if (elementType.hasDecimal) { | ||
makeArrayData(v => Decimal(v.getDecimal.getValue)) | ||
} else if (elementType.hasCalendarInterval) { | ||
makeArrayData(v => { | ||
val interval = v.getCalendarInterval | ||
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) | ||
}) | ||
} else if (elementType.hasArray) { | ||
makeArrayData(v => toArrayData(v.getArray)) | ||
} else { | ||
throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)") | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengruifeng Could you help move this utility function to "common" project ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a util function convert
proto.Expression.Literal
toAny
value?let me take a look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#40485