From 5117a297c089de36723ac836b85502201c927dba Mon Sep 17 00:00:00 2001 From: shaoxuan-wang Date: Fri, 26 May 2017 10:52:31 +0800 Subject: [PATCH] [FLINK-6725][table] make requiresOver as a contracted method in udagg --- .../table/expressions/aggregations.scala | 3 +- .../table/functions/AggregateFunction.scala | 33 ++++++++++++------- .../utils/UserDefinedFunctionUtils.scala | 18 +++++++++- .../java/utils/UserDefinedAggFunctions.java | 2 -- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index 6d906b9ea2e7f..38c9c0dc3f3ca 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -258,11 +258,12 @@ case class AggFunctionCall( override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val requiresOver = getRequiresOverConfig(aggregateFunction) val sqlAgg = AggSqlFunction(aggregateFunction.getClass.getSimpleName, aggregateFunction, resultType, typeFactory, - aggregateFunction.requiresOver) + requiresOver) sqlAgg } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index f90860b513d35..fcb6bf95c0a1c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -29,12 +29,17 @@ package org.apache.flink.table.functions * There are a few other methods that can be optional to have: * - retract, * - merge, - * - resetAccumulator, and - * - getAccumulatorType. + * - resetAccumulator, + * - getAccumulatorType, + * - getResultType, and + * - requiresOver. * * All these methods muse be declared publicly, not static and named exactly as the names * mentioned above. The methods createAccumulator and getValue are defined in the - * [[AggregateFunction]] functions, while other methods are explained below. + * [[AggregateFunction]] functions, while other methods are explained below. It should be + * also noted that the optional methods merge, resetAccumulator, getAccumulatorType, + * getResultType, and requiresOver cannot be overloaded. If provided, their inputs and outputs must + * be defined exactly same as the below. * * * {{{ @@ -72,7 +77,7 @@ package org.apache.flink.table.functions * custom merge method. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be * merged. - + * * def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit * }}} * @@ -82,7 +87,7 @@ package org.apache.flink.table.functions * dataset grouping aggregate. * * @param accumulator the accumulator which needs to be reset - + * * def resetAccumulator(accumulator: ACC): Unit * }}} * @@ -93,7 +98,7 @@ package org.apache.flink.table.functions * inferred from the instance returned by createAccumulator method. * * @return the type information for the accumulator. - + * * def getAccumulatorType: TypeInformation[_] * }}} * @@ -110,6 +115,17 @@ package org.apache.flink.table.functions * }}} * * + * {{{ + * Returns a boolean flag indicates if this aggregate can only be used in OVER clause. If this + * method is not provided, by default, Flink assumes the User Defined Aggregate can be used in + * grouping aggregate as well. + * + * @return the value indicates if this aggregate can only be used in OVER clause. + * + * def requiresOver: Boolean + * }}} + * + * * @tparam T the type of the aggregation result * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated * values which are needed to compute an aggregation result. AggregateFunction @@ -135,9 +151,4 @@ abstract class AggregateFunction[T, ACC] extends UserDefinedFunction { * @return the aggregation result */ def getValue(accumulator: ACC): T - - /** - * whether this aggregate only used in OVER clause - */ - def requiresOver: Boolean = false } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 1016574cd87da..c27f45b6cc1a1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -291,7 +291,8 @@ object UserDefinedFunctionUtils { //check if a qualified accumulate method exists before create Sql function checkAndExtractMethods(aggFunction, "accumulate") val resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggFunction, typeInfo) - AggSqlFunction(name, aggFunction, resultType, typeFactory, aggFunction.requiresOver) + val requiresOver = getRequiresOverConfig(aggFunction) + AggSqlFunction(name, aggFunction, resultType, typeFactory, requiresOver) } // ---------------------------------------------------------------------------------------------- @@ -328,6 +329,21 @@ object UserDefinedFunctionUtils { } } + /** + * Returns the value (boolean) of AggregateFunction#requiresOver() that indicates if this + * AggregateFunction can only be used for over window aggregate. If method requiresOver is not + * provided, return false as default value. + */ + def getRequiresOverConfig(aggregateFunction: AggregateFunction[_, _]): Boolean = { + try { + val method: Method = aggregateFunction.getClass.getMethod("requiresOver") + method.invoke(aggregateFunction).asInstanceOf[Boolean] + } catch { + case _: NoSuchMethodException => false + case ite: Throwable => throw new TableException("Unexpected exception:", ite) + } + } + /** * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses * [[TypeExtractor]] as default return type inference. diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java index a51a4af81993c..f273e6a5ea6ef 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedAggFunctions.java @@ -38,11 +38,9 @@ public Long getValue(Accumulator0 accumulator) { return 1L; } - //Overloaded accumulate method public void accumulate(Accumulator0 accumulator, long iValue, int iWeight) { } - @Override public boolean requiresOver() { return true; }