Skip to content

Commit

Permalink
Move "asJavaDataType" and "asScalaDataType" to DataTypeConversions.sc…
Browse files Browse the repository at this point in the history
…ala.
  • Loading branch information
yhuai committed Jul 28, 2014
1 parent 1cb35fe commit 991f860
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@ package org.apache.spark.sql.api.java

import java.beans.Introspector

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructType => JStructType}
import org.apache.spark.sql.api.java.types.{StructField => JStructField}
import org.apache.spark.sql.api.java.types.{StructType => JStructType}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.sql.types.util.DataTypeConversions
import DataTypeConversions.asScalaDataType;
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -107,7 +106,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
@DeveloperApi
def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = {
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
val scalaSchema = sqlContext.asScalaDataType(schema).asInstanceOf[StructType]
val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType]
val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))
new JavaSchemaRDD(sqlContext, logicalPlan)
}
Expand Down Expand Up @@ -156,7 +155,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
@Experimental
def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = {
val appliedScalaSchema =
Option(sqlContext.asScalaDataType(schema)).getOrElse(
Option(asScalaDataType(schema)).getOrElse(
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType]
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import org.apache.spark.Partitioner
import org.apache.spark.api.java.{JavaRDDLike, JavaRDD}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.sql.api.java.types.StructType
import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import DataTypeConversions._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -56,7 +58,7 @@ class JavaSchemaRDD(

/** Returns the schema of this JavaSchemaRDD (represented by a StructType). */
def schema: StructType =
sqlContext.asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType]
asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType]

// =======================================================================
// Base RDD functions that do NOT change schema
Expand Down
18 changes: 13 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark

import scala.collection.JavaConverters._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField}

/**
* Allows the execution of relational queries, including those expressed in SQL using Spark.
Expand Down Expand Up @@ -243,8 +246,7 @@ package object sql {
* The data type representing `Seq`s.
* An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
* `containsNull: Boolean`. The field of `elementType` is used to specify the type of
* array elements. The field of `containsNull` is used to specify if the array can have
* any `null` value.
* array elements. The field of `containsNull` is used to specify if the array has `null` valus.
*
* @group dataType
*/
Expand All @@ -271,10 +273,11 @@ package object sql {
/**
* :: DeveloperApi ::
*
* The data type representing `Map`s. A [[MapType]] object comprises two fields,
* `keyType: [[DataType]]` and `valueType: [[DataType]]`.
* The data type representing `Map`s. A [[MapType]] object comprises three fields,
* `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`.
* The field of `keyType` is used to specify the type of keys in the map.
* The field of `valueType` is used to specify the type of values in the map.
* The field of `valueContainsNull` is used to specify if values of this map has `null` values.
*
* @group dataType
*/
Expand All @@ -284,10 +287,15 @@ package object sql {
/**
* :: DeveloperApi ::
*
* A [[MapType]] can be constructed by
* A [[MapType]] object can be constructed with two ways,
* {{{
* MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean)
* }}} and
* {{{
* MapType(keyType: DataType, valueType: DataType)
* }}}
* For `MapType(keyType: DataType, valueType: DataType)`,
* the field of `valueContainsNull` is set to `true`.
*
* @group dataType
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types.util

import org.apache.spark.sql._
import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField}

import scala.collection.JavaConverters._

protected[sql] object DataTypeConversions {

/**
* Returns the equivalent StructField in Scala for the given StructField in Java.
*/
def asJavaStructField(scalaStructField: StructField): JStructField = {
org.apache.spark.sql.api.java.types.DataType.createStructField(
scalaStructField.name,
asJavaDataType(scalaStructField.dataType),
scalaStructField.nullable)
}

/**
* Returns the equivalent DataType in Java for the given DataType in Scala.
*/
def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
case StringType =>
org.apache.spark.sql.api.java.types.DataType.StringType
case BinaryType =>
org.apache.spark.sql.api.java.types.DataType.BinaryType
case BooleanType =>
org.apache.spark.sql.api.java.types.DataType.BooleanType
case TimestampType =>
org.apache.spark.sql.api.java.types.DataType.TimestampType
case DecimalType =>
org.apache.spark.sql.api.java.types.DataType.DecimalType
case DoubleType =>
org.apache.spark.sql.api.java.types.DataType.DoubleType
case FloatType =>
org.apache.spark.sql.api.java.types.DataType.FloatType
case ByteType =>
org.apache.spark.sql.api.java.types.DataType.ByteType
case IntegerType =>
org.apache.spark.sql.api.java.types.DataType.IntegerType
case LongType =>
org.apache.spark.sql.api.java.types.DataType.LongType
case ShortType =>
org.apache.spark.sql.api.java.types.DataType.ShortType

case arrayType: ArrayType =>
org.apache.spark.sql.api.java.types.DataType.createArrayType(
asJavaDataType(arrayType.elementType), arrayType.containsNull)
case mapType: MapType =>
org.apache.spark.sql.api.java.types.DataType.createMapType(
asJavaDataType(mapType.keyType),
asJavaDataType(mapType.valueType),
mapType.valueContainsNull)
case structType: StructType =>
org.apache.spark.sql.api.java.types.DataType.createStructType(
structType.fields.map(asJavaStructField).asJava)
}

/**
* Returns the equivalent StructField in Scala for the given StructField in Java.
*/
def asScalaStructField(javaStructField: JStructField): StructField = {
StructField(
javaStructField.getName,
asScalaDataType(javaStructField.getDataType),
javaStructField.isNullable)
}

/**
* Returns the equivalent DataType in Scala for the given DataType in Java.
*/
def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
case stringType: org.apache.spark.sql.api.java.types.StringType =>
StringType
case binaryType: org.apache.spark.sql.api.java.types.BinaryType =>
BinaryType
case booleanType: org.apache.spark.sql.api.java.types.BooleanType =>
BooleanType
case timestampType: org.apache.spark.sql.api.java.types.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.types.DecimalType =>
DecimalType
case doubleType: org.apache.spark.sql.api.java.types.DoubleType =>
DoubleType
case floatType: org.apache.spark.sql.api.java.types.FloatType =>
FloatType
case byteType: org.apache.spark.sql.api.java.types.ByteType =>
ByteType
case integerType: org.apache.spark.sql.api.java.types.IntegerType =>
IntegerType
case longType: org.apache.spark.sql.api.java.types.LongType =>
LongType
case shortType: org.apache.spark.sql.api.java.types.ShortType =>
ShortType

case arrayType: org.apache.spark.sql.api.java.types.ArrayType =>
ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull)
case mapType: org.apache.spark.sql.api.java.types.MapType =>
MapType(
asScalaDataType(mapType.getKeyType),
asScalaDataType(mapType.getValueType),
mapType.isValueContainsNull)
case structType: org.apache.spark.sql.api.java.types.StructType =>
StructType(structType.getFields.map(asScalaStructField))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,20 @@
import java.util.ArrayList;

