Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ClusterBySpec}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, QuotingUtils}
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, quoteNameParts, QuotingUtils}
import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -223,6 +223,8 @@ private[sql] object CatalogV2Implicits {

def quoted: String = parts.map(quoteIfNeeded).mkString(".")

def fullyQuoted: String = quoteNameParts(parts)

def original: String = parts.mkString(".")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,29 @@ private[sql] object CatalogV2Util {
.asTableCatalog
}

def toStructType(cols: Seq[MetadataColumn]): StructType = {
StructType(cols.map(toStructField))
}

private def toStructField(col: MetadataColumn): StructField = {
val metadata = Option(col.metadataInJSON).map(Metadata.fromJson).getOrElse(Metadata.empty)
var f = StructField(col.name, col.dataType, col.isNullable, metadata)
if (col.comment != null) {
f = f.withComment(col.comment)
}
f
}

def v2ColumnsToStructType(columns: Array[Column]): StructType = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think scala allows us to pass Array when Seq is required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala does allow this but it involves a copy operation and is prohibited by Spark checks.

method copyArrayToImmutableIndexedSeq in class LowPriorityImplicits2 is deprecated (since 2.13.0): implicit conversions from Array to immutable.IndexedSeq are implemented by copying; use `toIndexedSeq` explicitly if you want to copy, or use the more efficient non-copying ArraySeq.unsafeWrapArray

v2ColumnsToStructType(columns.toImmutableArraySeq)
}

