Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42870][CONNECT] Move toCatalystValue to connect-common #40485

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -22,6 +22,9 @@ import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.time._

import scala.collection.mutable
import scala.reflect.ClassTag

import com.google.protobuf.ByteString

import org.apache.spark.connect.proto
Expand Down Expand Up @@ -142,4 +145,115 @@ object LiteralValueProtoConverter {
case _ =>
throw new UnsupportedOperationException(s"Unsupported component type $clz in arrays.")
}

def toCatalystValue(literal: proto.Expression.Literal): Any = {
literal.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.NULL => null

case proto.Expression.Literal.LiteralTypeCase.BINARY => literal.getBinary.toByteArray

case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => literal.getBoolean

case proto.Expression.Literal.LiteralTypeCase.BYTE => literal.getByte.toByte

case proto.Expression.Literal.LiteralTypeCase.SHORT => literal.getShort.toShort

case proto.Expression.Literal.LiteralTypeCase.INTEGER => literal.getInteger

case proto.Expression.Literal.LiteralTypeCase.LONG => literal.getLong

case proto.Expression.Literal.LiteralTypeCase.FLOAT => literal.getFloat

case proto.Expression.Literal.LiteralTypeCase.DOUBLE => literal.getDouble

case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
Decimal(literal.getDecimal.getValue)

case proto.Expression.Literal.LiteralTypeCase.STRING => literal.getString

case proto.Expression.Literal.LiteralTypeCase.DATE =>
DateTimeUtils.toJavaDate(literal.getDate)

case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
DateTimeUtils.toJavaTimestamp(literal.getTimestamp)

case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
DateTimeUtils.microsToLocalDateTime(literal.getTimestampNtz)

case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
new CalendarInterval(
literal.getCalendarInterval.getMonths,
literal.getCalendarInterval.getDays,
literal.getCalendarInterval.getMicroseconds)

case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
IntervalUtils.monthsToPeriod(literal.getYearMonthInterval)

case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
IntervalUtils.microsToDuration(literal.getDayTimeInterval)

case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
toCatalystArray(literal.getArray)

case other =>
throw new UnsupportedOperationException(
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
}
}

def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
val builder = mutable.ArrayBuilder.make[T]
val elementList = array.getElementsList
builder.sizeHint(elementList.size())
val iter = elementList.iterator()
while (iter.hasNext) {
builder += converter(iter.next())
}
builder.result()
}

val elementType = array.getElementType
if (elementType.hasShort) {
makeArrayData(v => v.getShort.toShort)
} else if (elementType.hasInteger) {
makeArrayData(v => v.getInteger)
} else if (elementType.hasLong) {
makeArrayData(v => v.getLong)
} else if (elementType.hasDouble) {
makeArrayData(v => v.getDouble)
} else if (elementType.hasByte) {
makeArrayData(v => v.getByte.toByte)
} else if (elementType.hasFloat) {
makeArrayData(v => v.getFloat)
} else if (elementType.hasBoolean) {
makeArrayData(v => v.getBoolean)
} else if (elementType.hasString) {
makeArrayData(v => v.getString)
} else if (elementType.hasBinary) {
makeArrayData(v => v.getBinary.toByteArray)
} else if (elementType.hasDate) {
makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
} else if (elementType.hasTimestamp) {
makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
} else if (elementType.hasTimestampNtz) {
makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
} else if (elementType.hasDayTimeInterval) {
makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
} else if (elementType.hasYearMonthInterval) {
makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
} else if (elementType.hasDecimal) {
makeArrayData(v => Decimal(v.getDecimal.getValue))
} else if (elementType.hasCalendarInterval) {
makeArrayData(v => {
val interval = v.getCalendarInterval
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
})
} else if (elementType.hasArray) {
makeArrayData(v => toCatalystArray(v.getArray))
} else {
throw new UnsupportedOperationException(s"Unsupported Literal Type: $elementType)")
}
}
}
Expand Up @@ -17,13 +17,9 @@

package org.apache.spark.sql.connect.planner

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -103,7 +99,7 @@ object LiteralExpressionProtoConverter {

case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
expressions.Literal.create(
toArrayData(lit.getArray),
LiteralValueProtoConverter.toCatalystArray(lit.getArray),
ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))

