diff --git a/docs/content.zh/docs/dev/table/functions/ptfs.md b/docs/content.zh/docs/dev/table/functions/ptfs.md index 104274b50cefb..f3c94960533e8 100644 --- a/docs/content.zh/docs/dev/table/functions/ptfs.md +++ b/docs/content.zh/docs/dev/table/functions/ptfs.md @@ -1015,6 +1015,107 @@ not needed anymore via `Context#clearAllTimers()` or `TimeContext#clearTimer(Str {{< top >}} +Multiple Tables +--------------- + +A PTF can process multiple tables simultaneously. This enables a variety of use cases, including: + +- Implementing **custom joins** that efficiently manage state. +- Enriching the main table with information from dimension tables as **side inputs**. +- Sending **control events** to the keyed virtual processor during runtime. + +The `eval()` method can specify multiple table arguments to support multiple inputs. All table arguments must be declared +with set semantics and use consistent partitioning. In other words, the number of columns and their data types in the +`PARTITION BY` clause must match across all involved table arguments. + +Rows from either input are passed to the function one at a time. Thus, only one table argument is non-null at a time. Use +null checks to determine which input is currently being processed. + +{{< hint warning >}} +The system decides which input row is streamed through the virtual processor next. If not handled properly in the PTF, +this can lead to race conditions between inputs and, consequently, to non-deterministic results. It is recommended to +design the function in such a way that the join is either time-based (i.e., waiting for all rows to arrive up to a given +watermark) or condition-based, where the PTF buffers one or more input rows until a specific condition is met. +{{< /hint >}} + +### Example: Custom Join + +The following example illustrates how to implement a custom join between two tables: + +{{< tabs "2137eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +TableEnvironment env = TableEnvironment.create(EnvironmentSettings.inStreamingMode()); + +env.executeSql("CREATE VIEW Visits(name) AS VALUES ('Bob'), ('Alice'), ('Bob')"); +env.executeSql("CREATE VIEW Purchases(customer, item) AS VALUES ('Alice', 'milk')"); + +env.createFunction("Greeting", GreetingWithLastPurchase.class); + +env + .executeSql("SELECT * FROM Greeting(TABLE Visits PARTITION BY name, TABLE Purchases PARTITION BY customer)") + .print(); + +// -------------------- +// Function declaration +// -------------------- + +// Function that greets a customer and suggests the last purchase made, if available. +public static class GreetingWithLastPurchase extends ProcessTableFunction { + + // Keep the last purchased item in state + public static class LastItemState { + public String lastItem; + } + + // The eval() method takes two @ArgumentHint(TABLE_AS_SET) arguments + public void eval( + @StateHint LastItemState state, + @ArgumentHint(TABLE_AS_SET) Row visit, + @ArgumentHint(TABLE_AS_SET) Row purchase) { + + // Process row from table Purchases + if (purchase != null) { + state.lastItem = purchase.getFieldAs("item"); + } + + // Process row from table Visits + else if (visit != null) { + if (state.lastItem == null) { + collect("Hello " + visit.getFieldAs("name") + ", let me know if I can help!"); + } else { + collect("Hello " + visit.getFieldAs("name") + ", here to buy " + state.lastItem + " again?"); + } + } + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +The result will look similar to: + +```text ++----+--------------------------------+--------------------------------+--------------------------------+ +| op | name | customer | EXPR$0 | ++----+--------------------------------+--------------------------------+--------------------------------+ +| +I | Bob | Bob | Hello Bob, let me know if I... | +| +I | Alice | Alice | Hello Alice, here to buy Pr... | +| +I | Bob | Bob | Hello Bob, let me know if I... | ++----+--------------------------------+--------------------------------+--------------------------------+ +``` + +### Efficiency and Design Principles + +A high number of input tables can negatively impact a single TaskManager or subtask. Network buffers must be allocated +for each input, resulting in increased memory consumption which is why the number of table arguments is limited to a +maximum of 20 tables. + +Unevenly distributed keys may overload a single virtual processor, leading to backpressure. It is important to select +appropriate partition keys. + +{{< top >}} + Query Evolution with UIDs ------------------------- @@ -1061,7 +1162,8 @@ END; {{< top >}} -## Pass-Through Columns +Pass-Through Columns +-------------------- Depending on the table semantics and whether an `on_time` argument has been defined, the system adds addition columns for every function output. @@ -1089,7 +1191,7 @@ With pass-through columns: | k | v | c1 | c2 | This allows the PTF to focus on the main aggregation without the need to manually forward input columns. -*Note*: Timers are not available when pass-through columns are enabled. +*Note*: Pass-through columns are only available for append-only PTFs taking a single table argument and don't use timers. {{< top >}} @@ -1610,9 +1712,6 @@ while the PTF exists once in the pipeline. The following example shows how a PTF can be used for joining. Additionally, it also showcases how a PTF can be used as a data generator for creating bounded tables with dummy data. -Because PTFs don't support multiple table arguments yet, we use `unionAll` to for passing multiple partitioned tables -into the PTF. Because a union requires a unified schema, the data generators transform the data into a `UnifiedEvent`. - {{< tabs "1637eeed-3d13-455c-8e2f-5e164da9f844" >}} {{< tab "Java" >}} ```java @@ -1625,11 +1724,11 @@ TableEnvironment env = TableEnvironment.create(EnvironmentSettings.inStreamingMo Table orders = env.fromCall(OrderGenerator.class); Table payments = env.fromCall(PaymentGenerator.class); -// Union orders and payments before -// partitioning and passing them into the Joiner function -Table joined = orders.unionAll(payments) - .partitionBy($("orderId")) - .process(Joiner.class); +// Partition orders and payments and pass them into the Joiner function +Table joined = env.fromCall( + Joiner.class, + orders.partitionBy($("id")).asArgument("order"), + payments.partitionBy($("orderId")).asArgument("payment")); joined.execute().print(); @@ -1637,24 +1736,8 @@ joined.execute().print(); // Data Generation // --------------------------- -// A unified event for all input tables. -// One of the sides is al empty. -public static class UnifiedEvent { - public int orderId; - public Order order; - public Payment payment; - - public static UnifiedEvent of(int orderId, Order order, Payment payment) { - UnifiedEvent unifiedEvent = new UnifiedEvent(); - unifiedEvent.orderId = orderId; - unifiedEvent.order = order; - unifiedEvent.payment = payment; - return unifiedEvent; - } -} - // A PTF that generates Orders -public static class OrderGenerator extends ProcessTableFunction { +public static class OrderGenerator extends ProcessTableFunction { public void eval() { Stream.of( Order.of("Bob", 1000001, 23.46, "USD"), @@ -1662,13 +1745,12 @@ public static class OrderGenerator extends ProcessTableFunction { Order.of("Alice", 1000601, 0.79, "EUR"), Order.of("Charly", 1000703, 100.60, "EUR") ) - .map(order -> UnifiedEvent.of(order.id, order, null)) .forEach(this::collect); } } // A PTF that generates Payments -public static class PaymentGenerator extends ProcessTableFunction { +public static class PaymentGenerator extends ProcessTableFunction { public void eval() { Stream.of( Payment.of(999997870, 1000001), @@ -1676,7 +1758,6 @@ public static class PaymentGenerator extends ProcessTableFunction Payment.of(999993331, 1000021), Payment.of(999994111, 1000601) ) - .map(payment -> UnifiedEvent.of(payment.orderId, null, payment)) .forEach(this::collect); } } @@ -1714,8 +1795,8 @@ public static class Payment { {{< /tab >}} {{< /tabs >}} -After generating the data and performing the union, the stateful Joiner buffers events until a matching pair is -found. Any duplicates in either of the input tables are ignored. +After generating the data, the stateful Joiner buffers events until a matching pair is found. Any duplicates in either +of the input tables are ignored. {{< tabs "1737eeed-3d13-455c-8e2f-5e164da9f844" >}} {{< tab "Java" >}} @@ -1727,7 +1808,8 @@ public static class Joiner extends ProcessTableFunction { public void eval( Context ctx, @StateHint(ttl = "1 hour") JoinResult seen, - @ArgumentHint(TABLE_AS_SET) UnifiedEvent input + @ArgumentHint(TABLE_AS_SET) Order order, + @ArgumentHint(TABLE_AS_SET) Payment payment ) { if (input.order != null) { if (seen.order != null) { @@ -1767,19 +1849,18 @@ The output could look similar to the following. Duplicate events for payment `99 for `Charly` could not be found. ```text -+----+-------------+--------------------------------+--------------------------------+ -| op | orderId | order | payment | -+----+-------------+--------------------------------+--------------------------------+ -| +I | 1000021 | (amount=6.99, currency=USD,... | (id=999993331, orderId=1000... | -| +I | 1000601 | (amount=0.79, currency=EUR,... | (id=999994111, orderId=1000... | -| +I | 1000001 | (amount=23.46, currency=USD... | (id=999997870, orderId=1000... | -+----+-------------+--------------------------------+--------------------------------+ ++----+-------------+-------------+--------------------------------+--------------------------------+ +| op | id | orderId | order | payment | ++----+-------------+-------------+--------------------------------+--------------------------------+ +| +I | 1000021 | 1000021 | (amount=6.99, currency=USD,... | (id=999993331, orderId=1000... | +| +I | 1000601 | 1000601 | (amount=0.79, currency=EUR,... | (id=999994111, orderId=1000... | +| +I | 1000001 | 1000001 | (amount=23.46, currency=USD... | (id=999997870, orderId=1000... | ++----+-------------+-------------+--------------------------------+--------------------------------+ ``` Limitations ----------- PTFs are in an early stage. The following limitations apply: -- Multiple table arguments are not supported. - PTFs cannot run in batch mode. - Broadcast state diff --git a/docs/content/docs/dev/table/functions/ptfs.md b/docs/content/docs/dev/table/functions/ptfs.md index 104274b50cefb..f3c94960533e8 100644 --- a/docs/content/docs/dev/table/functions/ptfs.md +++ b/docs/content/docs/dev/table/functions/ptfs.md @@ -1015,6 +1015,107 @@ not needed anymore via `Context#clearAllTimers()` or `TimeContext#clearTimer(Str {{< top >}} +Multiple Tables +--------------- + +A PTF can process multiple tables simultaneously. This enables a variety of use cases, including: + +- Implementing **custom joins** that efficiently manage state. +- Enriching the main table with information from dimension tables as **side inputs**. +- Sending **control events** to the keyed virtual processor during runtime. + +The `eval()` method can specify multiple table arguments to support multiple inputs. All table arguments must be declared +with set semantics and use consistent partitioning. In other words, the number of columns and their data types in the +`PARTITION BY` clause must match across all involved table arguments. + +Rows from either input are passed to the function one at a time. Thus, only one table argument is non-null at a time. Use +null checks to determine which input is currently being processed. + +{{< hint warning >}} +The system decides which input row is streamed through the virtual processor next. If not handled properly in the PTF, +this can lead to race conditions between inputs and, consequently, to non-deterministic results. It is recommended to +design the function in such a way that the join is either time-based (i.e., waiting for all rows to arrive up to a given +watermark) or condition-based, where the PTF buffers one or more input rows until a specific condition is met. +{{< /hint >}} + +### Example: Custom Join + +The following example illustrates how to implement a custom join between two tables: + +{{< tabs "2137eeed-3d13-455c-8e2f-5e164da9f844" >}} +{{< tab "Java" >}} +```java +TableEnvironment env = TableEnvironment.create(EnvironmentSettings.inStreamingMode()); + +env.executeSql("CREATE VIEW Visits(name) AS VALUES ('Bob'), ('Alice'), ('Bob')"); +env.executeSql("CREATE VIEW Purchases(customer, item) AS VALUES ('Alice', 'milk')"); + +env.createFunction("Greeting", GreetingWithLastPurchase.class); + +env + .executeSql("SELECT * FROM Greeting(TABLE Visits PARTITION BY name, TABLE Purchases PARTITION BY customer)") + .print(); + +// -------------------- +// Function declaration +// -------------------- + +// Function that greets a customer and suggests the last purchase made, if available. +public static class GreetingWithLastPurchase extends ProcessTableFunction { + + // Keep the last purchased item in state + public static class LastItemState { + public String lastItem; + } + + // The eval() method takes two @ArgumentHint(TABLE_AS_SET) arguments + public void eval( + @StateHint LastItemState state, + @ArgumentHint(TABLE_AS_SET) Row visit, + @ArgumentHint(TABLE_AS_SET) Row purchase) { + + // Process row from table Purchases + if (purchase != null) { + state.lastItem = purchase.getFieldAs("item"); + } + + // Process row from table Visits + else if (visit != null) { + if (state.lastItem == null) { + collect("Hello " + visit.getFieldAs("name") + ", let me know if I can help!"); + } else { + collect("Hello " + visit.getFieldAs("name") + ", here to buy " + state.lastItem + " again?"); + } + } + } +} +``` +{{< /tab >}} +{{< /tabs >}} + +The result will look similar to: + +```text ++----+--------------------------------+--------------------------------+--------------------------------+ +| op | name | customer | EXPR$0 | ++----+--------------------------------+--------------------------------+--------------------------------+ +| +I | Bob | Bob | Hello Bob, let me know if I... | +| +I | Alice | Alice | Hello Alice, here to buy Pr... | +| +I | Bob | Bob | Hello Bob, let me know if I... | ++----+--------------------------------+--------------------------------+--------------------------------+ +``` + +### Efficiency and Design Principles + +A high number of input tables can negatively impact a single TaskManager or subtask. Network buffers must be allocated +for each input, resulting in increased memory consumption which is why the number of table arguments is limited to a +maximum of 20 tables. + +Unevenly distributed keys may overload a single virtual processor, leading to backpressure. It is important to select +appropriate partition keys. + +{{< top >}} + Query Evolution with UIDs ------------------------- @@ -1061,7 +1162,8 @@ END; {{< top >}} -## Pass-Through Columns +Pass-Through Columns +-------------------- Depending on the table semantics and whether an `on_time` argument has been defined, the system adds addition columns for every function output. @@ -1089,7 +1191,7 @@ With pass-through columns: | k | v | c1 | c2 | This allows the PTF to focus on the main aggregation without the need to manually forward input columns. -*Note*: Timers are not available when pass-through columns are enabled. +*Note*: Pass-through columns are only available for append-only PTFs taking a single table argument and don't use timers. {{< top >}} @@ -1610,9 +1712,6 @@ while the PTF exists once in the pipeline. The following example shows how a PTF can be used for joining. Additionally, it also showcases how a PTF can be used as a data generator for creating bounded tables with dummy data. -Because PTFs don't support multiple table arguments yet, we use `unionAll` to for passing multiple partitioned tables -into the PTF. Because a union requires a unified schema, the data generators transform the data into a `UnifiedEvent`. - {{< tabs "1637eeed-3d13-455c-8e2f-5e164da9f844" >}} {{< tab "Java" >}} ```java @@ -1625,11 +1724,11 @@ TableEnvironment env = TableEnvironment.create(EnvironmentSettings.inStreamingMo Table orders = env.fromCall(OrderGenerator.class); Table payments = env.fromCall(PaymentGenerator.class); -// Union orders and payments before -// partitioning and passing them into the Joiner function -Table joined = orders.unionAll(payments) - .partitionBy($("orderId")) - .process(Joiner.class); +// Partition orders and payments and pass them into the Joiner function +Table joined = env.fromCall( + Joiner.class, + orders.partitionBy($("id")).asArgument("order"), + payments.partitionBy($("orderId")).asArgument("payment")); joined.execute().print(); @@ -1637,24 +1736,8 @@ joined.execute().print(); // Data Generation // --------------------------- -// A unified event for all input tables. -// One of the sides is al empty. -public static class UnifiedEvent { - public int orderId; - public Order order; - public Payment payment; - - public static UnifiedEvent of(int orderId, Order order, Payment payment) { - UnifiedEvent unifiedEvent = new UnifiedEvent(); - unifiedEvent.orderId = orderId; - unifiedEvent.order = order; - unifiedEvent.payment = payment; - return unifiedEvent; - } -} - // A PTF that generates Orders -public static class OrderGenerator extends ProcessTableFunction { +public static class OrderGenerator extends ProcessTableFunction { public void eval() { Stream.of( Order.of("Bob", 1000001, 23.46, "USD"), @@ -1662,13 +1745,12 @@ public static class OrderGenerator extends ProcessTableFunction { Order.of("Alice", 1000601, 0.79, "EUR"), Order.of("Charly", 1000703, 100.60, "EUR") ) - .map(order -> UnifiedEvent.of(order.id, order, null)) .forEach(this::collect); } } // A PTF that generates Payments -public static class PaymentGenerator extends ProcessTableFunction { +public static class PaymentGenerator extends ProcessTableFunction { public void eval() { Stream.of( Payment.of(999997870, 1000001), @@ -1676,7 +1758,6 @@ public static class PaymentGenerator extends ProcessTableFunction Payment.of(999993331, 1000021), Payment.of(999994111, 1000601) ) - .map(payment -> UnifiedEvent.of(payment.orderId, null, payment)) .forEach(this::collect); } } @@ -1714,8 +1795,8 @@ public static class Payment { {{< /tab >}} {{< /tabs >}} -After generating the data and performing the union, the stateful Joiner buffers events until a matching pair is -found. Any duplicates in either of the input tables are ignored. +After generating the data, the stateful Joiner buffers events until a matching pair is found. Any duplicates in either +of the input tables are ignored. {{< tabs "1737eeed-3d13-455c-8e2f-5e164da9f844" >}} {{< tab "Java" >}} @@ -1727,7 +1808,8 @@ public static class Joiner extends ProcessTableFunction { public void eval( Context ctx, @StateHint(ttl = "1 hour") JoinResult seen, - @ArgumentHint(TABLE_AS_SET) UnifiedEvent input + @ArgumentHint(TABLE_AS_SET) Order order, + @ArgumentHint(TABLE_AS_SET) Payment payment ) { if (input.order != null) { if (seen.order != null) { @@ -1767,19 +1849,18 @@ The output could look similar to the following. Duplicate events for payment `99 for `Charly` could not be found. ```text -+----+-------------+--------------------------------+--------------------------------+ -| op | orderId | order | payment | -+----+-------------+--------------------------------+--------------------------------+ -| +I | 1000021 | (amount=6.99, currency=USD,... | (id=999993331, orderId=1000... | -| +I | 1000601 | (amount=0.79, currency=EUR,... | (id=999994111, orderId=1000... | -| +I | 1000001 | (amount=23.46, currency=USD... | (id=999997870, orderId=1000... | -+----+-------------+--------------------------------+--------------------------------+ ++----+-------------+-------------+--------------------------------+--------------------------------+ +| op | id | orderId | order | payment | ++----+-------------+-------------+--------------------------------+--------------------------------+ +| +I | 1000021 | 1000021 | (amount=6.99, currency=USD,... | (id=999993331, orderId=1000... | +| +I | 1000601 | 1000601 | (amount=0.79, currency=EUR,... | (id=999994111, orderId=1000... | +| +I | 1000001 | 1000001 | (amount=23.46, currency=USD... | (id=999997870, orderId=1000... | ++----+-------------+-------------+--------------------------------+--------------------------------+ ``` Limitations ----------- PTFs are in an early stage. The following limitations apply: -- Multiple table arguments are not supported. - PTFs cannot run in batch mode. - Broadcast state diff --git a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html index ae2caa7aad89a..7c3af56716e40 100644 --- a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html +++ b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html @@ -77,6 +77,12 @@

Enum

When it is `TRY_RESOLVE`, the optimizer tries to resolve the correctness issue caused by 'Non-Deterministic Updates' (NDU) in a changelog pipeline. Changelog may contain kinds of message types: Insert (I), Delete (D), Update_Before (UB), Update_After (UA). There's no NDU problem in an insert only changelog pipeline. For updates, there are three main NDU problems:
1. Non-deterministic functions, include scalar, table, aggregate functions, both builtin and custom ones.
2. LookupJoin on an evolving source
3. Cdc-source carries metadata fields which are system columns, not belongs to the entity data itself.

For the first step, the optimizer automatically enables the materialization for No.2(LookupJoin) if needed, and gives the detailed error message for No.1(Non-deterministic functions) and No.3(Cdc-source with metadata) which is relatively easier to solve by changing the SQL.
Default value is `IGNORE`, the optimizer does no changes.

Possible values:
  • "TRY_RESOLVE"
  • "IGNORE"
+ +
table.optimizer.ptf.max-tables

Streaming + 20 + Integer + The maximum number of table arguments for a Process Table Function (PTF). In theory, a PTF can accept an arbitrary number of input tables. In practice, however, each input requires reserving network buffers, which impacts memory usage. For this reason, the number of input tables is limited to 20. +
table.optimizer.reuse-optimize-block-with-digest-enabled

Batch Streaming false diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/transformations/KeyedMultipleInputTransformation.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/transformations/KeyedMultipleInputTransformation.java index 9adb1280dd1c1..00da25b3ffa71 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/transformations/KeyedMultipleInputTransformation.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/transformations/KeyedMultipleInputTransformation.java @@ -45,6 +45,18 @@ public KeyedMultipleInputTransformation( updateManagedMemoryStateBackendUseCase(true); } + public KeyedMultipleInputTransformation( + String name, + StreamOperatorFactory operatorFactory, + TypeInformation outputType, + int parallelism, + boolean parallelismConfigured, + TypeInformation stateKeyType) { + super(name, operatorFactory, outputType, parallelism, parallelismConfigured); + this.stateKeyType = stateKeyType; + updateManagedMemoryStateBackendUseCase(true); + } + public KeyedMultipleInputTransformation addInput( Transformation input, KeySelector keySelector) { inputs.add(input); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java index 62d085d5cd6e3..7157725881908 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java @@ -362,6 +362,17 @@ public class OptimizerConfigOptions { + "it receives incremental accumulators and outputs incremental results). " + "In this way, we can reduce some state overhead and resources. Default is enabled."); + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_OPTIMIZER_PTF_MAX_TABLES = + key("table.optimizer.ptf.max-tables") + .intType() + .defaultValue(20) + .withDescription( + "The maximum number of table arguments for a Process Table Function (PTF). In theory, a PTF " + + "can accept an arbitrary number of input tables. In practice, however, each input " + + "requires reserving network buffers, which impacts memory usage. For this reason, " + + "the number of input tables is limited to 20."); + /** Strategy for handling non-deterministic updates. */ @PublicEvolving public enum NonDeterministicUpdateStrategy { diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java index a24379f4001f8..f5c585087dd01 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java @@ -727,11 +727,6 @@ public int timeColumn() { return -1; } - @Override - public List coPartitionArgs() { - return List.of(); - } - @Override public Optional changelogMode() { return Optional.empty(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java index 2e4829f8bb4e6..2468f6c749392 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java @@ -99,10 +99,8 @@ public enum ArgumentTrait { * With pass-through columns: | k | v | c1 | c2 | * * - *

In case of multiple table arguments, pass-through columns are added according to the - * declaration order in the PTF signature. - * - *

Timers are not available when pass-through columns are enabled. + *

Pass-through columns are only available for append-only PTFs taking a single table + * argument and don't use timers. * *

Note: This trait is valid for {@link #TABLE_AS_ROW} and {@link #TABLE_AS_SET} arguments. */ diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ChangelogFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ChangelogFunction.java index 744484db2d238..df182cb0f20b3 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ChangelogFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ChangelogFunction.java @@ -32,8 +32,7 @@ * *

Note: This interface is intended for advanced use cases and should be implemented with care. * Emitting an incorrect changelog from the PTF may lead to undefined behavior in the overall query. - * Many features such as the `on_time` argument and pass-through columns are not available for - * updating PTFs. + * The `on_time` argument is unsupported for updating PTFs. * *

The resulting changelog mode can be influenced by: * diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableSemantics.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableSemantics.java index 6119b6ba791cf..4ff04789aa805 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableSemantics.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableSemantics.java @@ -24,7 +24,6 @@ import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.types.DataType; -import java.util.List; import java.util.Optional; /** @@ -106,15 +105,6 @@ public interface TableSemantics { */ int timeColumn(); - /** - * Returns information about which passed tables are co-partitioned with the passed table. - * Applies only to table arguments with set semantics. - * - * @return List of table argument names (not table names!) that are co-partitioned with the - * passed table. - */ - List coPartitionArgs(); - /** * Actual changelog mode for the passed table. By default, table arguments take only {@link * ChangelogMode#insertOnly()}. They are able to take tables of other changelog modes, if diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java index 0dbdc20ae82d0..aac34f546deb9 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java @@ -35,6 +35,7 @@ import org.apache.flink.table.types.logical.RowType.RowField; import org.apache.flink.table.types.logical.TimestampKind; import org.apache.flink.table.types.logical.TimestampType; +import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; import org.apache.flink.table.types.logical.utils.LogicalTypeUtils; @@ -43,11 +44,11 @@ import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Predicate; @@ -143,7 +144,7 @@ private static void checkScalarArgsOnly(List defaultArgs) { checkReservedArgs(declaredArgs); checkMultipleTableArgs(declaredArgs); - checkUpdatingPassThroughColumns(declaredArgs); + checkPassThroughColumns(declaredArgs); final List newStaticArgs = new ArrayList<>(declaredArgs); newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS); @@ -166,22 +167,31 @@ private static void checkReservedArgs(List staticArgs) { } private static void checkMultipleTableArgs(List staticArgs) { - if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() > 1) { + if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() <= 1) { + return; + } + if (staticArgs.stream().anyMatch(arg -> !arg.is(StaticArgumentTrait.TABLE_AS_SET))) { throw new ValidationException( - "Currently, only signatures with at most one table argument are supported."); + "All table arguments must use set semantics if multiple table arguments are declared."); } } - private static void checkUpdatingPassThroughColumns(List staticArgs) { + private static void checkPassThroughColumns(List staticArgs) { final Set traits = staticArgs.stream() .flatMap(arg -> arg.getTraits().stream()) .collect(Collectors.toSet()); - if (traits.contains(StaticArgumentTrait.SUPPORT_UPDATES) - && traits.contains(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { + if (!traits.contains(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { + return; + } + if (traits.contains(StaticArgumentTrait.SUPPORT_UPDATES)) { throw new ValidationException( "Signatures with updating inputs must not pass columns through."); } + if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() > 1) { + throw new ValidationException( + "Pass-through columns are not supported if multiple table arguments are declared."); + } } private static InputTypeStrategy deriveSystemInputStrategy( @@ -323,37 +333,44 @@ private List deriveRowtimeField(CallContext callContext) { final Set usedOnTimeFields = new HashSet<>(); - final List onTimeColumns = - IntStream.range(0, staticArgs.size()) - .mapToObj( - pos -> { - final StaticArgument staticArg = staticArgs.get(pos); - if (!staticArg.is(StaticArgumentTrait.TABLE)) { - return null; - } - final RowType rowType = - LogicalTypeUtils.toRowType( - args.get(pos).getLogicalType()); - final int onTimeColumn = - findUniqueOnTimeColumn( - staticArg.getName(), rowType, onTimeFields); - if (onTimeColumn >= 0) { - usedOnTimeFields.add( - rowType.getFieldNames().get(onTimeColumn)); - return rowType.getTypeAt(onTimeColumn); - } - if (staticArg.is(StaticArgumentTrait.REQUIRE_ON_TIME)) { - throw new ValidationException( - String.format( - "Table argument '%s' requires a time attribute. " - + "Please provide one using the implicit `on_time` argument. " - + "For example: myFunction(..., on_time => DESCRIPTOR(`my_timestamp`)", - staticArg.getName())); - } - return null; - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + final List onTimeColumns = new ArrayList<>(); + final List missingOnTimeColumns = new ArrayList<>(); + IntStream.range(0, staticArgs.size()) + .forEach( + pos -> { + final StaticArgument staticArg = staticArgs.get(pos); + if (!staticArg.is(StaticArgumentTrait.TABLE)) { + return; + } + final RowType rowType = + LogicalTypeUtils.toRowType(args.get(pos).getLogicalType()); + final int onTimeColumn = + findUniqueOnTimeColumn( + staticArg.getName(), rowType, onTimeFields); + if (onTimeColumn >= 0) { + usedOnTimeFields.add(rowType.getFieldNames().get(onTimeColumn)); + onTimeColumns.add(rowType.getTypeAt(onTimeColumn)); + return; + } + if (staticArg.is(StaticArgumentTrait.REQUIRE_ON_TIME)) { + throw new ValidationException( + String.format( + "Table argument '%s' requires a time attribute. " + + "Please provide one using the implicit `on_time` argument. " + + "For example: myFunction(..., on_time => DESCRIPTOR(`my_timestamp`)", + staticArg.getName())); + } else { + missingOnTimeColumns.add(staticArg.getName()); + } + }); + + if (!onTimeColumns.isEmpty() && !missingOnTimeColumns.isEmpty()) { + throw new ValidationException( + "Invalid time attribute declaration. If multiple tables are declared, the `on_time` argument " + + "must reference a time column for each table argument or none. " + + "Missing time attributes for: " + + missingOnTimeColumns); + } final Set unusedOnTimeFields = new HashSet<>(onTimeFields); unusedOnTimeFields.removeAll(usedOnTimeFields); @@ -494,7 +511,7 @@ public Optional> inferInputTypes( } try { - checkTableArgTraits(staticArgs, callContext); + checkTableArgs(staticArgs, callContext); checkUidArg(callContext); } catch (ValidationException e) { return callContext.fail(throwOnFailure, e.getMessage()); @@ -527,8 +544,9 @@ private static void checkUidArg(CallContext callContext) { } } - private static void checkTableArgTraits( + private static void checkTableArgs( List staticArgs, CallContext callContext) { + final List tableSemantics = new ArrayList<>(); IntStream.range(0, staticArgs.size()) .forEach( pos -> { @@ -546,7 +564,46 @@ private static void checkTableArgTraits( } checkRowSemantics(staticArg, semantics); checkSetSemantics(staticArg, semantics); + tableSemantics.add(semantics); }); + checkCoPartitioning(tableSemantics); + } + + private static void checkCoPartitioning(List tableSemantics) { + if (tableSemantics.isEmpty()) { + return; + } + final List partitioningTypes = + tableSemantics.stream() + .map( + semantics -> { + final LogicalType tableType = + semantics.dataType().getLogicalType(); + final List fieldTypes = + LogicalTypeChecks.getFieldTypes(tableType); + final LogicalType[] partitionTypes = + Arrays.stream(semantics.partitionByColumns()) + .mapToObj(fieldTypes::get) + .toArray(LogicalType[]::new); + return (LogicalType) RowType.of(partitionTypes); + }) + .collect(Collectors.toList()); + final LogicalType commonType = + LogicalTypeMerging.findCommonType(partitioningTypes).orElse(null); + if (commonType == null + || partitioningTypes.stream() + .anyMatch( + partitioningType -> + !LogicalTypeCasts.supportsAvoidingCast( + partitioningType, commonType))) { + throw new ValidationException( + "Invalid PARTITION BY columns. The number of columns and their data types must match " + + "across all involved table arguments. Given partition key sets: " + + partitioningTypes.stream() + .map(LogicalType::getChildren) + .map(Object::toString) + .collect(Collectors.joining(", "))); + } } private static void checkRowSemantics(StaticArgument staticArg, TableSemantics semantics) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/CallBindingCallContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/CallBindingCallContext.java index f5fb5c62f22cb..091fcae385d7a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/CallBindingCallContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/CallBindingCallContext.java @@ -256,11 +256,6 @@ public int timeColumn() { return -1; } - @Override - public List coPartitionArgs() { - return List.of(); - } - @Override public Optional changelogMode() { return Optional.empty(); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/OperatorBindingCallContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/OperatorBindingCallContext.java index dd1b9d456955c..cd226f9b417d4 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/OperatorBindingCallContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/OperatorBindingCallContext.java @@ -317,11 +317,6 @@ public int timeColumn() { return timeColumn; } - @Override - public List coPartitionArgs() { - return List.of(); - } - @Override public Optional changelogMode() { return Optional.ofNullable(changelogMode); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java index 9b39f4d549b36..08ebe72f9e425 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java @@ -302,7 +302,7 @@ public RelNode visit(FunctionQueryOperation functionTable) { final RelDataType outputRelDataType = typeFactory.buildRelNodeRowType((RowType) outputType); - final List inputs = new ArrayList<>(); + final List inputStack = new ArrayList<>(); final List rexNodeArgs = resolvedArgs.stream() .map( @@ -329,16 +329,19 @@ public RelNode visit(FunctionQueryOperation functionTable) { final RexTableArgCall tableArgCall = new RexTableArgCall( rowType, - inputs.size(), + inputStack.size(), partitionKeys, new int[0]); - inputs.add(relBuilder.build()); + inputStack.add(relBuilder.build()); return tableArgCall; } return convertExprToRexNode(resolvedArg); }) .collect(Collectors.toList()); + // relBuilder.build() works in LIFO fashion, this restores the original input order + Collections.reverse(inputStack); + final BridgingSqlFunction sqlFunction = BridgingSqlFunction.of(relBuilder.getCluster(), contextFunction); @@ -349,7 +352,7 @@ public RelNode visit(FunctionQueryOperation functionTable) { final RelNode functionScan = LogicalTableFunctionScan.create( relBuilder.getCluster(), - inputs, + inputStack, call, null, outputRelDataType, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java index bb607b8d65cfe..4fa75ee2d6ea0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java @@ -20,10 +20,11 @@ import org.apache.flink.FlinkVersion; import org.apache.flink.api.dag.Transformation; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ReadableConfig; -import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.transformations.KeyedMultipleInputTransformation; import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.TableException; import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.data.RowData; import org.apache.flink.table.functions.ProcessTableFunction; @@ -166,10 +167,6 @@ protected Transformation translateToPlanInternal( getInputEdges().stream() .map(e -> (Transformation) e.translateToPlan(planner)) .collect(Collectors.toList()); - if (inputTransforms.size() != 1) { - throw new TableException("Process table function only supports exactly one input."); - } - final Transformation inputTransform = inputTransforms.get(0); final List> providedInputArgs = StreamPhysicalProcessTableFunction.getProvidedInputArgs(invocation); @@ -224,20 +221,12 @@ protected Transformation translateToPlanInternal( .map(t -> EqualiserCodeGenerator.generateRowEquals(ctx, t, "StateEquals")) .toArray(GeneratedRecordEqualiser[]::new); - final RuntimeTableSemantics singleTableSemantics; - if (runtimeTableSemantics.isEmpty()) { - // For constant function calls - singleTableSemantics = null; - } else { - singleTableSemantics = runtimeTableSemantics.get(0); - } - final RuntimeChangelogMode producedChangelogMode = RuntimeChangelogMode.serialize(outputChangelogMode); final ProcessTableOperatorFactory operatorFactory = new ProcessTableOperatorFactory( - singleTableSemantics, + runtimeTableSemantics, runtimeStateInfos, generatedRunner, stateHashCode, @@ -253,24 +242,17 @@ protected Transformation translateToPlanInternal( createTransformationName(config), createTransformationDescription(config)); - final OneInputTransformation transform = - ExecNodeUtil.createOneInputTransformation( - inputTransform, - metadata, - operatorFactory, - InternalTypeInfo.of(getOutputType()), - inputTransform.getParallelism(), - false); - - // For one input (but non-constant) functions with set semantics - if (singleTableSemantics != null && singleTableSemantics.hasSetSemantics()) { - final RowDataKeySelector selector = - KeySelectorUtil.getRowDataSelector( - planner.getFlinkContext().getClassLoader(), - singleTableSemantics.partitionByColumns(), - (InternalTypeInfo) inputTransform.getOutputType()); - transform.setStateKeySelector(selector); - transform.setStateKeyType(selector.getProducedType()); + final Transformation transform; + if (runtimeTableSemantics.stream().anyMatch(RuntimeTableSemantics::hasSetSemantics)) { + transform = + createKeyedTransformation( + inputTransforms, + metadata, + operatorFactory, + planner, + runtimeTableSemantics); + } else { + transform = createNonKeyedTransformation(inputTransforms, metadata, operatorFactory); } if (inputsContainSingleton()) { @@ -306,6 +288,57 @@ private RuntimeTableSemantics createRuntimeTableSemantics( timeColumn); } + private Transformation createKeyedTransformation( + List> inputTransforms, + TransformationMetadata metadata, + ProcessTableOperatorFactory operatorFactory, + PlannerBase planner, + List runtimeTableSemantics) { + assert runtimeTableSemantics.size() == inputTransforms.size(); + + final List> keySelectors = + runtimeTableSemantics.stream() + .map( + inputSemantics -> + KeySelectorUtil.getRowDataSelector( + planner.getFlinkContext().getClassLoader(), + inputSemantics.partitionByColumns(), + (InternalTypeInfo) + inputTransforms + .get(inputSemantics.getInputIndex()) + .getOutputType())) + .collect(Collectors.toList()); + + final KeyedMultipleInputTransformation transform = + ExecNodeUtil.createKeyedMultiInputTransformation( + inputTransforms, + keySelectors, + ((RowDataKeySelector) keySelectors.get(0)).getProducedType(), + metadata, + operatorFactory, + InternalTypeInfo.of(getOutputType()), + inputTransforms.get(0).getParallelism(), + false); + + transform.setChainingStrategy(ChainingStrategy.HEAD_WITH_SOURCES); + + return transform; + } + + private Transformation createNonKeyedTransformation( + List> inputTransforms, + TransformationMetadata metadata, + ProcessTableOperatorFactory operatorFactory) { + final Transformation inputTransform = inputTransforms.get(0); + return ExecNodeUtil.createOneInputTransformation( + inputTransform, + metadata, + operatorFactory, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism(), + false); + } + private static RuntimeStateInfo createRuntimeStateInfo( String name, StateInfo stateInfo, ExecNodeConfig config) { return new RuntimeStateInfo( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/ExecNodeUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/ExecNodeUtil.java index be18acc3d71e6..9d39cf8e02e21 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/ExecNodeUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/ExecNodeUtil.java @@ -21,11 +21,13 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.connector.source.Boundedness; import org.apache.flink.api.dag.Transformation; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.transformations.KeyedMultipleInputTransformation; import org.apache.flink.streaming.api.transformations.LegacySourceTransformation; import org.apache.flink.streaming.api.transformations.OneInputTransformation; import org.apache.flink.streaming.api.transformations.PartitionTransformation; @@ -38,6 +40,7 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; /** An Utility class that helps translating {@link ExecNode} to {@link Transformation}. */ public class ExecNodeUtil { @@ -359,6 +362,33 @@ public static TwoInputTransformation createTwoInputTransf parallelismConfigured); } + /** Create a {@link KeyedMultipleInputTransformation}. */ + public static KeyedMultipleInputTransformation createKeyedMultiInputTransformation( + List> inputs, + List> keySelectors, + TypeInformation keyType, + TransformationMetadata transformationMeta, + StreamOperatorFactory operatorFactory, + TypeInformation outputType, + int parallelism, + boolean parallelismConfigured) { + final KeyedMultipleInputTransformation transformation = + new KeyedMultipleInputTransformation<>( + transformationMeta.getName(), + operatorFactory, + outputType, + parallelism, + parallelismConfigured, + keyType); + transformationMeta.fill(transformation); + IntStream.range(0, inputs.size()) + .forEach( + inputIdx -> + transformation.addInput( + inputs.get(inputIdx), keySelectors.get(inputIdx))); + return transformation; + } + /** Return description for multiple input node. */ public static String getMultipleInputDescription( ExecNode rootNode, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java index daf455688c2b9..6e5acf7efbf66 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java @@ -17,8 +17,9 @@ package org.apache.flink.table.planner.plan.nodes.physical.stream; -import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.api.config.OptimizerConfigOptions; import org.apache.flink.table.catalog.ContextResolvedFunction; import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.functions.FunctionDefinition; @@ -41,6 +42,8 @@ import org.apache.flink.table.types.inference.SystemTypeInference; import org.apache.flink.types.RowKind; +import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet; + import org.apache.calcite.linq4j.Ord; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; @@ -64,6 +67,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; @@ -100,6 +104,7 @@ public StreamPhysicalProcessTableFunction( this.rowType = rowType; this.scan = scan; this.uid = deriveUniqueIdentifier(scan); + verifyInputSize(ShortcutUtils.unwrapTableConfig(cluster), inputs.size()); } public StreamPhysicalProcessTableFunction( @@ -161,7 +166,8 @@ public ExecNode translateToExecNode() { .orElseThrow(IllegalStateException::new); final RexCall call = (RexCall) scan.getCall(); verifyTimeAttributes(getInputs(), call, inputChangelogModes, outputChangelogMode); - verifyPassThroughColumnsForUpdates(call, outputChangelogMode); + final List> providedInputArgs = getProvidedInputArgs(call); + verifyPassThroughColumnsForUpdates(providedInputArgs, outputChangelogMode); return new StreamExecProcessTableFunction( unwrapTableConfig(this), getInputs().stream().map(i -> InputProperty.DEFAULT).collect(Collectors.toList()), @@ -283,15 +289,25 @@ private static void verifyOnTimeForUpdates( } private static void verifyPassThroughColumnsForUpdates( - RexCall call, ChangelogMode outputChangelogMode) { - if (!outputChangelogMode.containsOnly(RowKind.INSERT) - && getProvidedInputArgs(call).stream() + List> providedInputArgs, ChangelogMode requiredChangelogMode) { + if (!requiredChangelogMode.containsOnly(RowKind.INSERT) + && providedInputArgs.stream() .anyMatch(arg -> arg.e.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH))) { throw new ValidationException( "Pass-through columns are not supported for PTFs that produce updates."); } } + private static void verifyInputSize(TableConfig tableConfig, int providedInputArgs) { + final int maxCount = tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_PTF_MAX_TABLES); + if (providedInputArgs > maxCount) { + throw new ValidationException( + String.format( + "Unsupported table argument count. Currently, the number of input tables is limited to %s.", + maxCount)); + } + } + // -------------------------------------------------------------------------------------------- // Shared utilities // -------------------------------------------------------------------------------------------- @@ -414,26 +430,40 @@ public static List toInputTimeColumns(RexCall call) { .collect(Collectors.toList()); } - public static ImmutableBitSet toPartitionColumns(RexCall call) { + public static Set toPartitionColumns(RexCall call) { final List operands = call.getOperands(); final List> providedInputArgs = StreamPhysicalProcessTableFunction.getProvidedInputArgs(call); - if (providedInputArgs.size() > 1) { - throw new TableException("More than one table input is not supported yet."); - } - final List partitionColumns = new ArrayList<>(); + final Set partitionColumnsPerArg = new HashSet<>(); + int pos = 0; for (Ord providedInputArg : providedInputArgs) { final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(providedInputArg.i); if (providedInputArg.e.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { - // Output preserved key positions - Arrays.stream(tableArgCall.getPartitionKeys()).forEach(partitionColumns::add); + // System type inference ensures that at most one table + // argument can pass columns through. In that case, the + // output preserves the position of partition columns. + // f(t(c1, c2, k1, k2, c3) PARTITION BY (k1, k2)) + // -> [c1, c2, k1, k2, c3, function out...] + assert providedInputArgs.size() == 1; + final List partitionColumns = + Arrays.stream(tableArgCall.getPartitionKeys()) + .boxed() + .collect(Collectors.toList()); + partitionColumnsPerArg.add(ImmutableBitSet.of(partitionColumns)); } else { - // Output is prefixed with partition keys only - IntStream.range(0, tableArgCall.getPartitionKeys().length) - .forEach(partitionColumns::add); + final int partitionKeyCount = tableArgCall.getPartitionKeys().length; + // Output is prefixed with partition keys: + // f(t1 PARTITION BY (k1, k2), t2 PARTITION BY (k3, k4)) + // -> [k1, k2, k3, k4, function out...] + final List partitionColumns = + IntStream.range(pos, partitionKeyCount) + .boxed() + .collect(Collectors.toList()); + pos += partitionKeyCount; + partitionColumnsPerArg.add(ImmutableBitSet.of(partitionColumns)); } } - return ImmutableBitSet.of(partitionColumns); + return ImmutableSet.copyOf(partitionColumnsPerArg); } public static CallContext toCallContext( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala index 04f3af55ffd8c..4f1d8d65d419f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala @@ -34,7 +34,6 @@ import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.{verifyFunctionAwareOutputType, DefaultExpressionEvaluatorFactory} import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction -import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalProcessTableFunction import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.dataview.DataViewUtils @@ -43,17 +42,15 @@ import org.apache.flink.table.runtime.dataview.StateMapView.KeyedStateMapViewWit import org.apache.flink.table.runtime.generated.{GeneratedProcessTableRunner, ProcessTableRunner} import org.apache.flink.table.types.DataType import org.apache.flink.table.types.extraction.ExtractionUtils -import org.apache.flink.table.types.inference.{StaticArgument, StaticArgumentTrait, SystemTypeInference, TypeInferenceUtil} +import org.apache.flink.table.types.inference.TypeInferenceUtil import org.apache.flink.table.types.inference.TypeInferenceUtil.StateInfo import org.apache.flink.table.types.logical.LogicalType import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.types.Row -import org.apache.calcite.rex.{RexCall, RexCallBinding, RexNode, RexUtil} -import org.apache.calcite.sql.SqlKind +import org.apache.calcite.rex.{RexCall, RexNode} import java.util -import java.util.Collections import scala.collection.JavaConverters._ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala index 80a8c8608a391..2a082de05263f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala @@ -664,8 +664,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu if (isUpsert) { // Upsert PTFs use the partition keys as upsert keys, // thus the keys are unique - val partitionColumns = StreamPhysicalProcessTableFunction.toPartitionColumns(rel.getCall) - ImmutableSet.of(partitionColumns) + StreamPhysicalProcessTableFunction.toPartitionColumns(rel.getCall) } else { null } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index 0560bcedc5838..e84edf8ad2e5a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -1529,7 +1529,6 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti val changelogMode = changelogFunction.getChangelogMode(changelogContext) if (!changelogMode.containsOnly(RowKind.INSERT)) { verifyPtfTableArgsForUpdates(call) - verifyPtfRequirementsForUpdates(call, requiredChangelogMode, changelogMode) } toTraitSet(changelogMode) case _ => @@ -1551,16 +1550,4 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } } } - - private def verifyPtfRequirementsForUpdates( - call: RexCall, - required: ChangelogMode, - returned: ChangelogMode): Unit = { - if (!required.keyOnlyDeletes() && returned.keyOnlyDeletes()) { - throw new ValidationException( - s"Unsupported changelog mode returned from PTF '${call.getOperator.toString}'. " + - s"The system requires that deletions include all fields in DELETE messages. " + - s"Key-only deletes are not sufficient.") - } - } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java index 7140a578f459d..341c3c8a9db0d 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionSemanticTests.java @@ -82,6 +82,9 @@ public List programs() { ProcessTableFunctionTestPrograms.PROCESS_INVALID_TABLE_AS_ROW_TIMERS, ProcessTableFunctionTestPrograms.PROCESS_INVALID_PASS_THROUGH_TIMERS, ProcessTableFunctionTestPrograms.PROCESS_LIST_STATE, - ProcessTableFunctionTestPrograms.PROCESS_MAP_STATE); + ProcessTableFunctionTestPrograms.PROCESS_MAP_STATE, + ProcessTableFunctionTestPrograms.PROCESS_MULTI_INPUT, + ProcessTableFunctionTestPrograms.PROCESS_STATEFUL_MULTI_INPUT_WITH_TIMEOUT, + ProcessTableFunctionTestPrograms.PROCESS_UPDATING_MULTI_INPUT); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java index 13c8f1bfb275d..59a861289d883 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestPrograms.java @@ -34,6 +34,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.LateTimersFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ListStateFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MapStateFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MultiInputFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MultiStateFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.NamedTimersFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.OptionalOnTimeFunction; @@ -55,9 +56,11 @@ import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TableAsSetUpdatingArgFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TimeConversionsFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TimeToLiveStateFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TimedJoinFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedTableAsRowFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedTableAsSetFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UnnamedTimersFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UpdatingJoinFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UpdatingRetractFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UpdatingUpsertFunction; import org.apache.flink.table.test.program.SinkTestStep; @@ -74,11 +77,15 @@ import static org.apache.flink.table.api.Expressions.lit; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASIC_VALUES; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.CITY_VALUES; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.KEYED_BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.KEYED_TIMED_BASE_SINK_SCHEMA; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MULTI_BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MULTI_VALUES; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.PASS_THROUGH_BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TIMED_BASE_SINK_SCHEMA; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TIMED_CITY_SOURCE; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TIMED_MULTI_BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TIMED_SOURCE; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TIMED_SOURCE_LATE_EVENTS; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UPDATING_VALUES; @@ -1216,4 +1223,106 @@ public class ProcessTableFunctionTestPrograms { .build()) .runSql("INSERT INTO sink SELECT * FROM f(r => TABLE t PARTITION BY name)") .build(); + + public static final TableTestProgram PROCESS_MULTI_INPUT = + TableTestProgram.of("process-multi-input", "takes multiple tables") + .setupTemporarySystemFunction("f", MultiInputFunction.class) + .setupSql(MULTI_VALUES) + .setupSql(CITY_VALUES) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema(MULTI_BASE_SINK_SCHEMA) + .consumedValues( + "+I[Bob, Bob, {+I[Bob, 12], null}]", + "+I[Bob, Bob, {null, +I[Bob, London]}]", + "+I[Alice, Alice, {+I[Alice, 42], null}]", + "+I[Alice, Alice, {null, +I[Alice, Berlin]}]", + "+I[Bob, Bob, {+I[Bob, 99], null}]", + "+I[Charly, Charly, {null, +I[Charly, Paris]}]", + "+I[Bob, Bob, {+I[Bob, 100], null}]", + "+I[Alice, Alice, {+I[Alice, 400], null}]") + .build()) + .runSql( + "INSERT INTO sink SELECT * FROM f(in1 => TABLE t PARTITION BY name, in2 => TABLE city PARTITION BY name)") + .build(); + + public static final TableTestProgram PROCESS_STATEFUL_MULTI_INPUT_WITH_TIMEOUT = + TableTestProgram.of( + "process-stateful-multi-input-with-timeout", + "joins two tables and emits the left side after a timeout if there is no right side") + .setupTemporarySystemFunction("f", TimedJoinFunction.class) + .setupTableSource(TIMED_SOURCE) + .setupTableSource(TIMED_CITY_SOURCE) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema(TIMED_MULTI_BASE_SINK_SCHEMA) + .consumedValues( + "+I[Bob, Bob, 1 score in city London, 1970-01-01T00:00:00Z]", + "+I[Bob, Bob, 2 score in city London, 1970-01-01T00:00:00.002Z]", + "+I[Bob, Bob, 3 score in city London, 1970-01-01T00:00:00.003Z]", + "+I[Bob, Bob, 4 score in city London, 1970-01-01T00:00:00.004Z]", + "+I[Bob, Bob, 5 score in city London, 1970-01-01T00:00:00.005Z]", + "+I[Bob, Bob, 6 score in city London, 1970-01-01T00:00:00.006Z]", + "+I[Alice, Alice, no city found for score 1, 1970-01-01T00:00:01.001Z]") + .build()) + .runSql( + "INSERT INTO sink SELECT * FROM f(" + + "scoreTable => TABLE t PARTITION BY name, " + + "cityTable => TABLE city PARTITION BY name, " + + "on_time => DESCRIPTOR(ts))") + .build(); + + public static final TableTestProgram PROCESS_UPDATING_MULTI_INPUT = + TableTestProgram.of( + "process-updating-multi-input", + "joins two tables with input and output updates") + .setupTemporarySystemFunction("f", UpdatingJoinFunction.class) + .setupTableSource( + SourceTestStep.newBuilder("scores") + .addSchema( + "name STRING PRIMARY KEY NOT ENFORCED", + "score INT NOT NULL") + .addOption("changelog-mode", "I,UA,D") + .addOption("source.produces-delete-by-key", "true") + .producedValues( + Row.ofKind(RowKind.INSERT, "Bob", 5), + Row.ofKind(RowKind.INSERT, "Alice", 2), + Row.ofKind(RowKind.UPDATE_AFTER, "Bob", 3), + Row.ofKind(RowKind.DELETE, "Bob", null), + Row.ofKind(RowKind.INSERT, "Bob", 2), + Row.ofKind(RowKind.DELETE, "Alice", null)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("city") + .addSchema( + "name STRING PRIMARY KEY NOT ENFORCED", + "city STRING NOT NULL") + .addOption("changelog-mode", "I,UA,D") + .addOption("source.produces-delete-by-key", "true") + .producedValues( + Row.ofKind(RowKind.INSERT, "Bob", "London"), + Row.ofKind(RowKind.INSERT, "Alice", "Zurich"), + Row.ofKind(RowKind.UPDATE_AFTER, "Bob", "Berlin")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("sink") + .addSchema( + "`name` STRING PRIMARY KEY NOT ENFORCED", + "`out` STRING") + .addOption("sink-changelog-mode-enforced", "I,UA,D") + .addOption("sink.supports-delete-by-key", "true") + .consumedValues( + "+I[Bob, score 5 in city London]", + "+I[Alice, score 2 in city Zurich]", + "+U[Bob, score 3 in city London]", + "+U[Bob, score 3 in city Berlin]", + "-D[Bob, null]", + "+I[Bob, score 2 in city Berlin]", + "-D[Alice, null]") + .build()) + .runSql( + "INSERT INTO sink SELECT `name`, `out` FROM f(" + + "scoreTable => TABLE scores PARTITION BY name, " + + "cityTable => TABLE city PARTITION BY name)") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java index 110297904eb5c..5082ebda06b83 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/ProcessTableFunctionTestUtils.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.ArgumentTrait; import org.apache.flink.table.annotation.DataTypeHint; @@ -30,7 +31,7 @@ import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.TableSemantics; -import org.apache.flink.table.runtime.operators.process.ProcessTableOperator; +import org.apache.flink.table.runtime.operators.process.AbstractProcessTableOperator.RunnerContext; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.types.ColumnList; import org.apache.flink.types.Row; @@ -67,6 +68,20 @@ public class ProcessTableFunctionTestUtils { "CREATE VIEW t AS SELECT * FROM " + "(VALUES ('Bob', 12), ('Alice', 42), ('Bob', 99), ('Bob', 100), ('Alice', 400)) AS T(name, score)"; + public static final String CITY_VALUES = + "CREATE VIEW city AS SELECT * FROM " + + "(VALUES ('Bob', 'London'), ('Alice', 'Berlin'), ('Charly', 'Paris')) AS T(name, city)"; + + public static final SourceTestStep TIMED_CITY_SOURCE = + SourceTestStep.newBuilder("city") + .addSchema( + "name STRING", + "city STRING", + "ts TIMESTAMP_LTZ(3)", + "WATERMARK FOR ts AS ts - INTERVAL '0.001' SECOND") + .producedValues(Row.of("Bob", "London", Instant.ofEpochMilli(0))) + .build(); + public static final String UPDATING_VALUES = "CREATE VIEW t AS SELECT name, COUNT(*) FROM " + "(VALUES ('Bob', 12), ('Alice', 42), ('Bob', 14)) AS T(name, score) " @@ -119,6 +134,18 @@ public class ProcessTableFunctionTestUtils { public static final List KEYED_BASE_SINK_SCHEMA = List.of("`name` STRING", "`out` STRING"); + /** Corresponds to {@link AppendProcessTableFunctionBase}. */ + public static final List MULTI_BASE_SINK_SCHEMA = + List.of("`name` STRING", "`name0` STRING", "`out` STRING"); + + /** Corresponds to {@link AppendProcessTableFunctionBase}. */ + public static final List TIMED_MULTI_BASE_SINK_SCHEMA = + List.of( + "`name` STRING", + "`name0` STRING", + "`out` STRING", + "`rowtime` TIMESTAMP_LTZ(3)"); + /** Corresponds to {@link AppendProcessTableFunctionBase}. */ public static final List PASS_THROUGH_BASE_SINK_SCHEMA = List.of("`name` STRING", "`score` INT", "`out` STRING"); @@ -377,8 +404,7 @@ public void eval( @StateHint(ttl = "0") Score s2, @StateHint Score s3, @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row r) { - final ProcessTableOperator.RunnerContext internalContext = - (ProcessTableOperator.RunnerContext) ctx; + final RunnerContext internalContext = (RunnerContext) ctx; if (s0.getFieldAs("emitted") == null) { collect( String.format( @@ -751,19 +777,6 @@ public ChangelogMode getChangelogMode(ChangelogContext changelogContext) { } } - /** Testing function. */ - public static class UpdatingUpsertPartialDeletesFunction - extends ChangelogProcessTableFunctionBase { - public void eval(Context ctx, @ArgumentHint({TABLE_AS_SET, SUPPORT_UPDATES}) Row r) { - collectUpdate(ctx, r); - } - - @Override - public ChangelogMode getChangelogMode(ChangelogContext changelogContext) { - return ChangelogMode.upsert(); - } - } - /** Testing function. */ public static class UpdatingUpsertFullDeletesFunction extends ChangelogProcessTableFunctionBase { @@ -798,6 +811,103 @@ public void eval(@ArgumentHint(TABLE_AS_ROW) Row r) { } } + /** Testing function. */ + public static class MultiInputFunction extends AppendProcessTableFunctionBase { + public void eval( + Context ctx, + @ArgumentHint(TABLE_AS_SET) Row in1, + @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row in2) + throws Exception { + collectObjects(in1, in2); + } + } + + /** Testing function. */ + public static class TimedJoinFunction extends AppendProcessTableFunctionBase { + public void eval( + Context ctx, + @StateHint Tuple1 score, + @StateHint Tuple1 city, + @ArgumentHint({TABLE_AS_SET, REQUIRE_ON_TIME}) Row scoreTable, + @ArgumentHint({TABLE_AS_SET, REQUIRE_ON_TIME}) Row cityTable) + throws Exception { + final TimeContext timeCtx = ctx.timeContext(Instant.class); + if (scoreTable != null) { + score.f0 = scoreTable.getFieldAs("score"); + timeCtx.registerOnTime("timeout", timeCtx.time().plusMillis(1000)); + } + if (cityTable != null) { + city.f0 = cityTable.getFieldAs("city"); + } + if (score.f0 != null && city.f0 != null) { + collect(Row.of(score.f0 + " score in city " + city.f0)); + ctx.clearAllTimers(); + } + } + + public void onTimer(OnTimerContext ctx, Tuple1 score, Tuple1 city) { + collect(Row.of("no city found for score " + score.f0)); + score.f0 = null; + } + } + + /** + * Implements a custom join that acts like kind of an outer join and never produces deletions. + * Both the score and city can change at any time. The join will output an update if a matching + * pair could be found. + */ + @DataTypeHint("ROW") + public static class UpdatingJoinFunction extends ProcessTableFunction + implements ChangelogFunction { + public void eval( + @StateHint Tuple1 score, + @StateHint Tuple1 city, + @ArgumentHint({TABLE_AS_SET, SUPPORT_UPDATES}) Row scoreTable, + @ArgumentHint({TABLE_AS_SET, SUPPORT_UPDATES}) Row cityTable) + throws Exception { + final boolean wasMatch = isMatch(score, city); + if (isDelete(scoreTable) || isDelete(cityTable)) { + if (wasMatch) { + collect(Row.ofKind(RowKind.DELETE, (Object) null)); + } + } + + if (scoreTable != null) { + apply(score, scoreTable.getFieldAs("score"), scoreTable.getKind()); + } + if (cityTable != null) { + apply(city, cityTable.getFieldAs("city"), cityTable.getKind()); + } + if (isMatch(score, city)) { + collect( + Row.ofKind( + wasMatch ? RowKind.UPDATE_AFTER : RowKind.INSERT, + "score " + score.f0 + " in city " + city.f0)); + } + } + + public boolean isDelete(Row r) { + return r != null && r.getKind() == RowKind.DELETE; + } + + public boolean isMatch(Tuple1 score, Tuple1 city) { + return score.f0 != null && city.f0 != null; + } + + @Override + public ChangelogMode getChangelogMode(ChangelogContext changelogContext) { + return ChangelogMode.upsert(); + } + + private static void apply(Tuple1 t, T o, RowKind op) { + if (op == RowKind.INSERT || op == RowKind.UPDATE_AFTER) { + t.f0 = o; + } else { + t.f0 = null; + } + } + } + // -------------------------------------------------------------------------------------------- // Helpers // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java index 45fb5dbb9b7cf..185eec0ae8865 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java @@ -25,9 +25,11 @@ import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.AppendProcessTableFunctionBase; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.DescriptorFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.EmptyArgFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.InvalidUpdatingSemanticsFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.MultiInputFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RequiredTimeFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ScalarArgsFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TableAsRowFunction; @@ -37,7 +39,6 @@ import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedTableAsRowFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedTableAsSetFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UpdatingUpsertFunction; -import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.UpdatingUpsertPartialDeletesFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.User; import org.apache.flink.table.planner.utils.TableTestBase; import org.apache.flink.table.planner.utils.TableTestUtil; @@ -285,9 +286,9 @@ private static Stream errorSpecs() { "Function signature must not declare system arguments. Reserved argument names are: [on_time, uid]"), ErrorSpec.ofSelect( "multiple table args", - MultiTableFunction.class, + InvalidMultiTableWithRowFunction.class, "SELECT * FROM f(r1 => TABLE t, r2 => TABLE t)", - "Currently, only signatures with at most one table argument are supported."), + "All table arguments must use set semantics if multiple table arguments are declared."), ErrorSpec.ofSelect( "row instead of table", TableAsRowFunction.class, @@ -375,12 +376,58 @@ private static Stream errorSpecs() { "SELECT * FROM f(r => TABLE t_watermarked PARTITION BY name, on_time => DESCRIPTOR(ts))", "Time operations using the `on_time` argument are currently not supported for " + "PTFs that consume or produce updates."), - ErrorSpec.ofInsertInto( - "PTF returns partial deletes but full deleted are required", - UpdatingUpsertPartialDeletesFunction.class, - "INSERT INTO t_full_delete_sink SELECT * FROM f(r => TABLE t PARTITION BY name)", - "Unsupported changelog mode returned from PTF 'f'. The system requires that deletions include " - + "all fields in DELETE messages. Key-only deletes are not sufficient.")); + ErrorSpec.ofSelect( + "no pass-through for multiple table args", + InvalidPassThroughTables.class, + "SELECT * FROM f(r1 => TABLE t, r2 => TABLE t)", + "Pass-through columns are not supported if multiple table arguments are declared."), + ErrorSpec.ofSelect( + "on_time must be declared for multiple table args", + MultiInputFunction.class, + "SELECT * FROM f(in1 => TABLE t PARTITION BY name, in2 => TABLE t_watermarked PARTITION BY name, " + + "on_time => DESCRIPTOR(ts))", + "Invalid time attribute declaration. If multiple tables are declared, " + + "the `on_time` argument must reference a time column for each table argument " + + "or none. Missing time attributes for: [in1]"), + ErrorSpec.ofSelect( + "different partition keys by data type", + MultiInputFunction.class, + "SELECT * FROM f(in1 => TABLE t PARTITION BY score, in2 => TABLE t PARTITION BY name)", + "Invalid PARTITION BY columns. The number of columns and their data types must match across all " + + "involved table arguments. Given partition key sets: [INT NOT NULL], [VARCHAR(5) NOT NULL]"), + ErrorSpec.ofSelect( + "different partition keys by column count", + MultiInputFunction.class, + "SELECT * FROM f(in1 => TABLE t PARTITION BY score, in2 => TABLE t)", + "Invalid PARTITION BY columns. The number of columns and their data types must match across all " + + "involved table arguments. Given partition key sets: [INT NOT NULL], []"), + ErrorSpec.ofSelect( + "maximum table arguments reached", + HighMultiInputFunction.class, + "SELECT * FROM f(" + + "in1 => TABLE t PARTITION BY score, " + + "in2 => TABLE t PARTITION BY score, " + + "in3 => TABLE t PARTITION BY score, " + + "in4 => TABLE t PARTITION BY score, " + + "in5 => TABLE t PARTITION BY score, " + + "in6 => TABLE t PARTITION BY score, " + + "in7 => TABLE t PARTITION BY score, " + + "in8 => TABLE t PARTITION BY score, " + + "in9 => TABLE t PARTITION BY score, " + + "in10 => TABLE t PARTITION BY score, " + + "in11 => TABLE t PARTITION BY score, " + + "in12 => TABLE t PARTITION BY score, " + + "in13 => TABLE t PARTITION BY score, " + + "in14 => TABLE t PARTITION BY score, " + + "in15 => TABLE t PARTITION BY score, " + + "in16 => TABLE t PARTITION BY score, " + + "in17 => TABLE t PARTITION BY score, " + + "in18 => TABLE t PARTITION BY score, " + + "in19 => TABLE t PARTITION BY score, " + + "in20 => TABLE t PARTITION BY score, " + + "in21 => TABLE t PARTITION BY score" + + ")", + "Unsupported table argument count. Currently, the number of input tables is limited to 20.")); } /** Testing function. */ @@ -390,11 +437,11 @@ public void eval(@ArgumentHint({TABLE_AS_ROW, SUPPORT_UPDATES}) User u, Integer } /** Testing function. */ - public static class MultiTableFunction extends ProcessTableFunction { + public static class InvalidMultiTableWithRowFunction extends ProcessTableFunction { @SuppressWarnings("unused") public void eval( @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row r1, - @ArgumentHint({TABLE_AS_SET, OPTIONAL_PARTITION_BY}) Row r2) {} + @ArgumentHint(TABLE_AS_ROW) Row r2) {} } /** Testing function. */ @@ -436,6 +483,42 @@ public static class OptionalUntypedTable extends ProcessTableFunction { public void eval(@ArgumentHint(value = TABLE_AS_ROW, isOptional = true) Row r) {} } + /** Testing function. */ + public static class InvalidPassThroughTables extends ProcessTableFunction { + @SuppressWarnings("unused") + public void eval( + @ArgumentHint({TABLE_AS_SET, PASS_COLUMNS_THROUGH}) Row r1, + @ArgumentHint({TABLE_AS_SET, PASS_COLUMNS_THROUGH}) Row r2) {} + } + + /** Testing function. */ + public static class HighMultiInputFunction extends AppendProcessTableFunctionBase { + @SuppressWarnings("unused") + public void eval( + @ArgumentHint(TABLE_AS_SET) Row in1, + @ArgumentHint(TABLE_AS_SET) Row in2, + @ArgumentHint(TABLE_AS_SET) Row in3, + @ArgumentHint(TABLE_AS_SET) Row in4, + @ArgumentHint(TABLE_AS_SET) Row in5, + @ArgumentHint(TABLE_AS_SET) Row in6, + @ArgumentHint(TABLE_AS_SET) Row in7, + @ArgumentHint(TABLE_AS_SET) Row in8, + @ArgumentHint(TABLE_AS_SET) Row in9, + @ArgumentHint(TABLE_AS_SET) Row in10, + @ArgumentHint(TABLE_AS_SET) Row in11, + @ArgumentHint(TABLE_AS_SET) Row in12, + @ArgumentHint(TABLE_AS_SET) Row in13, + @ArgumentHint(TABLE_AS_SET) Row in14, + @ArgumentHint(TABLE_AS_SET) Row in15, + @ArgumentHint(TABLE_AS_SET) Row in16, + @ArgumentHint(TABLE_AS_SET) Row in17, + @ArgumentHint(TABLE_AS_SET) Row in18, + @ArgumentHint(TABLE_AS_SET) Row in19, + @ArgumentHint(TABLE_AS_SET) Row in20, + @ArgumentHint(TABLE_AS_SET) Row in21) + throws Exception {} + } + private static class ErrorSpec { private final String description; private final Class functionClass; diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java index f285f0c3d461a..eda3a6c244269 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/ProcessTableRunner.java @@ -25,11 +25,11 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.runtime.operators.process.AbstractProcessTableOperator; +import org.apache.flink.table.runtime.operators.process.AbstractProcessTableOperator.RunnerContext; +import org.apache.flink.table.runtime.operators.process.AbstractProcessTableOperator.RunnerOnTimerContext; import org.apache.flink.table.runtime.operators.process.PassAllCollector; import org.apache.flink.table.runtime.operators.process.PassThroughCollectorBase; -import org.apache.flink.table.runtime.operators.process.ProcessTableOperator; -import org.apache.flink.table.runtime.operators.process.ProcessTableOperator.RunnerContext; -import org.apache.flink.table.runtime.operators.process.ProcessTableOperator.RunnerOnTimerContext; import org.apache.flink.types.RowKind; import org.apache.flink.util.function.RunnableWithException; @@ -40,7 +40,7 @@ /** * Abstraction of code-generated calls to {@link ProcessTableFunction} to be used within {@link - * ProcessTableOperator}. + * AbstractProcessTableOperator}. */ @Internal public abstract class ProcessTableRunner extends AbstractRichFunction { @@ -105,7 +105,7 @@ public void initialize( } public void ingestTableEvent(int pos, RowData row, int timeColumn) { - evalCollector.setPrefix(row); + evalCollector.setPrefix(pos, row); if (timeColumn == -1) { rowtime = null; } else { @@ -120,7 +120,7 @@ public void ingestTableEvent(int pos, RowData row, int timeColumn) { } public void ingestTimerEvent(RowData key, @Nullable StringData name, long timerTime) { - onTimerCollector.setPrefix(key); + onTimerCollector.setPrefix(-1, key); if (emitRowtime) { onTimerCollector.setRowtime(timerTime); } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/AbstractProcessTableOperator.java similarity index 88% rename from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java rename to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/AbstractProcessTableOperator.java index cf7a6bda043b5..5432a826cc264 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperator.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/AbstractProcessTableOperator.java @@ -33,14 +33,12 @@ import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorParameters; import org.apache.flink.streaming.api.operators.Triggerable; import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.TableRuntimeException; import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.api.dataview.MapView; @@ -73,15 +71,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; +import java.util.stream.Collectors; -/** Operator for {@link ProcessTableFunction}. */ -public class ProcessTableOperator extends AbstractStreamOperator - implements OneInputStreamOperator, Triggerable { +/** Base class for operators for {@link ProcessTableFunction}. */ +@Internal +public abstract class AbstractProcessTableOperator extends AbstractStreamOperatorV2 + implements Triggerable { + + protected final List tableSemantics; + protected final ProcessTableRunner processTableRunner; - private final @Nullable RuntimeTableSemantics tableSemantics; private final List stateInfos; - private final ProcessTableRunner processTableRunner; private final HashFunction[] stateHashCode; private final RecordEqualiser[] stateEquals; private final RuntimeChangelogMode producedChangelogMode; @@ -97,15 +97,16 @@ public class ProcessTableOperator extends AbstractStreamOperator private transient @Nullable InternalTimerService namedTimerService; private transient @Nullable InternalTimerService unnamedTimerService; - public ProcessTableOperator( + public AbstractProcessTableOperator( StreamOperatorParameters parameters, - @Nullable RuntimeTableSemantics tableSemantics, + List tableSemantics, List stateInfos, ProcessTableRunner processTableRunner, HashFunction[] stateHashCode, RecordEqualiser[] stateEquals, RuntimeChangelogMode producedChangelogMode) { - super(parameters); + // Operator always has at least one input (i.e. empty values) + super(parameters, Math.max(tableSemantics.size(), 1)); this.tableSemantics = tableSemantics; this.stateInfos = stateInfos; this.processTableRunner = processTableRunner; @@ -144,15 +145,6 @@ public void open() throws Exception { FunctionUtils.openFunction(processTableRunner, DefaultOpenContext.INSTANCE); } - @Override - public void processElement(StreamRecord element) throws Exception { - // Set table argument - if (tableSemantics != null) { - processTableRunner.ingestTableEvent(0, element.getValue(), tableSemantics.timeColumn()); - } - processTableRunner.processEval(); - } - @Override public void processWatermark(Watermark mark) throws Exception { super.processWatermark(mark); @@ -185,9 +177,9 @@ public class RunnerContext implements ProcessTableFunction.Context { } private Map createTableSemanticsMap() { - return Optional.ofNullable(tableSemantics) - .map(s -> Map.of(tableSemantics.getArgName(), tableSemantics)) - .orElse(Map.of()); + return tableSemantics.stream() + .map(inputSemantics -> Map.entry(inputSemantics.getArgName(), inputSemantics)) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } private Map createStateNameToPosMap() { @@ -289,7 +281,8 @@ public class RunnerOnTimerContext extends RunnerContext @SuppressWarnings({"unchecked", "rawtypes"}) private void setTimerServices() { if (shouldEnableTimers()) { - final KeyedStateStore keyedStateStore = getKeyedStateStore(); + final KeyedStateStore keyedStateStore = + getKeyedStateStore().orElseThrow(IllegalStateException::new); final MapStateDescriptor namedTimersDescriptor = new MapStateDescriptor<>( "internal-named-timers-map", @@ -322,14 +315,20 @@ private void setTimeContext() { } private void setCollectors() { - if (tableSemantics == null || tableSemantics.passColumnsThrough()) { - evalCollector = new PassAllCollector(output, changelogMode); + final int tableCount = tableSemantics.size(); + if (tableCount == 0 + || tableSemantics.stream().anyMatch(RuntimeTableSemantics::passColumnsThrough)) { + assert tableCount <= 1; + // Collect from table event with all input columns (potentially none) + evalCollector = new PassAllCollector(output, changelogMode, 1); } else { - evalCollector = - new PassPartitionKeysCollector( - output, changelogMode, tableSemantics.partitionByColumns()); + // Collect from table event with partition keys for each table + evalCollector = new PassPartitionKeysCollector(output, changelogMode, tableSemantics); } - onTimerCollector = new PassAllCollector(output, changelogMode); + + // Collect with partition keys for each table but from timer events which only contains the + // key, so passing all columns is the right strategy + onTimerCollector = new PassAllCollector(output, changelogMode, tableCount); } private void setStateDescriptors() { @@ -374,9 +373,10 @@ private void setStateDescriptors() { } private void setStateHandles() { - final KeyedStateStore keyedStateStore = getKeyedStateStore(); final State[] stateHandles = new State[stateDescriptors.length]; for (int i = 0; i < stateDescriptors.length; i++) { + final KeyedStateStore keyedStateStore = + getKeyedStateStore().orElseThrow(IllegalStateException::new); final StateDescriptor stateDescriptor = stateDescriptors[i]; final State stateHandle; if (stateDescriptor instanceof ValueStateDescriptor) { @@ -396,12 +396,13 @@ private void setStateHandles() { } private boolean shouldEmitRowtime() { - return tableSemantics != null && tableSemantics.timeColumn() != -1; + return !tableSemantics.isEmpty() + && tableSemantics.stream().allMatch(input -> input.timeColumn() != -1); } private boolean shouldEnableTimers() { - return tableSemantics != null - && tableSemantics.hasSetSemantics() - && !tableSemantics.passColumnsThrough(); + return !tableSemantics.isEmpty() + && tableSemantics.stream() + .allMatch(input -> input.hasSetSemantics() && !input.passColumnsThrough()); } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassAllCollector.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassAllCollector.java index d82b115a684ee..ac00126f1e810 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassAllCollector.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassAllCollector.java @@ -28,11 +28,15 @@ @Internal public class PassAllCollector extends PassThroughCollectorBase { - public PassAllCollector(Output> output, ChangelogMode changelogMode) { - super(output, changelogMode); + public PassAllCollector( + Output> output, + ChangelogMode changelogMode, + int prefixRepetition) { + super(output, changelogMode, prefixRepetition); } - public void setPrefix(RowData input) { + @Override + public void setPrefix(int pos, RowData input) { prefix = input; } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassPartitionKeysCollector.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassPartitionKeysCollector.java index e37b4f37a1490..f5a309de568d3 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassPartitionKeysCollector.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassPartitionKeysCollector.java @@ -25,19 +25,31 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.utils.ProjectedRowData; -/** Forwards partition keys of the given row. */ +import java.util.List; +import java.util.stream.IntStream; + +/** Forwards partition keys of the given input's row. */ @Internal public class PassPartitionKeysCollector extends PassThroughCollectorBase { + private final ProjectedRowData[] partitionKeys; + public PassPartitionKeysCollector( Output> output, ChangelogMode changelogMode, - int[] partitionKeys) { - super(output, changelogMode); - prefix = ProjectedRowData.from(partitionKeys); + List tableSemantics) { + super(output, changelogMode, tableSemantics.size()); + partitionKeys = new ProjectedRowData[tableSemantics.size()]; + IntStream.range(0, tableSemantics.size()) + .forEach( + pos -> + partitionKeys[pos] = + ProjectedRowData.from( + tableSemantics.get(pos).partitionByColumns())); } - public void setPrefix(RowData input) { - ((ProjectedRowData) prefix).replaceRow(input); + @Override + public void setPrefix(int pos, RowData input) { + prefix = partitionKeys[pos].replaceRow(input); } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassThroughCollectorBase.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassThroughCollectorBase.java index e8cc1015f8531..5b032c3891f93 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassThroughCollectorBase.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/PassThroughCollectorBase.java @@ -34,25 +34,30 @@ @Internal public abstract class PassThroughCollectorBase extends StreamRecordCollector { - private final JoinedRowData withPrefix; + private final RepeatedRowData repeatedPrefix; + private final JoinedRowData withFunctionOutput; private final JoinedRowData withRowtime; private final ChangelogMode changelogMode; - private RowData rowtime; protected RowData prefix; + private RowData rowtime; + public PassThroughCollectorBase( - Output> output, ChangelogMode changelogMode) { + Output> output, + ChangelogMode changelogMode, + int prefixRepetition) { super(output); this.changelogMode = changelogMode; - // constructs a flattened row of [[prefix | function output] | rowtime] - withPrefix = new JoinedRowData(); + // constructs a flattened row of [[[prefix]{1,n} | function output] | rowtime] + repeatedPrefix = new RepeatedRowData(prefixRepetition); + withFunctionOutput = new JoinedRowData(); withRowtime = new JoinedRowData(); prefix = GenericRowData.of(); rowtime = GenericRowData.of(); } - public abstract void setPrefix(RowData input); + public abstract void setPrefix(int pos, RowData input); public void setRowtime(Long time) { rowtime = GenericRowData.of(TimestampData.fromEpochMillis(time)); @@ -60,8 +65,9 @@ public void setRowtime(Long time) { @Override public void collect(RowData functionOutput) { - withPrefix.replace(prefix, functionOutput); - withRowtime.replace(withPrefix, rowtime); + repeatedPrefix.replace(prefix); + withFunctionOutput.replace(repeatedPrefix, functionOutput); + withRowtime.replace(withFunctionOutput, rowtime); // Forward supported change flags. final RowKind kind = functionOutput.getRowKind(); if (!changelogMode.contains(kind)) { diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessRowTableOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessRowTableOperator.java new file mode 100644 index 0000000000000..075f589e926a6 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessRowTableOperator.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.process; + +import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.functions.TableSemantics; +import org.apache.flink.table.runtime.generated.HashFunction; +import org.apache.flink.table.runtime.generated.ProcessTableRunner; +import org.apache.flink.table.runtime.generated.RecordEqualiser; + +import javax.annotation.Nullable; + +import java.util.List; + +/** + * Implementation of {@link OneInputStreamOperator} for {@link ProcessTableFunction} with at most + * one table with row semantics. + * + *

This class is required because {@link MultipleInputStreamOperator} has issues with chaining + * when the transformation is not keyed. + */ +public class ProcessRowTableOperator extends AbstractProcessTableOperator + implements OneInputStreamOperator { + + private final @Nullable TableSemantics inputSemantics; + + public ProcessRowTableOperator( + StreamOperatorParameters parameters, + List tableSemantics, + List stateInfos, + ProcessTableRunner processTableRunner, + HashFunction[] stateHashCode, + RecordEqualiser[] stateEquals, + RuntimeChangelogMode producedChangelogMode) { + super( + parameters, + tableSemantics, + stateInfos, + processTableRunner, + stateHashCode, + stateEquals, + producedChangelogMode); + if (tableSemantics.isEmpty()) { + inputSemantics = null; + } else { + inputSemantics = tableSemantics.get(0); + } + } + + @Override + public void setKeyContextElement1(StreamRecord record) { + // not applicable + } + + @Override + public void processElement(StreamRecord element) throws Exception { + if (inputSemantics != null) { + processTableRunner.ingestTableEvent(0, element.getValue(), inputSemantics.timeColumn()); + } + processTableRunner.processEval(); + } + + @Override + public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws Exception { + super.processWatermarkStatus(watermarkStatus, 1); + } + + @Override + public void processLatencyMarker(LatencyMarker latencyMarker) throws Exception { + super.reportOrForwardLatencyMarker(latencyMarker); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessSetTableOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessSetTableOperator.java new file mode 100644 index 0000000000000..948464d5f3f89 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessSetTableOperator.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.process; + +import org.apache.flink.streaming.api.operators.AbstractInput; +import org.apache.flink.streaming.api.operators.Input; +import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.ProcessTableFunction; +import org.apache.flink.table.functions.TableSemantics; +import org.apache.flink.table.runtime.generated.HashFunction; +import org.apache.flink.table.runtime.generated.ProcessTableRunner; +import org.apache.flink.table.runtime.generated.RecordEqualiser; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Implementation of {@link MultipleInputStreamOperator} for {@link ProcessTableFunction} with at + * least one table with set semantics. + */ +public class ProcessSetTableOperator extends AbstractProcessTableOperator + implements MultipleInputStreamOperator { + + public ProcessSetTableOperator( + StreamOperatorParameters parameters, + List tableSemantics, + List stateInfos, + ProcessTableRunner processTableRunner, + HashFunction[] stateHashCode, + RecordEqualiser[] stateEquals, + RuntimeChangelogMode producedChangelogMode) { + super( + parameters, + tableSemantics, + stateInfos, + processTableRunner, + stateHashCode, + stateEquals, + producedChangelogMode); + } + + @Override + @SuppressWarnings("rawtypes") + public List getInputs() { + return IntStream.range(0, tableSemantics.size()) + .mapToObj( + inputIdx -> { + final TableSemantics inputSemantics = tableSemantics.get(inputIdx); + final int timeColumn = inputSemantics.timeColumn(); + return new AbstractInput(this, inputIdx + 1) { + @Override + public void processElement(StreamRecord element) + throws Exception { + processTableRunner.ingestTableEvent( + inputIdx, element.getValue(), timeColumn); + processTableRunner.processEval(); + } + }; + }) + .collect(Collectors.toList()); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperatorFactory.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperatorFactory.java index 68373af78e5a3..d0eed417f110a 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperatorFactory.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/ProcessTableOperatorFactory.java @@ -19,7 +19,6 @@ package org.apache.flink.table.runtime.operators.process; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; -import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorParameters; import org.apache.flink.table.data.RowData; @@ -30,18 +29,15 @@ import org.apache.flink.table.runtime.generated.ProcessTableRunner; import org.apache.flink.table.runtime.generated.RecordEqualiser; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.List; -/** The factory of {@link ProcessTableOperator}. */ -public class ProcessTableOperatorFactory extends AbstractStreamOperatorFactory - implements OneInputStreamOperatorFactory { +/** The factory for subclasses of {@link AbstractProcessTableOperator}. */ +public class ProcessTableOperatorFactory extends AbstractStreamOperatorFactory { private static final long serialVersionUID = 1L; - private final @Nullable RuntimeTableSemantics tableSemantics; + private final List tableSemantics; private final List stateInfos; private final GeneratedProcessTableRunner generatedProcessTableRunner; private final GeneratedHashFunction[] generatedStateHashCode; @@ -49,7 +45,7 @@ public class ProcessTableOperatorFactory extends AbstractStreamOperatorFactory tableSemantics, List stateInfos, GeneratedProcessTableRunner generatedProcessTableRunner, GeneratedHashFunction[] generatedStateHashCode, @@ -76,19 +72,34 @@ public StreamOperator createStreamOperator(StreamOperatorParameters parameters) Arrays.stream(generatedStateEquals) .map(g -> g.newInstance(classLoader)) .toArray(RecordEqualiser[]::new); - return new ProcessTableOperator( - parameters, - tableSemantics, - stateInfos, - runner, - stateHashCode, - stateEquals, - producedChangelogMode); + if (tableSemantics.stream().anyMatch(RuntimeTableSemantics::hasSetSemantics)) { + return new ProcessSetTableOperator( + parameters, + tableSemantics, + stateInfos, + runner, + stateHashCode, + stateEquals, + producedChangelogMode); + } else { + return new ProcessRowTableOperator( + parameters, + tableSemantics, + stateInfos, + runner, + stateHashCode, + stateEquals, + producedChangelogMode); + } } @Override @SuppressWarnings("rawtypes") public Class getStreamOperatorClass(ClassLoader classLoader) { - return ProcessTableOperator.class; + if (tableSemantics.stream().anyMatch(RuntimeTableSemantics::hasSetSemantics)) { + return ProcessSetTableOperator.class; + } else { + return ProcessRowTableOperator.class; + } } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RepeatedRowData.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RepeatedRowData.java new file mode 100644 index 0000000000000..3cdfa23a00b83 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RepeatedRowData.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.process; + +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.types.RowKind; + +/** A row that repeats the columns of a given row by the given count. */ +public class RepeatedRowData implements RowData { + + private final int count; + private RowData row; + + public RepeatedRowData(int count) { + this.count = count; + } + + /** + * Replaces the {@link RowData} backing this {@link RepeatedRowData}. + * + *

This method replaces the backing rows in place and does not return a new object. This is + * done for performance reasons. + */ + public RepeatedRowData replace(RowData row) { + this.row = row; + return this; + } + + @Override + public int getArity() { + return row.getArity() * count; + } + + @Override + public RowKind getRowKind() { + return row.getRowKind(); + } + + @Override + public void setRowKind(RowKind kind) { + row.setRowKind(kind); + } + + @Override + public boolean isNullAt(int pos) { + return row.isNullAt(pos / count); + } + + @Override + public boolean getBoolean(int pos) { + return row.getBoolean(pos / count); + } + + @Override + public byte getByte(int pos) { + return row.getByte(pos / count); + } + + @Override + public short getShort(int pos) { + return row.getShort(pos / count); + } + + @Override + public int getInt(int pos) { + return row.getInt(pos / count); + } + + @Override + public long getLong(int pos) { + return row.getLong(pos / count); + } + + @Override + public float getFloat(int pos) { + return row.getFloat(pos / count); + } + + @Override + public double getDouble(int pos) { + return row.getDouble(pos / count); + } + + @Override + public StringData getString(int pos) { + return row.getString(pos / count); + } + + @Override + public DecimalData getDecimal(int pos, int precision, int scale) { + return row.getDecimal(pos / count, precision, scale); + } + + @Override + public TimestampData getTimestamp(int pos, int precision) { + return row.getTimestamp(pos / count, precision); + } + + @Override + public RawValueData getRawValue(int pos) { + return row.getRawValue(pos / count); + } + + @Override + public byte[] getBinary(int pos) { + return row.getBinary(pos / count); + } + + @Override + public ArrayData getArray(int pos) { + return row.getArray(pos / count); + } + + @Override + public MapData getMap(int pos) { + return row.getMap(pos / count); + } + + @Override + public RowData getRow(int pos, int numFields) { + return row.getRow(pos / count, numFields); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeTableSemantics.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeTableSemantics.java index 378f494b79e57..c5d4860d8f46b 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeTableSemantics.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/process/RuntimeTableSemantics.java @@ -24,7 +24,6 @@ import org.apache.flink.table.types.DataType; import java.io.Serializable; -import java.util.List; import java.util.Optional; /** @@ -108,11 +107,6 @@ public int timeColumn() { return timeColumn; } - @Override - public List coPartitionArgs() { - return List.of(); - } - @Override public Optional changelogMode() { return Optional.of(getChangelogMode());