From c7e308b4bc52199a0a7b7f3c81c1bc2de78d55c0 Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Mon, 20 Apr 2015 23:16:11 -0400 Subject: [PATCH] Support java iterable types in POJOs --- .../sql/catalyst/CatalystTypeConverters.scala | 20 +++++++++++++++++++ .../apache/spark/sql/JavaTypeInference.scala | 4 ++++ .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4f9fdacda4fb..a13e2f36a1a1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import java.lang.{Iterable => JavaIterable} import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap @@ -49,6 +50,16 @@ object CatalystTypeConverters { case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (jit: JavaIterable[_], arrayType: ArrayType) => { + val iter = jit.iterator + var listOfItems: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + listOfItems :+= convertToCatalyst(item, arrayType.elementType) + } + listOfItems + } + case (s: Array[_], arrayType: ArrayType) => s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) @@ -124,6 +135,15 @@ object CatalystTypeConverters { extractOption(item) match { case a: Array[_] => a.toSeq.map(elementConverter) case s: Seq[_] => s.map(elementConverter) + case i: JavaIterable[_] => { + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter(item) + } + convertedIterable + } case null => null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala index f2c3826aab485..db484c5f50074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala @@ -75,6 +75,10 @@ private [sql] object JavaTypeInference { val (dataType, nullable) = inferDataType(typeToken.getComponentType) (ArrayType(dataType, nullable), true) + case _ if iterableType.isAssignableFrom(typeToken) => + val (dataType, nullable) = inferDataType(elementType(typeToken)) + (ArrayType(dataType, nullable), true) + case _ if mapType.isAssignableFrom(typeToken) => val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c9ef89b852dbf..2d0354e1e58c8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.Arrays; +import java.util.List; import java.util.Map; import com.google.common.collect.ImmutableMap; @@ -112,6 +113,7 @@ public static class Bean implements Serializable { private double a = 0.0; private Integer[] b = new Integer[]{0, 1}; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); public double getA() { return a; @@ -124,6 +126,10 @@ public Integer[] getB() { public Map getC() { return c; } + + public List getD() { + return d; + } } @Test @@ -142,7 +148,10 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals( new StructField("c", mapType, true, Metadata.empty()), schema.apply("c")); - Row first = df.select("a", "b", "c").first(); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -155,6 +164,11 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertArrayEquals( bean.getC().get("hello"), Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } } }