From b6308dd0bb022194e74c9c24543cd56cde2ede2c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 12 Dec 2021 19:32:08 -0800 Subject: [PATCH] add sorted column in BucketTransform --- .../catalog/CatalogV2Implicits.scala | 9 ++--- .../connector/expressions/expressions.scala | 36 ++++++++++++++----- .../sql/errors/QueryCompilationErrors.scala | 7 +--- .../sql/connector/catalog/InMemoryTable.scala | 2 +- .../expressions/TransformExtractorSuite.scala | 4 +-- .../datasources/v2/V2SessionCatalog.scala | 4 +-- .../sql/connector/DataSourceV2SQLSuite.scala | 18 ++++++++++ 7 files changed, 57 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 39642fd541706..185a1a2644e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -38,12 +38,13 @@ private[sql] object CatalogV2Implicits { implicit class BucketSpecHelper(spec: BucketSpec) { def asTransform: BucketTransform = { + val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { - throw QueryCompilationErrors.cannotConvertBucketWithSortColumnsToTransformError(spec) + val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) + bucket(spec.numBuckets, references.toArray, sortedCol.toArray) + } else { + bucket(spec.numBuckets, references.toArray) } - - val references = spec.bucketColumnNames.map(col => reference(Seq(col))) - bucket(spec.numBuckets, references.toArray) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 2863d94d198b2..e52654ac69c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -45,6 +45,12 @@ private[sql] object LogicalExpressions { def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = BucketTransform(literal(numBuckets, IntegerType), references) + def bucket( + numBuckets: Int, + references: Array[NamedReference], + sortedCols: Array[NamedReference]): BucketTransform = + BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) def years(reference: NamedReference): YearsTransform = YearsTransform(reference) @@ -97,7 +103,8 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference]) extends RewritableTransform { + columns: Seq[NamedReference], + sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { override val name: String = "bucket" @@ -107,7 +114,13 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + override def describe: String = + if (sortedColumns.nonEmpty) { + s"bucket(${arguments.map(_.describe).mkString(", ")}," + + s" ${sortedColumns.map(_.describe).mkString(", ")})" + } else { + s"bucket(${arguments.map(_.describe).mkString(", ")})" + } override def toString: String = describe @@ -117,11 +130,12 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = + expr match { case transform: Transform => transform match { - case BucketTransform(n, FieldReference(parts)) => - Some((n, FieldReference(parts))) + case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => + Some((n, FieldReference(parts), FieldReference(sortCols))) case _ => None } @@ -129,11 +143,17 @@ private[sql] object BucketTransform { None } - def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = + transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(partCols: Seq[String]), + Ref(sortCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(sortCols))) case NamedTransform("bucket", Seq( Lit(value: Int, IntegerType), - Ref(seq: Seq[String]))) => - Some((value, FieldReference(seq))) + Ref(partCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 920a748e97ca5..9b731a693008a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedNamespace, ResolvedTable, ResolvedView, Star, TableAlreadyExistsException, UnresolvedRegex} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, InvalidUDFClassException} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.JoinType @@ -1384,11 +1384,6 @@ object QueryCompilationErrors { new AnalysisException("Cannot use interval type in the table schema.") } - def cannotConvertBucketWithSortColumnsToTransformError(spec: BucketSpec): Throwable = { - new AnalysisException( - s"Cannot convert bucketing with sort columns to a transform: $spec") - } - def cannotConvertTransformsToPartitionColumnsError(nonIdTransforms: Seq[Transform]): Throwable = { new AnalysisException("Transforms cannot be converted to partition columns: " + nonIdTransforms.map(_.describe).mkString(", ")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index fad6fe5fbe166..3880594108a4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -160,7 +160,7 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref) => + case BucketTransform(numBuckets, ref, _) => val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index fbd6a886d011b..340d225f80fdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -139,7 +139,7 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq)) => + case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) case _ => @@ -147,7 +147,7 @@ class TransformExtractorSuite extends SparkFunSuite { } transform("unknown", ref("a", "b")) match { - case BucketTransform(_, _) => + case BucketTransform(_, _, _) => fail("Matched unknown transform") case _ => // expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index d6c69fa03d698..d5547c1f3c1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -318,8 +318,8 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4c5a001ec076c..3481ef0336b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1574,6 +1574,24 @@ class DataSourceV2SQLSuite } } + test("create table using - with sorted bucket") { + val identifier = "testcat.table_name" + withTable(identifier) { + sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + + s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") + val table = getTableMetadata(identifier) + val describe = spark.sql(s"DESCRIBE $identifier") + val part1 = describe + .filter("col_name = 'Part 0'") + .select("data_type").head.getString(0) + assert(part1 === "c") + val part2 = describe + .filter("col_name = 'Part 1'") + .select("data_type").head.getString(0) + assert(part2 === "bucket(4, b, a)") + } + } + test("REFRESH TABLE: v2 table") { val t = "testcat.ns1.ns2.tbl" withTable(t) {