Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29883][SQL] Implement a helper method for aliasing bool_and() and bool_or() #26712

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,8 @@ object FunctionRegistry {
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
val f = constructors.find(e => e.getParameterTypes.toSeq == params
|| e.getParameterTypes.head == classOf[String]).getOrElse {
Copy link
Contributor

@cloud-fan cloud-fan Dec 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it's less hacky to create a new expressionWithAlias method, with only the necessary logic

def expressionWithAlias ... = {
  val constructors = tag.runtimeClass.getConstructors
    .filter(c => e.getParameterTypes.head == classOf[String])
  assert(constructors.length == 1)
  try {
    constructors.head.newInstance(name, expressions : _*).asInstanceOf[Expression]
  } ...
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we don't even need the MultiNamedExpression trait. We just need to register bool_and, bool_or with expressionWithAlias

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan updated as per your suggestions.

val validParametersCount = constructors
.filter(_.getParameterTypes.forall(_ == classOf[Expression]))
.map(_.getParameterCount).distinct.sorted
Expand All @@ -618,7 +619,13 @@ object FunctionRegistry {
}
throw new AnalysisException(invalidArgumentsMsg)
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
Try{
if (classOf[MultiNamedExpression].isAssignableFrom(f.getDeclaringClass)) {
f.newInstance(name.toString, expressions.head).asInstanceOf[Expression]
} else {
f.newInstance(expressions : _*).asInstanceOf[Expression]
}
} match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
}
}

trait MultiNamedExpression {
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.",
examples = """
Expand All @@ -52,8 +55,9 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
false
""",
since = "3.0.0")
case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_and"
case class BoolAnd(funcName: String, arg: Expression)
extends UnevaluableBooleanAggBase(arg) with MultiNamedExpression {
override def nodeName: String = funcName
}

@ExpressionDescription(
Expand All @@ -68,6 +72,7 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
false
""",
since = "3.0.0")
case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_or"
case class BoolOr(funcName: String, arg: Expression)
extends UnevaluableBooleanAggBase(arg) with MultiNamedExpression {
override def nodeName: String = funcName
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(arg) => Max(arg)
case BoolAnd(arg) => Min(arg)
case BoolOr(_, arg) => Max(arg)
case BoolAnd(_, arg) => Min(arg)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
assertSuccess(Min('arrayField))
assertSuccess(new BoolAnd('booleanField))
assertSuccess(new BoolOr('booleanField))
assertSuccess(new BoolAnd("bool_and", 'booleanField))
assertSuccess(new BoolOr("bool_or", 'booleanField))

assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
Expand Down