Skip to content

Commit

Permalink
[SPARK-6996][SQL] Support map types in java beans
Browse files Browse the repository at this point in the history
liancheng mengxr this is similar to #5146.

Author: Punya Biswal <pbiswal@palantir.com>

Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits:

d56c3e0 [Punya Biswal] Fix imports
c7e308b [Punya Biswal] Support java iterable types in POJOs
5e00685 [Punya Biswal] Support map types in java beans
  • Loading branch information
Punya Biswal authored and marmbrus committed Apr 21, 2015
1 parent 6265cba commit 2a24bf9
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import java.lang.{Iterable => JavaIterable}
import java.util.{Map => JavaMap}

import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -49,6 +50,16 @@ object CatalystTypeConverters {
case (s: Seq[_], arrayType: ArrayType) =>
s.map(convertToCatalyst(_, arrayType.elementType))

case (jit: JavaIterable[_], arrayType: ArrayType) => {
val iter = jit.iterator
var listOfItems: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
listOfItems :+= convertToCatalyst(item, arrayType.elementType)
}
listOfItems
}

case (s: Array[_], arrayType: ArrayType) =>
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))

Expand Down Expand Up @@ -124,6 +135,15 @@ object CatalystTypeConverters {
extractOption(item) match {
case a: Array[_] => a.toSeq.map(elementConverter)
case s: Seq[_] => s.map(elementConverter)
case i: JavaIterable[_] => {
val iter = i.iterator
var convertedIterable: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter(item)
}
convertedIterable
}
case null => null
}
}
Expand Down
110 changes: 110 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.beans.Introspector
import java.lang.{Iterable => JIterable}
import java.util.{Iterator => JIterator, Map => JMap}

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.types._

import scala.language.existentials

/**
* Type-inference utilities for POJOs and Java collections.
*/
private [sql] object JavaTypeInference {

private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType

/**
* Infers the corresponding SQL data type of a Java type.
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
(ArrayType(dataType, nullable), true)

case _ if iterableType.isAssignableFrom(typeToken) =>
val (dataType, nullable) = inferDataType(elementType(typeToken))
(ArrayType(dataType, nullable), true)

case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
val itemType = iteratorType.resolveType(nextReturnType)
itemType
}
}
52 changes: 5 additions & 47 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

import com.google.common.reflect.TypeToken

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
val (dataType, _) = inferDataType(beanClass)
val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}

/**
* Infers the corresponding SQL data type of a Java class.
* @param clazz Java class
* @return (SQL data type, nullable)
*/
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
clazz match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case c: Class[_] if c.isArray =>
val (dataType, nullable) = inferDataType(c.getComponentType)
(ArrayType(dataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(clazz)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val (dataType, nullable) = inferDataType(property.getPropertyType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,28 @@

package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.Arrays;

import scala.collection.Seq;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.TestData$;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
import org.junit.*;

import scala.collection.JavaConversions;
import scala.collection.Seq;
import scala.collection.mutable.Buffer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.apache.spark.sql.functions.*;

Expand Down Expand Up @@ -106,6 +111,8 @@ public void testShow() {
public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");

public double getA() {
return a;
Expand All @@ -114,6 +121,14 @@ public double getA() {
public Integer[] getB() {
return b;
}

public Map<String, int[]> getC() {
return c;
}

public List<String> getD() {
return d;
}
}

@Test
Expand All @@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
schema.apply("b"));
Row first = df.select("a", "b").first();
ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
MapType mapType = new MapType(DataTypes.StringType, valueType, true);
Assert.assertEquals(
new StructField("c", mapType, true, Metadata.empty()),
schema.apply("c"));
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
Seq<String> d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

}

0 comments on commit 2a24bf9

Please sign in to comment.