Skip to content

Commit

Permalink
Java API for applySchema.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 24, 2014
1 parent 1c9f33c commit b9f3071
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,20 @@ public static StructField createStructField(String name, DataType dataType, bool
}

/**
* Creates a StructType with the given StructFields ({@code fields}).
* Creates a StructType with the given list of StructFields ({@code fields}).
* @param fields
* @return
*/
public static StructType createStructType(List<StructField> fields) {
return createStructType(fields.toArray(new StructField[0]));
}

/**
* Creates a StructType with the given StructField array ({@code fields}).
* @param fields
* @return
*/
public static StructType createStructType(StructField[] fields) {
if (fields == null) {
throw new IllegalArgumentException("fields should not be null.");
}
Expand All @@ -151,11 +160,11 @@ public static StructType createStructType(List<StructField> fields) {

distinctNames.add(field.getName());
}
if (distinctNames.size() != fields.size()) {
throw new IllegalArgumentException(
"fields should have distinct names.");
if (distinctNames.size() != fields.length) {
throw new IllegalArgumentException("fields should have distinct names.");
}

return new StructType(fields);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@

/**
* The data type representing Rows.
* A StructType object comprises a List of StructFields.
* A StructType object comprises an array of StructFields.
*/
public class StructType extends DataType {
private StructField[] fields;

protected StructType(List<StructField> fields) {
this.fields = fields.toArray(new StructField[0]);
protected StructType(StructField[] fields) {
this.fields = fields;
}

public StructField[] getFields() {
Expand Down
95 changes: 95 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql

import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
Expand Down Expand Up @@ -420,4 +422,97 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
}

/**
* Returns the equivalent StructField in Scala for the given StructField in Java.
*/
protected def asJavaStructField(scalaStructField: StructField): JStructField = {
org.apache.spark.sql.api.java.types.DataType.createStructField(
scalaStructField.name,
asJavaDataType(scalaStructField.dataType),
scalaStructField.nullable)
}

/**
* Returns the equivalent DataType in Java for the given DataType in Scala.
*/
protected[sql] def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
case StringType =>
org.apache.spark.sql.api.java.types.DataType.StringType
case BinaryType =>
org.apache.spark.sql.api.java.types.DataType.BinaryType
case BooleanType =>
org.apache.spark.sql.api.java.types.DataType.BooleanType
case TimestampType =>
org.apache.spark.sql.api.java.types.DataType.TimestampType
case DecimalType =>
org.apache.spark.sql.api.java.types.DataType.DecimalType
case DoubleType =>
org.apache.spark.sql.api.java.types.DataType.DoubleType
case FloatType =>
org.apache.spark.sql.api.java.types.DataType.FloatType
case ByteType =>
org.apache.spark.sql.api.java.types.DataType.ByteType
case IntegerType =>
org.apache.spark.sql.api.java.types.DataType.IntegerType
case LongType =>
org.apache.spark.sql.api.java.types.DataType.LongType
case ShortType =>
org.apache.spark.sql.api.java.types.DataType.ShortType

case arrayType: ArrayType =>
org.apache.spark.sql.api.java.types.DataType.createArrayType(
asJavaDataType(arrayType.elementType), arrayType.containsNull)
case mapType: MapType =>
org.apache.spark.sql.api.java.types.DataType.createMapType(
asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType))
case structType: StructType =>
org.apache.spark.sql.api.java.types.DataType.createStructType(
structType.fields.map(asJavaStructField).asJava)
}

/**
* Returns the equivalent StructField in Scala for the given StructField in Java.
*/
protected def asScalaStructField(javaStructField: JStructField): StructField = {
StructField(
javaStructField.getName,
asScalaDataType(javaStructField.getDataType),
javaStructField.isNullable)
}

/**
* Returns the equivalent DataType in Scala for the given DataType in Java.
*/
protected[sql] def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
case stringType: org.apache.spark.sql.api.java.types.StringType =>
StringType
case binaryType: org.apache.spark.sql.api.java.types.BinaryType =>
BinaryType
case booleanType: org.apache.spark.sql.api.java.types.BooleanType =>
BooleanType
case timestampType: org.apache.spark.sql.api.java.types.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.types.DecimalType =>
DecimalType
case doubleType: org.apache.spark.sql.api.java.types.DoubleType =>
DoubleType
case floatType: org.apache.spark.sql.api.java.types.FloatType =>
FloatType
case byteType: org.apache.spark.sql.api.java.types.ByteType =>
ByteType
case integerType: org.apache.spark.sql.api.java.types.IntegerType =>
IntegerType
case longType: org.apache.spark.sql.api.java.types.LongType =>
LongType
case shortType: org.apache.spark.sql.api.java.types.ShortType =>
ShortType

case arrayType: org.apache.spark.sql.api.java.types.ArrayType =>
ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull)
case mapType: org.apache.spark.sql.api.java.types.MapType =>
MapType(asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType))
case structType: org.apache.spark.sql.api.java.types.StructType =>
StructType(structType.getFields.map(asScalaStructField))
}

}
5 changes: 5 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class SchemaRDD(
override protected def getDependencies: Seq[Dependency[_]] =
List(new OneToOneDependency(queryExecution.toRdd))

/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
*
* @group schema
*/
def schema: StructType = queryExecution.analyzed.schema

// =======================================================================
// Query DSL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,11 @@ private[sql] trait SchemaRDDLike {
def saveAsTable(tableName: String): Unit =
sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd

/** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
*
* @group schema
*/
def schema: StructType = queryExecution.analyzed.schema

/** Returns the schema as a string in the tree format.
*
* @group schema
*/
def schemaString: String = schema.treeString
def schemaString: String = baseSchemaRDD.schema.treeString

/** Prints out the schema.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ package org.apache.spark.sql.api.java

import java.beans.Introspector

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructType => JStructType}
import org.apache.spark.sql.api.java.types.{StructField => JStructField}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -95,6 +98,20 @@ class JavaSQLContext(val sqlContext: SQLContext) {
new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
}

/**
* :: DeveloperApi ::
* Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD.
* It is important to make sure that the structure of every Row of the provided RDD matches the
* provided schema. Otherwise, there will be runtime exception.
*/
@DeveloperApi
def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = {
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
val scalaSchema = sqlContext.asScalaDataType(schema).asInstanceOf[StructType]
val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))
new JavaSchemaRDD(sqlContext, logicalPlan)
}

