From 803b1961a34d4d9f4c8ebcbe5544dd23fbaa720a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Sep 2017 04:52:08 +0000 Subject: [PATCH 1/3] Fix isCascadingTruncateTable for AggregatedDialect. --- .../apache/spark/sql/jdbc/AggregatedDialect.scala | 13 ++++++++++++- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 13 +++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 7432a1538ce97..3c1861669e4ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -43,6 +43,17 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } override def isCascadingTruncateTable(): Option[Boolean] = { - dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) + // If any dialect claims cascading truncate, this dialect is also cascading truncate. + // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. + val cascading = dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) + if (cascading.get) { + cascading + } else { + if (dialects.exists(_.isCascadingTruncateTable().isEmpty)) { + None + } else { + Some(false) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index fd12bb9e530b8..a6898547f2f54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -747,6 +747,19 @@ class JDBCSuite extends SparkFunSuite assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) assert(agg.isCascadingTruncateTable() === Some(true)) + + val agg2 = new AggregatedDialect(List(new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + if (sqlType % 2 == 0) { + Some(LongType) + } else { + None + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + }, testH2Dialect)) + assert(agg2.isCascadingTruncateTable() === None) } test("DB2Dialect type mapping") { From d4fadbbab30c820352eb9bd54af1428b35afe0a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Sep 2017 06:47:51 +0000 Subject: [PATCH 2/3] Add a test case for AggregatedDialect.isCascadingTruncateTable. --- .../spark/sql/jdbc/AggregatedDialect.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 37 +++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 3c1861669e4ec..ab660fe926192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -46,7 +46,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect // If any dialect claims cascading truncate, this dialect is also cascading truncate. // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. val cascading = dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) - if (cascading.get) { + if (cascading.getOrElse(false)) { cascading } else { if (dialects.exists(_.isCascadingTruncateTable().isEmpty)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a6898547f2f54..3129b6e011f3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -747,19 +747,34 @@ class JDBCSuite extends SparkFunSuite assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) assert(agg.isCascadingTruncateTable() === Some(true)) + } - val agg2 = new AggregatedDialect(List(new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + test("Aggregated dialects: isCascadingTruncateTable") { + def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect { + override def canHandle(url: String): Boolean = true override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = - if (sqlType % 2 == 0) { - Some(LongType) - } else { - None - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) - }, testH2Dialect)) - assert(agg2.isCascadingTruncateTable() === None) + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = None + override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable + } + + val dialectCombination = Seq( + List(genDialect(Some(true)), genDialect(Some(false)), genDialect(None)), + List(genDialect(Some(true)), genDialect(Some(true)), genDialect(None)), + List(genDialect(Some(false)), genDialect(Some(false)), genDialect(None)), + List(genDialect(Some(true)), genDialect(Some(true))), + List(genDialect(Some(false)), genDialect(Some(false))), + List(genDialect(None), genDialect(None)) + ) + + val expectedCascading = Seq(Some(true), Some(true), None, Some(true), Some(false), None) + + dialectCombination.zip(expectedCascading).foreach { case (dialects, cascading) => + val agg = new AggregatedDialect(dialects) + assert(agg.isCascadingTruncateTable() === cascading) + } } test("DB2Dialect type mapping") { From 7e5a57c3e4d9550d2ddd8a971293ace3984b5447 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 24 Sep 2017 00:37:49 +0000 Subject: [PATCH 3/3] Address comments. --- .../spark/sql/jdbc/AggregatedDialect.scala | 13 ++++------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 23 ++++++++----------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index ab660fe926192..1419d69f983ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -45,15 +45,10 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect override def isCascadingTruncateTable(): Option[Boolean] = { // If any dialect claims cascading truncate, this dialect is also cascading truncate. // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. - val cascading = dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) - if (cascading.getOrElse(false)) { - cascading - } else { - if (dialects.exists(_.isCascadingTruncateTable().isEmpty)) { - None - } else { - Some(false) - } + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) match { + case Some(true) => Some(true) + case _ if dialects.exists(_.isCascadingTruncateTable().isEmpty) => None + case _ => Some(false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3129b6e011f3b..34205e0b2bf08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -760,21 +760,18 @@ class JDBCSuite extends SparkFunSuite override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable } - val dialectCombination = Seq( - List(genDialect(Some(true)), genDialect(Some(false)), genDialect(None)), - List(genDialect(Some(true)), genDialect(Some(true)), genDialect(None)), - List(genDialect(Some(false)), genDialect(Some(false)), genDialect(None)), - List(genDialect(Some(true)), genDialect(Some(true))), - List(genDialect(Some(false)), genDialect(Some(false))), - List(genDialect(None), genDialect(None)) - ) - - val expectedCascading = Seq(Some(true), Some(true), None, Some(true), Some(false), None) - - dialectCombination.zip(expectedCascading).foreach { case (dialects, cascading) => + def testDialects(cascadings: List[Option[Boolean]], expected: Option[Boolean]): Unit = { + val dialects = cascadings.map(genDialect(_)) val agg = new AggregatedDialect(dialects) - assert(agg.isCascadingTruncateTable() === cascading) + assert(agg.isCascadingTruncateTable() === expected) } + + testDialects(List(Some(true), Some(false), None), Some(true)) + testDialects(List(Some(true), Some(true), None), Some(true)) + testDialects(List(Some(false), Some(false), None), None) + testDialects(List(Some(true), Some(true)), Some(true)) + testDialects(List(Some(false), Some(false)), Some(false)) + testDialects(List(None, None), None) } test("DB2Dialect type mapping") {