Skip to content

Commit

Permalink
Add a field containsNull to ArrayType to indicate if an array can con…
Browse files Browse the repository at this point in the history
…tain null values or not. If an ArrayType is constructed by "ArrayType(elementType)" (the existing constructor), the value of containsNull is false.
  • Loading branch information
yhuai committed Jul 12, 2014
1 parent 9168b83 commit dcaf22f
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def foldable = child.foldable && ordinal.foldable
override def references = children.flatMap(_.references).toSet
def dataType = child.dataType match {
case ArrayType(dt) => dt
case ArrayType(dt, _) => dt
case MapType(_, vt) => vt
}
override lazy val resolved =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

private lazy val elementTypes = child.dataType match {
case ArrayType(et) => et :: Nil
case ArrayType(et, _) => et :: Nil
case MapType(kt,vt) => kt :: vt :: Nil
}

Expand All @@ -102,7 +102,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)

override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_) =>
case ArrayType(_, _) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
case MapType(_, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
case StructType(fields) =>
StructType(fields.map(f =>
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType))
case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
case otherType => otherType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ object DataType extends RegexParsers {
"TimestampType" ^^^ TimestampType

protected lazy val arrayType: Parser[DataType] =
"ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType
"ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
}

protected lazy val mapType: Parser[DataType] =
"MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ {
Expand Down Expand Up @@ -241,9 +243,14 @@ case object FloatType extends FractionalType {
def simpleString: String = "float"
}

case class ArrayType(elementType: DataType) extends DataType {
object ArrayType {
def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
}

case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"${prefix}-- element: ${elementType.simpleString}\n")
builder.append(
s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
elementType match {
case array: ArrayType =>
array.buildFormattedString(s"$prefix |", builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD =
def applySchema[A](rdd: RDD[A], schema: StructType, f: A => Row): SchemaRDD =
applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ class SchemaRDD(
case (obj, (name, dataType)) =>
dataType match {
case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct))
case array @ ArrayType(struct: StructType) =>
case array @ ArrayType(struct: StructType, _) =>
val arrayValues = obj match {
case seq: Seq[Any] =>
seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
Expand Down
29 changes: 16 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ private[sql] object JsonRDD extends Logging {
val (topLevel, structLike) = values.partition(_.size == 1)
val topLevelFields = topLevel.filter {
name => resolved.get(prefix ++ name).get match {
case ArrayType(StructType(Nil)) => false
case ArrayType(_) => true
case ArrayType(StructType(Nil), _) => false
case ArrayType(_, _) => true
case struct: StructType => false
case _ => true
}
Expand All @@ -83,7 +83,8 @@ private[sql] object JsonRDD extends Logging {
val structType = makeStruct(nestedFields, prefix :+ name)
val dataType = resolved.get(prefix :+ name).get
dataType match {
case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true))
case array: ArrayType =>
Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true))
case struct: StructType => Some(StructField(name, structType, nullable = true))
// dataType is StringType means that we have resolved type conflicts involving
// primitive types and complex types. So, the type of name has been relaxed to
Expand All @@ -107,7 +108,7 @@ private[sql] object JsonRDD extends Logging {
case StructField(fieldName, dataType, nullable) => {
val newType = dataType match {
case NullType => StringType
case ArrayType(NullType) => ArrayType(StringType)
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
case struct: StructType => nullTypeToStringType(struct)
case other: DataType => other
}
Expand Down Expand Up @@ -148,8 +149,8 @@ private[sql] object JsonRDD extends Logging {
case StructField(name, _, _) => name
})
}
case (ArrayType(elementType1), ArrayType(elementType2)) =>
ArrayType(compatibleType(elementType1, elementType2))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
Expand All @@ -176,12 +177,13 @@ private[sql] object JsonRDD extends Logging {
* treat the element as String.
*/
private def typeOfArray(l: Seq[Any]): ArrayType = {
val containsNull = l.exists(v => v == null)
val elements = l.flatMap(v => Option(v))
if (elements.isEmpty) {
// If this JSON array is empty, we use NullType as a placeholder.
// If this array is not empty in other JSON objects, we can resolve
// the type after we have passed through all JSON objects.
ArrayType(NullType)
ArrayType(NullType, containsNull)
} else {
val elementType = elements.map {
e => e match {
Expand All @@ -193,7 +195,7 @@ private[sql] object JsonRDD extends Logging {
}
}.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2))

ArrayType(elementType)
ArrayType(elementType, containsNull)
}
}

Expand All @@ -220,15 +222,16 @@ private[sql] object JsonRDD extends Logging {
case (key: String, array: List[Any]) => {
// The value associted with the key is an array.
typeOfArray(array) match {
case ArrayType(StructType(Nil)) => {
case ArrayType(StructType(Nil), containsNull) => {
// The elements of this arrays are structs.
array.asInstanceOf[List[Map[String, Any]]].flatMap {
element => allKeysWithValueTypes(element)
}.map {
case (k, dataType) => (s"$key.$k", dataType)
} :+ (key, ArrayType(StructType(Nil)))
} :+ (key, ArrayType(StructType(Nil), containsNull))
}
case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil
case ArrayType(elementType, containsNull) =>
(key, ArrayType(elementType, containsNull)) :: Nil
}
}
case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil
Expand Down Expand Up @@ -340,7 +343,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
case ArrayType(elementType) =>
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case StringType => toString(value)
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
Expand All @@ -363,7 +366,7 @@ private[sql] object JsonRDD extends Logging {
v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull)

// ArrayType(StructType)
case (StructField(name, ArrayType(structType: StructType), _), i) =>
case (StructField(name, ArrayType(structType: StructType, _), _), i) =>
row.update(i,
json.get(name).flatMap(v => Option(v)).map(
v => v.asInstanceOf[Seq[Any]].map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ private[sql] object CatalystConverter {
val fieldType: DataType = field.dataType
fieldType match {
// For native JVM types we use a converter with native arrays
case ArrayType(elementType: NativeType) => {
case ArrayType(elementType: NativeType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
}
// This is for other types of arrays, including those with nested fields
case ArrayType(elementType: DataType) => {
case ArrayType(elementType: DataType, false) => {
new CatalystArrayConverter(elementType, fieldIndex, parent)
}
case StructType(fields: Seq[StructField]) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
case t @ ArrayType(_) => writeArray(
case t @ ArrayType(_, false) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
case t @ MapType(_, _) => writeMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetOriginalType.LIST => { // TODO: check enums!
assert(groupType.getFieldCount == 1)
val field = groupType.getFields.apply(0)
new ArrayType(toDataType(field))
ArrayType(toDataType(field), false)
}
case ParquetOriginalType.MAP => {
assert(
Expand All @@ -127,7 +127,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
new MapType(keyType, valueType)
MapType(keyType, valueType)
}
case _ => {
// Note: the order of these checks is important!
Expand All @@ -137,18 +137,18 @@ private[parquet] object ParquetTypesConverter extends Logging {
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
new MapType(keyType, valueType)
MapType(keyType, valueType)
} else if (correspondsToArray(groupType)) { // ArrayType
val elementType = toDataType(groupType.getFields.apply(0))
new ArrayType(elementType)
ArrayType(elementType, false)
} else { // everything else: StructType
val fields = groupType
.getFields
.map(ptype => new StructField(
ptype.getName,
toDataType(ptype),
ptype.getRepetition != Repetition.REQUIRED))
new StructType(fields)
StructType(fields)
}
}
}
Expand All @@ -168,7 +168,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
case StringType => Some(ParquetPrimitiveTypeName.BINARY)
case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN)
case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE)
case ArrayType(ByteType) =>
case ArrayType(ByteType, false) =>
Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)
case FloatType => Some(ParquetPrimitiveTypeName.FLOAT)
case IntegerType => Some(ParquetPrimitiveTypeName.INT32)
Expand Down Expand Up @@ -231,7 +231,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
new ParquetPrimitiveType(repetition, primitiveType.get, name)
} else {
ctype match {
case ArrayType(elementType) => {
case ArrayType(elementType, false) => {
val parquetElementType = fromDataType(
elementType,
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
Expand Down
16 changes: 14 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ class JsonSuite extends QueryTest {
checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType))
checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType))
checkDataType(ArrayType(IntegerType), StructType(Nil), StringType)
checkDataType(
ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType))

// StructType
checkDataType(StructType(Nil), StructType(Nil), StructType(Nil))
Expand Down Expand Up @@ -200,7 +212,7 @@ class JsonSuite extends QueryTest {
AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() ::
AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() ::
AttributeReference("arrayOfLong", ArrayType(LongType), true)() ::
AttributeReference("arrayOfNull", ArrayType(StringType), true)() ::
AttributeReference("arrayOfNull", ArrayType(StringType, true), true)() ::
AttributeReference("arrayOfString", ArrayType(StringType), true)() ::
AttributeReference("arrayOfStruct", ArrayType(
StructType(StructField("field1", BooleanType, true) ::
Expand Down Expand Up @@ -451,7 +463,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict)

val expectedSchema =
AttributeReference("array1", ArrayType(StringType), true)() ::
AttributeReference("array1", ArrayType(StringType, true), true)() ::
AttributeReference("array2", ArrayType(StructType(
StructField("field", LongType, true) :: Nil)), true)() :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ)) =>
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_,_], MapType(kType, vType)) =>
map.map {
Expand All @@ -274,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ)) =>
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_,_], MapType(kType, vType)) =>
map.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ object HiveMetastoreTypes extends RegexParsers {
"varchar\\((\\d+)\\)".r ^^^ StringType

protected lazy val arrayType: Parser[DataType] =
"array" ~> "<" ~> dataType <~ ">" ^^ ArrayType
"array" ~> "<" ~> dataType <~ ">" ^^ {
case tpe => ArrayType(tpe)
}

protected lazy val mapType: Parser[DataType] =
"map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
Expand Down Expand Up @@ -228,7 +230,7 @@ object HiveMetastoreTypes extends RegexParsers {
}

def toMetastoreType(dt: DataType): String = dt match {
case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>"
case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
case StructType(fields) =>
s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
case MapType(keyType, valueType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ private[hive] trait HiveInspectors {
}

def toInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
case ArrayType(tpe, _) =>
ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
case MapType(keyType, valueType) =>
ObjectInspectorFactory.getStandardMapObjectInspector(
toInspector(keyType), toInspector(valueType))
Expand Down

0 comments on commit dcaf22f

Please sign in to comment.