Skip to content

Commit

Permalink
Support java iterable types in POJOs
Browse files Browse the repository at this point in the history
  • Loading branch information
Punya Biswal committed Apr 21, 2015
1 parent 5e00685 commit c7e308b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import java.lang.{Iterable => JavaIterable}
import java.util.{Map => JavaMap}

import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -49,6 +50,16 @@ object CatalystTypeConverters {
case (s: Seq[_], arrayType: ArrayType) =>
s.map(convertToCatalyst(_, arrayType.elementType))

case (jit: JavaIterable[_], arrayType: ArrayType) => {
val iter = jit.iterator
var listOfItems: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
listOfItems :+= convertToCatalyst(item, arrayType.elementType)
}
listOfItems
}

case (s: Array[_], arrayType: ArrayType) =>
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))

Expand Down Expand Up @@ -124,6 +135,15 @@ object CatalystTypeConverters {
extractOption(item) match {
case a: Array[_] => a.toSeq.map(elementConverter)
case s: Seq[_] => s.map(elementConverter)
case i: JavaIterable[_] => {
val iter = i.iterator
var convertedIterable: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter(item)
}
convertedIterable
}
case null => null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ private [sql] object JavaTypeInference {
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
(ArrayType(dataType, nullable), true)

case _ if iterableType.isAssignableFrom(typeToken) =>
val (dataType, nullable) = inferDataType(elementType(typeToken))
(ArrayType(dataType, nullable), true)

case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -112,6 +113,7 @@ public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");

public double getA() {
return a;
Expand All @@ -124,6 +126,10 @@ public Integer[] getB() {
public Map<String, int[]> getC() {
return c;
}

public List<String> getD() {
return d;
}
}

@Test
Expand All @@ -142,7 +148,10 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(
new StructField("c", mapType, true, Metadata.empty()),
schema.apply("c"));
Row first = df.select("a", "b", "c").first();
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -155,6 +164,11 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertArrayEquals(
bean.getC().get("hello"),
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
Seq<String> d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

}

0 comments on commit c7e308b

Please sign in to comment.