diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index a14041d4ccb5d..93ffee327cefe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -295,12 +295,19 @@ object StructType { case class StructType(fields: Seq[StructField]) extends DataType { require(StructType.validateFields(fields), "Found fields with the same name.") + /** + * Returns all field names in a [[Seq]]. + */ + lazy val fieldNames: Seq[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - fields.find(f => f.name == name).orNull + fields.find(f => f.name == name).getOrElse( + throw new IllegalArgumentException(s"Field ${name} does not exist.")) } /** @@ -308,6 +315,11 @@ case class StructType(fields: Seq[StructField]) extends DataType { * Those names which do not have matching fields will be ignored. */ def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (!nonExistFields.isEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } StructType(fields.filter(f => names.contains(f.name))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index c1e1b5333927d..c5bd7b391db41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql import org.scalatest.FunSuite -class SchemaSuite extends FunSuite { +class DataTypeSuite extends FunSuite { - test("constructing an ArrayType") { + test("construct an ArrayType") { val array = ArrayType(StringType) assert(ArrayType(StringType, false) === array) } - test("extracting fields from a StructType") { + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: StructField("b", LongType, false) :: @@ -36,14 +36,17 @@ class SchemaSuite extends FunSuite { assert(StructField("b", LongType, false) === struct("b")) - assert(struct("e") === null) + intercept[IllegalArgumentException] { + struct("e") + } val expectedStruct = StructType( StructField("b", LongType, false) :: StructField("d", FloatType, true) :: Nil) assert(expectedStruct === struct(Set("b", "d"))) - // struct does not have a field called e. So e is ignored. - assert(expectedStruct === struct(Set("b", "d", "e"))) + intercept[IllegalArgumentException] { + struct(Set("b", "d", "e", "f")) + } } }