Skip to content
Permalink
Browse files

[SPARK-30082][SQL][2.4] Do not replace Zeros when replacing NaNs

### What changes were proposed in this pull request?
Do not cast `NaN` to an `Integer`, `Long`, `Short` or `Byte`. This is because casting `NaN` to those types results in a `0` which erroneously replaces `0`s while only `NaN`s should be replaced.

### Why are the changes needed?
This Scala code snippet:
```
import scala.math;

println(Double.NaN.toLong)
```
returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+
```

### Does this PR introduce any user-facing change?
Yes, after the PR, running the same above code snippet returns the correct expected results:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+
```

### How was this patch tested?

Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double`

Closes #26749 from johnhany97/SPARK-30082-2.4.

Authored-by: John Ayad <johnhany97@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
johnhany97 authored and cloud-fan committed Dec 4, 2019
1 parent 76576b6 commit 663441f5324eee54fa698dc6f42f8f5ced8fce90
@@ -455,7 +455,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
Seq(Literal(source), buildExpr(target))
}.toSeq
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}
@@ -36,6 +36,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("name", "age", "height")
}

def createNaNDF(): DataFrame = {
Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
java.lang.Byte, java.lang.Float, java.lang.Double)](
(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0),
(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN)
).toDF("int", "long", "short", "byte", "float", "double")
}

test("drop") {
val input = createDF()
val rows = input.collect()
@@ -305,4 +313,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
)).na.drop("name" :: Nil).select("name"),
Row("Alice") :: Row("David") :: Nil)
}

test("replace nan with float") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Float.NaN -> 10.0f
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace nan with double") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Double.NaN -> 10.0
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace float with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0f -> Float.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("replace double with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0 -> Double.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
}

0 comments on commit 663441f

Please sign in to comment.
You can’t perform that action at this time.