From 5567f20dc9de0a2fab2f67496ab7e6ec8dd51903 Mon Sep 17 00:00:00 2001 From: darionyaphet Date: Fri, 23 Jun 2017 18:11:29 +0800 Subject: [PATCH] [SPARK-21191] DataFrame Row StructType check duplicate name --- .../apache/spark/sql/types/StructType.scala | 13 +++++++--- .../org/apache/spark/sql/DatasetHolder.scala | 8 +++++- .../scala/org/apache/spark/sql/RowSuite.scala | 25 +++++++++++++++++++ 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 54006e20a3eb6..a0a8e0c37488e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -97,6 +97,8 @@ import org.apache.spark.util.Utils @InterfaceStability.Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + require(fields.map(_.name).distinct.size == fields.size, "Struct fields have duplicate name") + /** No-arg constructor for kryo. */ def this() = this(Array.empty[StructField]) @@ -129,6 +131,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru *}}} */ def add(field: StructField): StructType = { + if (fieldNames.contains(field.name)) { + throw new IllegalArgumentException(s"${field.name} is duplicated") + } StructType(fields :+ field) } @@ -141,7 +146,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType) */ def add(name: String, dataType: DataType): StructType = { - StructType(fields :+ StructField(name, dataType, nullable = true, Metadata.empty)) + add(StructField(name, dataType, nullable = true, Metadata.empty)) } /** @@ -153,7 +158,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType, true) */ def add(name: String, dataType: DataType, nullable: Boolean): StructType = { - StructType(fields :+ StructField(name, dataType, nullable, Metadata.empty)) + add(StructField(name, dataType, nullable, Metadata.empty)) } /** @@ -170,7 +175,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = { - StructType(fields :+ StructField(name, dataType, nullable, metadata)) + add(StructField(name, dataType, nullable, metadata)) } /** @@ -187,7 +192,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru dataType: DataType, nullable: Boolean, comment: String): StructType = { - StructType(fields :+ StructField(name, dataType, nullable).withComment(comment)) + add(StructField(name, dataType, nullable).withComment(comment)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 582d4a3670b8e..0654b1037d0fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -41,5 +41,11 @@ case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = ds.toDF() - def toDF(colNames: String*): DataFrame = ds.toDF(colNames : _*) + def toDF(colNames: String*): DataFrame = { + if (colNames.distinct.size < colNames.size) { + throw new IllegalArgumentException("column have duplicate name") + } + + ds.toDF(colNames : _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7516be315dd2d..f5311da9248dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -54,6 +54,31 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { assert(row.isNullAt(0)) } + test("Create row with duplicate name") { + intercept[IllegalArgumentException] { + val dataset = Seq( + (0, 3, 4), + (1, 3, 3), + (2, 3, 5), + (2, 4, 3) + ).toDF("1", "1", "2") + } + + intercept[IllegalArgumentException] { + val struct = + StructType(StructField("a", IntegerType, true) :: + StructField("a", LongType, false) :: + StructField("c", BooleanType, false) :: Nil) + } + + intercept[IllegalArgumentException] { + val struct = (new StructType). + add(StructField("a", IntegerType, true)). + add(StructField("a", LongType, false)). + add(StructField("c", StringType, true)) + } + } + test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1)