Skip to content

Commit

Permalink
[SPARK-29883][SQL] Implement a helper method for aliasing bool_and() …
Browse files Browse the repository at this point in the history
…and bool_or()

### What changes were proposed in this pull request?
This PR introduces a method `expressionWithAlias` in class `FunctionRegistry` which is used to register function's constructor. Currently, `expressionWithAlias` is used to register `BoolAnd` & `BoolOr`.

### Why are the changes needed?
Error message is wrong when alias name is used for `BoolAnd` & `BoolOr`.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
Tested manually.

For query,
`select every('true');`

Output before this PR,

> Error in query: cannot resolve 'bool_and('true')' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [string].; line 1 pos 7;

After this PR,

> Error in query: cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7;

Closes #26712 from amanomer/29883.

Authored-by: Aman Omer <amanomer1996@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
amanomer authored and cloud-fan committed Dec 9, 2019
1 parent a57bbf2 commit dcea7a4
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,11 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
expression[BoolAnd]("every"),
expression[BoolAnd]("bool_and"),
expression[BoolOr]("any"),
expression[BoolOr]("some"),
expression[BoolOr]("bool_or"),
expressionWithAlias[BoolAnd]("every"),
expressionWithAlias[BoolAnd]("bool_and"),
expressionWithAlias[BoolOr]("any"),
expressionWithAlias[BoolOr]("some"),
expressionWithAlias[BoolOr]("bool_or"),

// string functions
expression[Ascii]("ascii"),
Expand Down Expand Up @@ -590,12 +590,12 @@ object FunctionRegistry {
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw new AnalysisException(e.getCause.getMessage)
try {
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
} catch {
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
case e: Exception => throw new AnalysisException(e.getCause.getMessage)
}
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
Expand All @@ -618,19 +618,55 @@ object FunctionRegistry {
}
throw new AnalysisException(invalidArgumentsMsg)
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw new AnalysisException(e.getCause.getMessage)
try {
f.newInstance(expressions : _*).asInstanceOf[Expression]
} catch {
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
case e: Exception => throw new AnalysisException(e.getCause.getMessage)
}
}
}

(name, (expressionInfo[T](name), builder))
}

private def expressionWithAlias[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
val constructors = tag.runtimeClass.getConstructors
.filter(_.getParameterTypes.head == classOf[String])
assert(constructors.length == 1)
val builder = (expressions: Seq[Expression]) => {
val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
val validParametersCount = constructors
.filter(_.getParameterTypes.tail.forall(_ == classOf[Expression]))
.map(_.getParameterCount - 1).distinct.sorted
val invalidArgumentsMsg = if (validParametersCount.length == 0) {
s"Invalid arguments for function $name"
} else {
val expectedNumberOfParameters = if (validParametersCount.length == 1) {
validParametersCount.head.toString
} else {
validParametersCount.init.mkString("one of ", ", ", " and ") +
validParametersCount.last
}
s"Invalid number of arguments for function $name. " +
s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}"
}
throw new AnalysisException(invalidArgumentsMsg)
}
try {
f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression]
} catch {
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
case e: Exception => throw new AnalysisException(e.getCause.getMessage)
}
}
(name, (expressionInfo[T](name), builder))
}

