From 2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Tue, 21 Apr 2015 14:50:02 -0700 Subject: [PATCH] [SPARK-6996][SQL] Support map types in java beans liancheng mengxr this is similar to #5146. Author: Punya Biswal Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits: d56c3e0 [Punya Biswal] Fix imports c7e308b [Punya Biswal] Support java iterable types in POJOs 5e00685 [Punya Biswal] Support map types in java beans --- .../sql/catalyst/CatalystTypeConverters.scala | 20 ++++ .../apache/spark/sql/JavaTypeInference.scala | 110 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 52 +-------- .../apache/spark/sql/JavaDataFrameSuite.java | 57 +++++++-- 4 files changed, 180 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4f9fdacda4fb..a13e2f36a1a1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import java.lang.{Iterable => JavaIterable} import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap @@ -49,6 +50,16 @@ object CatalystTypeConverters { case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (jit: JavaIterable[_], arrayType: ArrayType) => { + val iter = jit.iterator + var listOfItems: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + listOfItems :+= convertToCatalyst(item, arrayType.elementType) + } + listOfItems + } + case (s: Array[_], arrayType: ArrayType) => s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) @@ -124,6 +135,15 @@ object CatalystTypeConverters { extractOption(item) match { case a: Array[_] => a.toSeq.map(elementConverter) case s: Seq[_] => s.map(elementConverter) + case i: JavaIterable[_] => { + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter(item) + } + convertedIterable + } case null => null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala new file mode 100644 index 0000000000000..db484c5f50074 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.beans.Introspector +import java.lang.{Iterable => JIterable} +import java.util.{Iterator => JIterator, Map => JMap} + +import com.google.common.reflect.TypeToken + +import org.apache.spark.sql.types._ + +import scala.language.existentials + +/** + * Type-inference utilities for POJOs and Java collections. + */ +private [sql] object JavaTypeInference { + + private val iterableType = TypeToken.of(classOf[JIterable[_]]) + private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType + private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType + private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType + private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + + /** + * Infers the corresponding SQL data type of a Java type. + * @param typeToken Java type + * @return (SQL data type, nullable) + */ + private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + typeToken.getRawType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case _ if typeToken.isArray => + val (dataType, nullable) = inferDataType(typeToken.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ if iterableType.isAssignableFrom(typeToken) => + val (dataType, nullable) = inferDataType(elementType(typeToken)) + (ArrayType(dataType, nullable), true) + + case _ if mapType.isAssignableFrom(typeToken) => + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) + val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyDataType, _) = inferDataType(keyType) + val (valueDataType, nullable) = inferDataType(valueType) + (MapType(keyDataType, valueDataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) + } + } + + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] + val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSupertype.resolveType(iteratorReturnType) + val itemType = iteratorType.resolveType(nextReturnType) + itemType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f9f3eb2e03817..bcd20c06c6dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,6 +25,8 @@ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = inferDataType(beanClass) + val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } } - /** - * Infers the corresponding SQL data type of a Java class. - * @param clazz Java class - * @return (SQL data type, nullable) - */ - private def inferDataType(clazz: Class[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - clazz match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - - case c: Class[_] if c.isArray => - val (dataType, nullable) = inferDataType(c.getComponentType) - (ArrayType(dataType, nullable), true) - - case _ => - val beanInfo = Introspector.getBeanInfo(clazz) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - val fields = properties.map { property => - val (dataType, nullable) = inferDataType(property.getPropertyType) - new StructField(property.getName, dataType, nullable) - } - (new StructType(fields), true) - } - } } + + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 6d0fbe83c2f36..fc3ed4a708d46 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,23 +17,28 @@ package test.org.apache.spark.sql; -import java.io.Serializable; -import java.util.Arrays; - -import scala.collection.Seq; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.TestData$; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; +import org.junit.*; + +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.collection.mutable.Buffer; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; import static org.apache.spark.sql.functions.*; @@ -106,6 +111,8 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; private Integer[] b = new Integer[]{0, 1}; + private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); public double getA() { return a; @@ -114,6 +121,14 @@ public double getA() { public Integer[] getB() { return b; } + + public Map getC() { + return c; + } + + public List getD() { + return d; + } } @Test @@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals( new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); - Row first = df.select("a", "b").first(); + ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); + MapType mapType = new MapType(DataTypes.StringType, valueType, true); + Assert.assertEquals( + new StructField("c", mapType, true, Metadata.empty()), + schema.apply("c")); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } + Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Assert.assertArrayEquals( + bean.getC().get("hello"), + Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } } + }