Skip to content

Commit

Permalink
[SPARK-40295][SQL] Allow v2 functions with literal args in write dist…
Browse files Browse the repository at this point in the history
…ribution/ordering

### What changes were proposed in this pull request?

This PR adapts `V2ExpressionUtils` to support arbitrary transforms with multiple args that are either references or literals.

### Why are the changes needed?

After PR #36995, data sources can request distribution and ordering that reference v2 functions. If a data source needs a transform with multiple input args or a transform where not all args are references, Spark will throw an exception.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR adapts the test added recently in PR #36995.

Closes #37749 from aokolnychyi/spark-40295.

Lead-authored-by: aokolnychyi <aokolnychyi@apple.com>
Co-authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
Signed-off-by: Chao Sun <sunchao@apple.com>
  • Loading branch information
aokolnychyi authored and sunchao committed Sep 7, 2022
1 parent 32567e9 commit 127ccc2
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 25 deletions.
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -75,6 +75,8 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
query: LogicalPlan,
funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = {
expr match {
case l: V2Literal[_] =>
Some(Literal.create(l.value, l.dataType))
case t: Transform =>
toCatalystTransformOpt(t, query, funCatalogOpt)
case SortValue(child, direction, nullOrdering) =>
Expand Down Expand Up @@ -105,18 +107,13 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
TransformExpression(bound, resolvedRefs, Some(numBuckets))
}
}
case NamedTransform(name, refs)
if refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) =>
val resolvedRefs = refs.map(_.asInstanceOf[NamedReference]).map { r =>
resolveRef[NamedExpression](r, query)
}
case NamedTransform(name, args) =>
val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt))
funCatalogOpt.flatMap { catalog =>
loadV2FunctionOpt(catalog, name, resolvedRefs).map { bound =>
TransformExpression(bound, resolvedRefs)
loadV2FunctionOpt(catalog, name, catalystArgs).map { bound =>
TransformExpression(bound, catalystArgs)
}
}
case _ =>
throw new AnalysisException(s"Transform $trans is not currently supported")
}

private def loadV2FunctionOpt(
Expand Down
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.physical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -361,6 +362,25 @@ object KeyGroupedPartitioning {
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, Some(partitionValues))
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
transform.children.size == 1 && isReference(transform.children.head)
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
}

expressions.forall {
case t: TransformExpression if isSupportedTransform(t) => true
case e: Expression if isReference(e) => true
case _ => false
}
}
}

/**
Expand Down
Expand Up @@ -83,6 +83,7 @@ abstract class InMemoryBaseTable(
case _: HoursTransform =>
case _: BucketTransform =>
case _: SortedBucketTransform =>
case NamedTransform("truncate", Seq(_: NamedReference, _: Literal[_])) =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}
Expand Down Expand Up @@ -177,6 +178,13 @@ abstract class InMemoryBaseTable(
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
case NamedTransform("truncate", Seq(ref: NamedReference, length: Literal[_])) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (str: UTF8String, StringType) =>
str.substring(0, length.value.asInstanceOf[Int])
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
}
}

Expand Down
Expand Up @@ -91,11 +91,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
}

override def outputPartitioning: physical.Partitioning = {
if (partitions.length == 1) SinglePartition
else groupedPartitions.map { partitionValues =>
KeyGroupedPartitioning(keyGroupedPartitioning.get,
partitionValues.size, Some(partitionValues.map(_._1)))
}.getOrElse(super.outputPartitioning)
if (partitions.length == 1) {
SinglePartition
} else {
keyGroupedPartitioning match {
case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
groupedPartitions.map { partitionValues =>
KeyGroupedPartitioning(exprs, partitionValues.size, Some(partitionValues.map(_._1)))
}.getOrElse(super.outputPartitioning)
case _ =>
super.outputPartitioning
}
}
}

@transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] =
Expand Down
Expand Up @@ -20,7 +20,7 @@ import java.util.Collections

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.TransformExpression
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
Expand All @@ -38,6 +38,12 @@ import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._

class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
private val functions = Seq(
UnboundYearsFunction,
UnboundDaysFunction,
UnboundBucketFunction,
UnboundTruncateFunction)

private var originalV2BucketingEnabled: Boolean = false
private var originalAutoBroadcastJoinThreshold: Long = -1

Expand All @@ -59,7 +65,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}

before {
Seq(UnboundYearsFunction, UnboundDaysFunction, UnboundBucketFunction).foreach { f =>
functions.foreach { f =>
catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
}
}
Expand Down Expand Up @@ -179,6 +185,25 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("non-clustered distribution: V2 function with multiple args") {
val partitions: Array[Transform] = Array(
Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2))
)

// create a table with 3 partitions, partitioned by `truncate` transform
createTable(table, schema, partitions)
sql(s"INSERT INTO testcat.ns.$table VALUES " +
s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")

val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2)))))

checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
}

/**
* Check whether the query plan from `df` has the expected `distribution`, `ordering` and
* `partitioning`.
Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast,
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction}
import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, TruncateFunction, UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.expressions.LogicalExpressions._
Expand All @@ -45,7 +45,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
import testImplicits._

before {
Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f =>
Seq(UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction).foreach { f =>
catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
}
}
Expand Down Expand Up @@ -1041,19 +1041,36 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
distributionStrictlyRequired: Boolean = true,
dataSkewed: Boolean = false,
coalesce: Boolean = false): Unit = {

val stringSelfTransform = ApplyTransform(
"string_self",
Seq(FieldReference("data")))
val truncateTransform = ApplyTransform(
"truncate",
Seq(stringSelfTransform, LiteralValue(2, IntegerType)))

val tableOrdering = Array[SortOrder](
sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST),
sort(
stringSelfTransform,
SortDirection.DESCENDING,
NullOrdering.NULLS_FIRST),
sort(
BucketTransform(LiteralValue(10, IntegerType), Seq(FieldReference("id"))),
SortDirection.DESCENDING,
NullOrdering.NULLS_FIRST)
)
val tableDistribution = Distributions.clustered(Array(
ApplyTransform("string_self", Seq(FieldReference("data")))))
val tableDistribution = Distributions.clustered(Array(truncateTransform))

val stringSelfExpr = ApplyFunctionExpression(
StringSelfFunction,
Seq(attr("data")))
val truncateExpr = ApplyFunctionExpression(
TruncateFunction,
Seq(stringSelfExpr, Literal(2)))

val writeOrdering = Seq(
catalyst.expressions.SortOrder(
attr("data"),
stringSelfExpr,
catalyst.expressions.Descending,
catalyst.expressions.NullsFirst,
Seq.empty
Expand All @@ -1066,8 +1083,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
)
)

val writePartitioningExprs = Seq(
ApplyFunctionExpression(StringSelfFunction, Seq(attr("data"))))
val writePartitioningExprs = Seq(truncateExpr)
val writePartitioning = if (!coalesce) {
clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
} else {
Expand Down
Expand Up @@ -99,3 +99,22 @@ object StringSelfFunction extends ScalarFunction[UTF8String] {
input.getUTF8String(0)
}
}

object UnboundTruncateFunction extends UnboundFunction {
override def bind(inputType: StructType): BoundFunction = TruncateFunction
override def description(): String = name()
override def name(): String = "truncate"
}

object TruncateFunction extends ScalarFunction[UTF8String] {
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
override def resultType(): DataType = StringType
override def name(): String = "truncate"
override def canonicalName(): String = name()
override def toString: String = name()
override def produceResult(input: InternalRow): UTF8String = {
val str = input.getUTF8String(0)
val length = input.getInt(1)
str.substring(0, length)
}
}

0 comments on commit 127ccc2

Please sign in to comment.