Skip to content

Commit

Permalink
Update data type tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 23, 2014
1 parent 8da1a17 commit aa92e84
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,31 @@ object StructType {
case class StructType(fields: Seq[StructField]) extends DataType {
require(StructType.validateFields(fields), "Found fields with the same name.")

/**
* Returns all field names in a [[Seq]].
*/
lazy val fieldNames: Seq[String] = fields.map(_.name)
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet

/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
fields.find(f => f.name == name).orNull
fields.find(f => f.name == name).getOrElse(
throw new IllegalArgumentException(s"Field ${name} does not exist."))
}

/**
* Returns a [[StructType]] containing [[StructField]]s of the given names.
* Those names which do not have matching fields will be ignored.
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
if (!nonExistFields.isEmpty) {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
StructType(fields.filter(f => names.contains(f.name)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ package org.apache.spark.sql

import org.scalatest.FunSuite

class SchemaSuite extends FunSuite {
class DataTypeSuite extends FunSuite {

test("constructing an ArrayType") {
test("construct an ArrayType") {
val array = ArrayType(StringType)

assert(ArrayType(StringType, false) === array)
}

test("extracting fields from a StructType") {
test("extract fields from a StructType") {
val struct = StructType(
StructField("a", IntegerType, true) ::
StructField("b", LongType, false) ::
Expand All @@ -36,14 +36,17 @@ class SchemaSuite extends FunSuite {

assert(StructField("b", LongType, false) === struct("b"))

assert(struct("e") === null)
intercept[IllegalArgumentException] {
struct("e")
}

val expectedStruct = StructType(
StructField("b", LongType, false) ::
StructField("d", FloatType, true) :: Nil)

assert(expectedStruct === struct(Set("b", "d")))
// struct does not have a field called e. So e is ignored.
assert(expectedStruct === struct(Set("b", "d", "e")))
intercept[IllegalArgumentException] {
struct(Set("b", "d", "e", "f"))
}
}
}

0 comments on commit aa92e84

Please sign in to comment.