Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-6075] - Support Limit/Top(Sort) for Stream SQL #3889

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
rexBuilder,
input.getRowType.getFieldList.map(_.getType))

val offset = if(sort.offset != null) sort.offset.accept(materializer) else null
val fetch = if(sort.fetch != null) sort.fetch.accept(materializer) else null
//val offset = if(sort.offset != null) sort.offset.accept(materializer) else null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed

//val fetch = if(sort.fetch != null) sort.fetch.accept(materializer) else null

LogicalSort.create(input, sort.collation, offset, fetch)
LogicalSort.create(input, sort.collation, sort.offset, sort.fetch)
}

override def visit(`match`: LogicalMatch): RelNode =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,32 @@
package org.apache.flink.table.plan.nodes

import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.rel.RelFieldCollation
import org.apache.calcite.rel.RelCollation
import org.apache.calcite.rel.RelFieldCollation.Direction
import org.apache.calcite.rel.`type`._
import scala.collection.JavaConverters._
import org.apache.flink.api.common.operators.Order

import org.apache.calcite.rel.{RelWriter, RelCollation, RelFieldCollation}

/**
* Trait represents a collection of sort methods to manipulate the parameters
*/

trait CommonSort {

private[flink] def offsetToString(offset: RexNode): String = {
private def offsetToString(offset: RexNode): String = {
val offsetToString = s"$offset"
offsetToString
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm new line


private[flink] def sortFieldsToString(
private def sortFieldsToString(
collationSort: RelCollation,
rowRelDataType: RelDataType): String = {
val fieldCollations = collationSort.getFieldCollations.asScala
.map(c => (c.getFieldIndex, directionToOrder(c.getDirection)))
.map(c => (c.getFieldIndex, directionToOrder(c.getDirection)))

val sortFieldsToString = fieldCollations
.map(col => s"${
rowRelDataType.getFieldNames.get(col._1)} ${col._2.getShortName}" ).mkString(", ")

sortFieldsToString
fieldCollations
.map(col => s"${rowRelDataType.getFieldNames.get(col._1)} ${col._2.getShortName}" )
.mkString(", ")
}

private[flink] def directionToOrder(direction: Direction) = {
Expand All @@ -60,32 +55,61 @@ trait CommonSort {
}
}

private[flink] def fetchToString(fetch: RexNode, offset: RexNode): String = {
private def fetchToString(fetch: RexNode, offset: RexNode): String = {
val limitEnd = getFetchLimitEnd(fetch, offset)
val fetchToString = if (limitEnd == Long.MaxValue) {

if (limitEnd == Long.MaxValue) {
"unlimited"
} else {
s"$limitEnd"
}
fetchToString
}

private[flink] def getFetchLimitEnd (fetch: RexNode, offset: RexNode): Long = {
val limitEnd: Long = if (fetch != null) {
if (fetch != null) {
RexLiteral.intValue(fetch) + getFetchLimitStart(offset)
} else {
Long.MaxValue
}
limitEnd
}

private[flink] def getFetchLimitStart (offset: RexNode): Long = {
val limitStart: Long = if (offset != null) {
if (offset != null) {
RexLiteral.intValue(offset)
} else {
0L
}
limitStart
} else {
0L
}
}

private[flink] def sortToString(
rowRelDataType: RelDataType,
sortCollation: RelCollation,
sortOffset: RexNode,
sortFetch: RexNode): String = {
s"Sort(by: ($$sortFieldsToString(sortCollation, rowRelDataType))," +
(if (sortOffset != null) {
" offset: $offsetToString(sortOffset),"
} else {
""
}) +
(if (sortFetch != null) {
" fetch: $fetchToString(sortFetch, sortOffset))"
} else {
""
})
}

private[flink] def sortExplainTerms(
pw: RelWriter,
rowRelDataType: RelDataType,
sortCollation: RelCollation,
sortOffset: RexNode,
sortFetch: RexNode) : RelWriter = {

pw
.item("orderBy", sortFieldsToString(sortCollation, rowRelDataType))
.itemIf("offset", offsetToString(sortOffset), sortOffset != null)
.itemIf("fetch", fetchToString(sortFetch, sortOffset), sortFetch != null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,17 @@ class DataSetSort(

private val fieldCollations = collations.getFieldCollations.asScala
.map(c => (c.getFieldIndex, directionToOrder(c.getDirection)))


override def toString: String = {
s"Sort(by: ($$sortFieldsToString(collations, getRowType))," +
" offset: $offsetToString(offset)," +
" fetch: $fetchToString(fetch, offset))"
sortToString(getRowType, collations, offset, fetch)
}


override def explainTerms(pw: RelWriter) : RelWriter = {
super.explainTerms(pw)
.item("orderBy", sortFieldsToString(collations, getRowType))
.item("offset", offsetToString(offset))
.item("fetch", fetchToString(fetch, offset))
sortExplainTerms(
super.explainTerms(pw),
getRowType,
collations,
offset,
fetch)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException}
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.plan.nodes.CommonSort
import org.apache.calcite.rel.core.Sort

/**
* Flink RelNode which matches along with Sort Rule.
Expand All @@ -72,37 +73,42 @@ class DataStreamSort(
sortOffset: RexNode,
sortFetch: RexNode,
description: String)
extends SingleRel(cluster, traitSet, inputNode)
extends Sort(cluster, traitSet, inputNode, sortCollation, sortOffset, sortFetch)
with CommonSort
with DataStreamRel {

override def deriveRowType(): RelDataType = schema.logicalType

override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
override def copy(
traitSet: RelTraitSet,
input: RelNode,
newCollation: RelCollation,
offset: RexNode,
fetch: RexNode): Sort = {

new DataStreamSort(
cluster,
traitSet,
inputs.get(0),
input,
inputSchema,
schema,
sortCollation,
sortOffset,
sortFetch,
newCollation,
offset,
fetch,
description)
}

override def toString: String = {
s"Sort(by: ($$sortFieldsToString(sortCollation, schema.logicalType))," +
" offset: $offsetToString(sortOffset)," +
" fetch: $fetchToString(sortFetch, sortOffset))"
sortToString(schema.logicalType, sortCollation, sortOffset, sortFetch)
}

override def explainTerms(pw: RelWriter) : RelWriter = {

super.explainTerms(pw)
.item("orderBy", sortFieldsToString(sortCollation, schema.logicalType))
.item("offset", offsetToString(sortOffset))
.item("fetch", fetchToString(sortFetch, sortOffset))
sortExplainTerms(
super.explainTerms(pw),
schema.logicalType,
sortCollation,
sortOffset,
sortFetch)
}

override def translateToPlan(
Expand Down Expand Up @@ -173,7 +179,7 @@ class DataStreamSort(
.asInstanceOf[DataStream[CRow]]
} else {
//if the order is done only on proctime we only need to forward the elements
inputDS.keyBy(new NullByteKeySelector[CRow])
inputDS
.map(new IdentityCRowMap())
.setParallelism(1).setMaxParallelism(1)
.returns(returnTypeInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,16 @@ class DataStreamSortRule
"DataStreamSortRule") {

override def matches(call: RelOptRuleCall): Boolean = {

val result = super.matches(call)

//need to identify time between others order fields. Time needs to be first sort element
// we can safely convert the object if the match rule succeeded
if(result) {
val calcSort: FlinkLogicalSort = call.rel(0).asInstanceOf[FlinkLogicalSort]
checkTimeOrder(calcSort)
}

result
val calcSort: FlinkLogicalSort = call.rel(0).asInstanceOf[FlinkLogicalSort]
checkTimeOrder(calcSort)
}

override def convert(rel: RelNode): RelNode = {
val calcSort: FlinkLogicalSort = rel.asInstanceOf[FlinkLogicalSort]
val sort: FlinkLogicalSort = rel.asInstanceOf[FlinkLogicalSort]
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
val convInput: RelNode = RelOptRule.convert(calcSort.getInput(0), FlinkConventions.DATASTREAM)
val convInput: RelNode = RelOptRule.convert(sort.getInput(0), FlinkConventions.DATASTREAM)

val inputRowType = convInput.asInstanceOf[RelSubset].getOriginal.getRowType

Expand All @@ -72,38 +65,33 @@ class DataStreamSortRule
convInput,
new RowSchema(inputRowType),
new RowSchema(rel.getRowType),
calcSort.collation,
calcSort.offset,
calcSort.fetch,
sort.collation,
sort.offset,
sort.fetch,
description)

}


/**
* Function is used to check at verification time if the SQL syntax is supported
*/

def checkTimeOrder(calcSort: FlinkLogicalSort) = {
def checkTimeOrder(sort: FlinkLogicalSort): Boolean = {

val rowType = calcSort.getRowType
val sortCollation = calcSort.collation
//need to identify time between others order fields. Time needs to be first sort element
val rowType = sort.getRowType
val sortCollation = sort.collation
//need to identify time between others order fields. Time needs to be first sort element
val timeType = SortUtil.getTimeType(sortCollation, rowType)
//time ordering needs to be ascending
if (SortUtil.getTimeDirection(sortCollation) != Direction.ASCENDING) {
throw new TableException("SQL/Table supports only ascending time ordering")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not exit the optimizer with exceptions because it prevents the optimizer to find an alternative plan.
Rather return false in matches()

}
//enable to extend for other types of aggregates that will not be implemented in a window
timeType match {
case _ if FlinkTypeFactory.isProctimeIndicatorType(timeType) =>
case _ if FlinkTypeFactory.isRowtimeIndicatorType(timeType) =>
case _ =>
throw new TableException("SQL/Table needs to have sort on time as first sort element")

case _ if FlinkTypeFactory.isTimeIndicatorType(timeType) => true
case _ => false //enable optimizer to look for a different plan
}
}

}

object DataStreamSortRule {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,18 @@ class ProcTimeSortProcessFunction(

Preconditions.checkNotNull(rowComparator)

private var stateEventsBuffer: ListState[Row] = _
private var bufferedEvents: ListState[Row] = _
private val sortArray: ArrayList[Row] = new ArrayList[Row]

private var outputC: CRow = _

override def open(config: Configuration) {
val sortDescriptor = new ListStateDescriptor[Row]("sortState",
inputRowType.asInstanceOf[CRowTypeInfo].rowType)
stateEventsBuffer = getRuntimeContext.getListState(sortDescriptor)
bufferedEvents = getRuntimeContext.getListState(sortDescriptor)

if (outputC == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simply do outputC = new CRow() here.

val arity:Integer = inputRowType.getArity
outputC = new CRow(Row.of(arity), true)
outputC = new CRow()
}

}
Expand All @@ -80,7 +79,7 @@ class ProcTimeSortProcessFunction(
val currentTime = ctx.timerService.currentProcessingTime
//buffer the event incoming event

stateEventsBuffer.add(input)
bufferedEvents.add(input)

//deduplication of multiple registered timers is done automatically
ctx.timerService.registerProcessingTimeTimer(currentTime + 1)
Expand All @@ -92,28 +91,25 @@ class ProcTimeSortProcessFunction(
ctx: ProcessFunction[CRow, CRow]#OnTimerContext,
out: Collector[CRow]): Unit = {

val iter = stateEventsBuffer.get.iterator()
val iter = bufferedEvents.get.iterator()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove second newline

sortArray.clear()
while(iter.hasNext()) {
sortArray.add(iter.next())
}

//if we do not rely on java collections to do the sort we could implement
//an insertion sort as we get the elements from the state
Collections.sort(sortArray, rowComparator)

//we need to build the output and emit the events in order
var iElemenets = 0
while (iElemenets < sortArray.size) {
// do we need to recreate the object no to mess references in previous results?
outputC.row = sortArray.get(iElemenets)
out.collect(outputC)
iElemenets += 1
}

//we need to clear the events accumulated in the last millisecond
stateEventsBuffer.clear()
bufferedEvents.clear()

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class RowTimeSortProcessFunction(
lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)

if (outputC == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done as outputC = new CRow()

val arity:Integer = inputRowType.getArity
outputC = new CRow(Row.of(arity), true)
outputC = new CRow()
}
}

Expand Down Expand Up @@ -126,7 +125,7 @@ class RowTimeSortProcessFunction(

if (null != inputs) {

Collections.sort(inputs,rowComparator)
Collections.sort(inputs, rowComparator)

//we need to build the output and emit the events in order
var dataListIndex = 0
Expand Down