From 8616924ebb90eaf5e88a6a843dc99744b3dbc2b8 Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Thu, 28 May 2015 19:42:59 +0200 Subject: [PATCH] [SPARK-7886] Add built-in expressions to FunctionRegistry. - ExpressionBuilders is provided with helpers to create a function builder for each Expression. - Built-in functions removed from SqlParser when possible. Added to FunctionRegistry. TO DO: - Decide between the reflection and macro implementations of the expression builder helpers. - Fix Substring (whose constructor is not well suited for the helper). - Apply changes to Hive. --- .../apache/spark/sql/catalyst/SqlParser.scala | 20 +---- .../expressions/DefaultExpressions.scala | 23 ++++++ .../expressions/ExpressionBuilders.scala | 74 +++++++++++++++++++ .../expressions/ExpressionMacros.scala | 62 ++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 6 +- 5 files changed, 165 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DefaultExpressions.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionBuilders.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMacros.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index fc36b9f1f20d2..ed24aa4293f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -277,25 +277,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + ( SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ { case exps => CountDistinct(exps) } | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ { case exp => ApproxCountDistinct(exp) } | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => @@ -304,13 +293,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } ++ elsePart casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DefaultExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DefaultExpressions.scala new file mode 100644 index 0000000000000..86be41988160e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DefaultExpressions.scala @@ -0,0 +1,23 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.ExpressionBuilders._ + +object DefaultExpressions { + val expressions: Map[String,ExpressionBuilder] = Map( + expression[Sum], + expression[Count], + expression[First], + expression[Last], + expression[Average], + expression[Min], + expression[Max], + expression[Upper], + expression[Lower], + expression[If], + expression[Substring], expression[Substring]("SUBSTR"), + expression[Coalesce], + expression[Sqrt], + expression[Abs] + ) + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionBuilders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionBuilders.scala new file mode 100644 index 0000000000000..13463a52f26c6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionBuilders.scala @@ -0,0 +1,74 @@ +package org.apache.spark.sql.catalyst.expressions + +import java.util.Locale + +import scala.language.experimental.macros +import scala.reflect.ClassTag +import scala.util.{Success, Failure, Try} + +object ExpressionBuilders { + + type ExpressionBuilder = (Seq[Expression]) => Expression + + private def camelToUnderscores(name: String) = "[A-Z\\d]".r + .replaceAllIn(name, { m => "_" + m.group(0).toLowerCase(Locale.ENGLISH) }) + + def expression[T <: Expression](name: String)(implicit ev: ClassTag[T]): (String, ExpressionBuilder) = + name -> expressionByReflection[T] + + /* TODO: Substring needs change so that it accepts a 2-arg constructor */ + def expression[T <: Expression](implicit ev: ClassTag[T]) + : (String, ExpressionBuilder) = { + val name = camelToUnderscores(ev.runtimeClass.getSimpleName) + /* XXX: With macros: name -> ExpressionMacros.expressionBuilder[T] */ + name -> expressionByReflection[T] + } + + private def expressionByReflection[T <: Expression](implicit tag: ClassTag[T]): ExpressionBuilder = { + val constructors = tag.runtimeClass.getDeclaredConstructors.toSeq + (expressions: Seq[Expression]) => { + val arity = expressions.size + val validBuilders = constructors.flatMap { c => + val parameterTypes = c.getParameterTypes + if (parameterTypes.size == arity && + parameterTypes.forall(_.getClasses.contains(classOf[Expression]))) { + Some(expressionFixedArity[T](arity)) + } else if (parameterTypes.size == 1 && parameterTypes(0).getClass == classOf[Seq[Expression]]) { + Some(expressionVariableArity[T]) + } else { + None + } + } + val builder = validBuilders.head + builder(expressions) + } + } + + private def expressionVariableArity[T <: Expression](implicit tag: ClassTag[T]): ExpressionBuilder = { + val argTypes = classOf[Seq[Expression]] + val clazz = tag.runtimeClass + val constructor = Try(clazz.getDeclaredConstructor(argTypes)) match { + case Failure(ex : NoSuchMethodException) => + sys.error(s"Did not find a constructor with Seq[Expression] for ${clazz.getCanonicalName}") + case Failure(ex) => throw ex + case Success(c) => c + } + (expressions: Seq[Expression]) => { + constructor.newInstance(expressions).asInstanceOf[Expression] + } + } + + private def expressionFixedArity[T <: Expression](arity: Int)(implicit tag: ClassTag[T]): ExpressionBuilder = { + val argTypes = (1 to arity).map(x => classOf[Expression]) + val constructor = tag.runtimeClass.getDeclaredConstructor(argTypes: _*) + (expressions: Seq[Expression]) => { + if (expressions.size != arity) { + throw new IllegalArgumentException( + s"Invalid number of arguments: ${expressions.size} (must be equal to $arity)" + ) + } + constructor.newInstance(expressions: _*).asInstanceOf[Expression] + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMacros.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMacros.scala new file mode 100644 index 0000000000000..8c63e71735afd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMacros.scala @@ -0,0 +1,62 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.ExpressionBuilders.ExpressionBuilder + +import scala.reflect.macros.Context +import scala.language.experimental.macros + +private[catalyst] object ExpressionMacros { + + def expressionBuilder[T <: Expression]: ExpressionBuilder = + macro ExpressionMacrosImpl.expressionImpl[T] + +} + +object ExpressionMacrosImpl { + + @DeveloperApi + def expressionImpl[T <: Expression](c: Context): c.Expr[ExpressionBuilder] = { + import c.universe._ + val ev1 = implicitly[c.WeakTypeTag[T]] + ev1.tpe.declarations + .filter(_.isMethod) + .map(_.asMethod) + .filter(_.isPrimaryConstructor) + .flatMap({ methodSymbol => + methodSymbol.typeParams match { + case Nil => + Some(Block( + q"""if (expr.nonEmpty) { sys.error("Expressions takes no arguments:") }""", + Apply(Ident(methodSymbol), Nil) + )) + case seq :: Nil if seq.asTerm == newTermName("Seq") => + Some(Block( + + Apply(Ident(methodSymbol), Ident(newTermName("expr")) :: Nil), + Apply(Ident(methodSymbol), Ident(newTermName("expr")) :: Nil) + )) + case seq if seq.forall({ s => s.asTerm == newTermName("Expression") }) => + val args = (0 to seq.size).map({ i => + Apply(Select(Ident(newTermName("expr")), newTermName("apply")), Literal(Constant(i)) :: Nil) + }) + val argNumber = Literal(Constant(args.size)) + val errorMsg = Literal(Constant(s"Expressions takes ${args.size}")) + /* TODO: Add a check just in case there are input expressions without processing */ + Some(Block( + q"""if (expr.size != 0) { sys.error($errorMsg) }""", + Apply(Ident(methodSymbol), args.toList) + )) + case _ => + None + } + }) + .headOption match { + case None => + sys.error("Expression generator requires a constructor accepting Expression... or Seq[Expression]") + case Some(tree) => + c.Expr(q"(expr: Seq[Expression]) => { $tree }") + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1ea596dddff02..c5f93b3ccc482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -122,7 +122,11 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val fr = new SimpleFunctionRegistry(conf) + DefaultExpressions.expressions foreach { case (name, func) => fr.registerFunction(name, func) } + fr + } @transient protected[sql] lazy val analyzer: Analyzer =