case _ =>
Expand All @@ -112,68 +108,4 @@ object LiteralExpressionProtoConverter {
s"(${lit.getLiteralTypeCase.name})")
}
}

def toCatalystValue(lit: proto.Expression.Literal): Any = {
lit.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString

case _ => toCatalystExpression(lit).value
}
}

private def toArrayData(array: proto.Expression.Literal.Array): Any = {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
val builder = mutable.ArrayBuilder.make[T]
val elementList = array.getElementsList
builder.sizeHint(elementList.size())
val iter = elementList.iterator()
while (iter.hasNext) {
builder += converter(iter.next())
}
builder.result()
}

val elementType = array.getElementType
if (elementType.hasShort) {
makeArrayData(v => v.getShort.toShort)
} else if (elementType.hasInteger) {
makeArrayData(v => v.getInteger)
} else if (elementType.hasLong) {
makeArrayData(v => v.getLong)
} else if (elementType.hasDouble) {
makeArrayData(v => v.getDouble)
} else if (elementType.hasByte) {
makeArrayData(v => v.getByte.toByte)
} else if (elementType.hasFloat) {
makeArrayData(v => v.getFloat)
} else if (elementType.hasBoolean) {
makeArrayData(v => v.getBoolean)
} else if (elementType.hasString) {
makeArrayData(v => v.getString)
} else if (elementType.hasBinary) {
makeArrayData(v => v.getBinary.toByteArray)
} else if (elementType.hasDate) {
makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
} else if (elementType.hasTimestamp) {
makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
} else if (elementType.hasTimestampNtz) {
makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
} else if (elementType.hasDayTimeInterval) {
makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
} else if (elementType.hasYearMonthInterval) {
makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
} else if (elementType.hasDecimal) {
makeArrayData(v => Decimal(v.getDecimal.getValue))
} else if (elementType.hasCalendarInterval) {
makeArrayData(v => {
val interval = v.getCalendarInterval
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
})
} else if (elementType.hasArray) {
makeArrayData(v => toArrayData(v.getArray))
} else {
throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)")
}
}
}
Expand Up @@ -41,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.{toCatalystExpression, toCatalystValue}
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -326,7 +325,7 @@ class SparkConnectPlanner(val session: SparkSession) {
} else {
val valueMap = mutable.Map.empty[String, Any]
cols.zip(values).foreach { case (col, value) =>
valueMap.update(col, toCatalystValue(value))
valueMap.update(col, LiteralValueProtoConverter.toCatalystValue(value))
}
dataset.na.fill(valueMap = valueMap.toMap).logicalPlan
}
Expand All @@ -353,8 +352,8 @@ class SparkConnectPlanner(val session: SparkSession) {
val replacement = mutable.Map.empty[Any, Any]
rel.getReplacementsList.asScala.foreach { replace =>
replacement.update(
toCatalystValue(replace.getOldValue),
toCatalystValue(replace.getNewValue))
LiteralValueProtoConverter.toCatalystValue(replace.getOldValue),
LiteralValueProtoConverter.toCatalystValue(replace.getNewValue))
}

if (rel.getColsCount == 0) {
Expand Down Expand Up @@ -896,7 +895,7 @@ class SparkConnectPlanner(val session: SparkSession) {
* Expression
*/
private def transformLiteral(lit: proto.Expression.Literal): Literal = {
toCatalystExpression(lit)
LiteralExpressionProtoConverter.toCatalystExpression(lit)
}

private def transformLimit(limit: proto.Limit): LogicalPlan = {
Expand Down
Expand Up @@ -19,15 +19,15 @@ package org.apache.spark.sql.connect.planner

import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite

import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.toCatalystValue
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter

class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite

test("basic proto value and catalyst value conversion") {
val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark")
for (v <- values) {
assertResult(v)(toCatalystValue(toLiteralProto(v)))
assertResult(v)(
LiteralValueProtoConverter.toCatalystValue(LiteralValueProtoConverter.toLiteralProto(v)))
}
}
}