Skip to content

Commit

Permalink
change metadata type in StructField for Scala/Java
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Oct 14, 2014
1 parent 60cc131 commit 1fcbf13
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
import org.apache.spark.sql.catalyst.util.Metadata

abstract class Expression extends TreeNode[Expression] {
self: Product =>
Expand All @@ -44,7 +45,7 @@ abstract class Expression extends TreeNode[Expression] {
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))

/** Returns the metadata when an expression is a reference to another expression with metadata. */
def metadata: Map[String, Any] = Map.empty
def metadata: Metadata = Metadata.empty

/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): EvaluatedType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.Metadata

object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
Expand Down Expand Up @@ -86,7 +87,7 @@ case class Alias(child: Expression, name: String)

override def dataType = child.dataType
override def nullable = child.nullable
override def metadata: Map[String, Any] = child.metadata
override def metadata: Metadata = child.metadata

override def toAttribute = {
if (resolved) {
Expand Down Expand Up @@ -118,7 +119,7 @@ case class AttributeReference(
name: String,
dataType: DataType,
nullable: Boolean = true,
override val metadata: Map[String, Any] = Map.empty)(
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.types

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.util.Metadata

import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
Expand Down Expand Up @@ -74,7 +76,7 @@ object DataType {
("name", JString(name)),
("nullable", JBool(nullable)),
("type", dataType: JValue)) =>
StructField(name, parseDataType(dataType), nullable, metadata.values)
StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata))
}

