From 7d2f70d1ead43727a806305803daac7d4bdbb2e3 Mon Sep 17 00:00:00 2001 From: "xiyu.zk" Date: Thu, 21 May 2026 16:31:31 +0800 Subject: [PATCH] [spark] Support catalog-qualified CREATE TABLE LIKE --- ...stractPaimonSparkSqlExtensionsParser.scala | 174 ++++++++++++++++- .../PaimonSqlExtensions.g4 | 7 +- ...stractPaimonSparkSqlExtensionsParser.scala | 173 ++++++++++++++++- .../PaimonSqlExtensionsAstBuilder.scala | 53 ++++++ .../CatalogQualifiedCreateTableLikeTest.scala | 180 ++++++++++++++++++ 5 files changed, 583 insertions(+), 4 deletions(-) create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala index 67f72d953a3c..3529944f37ef 100644 --- a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala @@ -52,12 +52,46 @@ import scala.collection.JavaConverters._ * @param delegate * The extension parser. */ +// Keep this class in the Spark 4.0 module so it is compiled against Spark 4.0's ParserInterface. abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterface) extends org.apache.spark.sql.catalyst.parser.ParserInterface with Logging { private lazy val substitutor = new VariableSubstitution() private lazy val astBuilder = new PaimonSqlExtensionsAstBuilder(delegate) + private val nonReservedIdentifierTokenTypes = Set( + PaimonSqlExtensionsParser.ALTER, + PaimonSqlExtensionsParser.AS, + PaimonSqlExtensionsParser.CALL, + PaimonSqlExtensionsParser.CREATE, + PaimonSqlExtensionsParser.DAYS, + PaimonSqlExtensionsParser.DELETE, + PaimonSqlExtensionsParser.EXISTS, + PaimonSqlExtensionsParser.HOURS, + PaimonSqlExtensionsParser.IF, + PaimonSqlExtensionsParser.LIKE, + PaimonSqlExtensionsParser.NOT, + PaimonSqlExtensionsParser.OF, + PaimonSqlExtensionsParser.OR, + PaimonSqlExtensionsParser.TABLE, + PaimonSqlExtensionsParser.REPLACE, + PaimonSqlExtensionsParser.RETAIN, + PaimonSqlExtensionsParser.VERSION, + PaimonSqlExtensionsParser.TAG, + PaimonSqlExtensionsParser.TRUE, + PaimonSqlExtensionsParser.FALSE, + PaimonSqlExtensionsParser.MAP, + PaimonSqlExtensionsParser.COPY, + PaimonSqlExtensionsParser.INTO, + PaimonSqlExtensionsParser.FROM, + PaimonSqlExtensionsParser.FILE_FORMAT, + PaimonSqlExtensionsParser.PATTERN, + PaimonSqlExtensionsParser.FORCE, + PaimonSqlExtensionsParser.ON_ERROR, + PaimonSqlExtensionsParser.ABORT_STATEMENT, + PaimonSqlExtensionsParser.OVERWRITE, + PaimonSqlExtensionsParser.CSV + ) /** Parses a string to a LogicalPlan. */ override def parsePlan(sqlText: String): LogicalPlan = { @@ -66,7 +100,14 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) .asInstanceOf[LogicalPlan] } else { - var plan = delegate.parsePlan(sqlText) + var plan = + try { + delegate.parsePlan(sqlText) + } catch { + case _: ParseException if maybeCatalogCreateTableLike(sqlTextAfterSubstitution) => + parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) + .asInstanceOf[LogicalPlan] + } val sparkSession = PaimonSparkSession.active parserRules(sparkSession).foreach( rule => { @@ -144,6 +185,137 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf normalized.startsWith("copy into") } + /** + * Cheap token-level check for `CREATE TABLE [IF NOT EXISTS] x.y[.z] LIKE ...` shape. Used as a + * gate for the Paimon parser fallback when the delegate parser rejects a catalog-qualified CREATE + * TABLE LIKE statement. + */ + private def maybeCatalogCreateTableLike(sqlText: String): Boolean = { + if (org.apache.spark.SPARK_VERSION < "3.4") { + return false + } + if (!startsWithCreateTable(sqlText)) { + return false + } + + tokenStream(sqlText) match { + case Some(tokens) => maybeCreateTableLike(tokens) + case None => false + } + } + + private def tokenStream(sqlText: String): Option[CommonTokenStream] = { + try { + val lexer = new PaimonSqlExtensionsLexer( + new UpperCaseCharStream(CharStreams.fromString(sqlText))) + lexer.removeErrorListeners() + lexer.addErrorListener(PaimonParseErrorListener) + + val tokens = new CommonTokenStream(lexer) + tokens.fill() + Some(tokens) + } catch { + case _: PaimonParseException => None + } + } + + private def maybeCreateTableLike(tokenStream: CommonTokenStream): Boolean = { + val tokens = tokenStream.getTokens.asScala + .filter(token => token.getChannel == Token.DEFAULT_CHANNEL) + .filterNot(token => token.getType == Token.EOF) + + if (tokens.length < 5) return false + if (tokens(0).getType != PaimonSqlExtensionsParser.CREATE) return false + if (tokens(1).getType != PaimonSqlExtensionsParser.TABLE) return false + + var idx = 2 + if ( + idx + 2 < tokens.length && + tokens(idx).getType == PaimonSqlExtensionsParser.IF && + tokens(idx + 1).getType == PaimonSqlExtensionsParser.NOT && + tokens(idx + 2).getType == PaimonSqlExtensionsParser.EXISTS + ) { + idx += 3 + } + + if (idx >= tokens.length || !isIdentifierToken(tokens(idx))) return false + idx += 1 + + while ( + idx + 1 < tokens.length && + tokens(idx).getText == "." && + isIdentifierToken(tokens(idx + 1)) + ) { + idx += 2 + } + + idx < tokens.length && tokens(idx).getType == PaimonSqlExtensionsParser.LIKE + } + + private def isIdentifierToken(token: Token): Boolean = { + token.getType == PaimonSqlExtensionsParser.IDENTIFIER || + token.getType == PaimonSqlExtensionsParser.BACKQUOTED_IDENTIFIER || + nonReservedIdentifierTokenTypes.contains(token.getType) + } + + private def startsWithCreateTable(sqlText: String): Boolean = { + val createIndex = skipWhitespaceAndComments(sqlText, 0) + if (!matchesWord(sqlText, createIndex, "create")) { + return false + } + + val tableIndex = skipWhitespaceAndComments(sqlText, createIndex + "create".length) + matchesWord(sqlText, tableIndex, "table") + } + + private def skipWhitespaceAndComments(sqlText: String, start: Int): Int = { + var index = start + var continue = true + + while (continue) { + while (index < sqlText.length && sqlText.charAt(index).isWhitespace) { + index += 1 + } + + if ( + index + 1 < sqlText.length && + sqlText.charAt(index) == '-' && + sqlText.charAt(index + 1) == '-' + ) { + index += 2 + while ( + index < sqlText.length && + sqlText.charAt(index) != '\n' && + sqlText.charAt(index) != '\r' + ) { + index += 1 + } + } else if ( + index + 1 < sqlText.length && + sqlText.charAt(index) == '/' && + sqlText.charAt(index + 1) == '*' + ) { + val close = sqlText.indexOf("*/", index + 2) + index = if (close >= 0) close + 2 else sqlText.length + } else { + continue = false + } + } + + index + } + + private def matchesWord(sqlText: String, index: Int, word: String): Boolean = { + index + word.length <= sqlText.length && + sqlText.regionMatches(true, index, word, 0, word.length) && + (index + word.length == sqlText.length || + !isIdentifierPart(sqlText.charAt(index + word.length))) + } + + private def isIdentifierPart(char: Char): Boolean = { + char.isLetterOrDigit || char == '_' + } + protected def parse[T](command: String)(toResult: PaimonSqlExtensionsParser => T): T = { val lexer = new PaimonSqlExtensionsLexer( new UpperCaseCharStream(CharStreams.fromString(command))) diff --git a/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 b/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 index 4255e530c122..12a5bc8c51b6 100644 --- a/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 +++ b/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 @@ -84,6 +84,8 @@ statement FROM multipartIdentifier fileFormatClause overwriteClause? #copyIntoLocation + | CREATE TABLE (IF NOT EXISTS)? target=multipartIdentifier + LIKE source=multipartIdentifier ( . )*? #createTableLike ; callArgument @@ -197,8 +199,8 @@ quotedIdentifier ; nonReserved - : ALTER | AS | CALL | CREATE | DAYS | DELETE | EXISTS | HOURS | IF | NOT | OF | OR | TABLE - | REPLACE | RETAIN | VERSION | TAG + : ALTER | AS | CALL | CREATE | DAYS | DELETE | EXISTS | HOURS | IF | LIKE + | NOT | OF | OR | TABLE | REPLACE | RETAIN | VERSION | TAG | TRUE | FALSE | MAP | COPY | INTO | FROM | FILE_FORMAT | PATTERN | FORCE | ON_ERROR | ABORT_STATEMENT | OVERWRITE @@ -214,6 +216,7 @@ DELETE: 'DELETE'; EXISTS: 'EXISTS'; HOURS: 'HOURS'; IF : 'IF'; +LIKE: 'LIKE'; MINUTES: 'MINUTES'; NOT: 'NOT'; OF: 'OF'; diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala index 67f72d953a3c..1e0c13a573fc 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala @@ -58,6 +58,39 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf private lazy val substitutor = new VariableSubstitution() private lazy val astBuilder = new PaimonSqlExtensionsAstBuilder(delegate) + private val nonReservedIdentifierTokenTypes = Set( + PaimonSqlExtensionsParser.ALTER, + PaimonSqlExtensionsParser.AS, + PaimonSqlExtensionsParser.CALL, + PaimonSqlExtensionsParser.CREATE, + PaimonSqlExtensionsParser.DAYS, + PaimonSqlExtensionsParser.DELETE, + PaimonSqlExtensionsParser.EXISTS, + PaimonSqlExtensionsParser.HOURS, + PaimonSqlExtensionsParser.IF, + PaimonSqlExtensionsParser.LIKE, + PaimonSqlExtensionsParser.NOT, + PaimonSqlExtensionsParser.OF, + PaimonSqlExtensionsParser.OR, + PaimonSqlExtensionsParser.TABLE, + PaimonSqlExtensionsParser.REPLACE, + PaimonSqlExtensionsParser.RETAIN, + PaimonSqlExtensionsParser.VERSION, + PaimonSqlExtensionsParser.TAG, + PaimonSqlExtensionsParser.TRUE, + PaimonSqlExtensionsParser.FALSE, + PaimonSqlExtensionsParser.MAP, + PaimonSqlExtensionsParser.COPY, + PaimonSqlExtensionsParser.INTO, + PaimonSqlExtensionsParser.FROM, + PaimonSqlExtensionsParser.FILE_FORMAT, + PaimonSqlExtensionsParser.PATTERN, + PaimonSqlExtensionsParser.FORCE, + PaimonSqlExtensionsParser.ON_ERROR, + PaimonSqlExtensionsParser.ABORT_STATEMENT, + PaimonSqlExtensionsParser.OVERWRITE, + PaimonSqlExtensionsParser.CSV + ) /** Parses a string to a LogicalPlan. */ override def parsePlan(sqlText: String): LogicalPlan = { @@ -66,7 +99,14 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) .asInstanceOf[LogicalPlan] } else { - var plan = delegate.parsePlan(sqlText) + var plan = + try { + delegate.parsePlan(sqlText) + } catch { + case _: ParseException if maybeCatalogCreateTableLike(sqlTextAfterSubstitution) => + parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) + .asInstanceOf[LogicalPlan] + } val sparkSession = PaimonSparkSession.active parserRules(sparkSession).foreach( rule => { @@ -144,6 +184,137 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf normalized.startsWith("copy into") } + /** + * Cheap token-level check for `CREATE TABLE [IF NOT EXISTS] x.y[.z] LIKE ...` shape. Used as a + * gate for the Paimon parser fallback when the delegate parser rejects a catalog-qualified CREATE + * TABLE LIKE statement. + */ + private def maybeCatalogCreateTableLike(sqlText: String): Boolean = { + if (org.apache.spark.SPARK_VERSION < "3.4") { + return false + } + if (!startsWithCreateTable(sqlText)) { + return false + } + + tokenStream(sqlText) match { + case Some(tokens) => maybeCreateTableLike(tokens) + case None => false + } + } + + private def tokenStream(sqlText: String): Option[CommonTokenStream] = { + try { + val lexer = new PaimonSqlExtensionsLexer( + new UpperCaseCharStream(CharStreams.fromString(sqlText))) + lexer.removeErrorListeners() + lexer.addErrorListener(PaimonParseErrorListener) + + val tokens = new CommonTokenStream(lexer) + tokens.fill() + Some(tokens) + } catch { + case _: PaimonParseException => None + } + } + + private def maybeCreateTableLike(tokenStream: CommonTokenStream): Boolean = { + val tokens = tokenStream.getTokens.asScala + .filter(token => token.getChannel == Token.DEFAULT_CHANNEL) + .filterNot(token => token.getType == Token.EOF) + + if (tokens.length < 5) return false + if (tokens(0).getType != PaimonSqlExtensionsParser.CREATE) return false + if (tokens(1).getType != PaimonSqlExtensionsParser.TABLE) return false + + var idx = 2 + if ( + idx + 2 < tokens.length && + tokens(idx).getType == PaimonSqlExtensionsParser.IF && + tokens(idx + 1).getType == PaimonSqlExtensionsParser.NOT && + tokens(idx + 2).getType == PaimonSqlExtensionsParser.EXISTS + ) { + idx += 3 + } + + if (idx >= tokens.length || !isIdentifierToken(tokens(idx))) return false + idx += 1 + + while ( + idx + 1 < tokens.length && + tokens(idx).getText == "." && + isIdentifierToken(tokens(idx + 1)) + ) { + idx += 2 + } + + idx < tokens.length && tokens(idx).getType == PaimonSqlExtensionsParser.LIKE + } + + private def isIdentifierToken(token: Token): Boolean = { + token.getType == PaimonSqlExtensionsParser.IDENTIFIER || + token.getType == PaimonSqlExtensionsParser.BACKQUOTED_IDENTIFIER || + nonReservedIdentifierTokenTypes.contains(token.getType) + } + + private def startsWithCreateTable(sqlText: String): Boolean = { + val createIndex = skipWhitespaceAndComments(sqlText, 0) + if (!matchesWord(sqlText, createIndex, "create")) { + return false + } + + val tableIndex = skipWhitespaceAndComments(sqlText, createIndex + "create".length) + matchesWord(sqlText, tableIndex, "table") + } + + private def skipWhitespaceAndComments(sqlText: String, start: Int): Int = { + var index = start + var continue = true + + while (continue) { + while (index < sqlText.length && sqlText.charAt(index).isWhitespace) { + index += 1 + } + + if ( + index + 1 < sqlText.length && + sqlText.charAt(index) == '-' && + sqlText.charAt(index + 1) == '-' + ) { + index += 2 + while ( + index < sqlText.length && + sqlText.charAt(index) != '\n' && + sqlText.charAt(index) != '\r' + ) { + index += 1 + } + } else if ( + index + 1 < sqlText.length && + sqlText.charAt(index) == '/' && + sqlText.charAt(index + 1) == '*' + ) { + val close = sqlText.indexOf("*/", index + 2) + index = if (close >= 0) close + 2 else sqlText.length + } else { + continue = false + } + } + + index + } + + private def matchesWord(sqlText: String, index: Int, word: String): Boolean = { + index + word.length <= sqlText.length && + sqlText.regionMatches(true, index, word, 0, word.length) && + (index + word.length == sqlText.length || + !isIdentifierPart(sqlText.charAt(index + word.length))) + } + + private def isIdentifierPart(char: Char): Boolean = { + char.isLetterOrDigit || char == '_' + } + protected def parse[T](command: String)(toResult: PaimonSqlExtensionsParser => T): T = { val lexer = new PaimonSqlExtensionsLexer( new UpperCaseCharStream(CharStreams.fromString(command))) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala index dd3f3c4d15a4..da716ced11c5 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala @@ -26,11 +26,13 @@ import org.antlr.v4.runtime._ import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.parser.extensions.PaimonParserUtils.withOrigin import org.apache.spark.sql.catalyst.parser.extensions.PaimonSqlExtensionsParser._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.{CreateTableLikeCommand => SparkCreateTableLikeCommand} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -99,6 +101,13 @@ class PaimonSqlExtensionsAstBuilder(delegate: ParserInterface) ShowTagsCommand(typedVisit[Seq[String]](ctx.multipartIdentifier)) } + /** Create a CREATE TABLE LIKE logical command. */ + override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { + sparkCreateTableLikeCommand(ctx).copy( + targetTable = toTableIdentifier(typedVisit[Seq[String]](ctx.target)), + sourceTable = toTableIdentifier(typedVisit[Seq[String]](ctx.source))) + } + /** Create a CREATE OR REPLACE TAG logical command. */ override def visitCreateOrReplaceTag(ctx: CreateOrReplaceTagContext): CreateOrReplaceTagCommand = withOrigin(ctx) { @@ -253,6 +262,50 @@ class PaimonSqlExtensionsAstBuilder(delegate: ParserInterface) private def toSeq[T](list: java.util.List[T]) = toBuffer(list) + private def toTableIdentifier(identifier: Seq[String]): TableIdentifier = { + identifier match { + case Seq(table) => + TableIdentifier(table) + case Seq(database, table) => + TableIdentifier(table, Some(database)) + case parts => + TableIdentifier( + parts.last, + Some(parts.slice(1, parts.length - 1).mkString(".")), + Some(parts.head)) + } + } + + private def sparkCreateTableLikeCommand( + ctx: CreateTableLikeContext): SparkCreateTableLikeCommand = { + delegate.parsePlan(createSparkCreateTableLikeSql(ctx)) match { + case command: SparkCreateTableLikeCommand => command + case plan => + throw new UnsupportedOperationException( + s"Expected Spark CREATE TABLE LIKE command, but got ${plan.nodeName}.") + } + } + + private def createSparkCreateTableLikeSql(ctx: CreateTableLikeContext): String = { + val stream = ctx.getStart.getInputStream + val baseStart = ctx.getStart.getStartIndex + val baseStop = ctx.getStop.getStopIndex + val targetStart = ctx.target.getStart.getStartIndex + val targetStop = ctx.target.getStop.getStopIndex + val sourceStart = ctx.source.getStart.getStartIndex + val sourceStop = ctx.source.getStop.getStopIndex + + val prefix = stream.getText(Interval.of(baseStart, targetStart - 1)) + val middle = stream.getText(Interval.of(targetStop + 1, sourceStart - 1)) + val suffix = if (sourceStop < baseStop) { + stream.getText(Interval.of(sourceStop + 1, baseStop)) + } else { + "" + } + + prefix + "__paimon_create_like_target" + middle + "__paimon_create_like_source" + suffix + } + private def reconstructSqlString(ctx: ParserRuleContext): String = { toBuffer(ctx.children) .map { diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala new file mode 100644 index 000000000000..03e109919b53 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.commands.PaimonCreateTableLikeCommand + +import org.junit.jupiter.api.Assertions + +import scala.collection.JavaConverters._ + +class CatalogQualifiedCreateTableLikeTest extends PaimonSparkTestBase { + + test("Create table like with catalog-qualified identifiers") { + assume(gteqSpark3_4) + withTable("source_tbl", "target_tbl", "target_from_qualified_source", "qualified_target") { + createSourceTable() + + sql(s"CREATE TABLE paimon.$dbName0.target_tbl LIKE paimon.$dbName0.source_tbl") + assertCreatedLike("target_tbl") + + sql(s"CREATE TABLE target_from_qualified_source LIKE paimon.$dbName0.source_tbl") + assertCreatedLike("target_from_qualified_source") + + sql(s"CREATE TABLE paimon.$dbName0.qualified_target LIKE source_tbl") + assertCreatedLike("qualified_target") + } + } + + test("Create table like if not exists with catalog-qualified identifiers") { + assume(gteqSpark3_4) + withTable("source_tbl", "target_tbl") { + createSourceTable() + sql(""" + |CREATE TABLE target_tbl ( + | id BIGINT, + | pt STRING + |) COMMENT 'target comment' + |PARTITIONED BY (pt) + |TBLPROPERTIES ( + | 'primary-key' = 'id,pt', + | 'bucket' = '3' + |) + |""".stripMargin) + + val targetSchema = spark.table("target_tbl").schema + val targetLocation = loadTable("target_tbl").location().toString + + sql(s""" + |CREATE TABLE IF NOT EXISTS paimon.$dbName0.target_tbl + |LIKE paimon.$dbName0.source_tbl + |""".stripMargin) + + val target = loadTable("target_tbl") + Assertions.assertEquals(targetSchema, spark.table("target_tbl").schema) + Assertions.assertFalse(spark.table("target_tbl").schema.fieldNames.contains("name")) + Assertions.assertEquals("target comment", target.comment().get()) + Assertions.assertEquals("3", target.options().get("bucket")) + Assertions.assertEquals(targetLocation, target.location().toString) + } + } + + test("Create table like clauses with catalog-qualified identifiers") { + assume(gteqSpark3_4) + withTable("source_tbl", "target_tbl") { + createSourceTable() + + sql(s""" + |CREATE TABLE paimon.$dbName0.target_tbl + |LIKE paimon.$dbName0.source_tbl + |USING paimon + |TBLPROPERTIES ( + | 'bucket' = '8', + | 'target-file-size' = '256MB' + |) + |""".stripMargin) + + val source = loadTable("source_tbl") + val target = loadTable("target_tbl") + Assertions.assertEquals(spark.table("source_tbl").schema, spark.table("target_tbl").schema) + Assertions.assertEquals("source comment", target.comment().get()) + Assertions.assertEquals(List("pt"), target.partitionKeys().asScala.toList) + Assertions.assertEquals(List("id", "pt"), target.primaryKeys().asScala.toList) + Assertions.assertEquals("8", target.options().get("bucket")) + Assertions.assertEquals("256MB", target.options().get("target-file-size")) + Assertions.assertNotEquals(source.location().toString, target.location().toString) + } + } + + test("Create table like stored as is unsupported with catalog-qualified identifiers") { + assume(gteqSpark3_4) + withTable("source_tbl", "target_tbl") { + sql("CREATE TABLE source_tbl (id INT)") + + val error = intercept[Exception] { + sql(s""" + |CREATE TABLE paimon.$dbName0.target_tbl + |LIKE paimon.$dbName0.source_tbl + |STORED AS PARQUET + |""".stripMargin) + }.getMessage + + Assertions.assertTrue( + error.contains("CREATE TABLE LIKE ... STORED AS is not supported for SparkCatalog.")) + } + } + + test("Create table like parser accepts non-reserved and nested identifiers") { + assume(gteqSpark3_4) + + val nonReservedIdentifierCommand = + parseCreateTableLikeCommand("CREATE TABLE paimon.test.tag LIKE paimon.test.source_tbl") + Assertions.assertEquals("tag", nonReservedIdentifierCommand.targetIdent.name()) + Assertions.assertEquals(Seq("test"), nonReservedIdentifierCommand.targetIdent.namespace().toSeq) + + val nestedIdentifierCommand = + parseCreateTableLikeCommand( + "CREATE TABLE paimon.test.extra.target_tbl LIKE paimon.test.extra.source_tbl") + Assertions.assertEquals("target_tbl", nestedIdentifierCommand.targetIdent.name()) + Assertions.assertEquals( + Seq("test.extra"), + nestedIdentifierCommand.targetIdent.namespace().toSeq) + Assertions.assertEquals("source_tbl", nestedIdentifierCommand.sourceIdent.name()) + Assertions.assertEquals( + Seq("test.extra"), + nestedIdentifierCommand.sourceIdent.namespace().toSeq) + } + + private def createSourceTable(): Unit = { + sql(""" + |CREATE TABLE source_tbl ( + | id BIGINT, + | name STRING COMMENT 'name column', + | pt STRING + |) COMMENT 'source comment' + |PARTITIONED BY (pt) + |TBLPROPERTIES ( + | 'primary-key' = 'id,pt', + | 'bucket' = '2', + | 'target-file-size' = '64MB' + |) + |""".stripMargin) + } + + private def assertCreatedLike(tableName: String): Unit = { + val target = loadTable(tableName) + + Assertions.assertEquals(spark.table("source_tbl").schema, spark.table(tableName).schema) + Assertions.assertEquals("source comment", target.comment().get()) + Assertions.assertEquals(List("pt"), target.partitionKeys().asScala.toList) + Assertions.assertEquals(List("id", "pt"), target.primaryKeys().asScala.toList) + Assertions.assertEquals("2", target.options().get("bucket")) + Assertions.assertEquals("64MB", target.options().get("target-file-size")) + } + + private def parseCreateTableLikeCommand(sqlText: String): PaimonCreateTableLikeCommand = { + spark.sessionState.sqlParser.parsePlan(sqlText) match { + case command: PaimonCreateTableLikeCommand => command + case plan => + throw new AssertionError( + s"Expected PaimonCreateTableLikeCommand, but got ${plan.nodeName}.") + } + } +}