From 03055268fd9d18f3cafce31ae901881d503c63ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 14:39:55 -0700 Subject: [PATCH 1/4] convert inbound NaN as null --- python/pyspark/sql/tests.py | 4 ++++ .../sql/catalyst/expressions/UnsafeRow.java | 16 ++++++++++++---- .../sql/catalyst/CatalystTypeConverters.scala | 6 +++++- .../expressions/SpecificMutableRow.scala | 16 ++++++++++++---- .../expressions/codegen/GenerateProjection.scala | 9 ++++++++- .../sql/catalyst/expressions/literals.scala | 6 ++++-- .../spark/sql/catalyst/expressions/math.scala | 1 + .../spark/sql/catalyst/expressions/rows.scala | 15 +++++++++------ .../expressions/LiteralExpressionSuite.scala | 3 ++- .../apache/spark/sql/execution/pythonUDFs.scala | 6 ++++-- .../org/apache/spark/sql/json/JsonRDD.scala | 5 +++-- 11 files changed, 64 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 66827d48850d9..f555880eb4fc5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -713,6 +713,10 @@ def test_decimal(self): row = df2.first() self.assertEqual(row[0], Decimal("3.14159")) + def test_nan(self): + df = self.sqlCtx.createDataFrame([(float('nan'),)], ['a']) + self.assertEqual(None, df.first()[0]) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index edb7202245289..188e0f63ed2db 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -213,8 +213,12 @@ public void setLong(int ordinal, long value) { @Override public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); - setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + if (Double.isNaN(value)) { + setNullAt(ordinal); + } else { + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + } } @Override @@ -241,8 +245,12 @@ public void setByte(int ordinal, byte value) { @Override public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); - setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + if (Float.isNaN(value)) { + setNullAt(ordinal); + } else { + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + } } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8f63d2120ad0e..d3e224e7c28c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -314,7 +314,7 @@ object CatalystTypeConverters { private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue - final override def toCatalystImpl(scalaValue: T): Any = scalaValue + override def toCatalystImpl(scalaValue: T): Any = scalaValue } private object BooleanConverter extends PrimitiveConverter[Boolean] { @@ -339,10 +339,14 @@ object CatalystTypeConverters { private object FloatConverter extends PrimitiveConverter[Float] { override def toScalaImpl(row: InternalRow, column: Int): Float = row.getFloat(column) + override def toCatalystImpl(scalaValue: Float): Any = + if (java.lang.Float.isNaN(scalaValue)) null else scalaValue } private object DoubleConverter extends PrimitiveConverter[Double] { override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column) + override def toCatalystImpl(scalaValue: Double): Any = + if (java.lang.Double.isNaN(scalaValue)) null else scalaValue } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index efa24710a5a67..8a902f213b368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -259,8 +259,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setFloat(ordinal: Int, value: Float): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableFloat] - currentValue.isNull = false - currentValue.value = value + if (java.lang.Float.isNaN(value)) { + currentValue.isNull = true + } else { + currentValue.isNull = false + currentValue.value = value + } } override def getFloat(i: Int): Float = { @@ -279,8 +283,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setDouble(ordinal: Int, value: Double): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableDouble] - currentValue.isNull = false - currentValue.value = value + if (java.lang.Double.isNaN(value)) { + currentValue.isNull = true + } else { + currentValue.isNull = false + currentValue.value = value + } } override def getDouble(i: Int): Double = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc16599..4b0347fcc7d6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -98,7 +98,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificMutatorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") + Some(jt match { + case "float" => + s"case $i: {if (Float.isNaN(value)) nullBits[$i] = true else c$i = value; return;}" + case "double" => + s"case $i: {if (Double.isNaN(value)) nullBits[$i] = true else c$i = value; return;}" + case _ => + s"case $i: {c$i = value; return;}" + }) case _ => None }.mkString("\n ") if (cases.length > 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a7a7ae440036..7cc5e13024305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -30,8 +30,10 @@ object Literal { def apply(v: Any): Literal = v match { case i: Int => Literal(i, IntegerType) case l: Long => Literal(l, LongType) - case d: Double => Literal(d, DoubleType) - case f: Float => Literal(f, FloatType) + case d: Double => + if (java.lang.Double.isNaN(d)) Literal(null, DoubleType) else Literal(d, DoubleType) + case f: Float => + if (java.lang.Float.isNaN(f)) Literal(null, FloatType) else Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..171514bbcada4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -96,6 +96,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def toString: String = s"$name($left, $right)" override def dataType: DataType = DoubleType + override def nullable: Boolean = true protected override def nullSafeEval(input1: Any, input2: Any): Any = { val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8b472a529e5c9..6928d2a544034 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -32,17 +32,20 @@ abstract class MutableRow extends InternalRow { def update(i: Int, value: Any) // default implementation (slow) - def setInt(i: Int, value: Int): Unit = { update(i, value) } - def setLong(i: Int, value: Long): Unit = { update(i, value) } - def setDouble(i: Int, value: Double): Unit = { update(i, value) } def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } - def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setShort(i: Int, value: Short): Unit = { update(i, value) } + def setInt(i: Int, value: Int): Unit = { update(i, value) } + def setLong(i: Int, value: Long): Unit = { update(i, value) } + def setFloat(i: Int, value: Float): Unit = { + if (java.lang.Float.isNaN(value)) setNullAt(i) else update(i, value) + } + def setDouble(i: Int, value: Double): Unit = { + if (java.lang.Double.isNaN(value)) setNullAt(i) else update(i, value) + } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } - override def copy(): InternalRow = { val arr = new Array[Any](length) var i = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index d924ff7a102f6..2dcac5bf6d04c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -64,7 +64,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) checkEvaluation(Literal(Float.MinValue), Float.MinValue) checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) - + checkEvaluation(Literal(Double.NaN), null) + checkEvaluation(Literal(Float.NaN), null) } test("string literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 1c8130b07c7fb..ed45e1b8de410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -181,9 +181,11 @@ object EvaluatePython { case (c: Int, LongType) => c.toLong case (c: Long, LongType) => c - case (c: Double, FloatType) => c.toFloat + case (c: Double, FloatType) => + if (java.lang.Double.isNaN(c)) null else c.toFloat - case (c: Double, DoubleType) => c + case (c: Double, DoubleType) => + if (java.lang.Double.isNaN(c)) null else c case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b392a51bf7dce..10fb34dfea627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -333,11 +333,12 @@ private[sql] object JsonRDD extends Logging { } } - private def toDouble(value: Any): Double = { + private def toDouble(value: Any): java.lang.Double = { value match { case value: java.lang.Integer => value.asInstanceOf[Int].toDouble case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] + case value: java.lang.Double => + if (java.lang.Double.isNaN(value)) null else value.asInstanceOf[Double] } } From c03155798f5270fa02e9b8d170665b9c4ffd5d49 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 15:19:46 -0700 Subject: [PATCH 2/4] fix codegen --- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 4b0347fcc7d6f..3425c527bef02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -100,9 +100,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case (e, i) if ctx.javaType(e.dataType) == jt => Some(jt match { case "float" => - s"case $i: {if (Float.isNaN(value)) nullBits[$i] = true else c$i = value; return;}" + s"case $i: {if (Float.isNaN(value)) nullBits[$i] = true; else c$i = value; return;}" case "double" => - s"case $i: {if (Double.isNaN(value)) nullBits[$i] = true else c$i = value; return;}" + s"case $i: {if (Double.isNaN(value)) nullBits[$i] = true; else c$i = value; return;}" case _ => s"case $i: {c$i = value; return;}" }) From c8f391df49abfef5fd53088dbba4292777163fc9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 16:55:21 -0700 Subject: [PATCH 3/4] fix test --- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..98e5b47705cfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -104,7 +104,7 @@ class DataFrameStatSuite extends SparkFunSuite { assert(ct2.collect().length === 4) val ct3 = df.stat.crosstab("3", "2") assert(ct3.schema.fields.length === 6) - assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("null")) assert(ct3.schema.fieldNames.contains("Infinity")) assert(ct3.schema.fieldNames.contains("-Infinity")) assert(ct3.collect().length === 4) From c37388287c26a8f00f53a291b2881d8b26b8ea8b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 13 Jul 2015 10:46:35 -0700 Subject: [PATCH 4/4] disable a few tests that Hive generate NaN --- .../hive/execution/HiveCompatibilitySuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..8275c6cb3d6ce 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -257,7 +257,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", "timestamp_2", - "timestamp_udf" + "timestamp_udf", + + // Spark SQL use NULL for NaN + "udf7", + "udf_acos", + "udf_asin", + "udf_xpath_double", + "udf_xpath_float" ) /** @@ -819,19 +826,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity "udf_atan", "udf_avg", "udf_bigint", @@ -958,8 +962,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_xpath", "udf_xpath_boolean", - "udf_xpath_double", - "udf_xpath_float", "udf_xpath_int", "udf_xpath_long", "udf_xpath_short",