Skip to content

Commit

Permalink
[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type i…
Browse files Browse the repository at this point in the history
…nference for POJOs and Java collections

Hello : Can you help check this PR? I am adding support for the java.math.BigInteger for java bean code path. I saw internally spark is converting the BigInteger to BigDecimal in ColumnType.scala and CatalystRowConverter.scala. I use the similar way and convert the BigInteger to the BigDecimal. .

Author: Kevin Yu <qyu@us.ibm.com>

Closes #10125 from kevinyu98/working_on_spark-11827.

(cherry picked from commit 17591d9)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
kevinyu98 authored and cloud-fan committed May 20, 2016
1 parent c21c691 commit e6810e9
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst

import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
Expand Down Expand Up @@ -326,6 +327,7 @@ object CatalystTypeConverters {
val decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
if (decimal.changePrecision(dataType.precision, dataType.scale)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[BigDecimal] =>
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))

case t if t <:< localTypeOf[java.math.BigInteger] =>
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))

case t if t <:< localTypeOf[scala.math.BigInt] =>
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t

Expand Down Expand Up @@ -592,6 +598,20 @@ object ScalaReflection extends ScalaReflection {
"apply",
inputObject :: Nil)

case t if t <:< localTypeOf[java.math.BigInteger] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)

case t if t <:< localTypeOf[scala.math.BigInt] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)

case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case t if t <:< localTypeOf[java.lang.Long] =>
Expand Down Expand Up @@ -736,6 +756,10 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigInteger] =>
Schema(DecimalType.BigIntDecimal, nullable = true)
case t if t <:< localTypeOf[scala.math.BigInt] =>
Schema(DecimalType.BigIntDecimal, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.types

import java.math.{MathContext, RoundingMode}
import java.math.{BigInteger, MathContext, RoundingMode}

import org.apache.spark.annotation.DeveloperApi

Expand Down Expand Up @@ -128,6 +128,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
this
}

/**
* Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0.
*/
def set(bigintval: BigInteger): Decimal = {
try {
this.decimalVal = null
this.longVal = bigintval.longValueExact()
this._precision = DecimalType.MAX_PRECISION
this._scale = 0
this
}
catch {
case e: ArithmeticException =>
throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal")
}
}

/**
* Set this Decimal to the given Decimal value.
*/
Expand Down Expand Up @@ -155,6 +172,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}

def toScalaBigInt: BigInt = BigInt(toLong)

def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)

def toUnscaledLong: Long = {
if (decimalVal.ne(null)) {
decimalVal.underlying().unscaledValue().longValue()
Expand Down Expand Up @@ -371,6 +392,10 @@ object Decimal {

def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)

def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value)

def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger)

def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)

Expand All @@ -387,6 +412,8 @@ object Decimal {
value match {
case j: java.math.BigDecimal => apply(j)
case d: BigDecimal => apply(d)
case k: scala.math.BigInt => apply(k)
case l: java.math.BigInteger => apply(l)
case d: Decimal => d
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
private[sql] val LongDecimal = DecimalType(20, 0)
private[sql] val FloatDecimal = DecimalType(14, 7)
private[sql] val DoubleDecimal = DecimalType(30, 15)
private[sql] val BigIntDecimal = DecimalType(38, 0)

private[sql] def forType(dataType: DataType): DecimalType = dataType match {
case ByteType => ByteDecimal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.encoders

import java.math.BigInteger
import java.sql.{Date, Timestamp}
import java.util.Arrays

Expand Down Expand Up @@ -109,7 +110,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {

encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")

encodeDecodeTest(BigInt("23134123123"), "scala biginteger")
encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger")
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")

encodeDecodeTest("hello", "string")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.net.URISyntaxException;
import java.net.URL;
import java.util.*;
import java.math.BigInteger;
import java.math.BigDecimal;

import scala.collection.JavaConverters;
import scala.collection.Seq;
Expand Down Expand Up @@ -130,6 +132,7 @@ public static class Bean implements Serializable {
private Integer[] b = { 0, 1 };
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");
private BigInteger e = new BigInteger("1234567");

public double getA() {
return a;
Expand All @@ -146,6 +149,8 @@ public Map<String, int[]> getC() {
public List<String> getD() {
return d;
}

public BigInteger getE() { return e; }
}

void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
Expand All @@ -163,7 +168,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()),
schema.apply("e"));
Row first = df.select("a", "b", "c", "d", "e").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -182,6 +189,8 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
// Java.math.BigInteger is equavient to Spark Decimal(38,0)
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ case class ReflectData(
decimalField: java.math.BigDecimal,
date: Date,
timestampField: Timestamp,
seqInt: Seq[Int])
seqInt: Seq[Int],
javaBigInt: java.math.BigInteger,
scalaBigInt: scala.math.BigInt)

case class NullReflectData(
intField: java.lang.Integer,
Expand Down Expand Up @@ -77,13 +79,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {

test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3),
new java.math.BigInteger("1"), scala.math.BigInt(1))
Seq(data).toDF().createOrReplaceTempView("reflectData")

assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1),
new java.math.BigDecimal(1)))
}

test("query case class RDD with nulls") {
Expand Down

0 comments on commit e6810e9

Please sign in to comment.