Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -256,26 +256,31 @@ class CodeGenerator(
* @return A GeneratedAggregationsFunction
*/
def generateAggregations(
name: String,
generator: CodeGenerator,
inputType: RelDataType,
aggregates: Array[AggregateFunction[_ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
fwdMapping: Array[(Int, Int)],
outputArity: Int)
name: String,
generator: CodeGenerator,
inputType: RelDataType,
aggregates: Array[AggregateFunction[_ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
outputArity: Int,
groupingKeys: Array[Int],
ctrlParams: AggCodeGenCtrlParams,
fwdMapping: Array[(Int, Int)] = Array(),
gkeyOutFields: Array[Int] = null,
gkeyOutMapping: Array[(Int, Int)] = null)
: GeneratedAggregationsFunction = {

def genSetAggregationResults(
accTypes: Array[String],
aggs: Array[String],
aggMapping: Array[Int]): String = {

val offset = if (ctrlParams.setResultsWithKeyOffset) groupingKeys.length else 0
val sig: String =
j"""
| public void setAggregationResults(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output)""".stripMargin
| public void setAggregationResults(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output)""".stripMargin

val setAggs: String = {
for (i <- aggs.indices) yield
Expand All @@ -284,11 +289,12 @@ class CodeGenerator(
| (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
|
| output.setField(
| ${aggMapping(i)},
| ${aggMapping(i) + offset},
| baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin
}.mkString("\n")

j"""$sig {
j"""
|$sig {
|$setAggs
| }""".stripMargin
}
Expand All @@ -298,6 +304,7 @@ class CodeGenerator(
aggs: Array[String],
parameters: Array[String]): String = {

val offset = if (ctrlParams.accumulateWithKeyOffset) groupingKeys.length else 0
val sig: String =
j"""
| public void accumulate(
Expand All @@ -308,7 +315,7 @@ class CodeGenerator(
for (i <- aggs.indices) yield
j"""
| ${aggs(i)}.accumulate(
| ((${accTypes(i)}) accs.getField($i)),
| ((${accTypes(i)}) accs.getField(${i + offset})),
| ${parameters(i)});""".stripMargin
}.mkString("\n")

Expand Down Expand Up @@ -344,20 +351,22 @@ class CodeGenerator(
def genCreateAccumulators(
aggs: Array[String]): String = {

val offset = if (ctrlParams.accumulateWithKeyOffset) groupingKeys.length else 0
val arity = if (ctrlParams.accumulateWithKeyOffset) outputArity else aggs.length
val sig: String =
j"""
| public org.apache.flink.types.Row createAccumulators()
| """.stripMargin
val init: String =
j"""
| org.apache.flink.types.Row accs =
| new org.apache.flink.types.Row(${aggs.length});"""
| new org.apache.flink.types.Row(${arity});"""
.stripMargin
val create: String = {
for (i <- aggs.indices) yield
j"""
| accs.setField(
| $i,
| ${offset + i},
| ${aggs(i)}.createAccumulator());"""
.stripMargin
}.mkString("\n")
Expand All @@ -380,8 +389,10 @@ class CodeGenerator(
j"""
| public void setForwardedFields(
| org.apache.flink.types.Row input,
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output)
| """.stripMargin

val forward: String = {
for (i <- forwardMapping.indices) yield
j"""
Expand All @@ -391,9 +402,64 @@ class CodeGenerator(
.stripMargin
}.mkString("\n")

j"""$sig {
|$forward
| }""".stripMargin
var copyKeys: String = ""
if (gkeyOutMapping != null) {
copyKeys = {
for ((out, in) <- gkeyOutMapping) yield
j"""
| output.setField(
| $out,
| input.getField(${in}));"""
.stripMargin
}.mkString("\n")
} else if (gkeyOutFields != null) {
copyKeys = {
for (i <- gkeyOutFields.indices) yield
j"""
| output.setField(
| ${gkeyOutFields(i)},
| input.getField($i));"""
.stripMargin
}.mkString("\n")
} else {
copyKeys = {
for (i <- groupingKeys.indices) yield
j"""
| output.setField(
| $i,
| input.getField(${groupingKeys(i)}));"""
.stripMargin
}.mkString("\n")
}


val copyAccs: String = {
for (i <- aggregates.indices) yield
j"""
| output.setField(
| ${groupingKeys.length + i},
| accs.getField($i));"""
.stripMargin
}.mkString("\n")

if (forwardMapping.length > 0) {
// when forwardMappingCopies is not empty, this method just forwarded fields from input
// row to output row
j"""$sig {
|$forward
| }""".stripMargin
} else {
// when forwardMappingCopies is not defined, this method will be used to copy keys (from
// input row if provided) and accumulators (from accs row if provided) to output row
j"""$sig {
| if (input != null) {
| $copyKeys
| }
| if (accs != null) {
| $copyAccs
| }
| }""".stripMargin
}
}

def genCreateOutputRow(outputArity: Int): String = {
Expand All @@ -407,32 +473,29 @@ class CodeGenerator(
accTypes: Array[String],
aggs: Array[String]): String = {

val offset = if (ctrlParams.mergeWithKeyOffset) groupingKeys.length else 0
val sig: String =
j"""
| public org.apache.flink.types.Row mergeAccumulatorsPair(
| public void mergeAccumulatorsPair(
| org.apache.flink.types.Row a,
| org.apache.flink.types.Row b)
""".stripMargin
val merge: String = {
for (i <- aggs.indices) yield
j"""
| ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
| ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField($i);
| ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${i + offset});
| accList$i.set(0, aAcc$i);
| accList$i.set(1, bAcc$i);
| a.setField(
| $i,
| ${aggs(i)}.merge(accList$i));
""".stripMargin
}.mkString("\n")
val ret: String =
j"""
| return a;
""".stripMargin

j"""$sig {
j"""
|$sig {
|$merge
|$ret
| }""".stripMargin
}

Expand All @@ -458,6 +521,28 @@ class CodeGenerator(
}.mkString("\n")
}

def genResetAccumulator(
accTypes: Array[String],
aggs: Array[String]): String = {

val offset = if (ctrlParams.accumulateWithKeyOffset) groupingKeys.length else 0
val sig: String =
j"""
| public void resetAccumulator(
| org.apache.flink.types.Row accs)""".stripMargin

val reset: String = {
for (i <- aggs.indices) yield
j"""
| ${aggs(i)}.resetAccumulator(
| ((${accTypes(i)}) accs.getField(${offset + i})));""".stripMargin
}.mkString("\n")

j"""$sig {
|$reset
| }""".stripMargin
}

// get unique function name
val funcName = newName(name)
// register UDAGGs
Expand Down Expand Up @@ -499,6 +584,7 @@ class CodeGenerator(
funcCode += genSetForwardedFields(fwdMapping) + "\n"
funcCode += genCreateOutputRow(outputArity) + "\n"
funcCode += genMergeAccumulatorsPair(accTypes, aggs) + "\n"
funcCode += genResetAccumulator(accTypes, aggs) + "\n"
funcCode += "}"

GeneratedAggregationsFunction(funcName, funcCode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ case class GeneratedAggregationsFunction(
name: String,
code: String)

/**
* Describes the helper flags for the code-gen of aggregate functions
*
* @param setResultsWithKeyOffset flag to indicate if the results in output row has an offset
* @param mergeWithKeyOffset flag to indicate if the accumulators (for merge) in
* accumulator row has an offset
* @param accumulateWithKeyOffset flag to indicate if the accumulators (for accumulate) in
* accumulator row has an offset
*/
case class AggCodeGenCtrlParams(
setResultsWithKeyOffset: Boolean,
mergeWithKeyOffset: Boolean,
accumulateWithKeyOffset: Boolean
)

/**
* Describes a generated [[InputFormat]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.nodes.CommonAggregate
import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction}
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
Expand Down Expand Up @@ -89,19 +90,25 @@ class DataSetAggregate(

override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {

val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val generator = new CodeGenerator(
tableEnv.getConfig,
false,
inputDS.getType)

val (
preAgg: Option[DataSetPreAggFunction],
preAggType: Option[TypeInformation[Row]],
finalAgg: GroupReduceFunction[Row, Row]
) = AggregateUtil.createDataSetAggregateFunctions(
generator,
namedAggregates,
inputType,
rowRelDataType,
grouping,
inGroupingSet)

val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)

val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
Expand Down
Loading