/**
* Creates a function registry lookup entry for cast aliases (SPARK-16730).
* For example, if name is "int", and dataType is IntegerType, this means int(x) would become
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
false
""",
since = "3.0.0")
case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_and"
case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = funcName
}

@ExpressionDescription(
Expand All @@ -68,6 +68,6 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
false
""",
since = "3.0.0")
case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_or"
case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = funcName
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(arg) => Max(arg)
case BoolAnd(arg) => Min(arg)
case BoolOr(_, arg) => Max(arg)
case BoolAnd(_, arg) => Min(arg)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
assertSuccess(Min('arrayField))
assertSuccess(new BoolAnd('booleanField))
assertSuccess(new BoolOr('booleanField))
assertSuccess(new BoolAnd("bool_and", 'booleanField))
assertSuccess(new BoolOr("bool_or", 'booleanField))

assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
Expand Down
26 changes: 13 additions & 13 deletions sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -293,31 +293,31 @@ struct<>
-- !query 31
SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE 1 = 0
-- !query 31 schema
struct<bool_and(v):boolean,bool_or(v):boolean,bool_or(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
struct<every(v):boolean,some(v):boolean,any(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
-- !query 31 output
NULL NULL NULL NULL NULL


-- !query 32
SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE k = 4
-- !query 32 schema
struct<bool_and(v):boolean,bool_or(v):boolean,bool_or(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
struct<every(v):boolean,some(v):boolean,any(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
-- !query 32 output
NULL NULL NULL NULL NULL


-- !query 33
SELECT every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg WHERE k = 5
-- !query 33 schema
struct<bool_and(v):boolean,bool_or(v):boolean,bool_or(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
struct<every(v):boolean,some(v):boolean,any(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
-- !query 33 output
false true true false true


-- !query 34
SELECT k, every(v), some(v), any(v), bool_and(v), bool_or(v) FROM test_agg GROUP BY k
-- !query 34 schema
struct<k:int,bool_and(v):boolean,bool_or(v):boolean,bool_or(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
struct<k:int,every(v):boolean,some(v):boolean,any(v):boolean,bool_and(v):boolean,bool_or(v):boolean>
-- !query 34 output
1 false true true false true
2 true true true true true
Expand All @@ -329,7 +329,7 @@ struct<k:int,bool_and(v):boolean,bool_or(v):boolean,bool_or(v):boolean,bool_and(
-- !query 35
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false
-- !query 35 schema
struct<k:int,bool_and(v):boolean>
struct<k:int,every(v):boolean>
-- !query 35 output
1 false
3 false
Expand All @@ -339,7 +339,7 @@ struct<k:int,bool_and(v):boolean>
-- !query 36
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL
-- !query 36 schema
struct<k:int,bool_and(v):boolean>
struct<k:int,every(v):boolean>
-- !query 36 output
4 NULL

Expand Down Expand Up @@ -380,7 +380,7 @@ SELECT every(1)
struct<>
-- !query 39 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_and(1)' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [int].; line 1 pos 7
cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7


-- !query 40
Expand All @@ -389,7 +389,7 @@ SELECT some(1S)
struct<>
-- !query 40 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_or(1S)' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [smallint].; line 1 pos 7
cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7


-- !query 41
Expand All @@ -398,7 +398,7 @@ SELECT any(1L)
struct<>
-- !query 41 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_or(1L)' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [bigint].; line 1 pos 7
cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7


-- !query 42
Expand All @@ -407,7 +407,7 @@ SELECT every("true")
struct<>
-- !query 42 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_and('true')' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [string].; line 1 pos 7
cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7


-- !query 43
Expand All @@ -431,7 +431,7 @@ cannot resolve 'bool_or(1.0D)' due to data type mismatch: Input to function 'boo
-- !query 45
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 45 schema
struct<k:int,v:boolean,bool_and(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<k:int,v:boolean,every(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 45 output
1 false false
1 true false
Expand All @@ -448,7 +448,7 @@ struct<k:int,v:boolean,bool_and(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIR
-- !query 46
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 46 schema
struct<k:int,v:boolean,bool_or(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<k:int,v:boolean,some(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 46 output
1 false false
1 true true
Expand All @@ -465,7 +465,7 @@ struct<k:int,v:boolean,bool_or(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRS
-- !query 47
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 47 schema
struct<k:int,v:boolean,bool_or(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<k:int,v:boolean,any(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 47 output
1 false false
1 true true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,31 +293,31 @@ struct<>
-- !query 31
SELECT udf(every(v)), udf(some(v)), any(v) FROM test_agg WHERE 1 = 0
-- !query 31 schema
struct<CAST(udf(cast(bool_and(v) as string)) AS BOOLEAN):boolean,CAST(udf(cast(bool_or(v) as string)) AS BOOLEAN):boolean,bool_or(v):boolean>
struct<CAST(udf(cast(every(every, v) as string)) AS BOOLEAN):boolean,CAST(udf(cast(some(some, v) as string)) AS BOOLEAN):boolean,any(v):boolean>
-- !query 31 output
NULL NULL NULL


-- !query 32
SELECT udf(every(udf(v))), some(v), any(v) FROM test_agg WHERE k = 4
-- !query 32 schema
struct<CAST(udf(cast(bool_and(cast(udf(cast(v as string)) as boolean)) as string)) AS BOOLEAN):boolean,bool_or(v):boolean,bool_or(v):boolean>
struct<CAST(udf(cast(every(every, cast(udf(cast(v as string)) as boolean)) as string)) AS BOOLEAN):boolean,some(v):boolean,any(v):boolean>
-- !query 32 output
NULL NULL NULL


-- !query 33
SELECT every(v), udf(some(v)), any(v) FROM test_agg WHERE k = 5
-- !query 33 schema
struct<bool_and(v):boolean,CAST(udf(cast(bool_or(v) as string)) AS BOOLEAN):boolean,bool_or(v):boolean>
struct<every(v):boolean,CAST(udf(cast(some(some, v) as string)) AS BOOLEAN):boolean,any(v):boolean>
-- !query 33 output
false true true


-- !query 34
SELECT udf(k), every(v), udf(some(v)), any(v) FROM test_agg GROUP BY udf(k)
-- !query 34 schema
struct<CAST(udf(cast(k as string)) AS INT):int,bool_and(v):boolean,CAST(udf(cast(bool_or(v) as string)) AS BOOLEAN):boolean,bool_or(v):boolean>
struct<CAST(udf(cast(k as string)) AS INT):int,every(v):boolean,CAST(udf(cast(some(some, v) as string)) AS BOOLEAN):boolean,any(v):boolean>
-- !query 34 output
1 false true true
2 true true true
Expand All @@ -329,7 +329,7 @@ struct<CAST(udf(cast(k as string)) AS INT):int,bool_and(v):boolean,CAST(udf(cast
-- !query 35
SELECT udf(k), every(v) FROM test_agg GROUP BY k HAVING every(v) = false
-- !query 35 schema
struct<CAST(udf(cast(k as string)) AS INT):int,bool_and(v):boolean>
struct<CAST(udf(cast(k as string)) AS INT):int,every(v):boolean>
-- !query 35 output
1 false
3 false
Expand All @@ -339,7 +339,7 @@ struct<CAST(udf(cast(k as string)) AS INT):int,bool_and(v):boolean>
-- !query 36
SELECT udf(k), udf(every(v)) FROM test_agg GROUP BY udf(k) HAVING every(v) IS NULL
-- !query 36 schema
struct<CAST(udf(cast(k as string)) AS INT):int,CAST(udf(cast(bool_and(v) as string)) AS BOOLEAN):boolean>
struct<CAST(udf(cast(k as string)) AS INT):int,CAST(udf(cast(every(every, v) as string)) AS BOOLEAN):boolean>
-- !query 36 output
4 NULL

Expand Down Expand Up @@ -380,7 +380,7 @@ SELECT every(udf(1))
struct<>
-- !query 39 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_and(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [int].; line 1 pos 7
cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7


-- !query 40
Expand All @@ -389,7 +389,7 @@ SELECT some(udf(1S))
struct<>
-- !query 40 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_or(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [smallint].; line 1 pos 7
cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7


-- !query 41
Expand All @@ -398,7 +398,7 @@ SELECT any(udf(1L))
struct<>
-- !query 41 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_or(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: Input to function 'bool_or' should have been boolean, but it's [bigint].; line 1 pos 7
cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7


-- !query 42
Expand All @@ -407,13 +407,13 @@ SELECT udf(every("true"))
struct<>
-- !query 42 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bool_and('true')' due to data type mismatch: Input to function 'bool_and' should have been boolean, but it's [string].; line 1 pos 11
cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 11


-- !query 43
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 43 schema
struct<k:int,v:boolean,bool_and(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<k:int,v:boolean,every(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 43 output
1 false false
1 true false
Expand All @@ -430,7 +430,7 @@ struct<k:int,v:boolean,bool_and(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIR
-- !query 44
SELECT k, udf(udf(v)), some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 44 schema
struct<k:int,CAST(udf(cast(cast(udf(cast(v as string)) as boolean) as string)) AS BOOLEAN):boolean,bool_or(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<k:int,CAST(udf(cast(cast(udf(cast(v as string)) as boolean) as string)) AS BOOLEAN):boolean,some(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 44 output
1 false false
1 true true
Expand All @@ -447,7 +447,7 @@ struct<k:int,CAST(udf(cast(cast(udf(cast(v as string)) as boolean) as string)) A
-- !query 45
SELECT udf(udf(k)), v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
-- !query 45 schema
struct<CAST(udf(cast(cast(udf(cast(k as string)) as int) as string)) AS INT):int,v:boolean,bool_or(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
struct<CAST(udf(cast(cast(udf(cast(k as string)) as int) as string)) AS INT):int,v:boolean,any(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):boolean>
-- !query 45 output
1 false false
1 true true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class ExplainSuite extends QueryTest with SharedSparkSession {
// plan should show the rewritten aggregate expression.
val df = sql("SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k")
checkKeywordsExistsInExplain(df,
"Aggregate [k#x], [k#x, min(v#x) AS bool_and(v)#x, max(v#x) AS bool_or(v)#x, " +
"max(v#x) AS bool_or(v)#x]")
"Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, " +
"max(v#x) AS any(v)#x]")
}
}

Expand Down

0 comments on commit dcea7a4

Please sign in to comment.