Skip to content

Commit

Permalink
Infer the Catalyst data type from an object and cast a data value to …
Browse files Browse the repository at this point in the history
…the expected type.
  • Loading branch information
yhuai committed Jul 10, 2014
1 parent 3fa0df5 commit 90460ac
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ object ScalaReflection {
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}

def typeOfObject: PartialFunction[Any, DataType] = {
// The type of the can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
case obj: IntegerType.JvmType => IntegerType
case obj: LongType.JvmType => LongType
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
case obj: DecimalType.JvmType => DecimalType
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// There is no obvious mapping from the type of the given object to a Catalyst data type.
// A user should provide his/her specific rules (in a user-defined PartialFunction) to infer
// the Catalyst data type for other types of objects and then compose the user-defined
// PartialFunction with this one.
}

implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def schema: StructType = StructType.fromAttributes(output)

/** Returns the output schema in the tree format. */
def formattedSchemaString: String = schema.formattedSchemaString
def schemaString: String = schema.schemaString

/** Prints out the schema in the tree format */
def printSchema(): Unit = println(formattedSchemaString)
def printSchema(): Unit = println(schemaString)
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] val ordering = implicitly[Ordering[JvmType]]
def simpleString: String = "string"
}

case object BinaryType extends DataType with PrimitiveType {
private[sql] type JvmType = Array[Byte]
def simpleString: String = "binary"
}

case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = typeTag[JvmType]
Expand Down Expand Up @@ -292,7 +294,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {

def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())

def formattedSchemaString: String = {
def schemaString: String = {
val builder = new StringBuilder
builder.append("root\n")
val prefix = " |"
Expand All @@ -301,7 +303,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
builder.toString()
}

def printSchema(): Unit = println(formattedSchemaString)
def printSchema(): Unit = println(schemaString)

private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
fields.foreach(field => field.buildFormattedString(prefix, builder))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.catalyst

import java.math.BigInteger
import java.sql.Timestamp

import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._

case class PrimitiveData(
Expand Down Expand Up @@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite {
StructField("_2", StringType, nullable = true))),
nullable = true))
}

test("get data type of a value") {
// BooleanType
assert(BooleanType === typeOfObject(true))
assert(BooleanType === typeOfObject(false))

// BinaryType
assert(BinaryType === typeOfObject("string".getBytes))

// StringType
assert(StringType === typeOfObject("string"))

// ByteType
assert(ByteType === typeOfObject(127.toByte))

// ShortType
assert(ShortType === typeOfObject(32767.toShort))

// IntegerType
assert(IntegerType === typeOfObject(2147483647))

// LongType
assert(LongType === typeOfObject(9223372036854775807L))

// FloatType
assert(FloatType === typeOfObject(3.4028235E38.toFloat))

// DoubleType
assert(DoubleType === typeOfObject(1.7976931348623157E308))

// DecimalType
assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))

// TimestampType
assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00")))

// NullType
assert(NullType === typeOfObject(null))

def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
case value: java.math.BigInteger => DecimalType
case value: java.math.BigDecimal => DecimalType
case _ => StringType
}

assert(DecimalType === typeOfObject1(
new BigInteger("92233720368547758070")))
assert(DecimalType === typeOfObject1(
new java.math.BigDecimal("1.7976931348623157E318")))
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))

def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
case value: java.math.BigInteger => DecimalType
}

intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))

def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse {
case c: Seq[_] => ArrayType(typeOfObject3(c.head))
}

assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3)))
assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3))))
}
}
61 changes: 40 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
package org.apache.spark.sql

import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.hadoop.conf.Configuration

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkStrategies
Expand Down Expand Up @@ -89,14 +88,31 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd)))

/**
* Creates a SchemaRDD from an RDD by applying a schema and providing a function to construct
* a Row from a RDD record.
* Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function
* that will be applied to each partition of the RDD to convert RDD records to [[Row]]s.
*
* @group userf
*/
def createSchemaRDD[A](rdd: RDD[A], schema: StructType, constructRow: A => Row) = {
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.map(constructRow))))
}
def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD =
applySchemaPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f))

/**
* Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function
* that will be applied to each partition of the RDD to convert RDD records to [[Row]]s.
*
* @group userf
*/
def applySchemaPartitions[A](
rdd: RDD[A],
schema: StructType,
f: Iterator[A] => Iterator[Row]): SchemaRDD =
new SchemaRDD(this, makeCustomRDDScan(rdd, schema, f))

protected[sql] def makeCustomRDDScan[A](
rdd: RDD[A],
schema: StructType,
f: Iterator[A] => Iterator[Row]): LogicalPlan =
SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.mapPartitions(f)))

/**
* Loads a Parquet file, returning the result as a [[SchemaRDD]].
Expand Down Expand Up @@ -136,8 +152,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* :: Experimental ::
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD =
new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio))
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
applySchemaPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String]))
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -352,28 +370,29 @@ class SQLContext(@transient val sparkContext: SparkContext)

/**
* Peek at the first row of the RDD and infer its schema.
* TODO: consolidate this with the type system developed in SPARK-2060.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._
def typeFor(obj: Any): DataType = obj match {
case c: java.lang.String => StringType
case c: java.lang.Integer => IntegerType
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
def typeOfComplexValue: PartialFunction[Any, DataType] = {
case c: java.util.List[_] =>
ArrayType(ScalaReflection.typeOfObject(c.head))
case c: java.util.Set[_] =>
ArrayType(ScalaReflection.typeOfObject(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
MapType(
ScalaReflection.typeOfObject(key),
ScalaReflection.typeOfObject(value))
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeFor(elem))
ArrayType(ScalaReflection.typeOfObject(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}

def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue

val schema = rdd.first().map { case (fieldName, obj) =>
AttributeReference(fieldName, typeFor(obj), true)()
AttributeReference(fieldName, typeOfObject(obj), true)()
}.toSeq

val rowRdd = rdd.mapPartitions { iter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ private[sql] trait SchemaRDDLike {
def schema: StructType = queryExecution.analyzed.schema

/** Returns the output schema in the tree format. */
def formattedSchemaString: String = schema.formattedSchemaString
def schemaString: String = schema.schemaString

/** Prints out the schema in the tree format. */
def printSchema(): Unit = println(formattedSchemaString)
def printSchema(): Unit = println(schemaString)
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,13 @@ class JavaSQLContext(val sqlContext: SQLContext) {
*
* @group userf
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD =
new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0))
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))
val logicalPlan =
sqlContext.makeCustomRDDScan[String](json, schema, JsonRDD.jsonStringToRow(schema, _))

new JavaSchemaRDD(sqlContext, logicalPlan)
}

/**
* Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
Expand Down
Loading

0 comments on commit 90460ac

Please sign in to comment.