Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
imback82 committed Jan 31, 2020
1 parent 4c3c1d6 commit 7ea21ef
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def drop(): DataFrame = drop("any", df.columns)
def drop(): DataFrame = drop0("any", outputAttributes)

/**
* Returns a new `DataFrame` that drops rows containing null or NaN values.
Expand All @@ -51,7 +51,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def drop(how: String): DataFrame = drop(how, df.columns)
def drop(how: String): DataFrame = drop0(how, outputAttributes)

/**
* Returns a new `DataFrame` that drops rows containing any null or NaN values
Expand Down Expand Up @@ -90,11 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(how: String, cols: Seq[String]): DataFrame = {
how.toLowerCase(Locale.ROOT) match {
case "any" => drop(cols.size, cols)
case "all" => drop(1, cols)
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
}
drop0(how, toAttributes(cols))
}

/**
Expand All @@ -120,10 +116,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
drop0(minNonNulls, toAttributes(cols))
}

/**
Expand Down Expand Up @@ -488,6 +481,23 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.queryExecution.analyzed.output
}

private def drop0(how: String, cols: Seq[Attribute]): DataFrame = {
how.toLowerCase(Locale.ROOT) match {
case "any" => drop0(cols.size, cols)
case "all" => drop0(1, cols)
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
}
}

private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(
minNonNulls,
outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) })
df.filter(Column(predicate))
}

/**
* Returns a new `DataFrame` that replaces null or NaN values in the specified
* columns. If a specified column is not a numeric, string or boolean column,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
}
}

test("fill with col(*)") {
test("fill/drop with col(*)") {
val df = createDF()
// If columns are specified with "*", they are ignored.
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
checkAnswer(df.na.drop("any", Seq("*")), df.collect())
}

test("fill with nested columns") {
test("fill/drop with nested columns") {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
Expand All @@ -263,8 +264,9 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select("c1.c1-1"),
Row(null) :: Row("b1") :: Row(null) :: Nil)

// Nested columns are ignored for fill().
// Nested columns are ignored for fill() and drop().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data)
}

test("replace") {
Expand Down Expand Up @@ -394,4 +396,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
df.na.fill("hello"),
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
}

test("SPARK-30065: duplicate names are allowed for drop() if column names are not specified.") {
val left = Seq(("1", null), ("3", "4"), ("5", "6")).toDF("col1", "col2")
val right = Seq(("1", "2"), ("3", null), ("5", "6")).toDF("col1", "col2")
val df = left.join(right, Seq("col1"))

// If column names are specified, the following fails due to ambiguity.
val exception = intercept[AnalysisException] {
df.na.drop("any", Seq("col2"))
}
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))

// If column names are not specified, drop() is applied to all the eligible rows.
checkAnswer(
df.na.drop("any"),
Row("5", "6", "6") :: Nil)
}
}

0 comments on commit 7ea21ef

Please sign in to comment.