Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-27425][SQL] Add count_if function #24335

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ object FunctionRegistry {
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[CountIf]("count_if"),
expression[CovPopulation]("covar_pop"),
expression[CovSample]("covar_samp"),
expression[First]("first"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the number of `TRUE` values for the expression.
This function is equivalent to `count(CASE WHEN x THEN 1 END)`.
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
""",
examples = """
Examples:
> SELECT _FUNC_(col % 2 = 0) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col);
2
> SELECT _FUNC_(col IS NULL) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col);
1
""",
since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes {
override def prettyName: String = "count_if"

override def children: Seq[Expression] = Seq(predicate)

override def nullable: Boolean = false

override def dataType: DataType = LongType

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match {
case BooleanType =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@ import org.apache.spark.sql.types._
* Finds all the expressions that are unevaluable and replace/rewrite them with semantically
* equivalent expressions that can be evaluated. Currently we replace two kinds of expressions:
* 1) [[RuntimeReplaceable]] expressions
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any, CountIf
* This is mainly used to provide compatibility with other databases.
* Few examples are:
* we use this to support "nvl" by replacing it with "coalesce".
* we use this to replace Every and Any with Min and Max respectively.
*
* TODO: In future, explore an option to replace aggregate functions similar to
* how RruntimeReplaceable does.
* how RuntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case SomeAgg(arg) => Max(arg)
case AnyAgg(arg) => Max(arg)
case EveryAgg(arg) => Min(arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,4 +894,28 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
error.message.contains("function min_by does not support ordering on type map<int,string>"))
}
}

test("count_if") {
withTempView("tempView") {
Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)),
("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6)))
.toDF("x", "y")
.createOrReplaceTempView("tempView")

checkAnswer(
sql("SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView"),
Row(0L, 3L, 3L, 2L))

checkAnswer(
sql("SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView GROUP BY x"),
Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil)
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved

val error = intercept[AnalysisException] {
sql("SELECT COUNT_IF(x) FROM tempView")
}
assert(error.message.contains("function count_if requires boolean type"))
}
}
}