diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 986662c951144..0041ccf775573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -281,6 +281,7 @@ object FunctionRegistry { expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), + expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), expression[First]("first"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala new file mode 100644 index 0000000000000..d31355cd022fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType} + +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of `TRUE` values for the expression. + """, + examples = """ + Examples: + > SELECT _FUNC_(col % 2 = 0) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col); + 2 + > SELECT _FUNC_(col IS NULL) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col); + 1 + """, + since = "3.0.0") +case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes { + override def prettyName: String = "count_if" + + override def children: Seq[Expression] = Seq(predicate) + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match { + case BooleanType => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}" + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index c213a21ebaa6a..69ba76827c781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -34,18 +34,19 @@ import org.apache.spark.sql.types._ * Finds all the expressions that are unevaluable and replace/rewrite them with semantically * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: * 1) [[RuntimeReplaceable]] expressions - * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any + * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any, CountIf * This is mainly used to provide compatibility with other databases. * Few examples are: * we use this to support "nvl" by replacing it with "coalesce". * we use this to replace Every and Any with Min and Max respectively. * * TODO: In future, explore an option to replace aggregate functions similar to - * how RruntimeReplaceable does. + * how RuntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child + case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral)) case SomeAgg(arg) => Max(arg) case AnyAgg(arg) => Max(arg) case EveryAgg(arg) => Min(arg) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d89ecc22a7c00..e005a3e9a258e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -894,4 +894,44 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { error.message.contains("function min_by does not support ordering on type map")) } } + + test("count_if") { + withTempView("tempView") { + Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)), + ("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6))) + .toDF("x", "y") + .createOrReplaceTempView("tempView") + + checkAnswer( + sql("SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + + "COUNT_IF(y IS NULL) FROM tempView"), + Row(0L, 3L, 3L, 2L)) + + checkAnswer( + sql("SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + + "COUNT_IF(y IS NULL) FROM tempView GROUP BY x"), + Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 1"), + Row("a")) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 2"), + Row("b")) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y IS NULL) > 0"), + Row("a") :: Row("b") :: Nil) + + checkAnswer( + sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), + Nil) + + val error = intercept[AnalysisException] { + sql("SELECT COUNT_IF(x) FROM tempView") + } + assert(error.message.contains("function count_if requires boolean type")) + } + } }