/**
* Converts DS v2 columns to StructType, which encodes column comment and default value to
* StructField metadata. This is mainly used to define the schema of v2 scan, w.r.t. the columns
* of the v2 table.
*/
def v2ColumnsToStructType(columns: Array[Column]): StructType = {
def v2ColumnsToStructType(columns: Seq[Column]): StructType = {
StructType(columns.map(v2ColumnToStructField))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.sql.connector.catalog

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.sql.util.SchemaValidationMode
import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES
import org.apache.spark.util.ArrayImplicits._

private[sql] object V2TableUtil extends SQLConfHelper {
Expand All @@ -40,10 +40,14 @@ private[sql] object V2TableUtil extends SQLConfHelper {
*
* @param table the current table metadata
* @param relation the relation with captured columns
* @param mode validation mode that defines what changes are acceptable
* @return validation errors, or empty sequence if valid
*/
def validateCapturedColumns(table: Table, relation: DataSourceV2Relation): Seq[String] = {
validateCapturedColumns(table, relation.table.columns.toImmutableArraySeq)
def validateCapturedColumns(
table: Table,
relation: DataSourceV2Relation,
mode: SchemaValidationMode): Seq[String] = {
validateCapturedColumns(table, relation.table.columns.toImmutableArraySeq, mode)
}

/**
Expand All @@ -56,56 +60,42 @@ private[sql] object V2TableUtil extends SQLConfHelper {
*
* @param table the current table metadata
* @param originCols the originally captured columns
* @param mode validation mode that defines what changes are acceptable
* @return validation errors, or empty sequence if valid
*/
def validateCapturedColumns(table: Table, originCols: Seq[Column]): Seq[String] = {
val errors = mutable.ArrayBuffer[String]()
val colsByNormalizedName = indexColumns(table.columns.toImmutableArraySeq)
val originColsByNormalizedName = indexColumns(originCols)

originColsByNormalizedName.foreach { case (normalizedName, originCol) =>
colsByNormalizedName.get(normalizedName) match {
case Some(col) =>
if (originCol.dataType != col.dataType || originCol.nullable != col.nullable) {
val oldType = formatType(originCol.dataType, originCol.nullable)
val newType = formatType(col.dataType, col.nullable)
errors += s"`${originCol.name}` type has changed from $oldType to $newType"
}
case None =>
errors += s"${formatColumn(originCol)} has been removed"
}
}

colsByNormalizedName.foreach { case (normalizedName, col) =>
if (!originColsByNormalizedName.contains(normalizedName)) {
errors += s"${formatColumn(col)} has been added"
}
}

errors.toSeq
def validateCapturedColumns(
table: Table,
originCols: Seq[Column],
mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = {
val originSchema = CatalogV2Util.v2ColumnsToStructType(originCols)
val schema = CatalogV2Util.v2ColumnsToStructType(table.columns)
SchemaUtils.validateSchemaCompatibility(originSchema, schema, resolver, mode)
}

/**
* Validates that captured metadata columns are consistent with the current table metadata.
*
* @param table the current table metadata
* @param relation the relation with captured metadata columns
* @param mode validation mode that defines what changes are acceptable
* @return validation errors, or empty sequence if valid
*/
def validateCapturedMetadataColumns(table: Table, relation: DataSourceV2Relation): Seq[String] = {
validateCapturedMetadataColumns(table, extractMetadataColumns(relation))
def validateCapturedMetadataColumns(
table: Table,
relation: DataSourceV2Relation,
mode: SchemaValidationMode): Seq[String] = {
validateCapturedMetadataColumns(table, extractMetadataColumns(relation), mode)
}

// extracts original column info for all metadata attributes in relation
/**
* Extracts original column info for all metadata attributes in the relation.
*
* @param relation the relation with captured metadata columns
* @return metadata columns captured by the relation
*/
def extractMetadataColumns(relation: DataSourceV2Relation): Seq[MetadataColumn] = {
val metaAttrs = relation.output.filter(_.isMetadataCol)
if (metaAttrs.nonEmpty) {
val metaCols = metadataColumns(relation.table)
val normalizedMetaAttrNames = metaAttrs.map(attr => normalize(attr.name)).toSet
metaCols.filter(col => normalizedMetaAttrNames.contains(normalize(col.name)))
} else {
Seq.empty
}
val metaAttrNames = relation.output.filter(_.isMetadataCol).map(_.name)
filter(metaAttrNames, metadataColumns(relation.table))
}

/**
Expand All @@ -117,56 +107,23 @@ private[sql] object V2TableUtil extends SQLConfHelper {
*
* @param table the current table metadata
* @param originMetaCols the originally captured metadata columns
* @param mode validation mode that defines what changes are acceptable
* @return validation errors, or empty sequence if valid
*/
def validateCapturedMetadataColumns(
table: Table,
originMetaCols: Seq[MetadataColumn]): Seq[String] = {
val errors = mutable.ArrayBuffer[String]()
val metaCols = metadataColumns(table)
val metaColsByNormalizedName = indexMetadataColumns(metaCols)

originMetaCols.foreach { originMetaCol =>
val normalizedName = normalize(originMetaCol.name)
metaColsByNormalizedName.get(normalizedName) match {
case Some(metaCol) =>
if (originMetaCol.dataType != metaCol.dataType ||
originMetaCol.isNullable != metaCol.isNullable) {
val oldType = formatType(originMetaCol.dataType, originMetaCol.isNullable)
val newType = formatType(metaCol.dataType, metaCol.isNullable)
errors += s"`${originMetaCol.name}` type has changed from $oldType to $newType"
}
case None =>
errors += s"${formatMetadataColumn(originMetaCol)} has been removed"
}
}

errors.toSeq
}

private def formatColumn(col: Column): String = {
s"`${col.name}` ${formatType(col.dataType, col.nullable)}"
}

private def formatMetadataColumn(col: MetadataColumn): String = {
s"`${col.name}` ${formatType(col.dataType, col.isNullable)}"
}

private def formatType(dataType: DataType, nullable: Boolean): String = {
if (nullable) dataType.sql else s"${dataType.sql} NOT NULL"
}

private def indexColumns(cols: Seq[Column]): Map[String, Column] = {
index(cols)(_.name)
originMetaCols: Seq[MetadataColumn],
mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = {
val originMetaColNames = originMetaCols.map(_.name)
val originMetaSchema = CatalogV2Util.toStructType(originMetaCols)
val metaCols = filter(originMetaColNames, metadataColumns(table))
val metaSchema = CatalogV2Util.toStructType(metaCols)
SchemaUtils.validateSchemaCompatibility(originMetaSchema, metaSchema, resolver, mode)
}

private def indexMetadataColumns(cols: Seq[MetadataColumn]): Map[String, MetadataColumn] = {
index(cols)(_.name)
}

private def index[C](cols: Seq[C])(extractName: C => String): Map[String, C] = {
SchemaUtils.checkColumnNameDuplication(cols.map(extractName), conf.caseSensitiveAnalysis)
cols.map(col => normalize(extractName(col)) -> col).toMap
private def filter(colNames: Seq[String], cols: Seq[MetadataColumn]): Seq[MetadataColumn] = {
val normalizedColNames = colNames.map(normalize).toSet
cols.filter(col => normalizedColNames.contains(normalize(col.name)))
}

private def metadataColumns(table: Table): Seq[MetadataColumn] = table match {
Expand All @@ -177,4 +134,6 @@ private[sql] object V2TableUtil extends SQLConfHelper {
private def normalize(name: String): String = {
if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}

private def resolver: Resolver = conf.resolver
}
Loading