Skip to content

Commit

Permalink
[SPARK-39607][SQL][DSV2] Distribution and ordering support V2 functio…
Browse files Browse the repository at this point in the history
…n in writing
  • Loading branch information
pan3793 committed Aug 20, 2022
1 parent c8508fa commit 49ccd54
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 77 deletions.
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
Expand Down Expand Up @@ -47,8 +46,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -2336,33 +2334,7 @@ class Analyzer(override val catalogManager: CatalogManager)
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "IGNORE NULLS")
} else {
val declaredInputTypes = scalarFunc.inputTypes().toSeq
val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match {
case Some(_) =>
ApplyFunctionExpression(scalarFunc, arguments)
case _ =>
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
s" magic method nor override 'produceResult'")
}
}
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}
}

Expand All @@ -2377,23 +2349,6 @@ class Analyzer(override val catalogManager: CatalogManager)
val aggregator = V2Aggregator(aggFunc, arguments)
aggregator.toAggregateExpression(u.isDistinct, u.filter)
}

/**
* Check if the input `fn` implements the given `methodName` with parameter types specified
* via `argClasses`.
*/
private def findMethod(
fn: BoundFunction,
methodName: String,
argClasses: Seq[Class[_]]): Option[Method] = {
val cls = fn.getClass
try {
Some(cls.getDeclaredMethod(methodName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
None
}
}
}

/**
Expand Down
Expand Up @@ -17,13 +17,17 @@

package org.apache.spark.sql.catalyst.expressions

import java.lang.reflect.{Method, Modifier}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -52,8 +56,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
/**
* Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst.
*/
def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan): Seq[SortOrder] = {
ordering.map(toCatalyst(_, query).asInstanceOf[SortOrder])
def toCatalystOrdering(
ordering: Array[V2SortOrder],
query: LogicalPlan,
funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = {
ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder])
}

def toCatalyst(
Expand Down Expand Up @@ -143,4 +150,53 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
case V2NullOrdering.NULLS_FIRST => NullsFirst
case V2NullOrdering.NULLS_LAST => NullsLast
}

def resolveScalarFunction(
scalarFunc: ScalarFunction[_],
arguments: Seq[Expression]): Expression = {
val declaredInputTypes = scalarFunc.inputTypes().toSeq
val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match {
case Some(_) =>
ApplyFunctionExpression(scalarFunc, arguments)
case _ =>
throw new AnalysisException(s"ScalarFunction '${scalarFunc.name()}'" +
s" neither implement magic method nor override 'produceResult'")
}
}
}

/**
* Check if the input `fn` implements the given `methodName` with parameter types specified
* via `argClasses`.
*/
private def findMethod(
fn: BoundFunction,
methodName: String,
argClasses: Seq[Class[_]]): Option[Method] = {
val cls = fn.getClass
try {
Some(cls.getDeclaredMethod(methodName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
None
}
}
}
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -48,6 +48,10 @@ case class DataSourceV2Relation(

import DataSourceV2Implicits._

lazy val funCatalog: Option[FunctionCatalog] = catalog.collect {
case c: FunctionCatalog => c
}

override lazy val metadataOutput: Seq[AttributeReference] = table match {
case hasMeta: SupportsMetadataColumns =>
val resolve = conf.resolver
Expand Down
Expand Up @@ -17,22 +17,33 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.FunctionCatalog
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.connector.distributions._
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write}
import org.apache.spark.sql.errors.QueryCompilationErrors

object DistributionAndOrderingUtils {

def prepareQuery(write: Write, query: LogicalPlan): LogicalPlan = write match {
def prepareQuery(
write: Write,
query: LogicalPlan,
funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match {
case write: RequiresDistributionAndOrdering =>
val numPartitions = write.requiredNumPartitions()

val distribution = write.requiredDistribution match {
case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query)
case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)).toSeq
case d: OrderedDistribution =>
toCatalystOrdering(d.ordering(), query, funCatalogOpt)
.map(e => resolveTransformExpression(e).asInstanceOf[SortOrder])
case d: ClusteredDistribution =>
d.clustering.map(e => toCatalyst(e, query, funCatalogOpt))
.map(e => resolveTransformExpression(e)).toSeq
case _: UnspecifiedDistribution => Seq.empty[Expression]
}

Expand All @@ -53,16 +64,33 @@ object DistributionAndOrderingUtils {
query
}

val ordering = toCatalystOrdering(write.requiredOrdering, query)
val ordering = toCatalystOrdering(write.requiredOrdering, query, funCatalogOpt)
val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
Sort(ordering, global = false, queryWithDistribution)
Sort(
ordering.map(e => resolveTransformExpression(e).asInstanceOf[SortOrder]),
global = false,
queryWithDistribution)
} else {
queryWithDistribution
}

queryWithDistributionAndOrdering

// Apply typeCoercionRules since the converted expression from TransformExpression
// implemented ImplicitCastInputTypes
typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan, rule) => rule(plan))
case _ =>
query
}

private def resolveTransformExpression(expr: Expression): Expression = expr.transform {
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}

private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) {
AnsiTypeCoercion.typeCoercionRules
} else {
TypeCoercion.typeCoercionRules
}
}
Expand Up @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.FunctionCatalog
import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning}
import org.apache.spark.util.collection.Utils.sequenceToOption
Expand All @@ -41,14 +40,9 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe

private def partitioning(plan: LogicalPlan) = plan.transformDown {
case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) =>
val funCatalogOpt = relation.catalog.flatMap {
case c: FunctionCatalog => Some(c)
case _ => None
}

val catalystPartitioning = scan.outputPartitioning() match {
case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map(
V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt)))
V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog)))
case _: UnknownPartitioning => None
case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " +
"type: " + p.getClass.getSimpleName)
Expand Down
Expand Up @@ -43,7 +43,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
val writeBuilder = newWriteBuilder(r.table, options, query.schema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
a.copy(write = Some(write), query = newQuery)

case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) =>
Expand All @@ -67,7 +67,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
}

val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
o.copy(write = Some(write), query = newQuery)

case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
Expand All @@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case _ =>
throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table)
}
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
o.copy(write = Some(write), query = newQuery)

case WriteToMicroBatchDataSource(
Expand All @@ -89,14 +89,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toSeq
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
val funCatalogOpt = relation.flatMap(_.funCatalog)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)

case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) =>
val rowSchema = StructType.fromAttributes(rd.dataInput)
val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
// project away any metadata columns that could be used for distribution and ordering
rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))

Expand Down

0 comments on commit 49ccd54

Please sign in to comment.