Skip to content

Commit

Permalink
[SPARK-28818][SQL] Respect source column nullability in the arrays cr…
Browse files Browse the repository at this point in the history
…eated by `freqItems()`

### What changes were proposed in this pull request?
This PR replaces the hard-coded non-nullability of the array elements returned by `freqItems()` with a nullability that reflects the original schema. Essentially [the functional change](https://github.com/apache/spark/pull/25575/files#diff-bf59bb9f3dc351f5bf6624e5edd2dcf4R122) to the schema generation is:
```
StructField(name + "_freqItems", ArrayType(dataType, false))
```
Becomes:
```
StructField(name + "_freqItems", ArrayType(dataType, originalField.nullable))
```

Respecting the original nullability prevents issues when Spark depends on `ArrayType`'s `containsNull` being accurate. The example that uncovered this is calling `collect()` on the dataframe (see [ticket](https://issues.apache.org/jira/browse/SPARK-28818) for full repro). Though it's likely that there a several places where this could cause a problem.

I've also refactored a small amount of the surrounding code to remove some unnecessary steps and group together related operations.

### Why are the changes needed?
I think it's pretty clear why this change is needed. It fixes a bug that currently prevents users from calling `df.freqItems.collect()` along with potentially causing other, as yet unknown, issues.

### Does this PR introduce any user-facing change?
Nullability of columns when calling freqItems on them is now respected after the change.

### How was this patch tested?
I added a test that specifically tests the carry-through of the nullability as well as explicitly calling `collect()` to catch the exact regression that was observed. I also ran the test against the old version of the code and it fails as expected.

Closes #25575 from MGHawes/mhawes/SPARK-28818.

Authored-by: Matt Hawes <mhawes@palantir.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
Matt Hawes authored and HyukjinKwon committed Aug 29, 2019
1 parent 7452786 commit 137b20b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
Expand Up @@ -89,11 +89,6 @@ object FrequentItems extends Logging {
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
val originalSchema = df.schema
val colInfo: Array[(String, DataType)] = cols.map { name =>
val index = originalSchema.fieldIndex(name)
(name, originalSchema.fields(index).dataType)
}.toArray

val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
seqOp = (counts, row) => {
Expand All @@ -117,10 +112,16 @@ object FrequentItems extends Logging {
)
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}

val originalSchema = df.schema
val outputCols = cols.map { name =>
val index = originalSchema.fieldIndex(name)
val originalField = originalSchema.fields(index)

// append frequent Items to the column name for easy debugging
StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable))
}.toArray

val schema = StructType(outputCols).toAttributes
Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}

class DataFrameStatSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -366,6 +366,30 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-28818: Respect original column nullability in `freqItems`") {
val rows = spark.sparkContext.parallelize(
Seq(Row("1", "a"), Row("2", null), Row("3", "b"))
)
val schema = StructType(Seq(
StructField("non_null", StringType, false),
StructField("nullable", StringType, true)
))
val df = spark.createDataFrame(rows, schema)

val result = df.stat.freqItems(df.columns)

val nonNullableDataType = result.schema("non_null_freqItems").dataType.asInstanceOf[ArrayType]
val nullableDataType = result.schema("nullable_freqItems").dataType.asInstanceOf[ArrayType]

assert(nonNullableDataType.containsNull == false)
assert(nullableDataType.containsNull == true)
// Original bug was a NullPointerException exception caused by calling collect(), test for this
val resultRow = result.collect()(0)

assert(resultRow.get(0).asInstanceOf[Seq[String]].toSet == Set("1", "2", "3"))
assert(resultRow.get(1).asInstanceOf[Seq[String]].toSet == Set("a", "b", null))
}

test("sampleBy") {
val df = spark.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
Expand Down

0 comments on commit 137b20b

Please sign in to comment.