import org.junit.Assert;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.types.util.DataTypeConversions;
import org.apache.spark.sql.api.java.types.DataType;
import org.apache.spark.sql.api.java.types.StructField;
import org.apache.spark.sql.test.TestSQLContext;
import org.junit.rules.ExpectedException;

public class JavaSideDataTypeConversionSuite {
private transient JavaSparkContext javaCtx;
private transient JavaSQLContext javaSqlCtx;

public void checkDataType(DataType javaDataType) {
org.apache.spark.sql.catalyst.types.DataType scalaDataType =
javaSqlCtx.sqlContext().asScalaDataType(javaDataType);
DataType actual = javaSqlCtx.sqlContext().asJavaDataType(scalaDataType);
DataTypeConversions.asScalaDataType(javaDataType);
DataType actual = DataTypeConversions.asJavaDataType(scalaDataType);
Assert.assertEquals(javaDataType, actual);
}

@Before
public void setUp() {
javaCtx = new JavaSparkContext(TestSQLContext.sparkContext());
javaSqlCtx = new JavaSQLContext(javaCtx);
}

@After
public void tearDown() {
javaCtx.stop();
javaCtx = null;
}

@Test
public void createDataTypes() {
// Simple DataTypes.
Expand Down Expand Up @@ -102,7 +82,7 @@ public void createDataTypes() {

// Complex MapType.
DataType complexJavaMapType =
DataType.createMapType(complexJavaStructType, complexJavaArrayType);
DataType.createMapType(complexJavaStructType, complexJavaArrayType, false);
checkDataType(complexJavaMapType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class DataTypeSuite extends FunSuite {
assert(ArrayType(StringType, false) === array)
}

test("construct an MapType") {
val map = MapType(StringType, IntegerType)

assert(MapType(StringType, IntegerType, true) === map)
}

test("extract fields from a StructType") {
val struct = StructType(
StructField("a", IntegerType, true) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@

package org.apache.spark.sql.api.java

import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.util.DataTypeConversions
import org.scalatest.FunSuite

import org.apache.spark.sql._
import DataTypeConversions._

class ScalaSideDataTypeConversionSuite extends FunSuite {
val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
val javaSqlCtx = new JavaSQLContext(javaCtx)

def checkDataType(scalaDataType: DataType) {
val javaDataType = javaSqlCtx.sqlContext.asJavaDataType(scalaDataType)
val actual = javaSqlCtx.sqlContext.asScalaDataType(javaDataType)
val javaDataType = asJavaDataType(scalaDataType)
val actual = asScalaDataType(javaDataType)
assert(scalaDataType === actual, s"Converted data type ${actual} " +
s"does not equal the expected data type ${scalaDataType}")
}
Expand Down Expand Up @@ -76,7 +75,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(complexScalaArrayType)

// Complex MapType.
val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType)
val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false)
checkDataType(complexScalaMapType)
}
}

0 comments on commit 991f860

Please sign in to comment.