@deprecated("Use DataType.fromJson instead")
Expand Down Expand Up @@ -386,7 +388,7 @@ case class StructField(
name: String,
dataType: DataType,
nullable: Boolean,
metadata: Map[String, Any] = Map.empty) {
metadata: Metadata = Metadata.empty) {

private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
Expand All @@ -402,7 +404,7 @@ case class StructField(
("name" -> name) ~
("type" -> dataType.jsonValue) ~
("nullable" -> nullable) ~
("metadata" -> Extraction.decompose(metadata)(DefaultFormats))
("metadata" -> metadata.jsonValue)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ sealed class Metadata private[util] (private[util] val map: Map[String, Any]) ex
def getMetadataArray(key: String): Array[Metadata] = get(key)

/** Converts to its JSON representation. */
def toJson: String = {
compact(render(Metadata.toJValue(this)))
}
def json: String = compact(render(jsonValue))

override def toString: String = toJson
override def toString: String = json

override def equals(obj: Any): Boolean = {
obj match {
Expand All @@ -96,6 +94,8 @@ sealed class Metadata private[util] (private[util] val map: Map[String, Any]) ex
private def get[T](key: String): T = {
map(key).asInstanceOf[T]
}

private[sql] def jsonValue: JValue = Metadata.toJsonValue(this)
}

object Metadata {
Expand All @@ -105,41 +105,40 @@ object Metadata {

/** Creates a Metadata instance from JSON. */
def fromJson(json: String): Metadata = {
val map = parse(json).values.asInstanceOf[Map[String, Any]]
fromMap(map.toMap)
fromJObject(parse(json).asInstanceOf[JObject])
}

/** Creates a Metadata instance from Map[String, Any]. */
private def fromMap(map: Map[String, Any]): Metadata = {
/** Creates a Metadata instance from JSON AST. */
private[sql] def fromJObject(jObj: JObject): Metadata = {
val builder = new MetadataBuilder
map.foreach {
case (key, value: BigInt) =>
jObj.obj.foreach {
case (key, JInt(value)) =>
builder.putLong(key, value.toLong)
case (key, value: Double) =>
case (key, JDouble(value)) =>
builder.putDouble(key, value)
case (key, value: Boolean) =>
case (key, JBool(value)) =>
builder.putBoolean(key, value)
case (key, value: String) =>
case (key, JString(value)) =>
builder.putString(key, value)
case (key, value: Map[_, _]) =>
builder.putMetadata(key, fromMap(value.asInstanceOf[Map[String, Any]]))
case (key, value: Seq[_]) =>
case (key, o: JObject) =>
builder.putMetadata(key, fromJObject(o))
case (key, JArray(value)) =>
if (value.isEmpty) {
// If it is an empty array, we cannot infer its element type. We put an empty Array[Long].
builder.putLongArray(key, Array.empty)
} else {
value.head match {
case _: BigInt =>
builder.putLongArray(key, value.asInstanceOf[Seq[BigInt]].map(_.toLong).toArray)
case _: Double =>
builder.putDoubleArray(key, value.asInstanceOf[Seq[Double]].toArray)
case _: Boolean =>
builder.putBooleanArray(key, value.asInstanceOf[Seq[Boolean]].toArray)
case _: String =>
builder.putStringArray(key, value.asInstanceOf[Seq[String]].toSeq.toArray)
case _: Map[_, _] =>
case _: JInt =>
builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray)
case _: JDouble =>
builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray)
case _: JBool =>
builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray)
case _: JString =>
builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray)
case _: JObject =>
builder.putMetadataArray(
key, value.asInstanceOf[Seq[Map[String, Any]]].map(fromMap).toArray)
key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
case other =>
throw new RuntimeException(s"Do not support array of type ${other.getClass}.")
}
Expand All @@ -151,13 +150,13 @@ object Metadata {
}

/** Converts to JSON AST. */
private def toJValue(obj: Any): JValue = {
private def toJsonValue(obj: Any): JValue = {
obj match {
case map: Map[_, _] =>
val fields = map.toList.map { case (k: String, v) => (k, toJValue(v))}
val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v))}
JObject(fields)
case arr: Array[_] =>
val values = arr.toList.map(toJValue)
val values = arr.toList.map(toJsonValue)
JArray(values)
case x: Long =>
JInt(x)
Expand All @@ -168,7 +167,7 @@ object Metadata {
case x: String =>
JString(x)
case x: Metadata =>
toJValue(x.map)
toJsonValue(x.map)
case other =>
throw new RuntimeException(s"Do not support type ${other.getClass}.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class MetadataSuite extends FunSuite {
}

test("metadata json conversion") {
val json = metadata.toJson
val json = metadata.json
withClue("toJson must produce a valid JSON string") {
parse(json)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.api.java;

import org.apache.spark.sql.catalyst.util.Metadata;

import java.util.*;

/**
Expand Down Expand Up @@ -148,7 +150,7 @@ public static StructField createStructField(
String name,
DataType dataType,
boolean nullable,
Map<String, Object> metadata) {
Metadata metadata) {
if (name == null) {
throw new IllegalArgumentException("name should not be null.");
}
Expand All @@ -165,10 +167,10 @@ public static StructField createStructField(
/**
* Creates a StructField with empty metadata.
*
* @see #createStructField(String, DataType, boolean, java.util.Map)
* @see #createStructField(String, DataType, boolean, Metadata)
*/
public static StructField createStructField(String name, DataType dataType, boolean nullable) {
return createStructField(name, dataType, nullable, new HashMap<String, Object>());
return createStructField(name, dataType, nullable, Metadata.empty());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.api.java;

import org.apache.spark.sql.catalyst.util.Metadata;

import java.util.Map;

/**
Expand All @@ -37,13 +39,13 @@ public class StructField {
private String name;
private DataType dataType;
private boolean nullable;
private Map<String, Object> metadata;
private Metadata metadata;

protected StructField(
String name,
DataType dataType,
boolean nullable,
Map<String, Object> metadata) {
Metadata metadata) {
this.name = name;
this.dataType = dataType;
this.nullable = nullable;
Expand All @@ -62,7 +64,7 @@ public boolean isNullable() {
return nullable;
}

public Map<String, Object> getMetadata() {
public Metadata getMetadata() {
return metadata;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected[sql] object DataTypeConversions {
scalaStructField.name,
asJavaDataType(scalaStructField.dataType),
scalaStructField.nullable,
scalaStructField.metadata.asJava.asInstanceOf[java.util.Map[String, Object]])
scalaStructField.metadata)
}

/**
Expand Down Expand Up @@ -69,7 +69,7 @@ protected[sql] object DataTypeConversions {
javaStructField.getName,
asScalaDataType(javaStructField.getDataType),
javaStructField.isNullable,
javaStructField.getMetadata.asScala.toMap)
javaStructField.getMetadata)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.util.MetadataBuilder
import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.types.DataType
Expand Down Expand Up @@ -79,9 +80,12 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(ArrayType(StringType, false))
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
checkDataTypeJsonRepr(
StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false, metadata = Map("name" -> "age")))))
StructField("c", DoubleType, nullable = false, metadata))))
}
22 changes: 12 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
import java.util.TimeZone

/* Implicits */
import TestSQLContext._
import TestData._
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.MetadataBuilder
import org.apache.spark.sql.test.TestSQLContext._

class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
Expand Down Expand Up @@ -684,11 +683,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
val schema = person.schema
val docKey = "doc"
val docValue = "first name"
val metadata = new MetadataBuilder()
.putString(docKey, docValue)
.build()
val schemaWithMeta = new StructType(Seq(
schema("id"), schema("name").copy(metadata = Map(docKey -> docValue)), schema("age")))
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
val personWithMeta = applySchema(person, schemaWithMeta)
def validateMetadata(rdd: SchemaRDD): Unit = {
assert(rdd.schema("name").metadata(docKey) === docValue)
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
personWithMeta.registerTempTable("personWithMeta")
validateMetadata(personWithMeta.select('name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.api.java

import org.apache.spark.sql.catalyst.util.MetadataBuilder
import org.apache.spark.sql.types.util.DataTypeConversions
import org.scalatest.FunSuite

Expand Down Expand Up @@ -66,12 +67,15 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(simpleScalaStructType)

// Complex StructType.
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
val complexScalaStructType = SStructType(
SStructField("simpleArray", simpleScalaArrayType, true) ::
SStructField("simpleMap", simpleScalaMapType, true) ::
SStructField("simpleStruct", simpleScalaStructType, true) ::
SStructField("boolean", org.apache.spark.sql.BooleanType, false) ::
SStructField("withMeta", org.apache.spark.sql.DoubleType, false, Map("name" -> "age")) :: Nil)
SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil)
checkDataType(complexScalaStructType)

// Complex ArrayType.
Expand Down

0 comments on commit 1fcbf13

Please sign in to comment.