/**
* Loads a parquet file, returning the result as a [[JavaSchemaRDD]].
*/
Expand All @@ -104,26 +121,45 @@ class JavaSQLContext(val sqlContext: SQLContext) {
ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration)))

/**
* Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
* Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD.
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
def jsonFile(path: String): JavaSchemaRDD =
jsonRDD(sqlContext.sparkContext.textFile(path))

/**
* :: Experimental ::
* Loads a JSON file (one object per line) and applies the given schema,
* returning the result as a JavaSchemaRDD.
*/
@Experimental
def jsonFile(path: String, schema: JStructType): JavaSchemaRDD =
jsonRDD(sqlContext.sparkContext.textFile(path), schema)

/**
* Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
* [[JavaSchemaRDD]].
* [JavaSchemaRDD.
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))
val logicalPlan =
sqlContext.makeCustomRDDScan[String](json, schema, JsonRDD.jsonStringToRow(schema, _))
val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))
new JavaSchemaRDD(sqlContext, logicalPlan)
}

/**
* :: Experimental ::
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
* returning the result as a JavaSchemaRDD.
*/
@Experimental
def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = {
val appliedScalaSchema =
Option(sqlContext.asScalaDataType(schema)).getOrElse(
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType]
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))
new JavaSchemaRDD(sqlContext, logicalPlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.{List => JList}
import org.apache.spark.Partitioner
import org.apache.spark.api.java.{JavaRDDLike, JavaRDD}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.sql.api.java.types.StructType
import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -53,6 +54,10 @@ class JavaSchemaRDD(

override def toString: String = baseSchemaRDD.toString

/** Returns the schema of this JavaSchemaRDD (represented by a StructType). */
def schema: StructType =
sqlContext.asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType]

// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================
Expand Down
Loading

0 comments on commit b9f3071

Please sign in to comment.