diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 703e1ce9d5975..f28b5ede8a389 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,7 +7,7 @@ - Make sure that the pull request corresponds to a [JIRA issue](https://issues.apache.org/jira/projects/FLINK/issues). Exceptions are made for typos in JavaDoc or documentation files, which need no JIRA issue. - - Name the pull request in the form "[FLINK-1234] [component] Title of the pull request", where *FLINK-1234* should be replaced by the actual issue number. Skip *component* if you are unsure about which is the best component. + - Name the pull request in the form "[FLINK-XXXX] [component] Title of the pull request", where *FLINK-XXXX* should be replaced by the actual issue number. Skip *component* if you are unsure about which is the best component. Typo fixes that have no associated JIRA issue should be named following this pattern: `[hotfix] [docs] Fix typo in event time introduction` or `[hotfix] [javadocs] Expand JavaDoc for PuncuatedWatermarkGenerator`. - Fill out the template below to describe the changes contributed by the pull request. That will give reviewers the context they need to do the review. diff --git a/docs/_config.yml b/docs/_config.yml index 548278363a416..5a92bb99f8ac5 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -46,11 +46,12 @@ download_url: "http://flink.apache.org/downloads.html" # please use a protocol relative URL here baseurl: //ci.apache.org/projects/flink/flink-docs-release-1.4 -# Flag whether this is the latest stable version or not. If not, a warning -# will be printed pointing to the docs of the latest stable version. -is_latest: true +# Flag whether this is a stable version or not. Used for the quickstart page. is_stable: false +# Flag to indicate whether an outdated warning should be shown. +show_outdated_warning: false + previous_docs: 1.3: http://ci.apache.org/projects/flink/flink-docs-release-1.3 1.2: http://ci.apache.org/projects/flink/flink-docs-release-1.2 diff --git a/docs/_layouts/base.html b/docs/_layouts/base.html index d51451a3543f3..691670c47fdc0 100644 --- a/docs/_layouts/base.html +++ b/docs/_layouts/base.html @@ -54,12 +54,10 @@ - {% if site.is_stable %} - {% unless site.is_latest %} -
+ {% if site.show_outdated_warning %} +
This documentation is for an out-of-date version of Apache Flink. We recommend you use the latest stable version.
- {% endunless %} {% endif %} diff --git a/docs/_layouts/plain.html b/docs/_layouts/plain.html index 63a6681ed4afd..e991f78f82be3 100644 --- a/docs/_layouts/plain.html +++ b/docs/_layouts/plain.html @@ -53,5 +53,10 @@

{{ page.title }}{% if page.is_beta %} Beta{% endif %}

+{% if site.show_outdated_warning %} + +{% endif %} {{ content }} diff --git a/docs/dev/connectors/kafka.md b/docs/dev/connectors/kafka.md index 042ad11bc7705..f95c8c09ee477 100644 --- a/docs/dev/connectors/kafka.md +++ b/docs/dev/connectors/kafka.md @@ -475,8 +475,14 @@ are other constructor variants that allow providing the following: ### Kafka Producers and Fault Tolerance -With Flink's checkpointing enabled, the Flink Kafka Producer can provide -at-least-once delivery guarantees. +#### Kafka 0.8 + +Before 0.9 Kafka did not provide any mechanisms to guarantee at-least-once or exactly-once semantics. + +#### Kafka 0.9 and 0.10 + +With Flink's checkpointing enabled, the `FlinkKafkaProducer09` and `FlinkKafkaProducer010` +can provide at-least-once delivery guarantees. Besides enabling Flink's checkpointing, you should also configure the setter methods `setLogFailuresOnly(boolean)` and `setFlushOnCheckpoint(boolean)` appropriately, @@ -499,6 +505,19 @@ we recommend setting the number of retries to a higher value. **Note**: There is currently no transactional producer for Kafka, so Flink can not guarantee exactly-once delivery into a Kafka topic. +
+ Attention: Depending on your Kafka configuration, even after Kafka acknowledges + writes you can still experience data loss. In particular keep in mind the following Kafka settings: + + Default values for the above options can easily lead to data loss. Please refer to Kafka documentation + for more explanation. +
+ ## Using Kafka timestamps and Flink event time in Kafka 0.10 Since Apache Kafka 0.10+, Kafka's messages can carry [timestamps](https://cwiki.apache.org/confluence/display/KAFKA/KIP-32+-+Add+timestamps+to+Kafka+message), indicating diff --git a/docs/dev/connectors/kinesis.md b/docs/dev/connectors/kinesis.md index 5fbf24b8d3e7f..3ffe1c41e7922 100644 --- a/docs/dev/connectors/kinesis.md +++ b/docs/dev/connectors/kinesis.md @@ -256,23 +256,29 @@ consumer when calling this API can also be modified by using the other keys pref ## Kinesis Producer -The `FlinkKinesisProducer` is used for putting data from a Flink stream into a Kinesis stream. Note that the producer is not participating in -Flink's checkpointing and doesn't provide exactly-once processing guarantees. -Also, the Kinesis producer does not guarantee that records are written in order to the shards (See [here](https://github.com/awslabs/amazon-kinesis-producer/issues/23) and [here](http://docs.aws.amazon.com/kinesis/latest/APIReference/API_PutRecord.html#API_PutRecord_RequestSyntax) for more details). +The `FlinkKinesisProducer` uses [Kinesis Producer Library (KPL)](http://docs.aws.amazon.com/streams/latest/dev/developing-producers-with-kpl.html) to put data from a Flink stream into a Kinesis stream. + +Note that the producer is not participating in Flink's checkpointing and doesn't provide exactly-once processing guarantees. Also, the Kinesis producer does not guarantee that records are written in order to the shards (See [here](https://github.com/awslabs/amazon-kinesis-producer/issues/23) and [here](http://docs.aws.amazon.com/kinesis/latest/APIReference/API_PutRecord.html#API_PutRecord_RequestSyntax) for more details). In case of a failure or a resharding, data will be written again to Kinesis, leading to duplicates. This behavior is usually called "at-least-once" semantics. To put data into a Kinesis stream, make sure the stream is marked as "ACTIVE" in the AWS dashboard. -For the monitoring to work, the user accessing the stream needs access to the Cloud watch service. +For the monitoring to work, the user accessing the stream needs access to the CloudWatch service.
{% highlight java %} Properties producerConfig = new Properties(); -producerConfig.put(ProducerConfigConstants.AWS_REGION, "us-east-1"); -producerConfig.put(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); -producerConfig.put(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +// Required configs +producerConfig.put(AWSConfigConstants.AWS_REGION, "us-east-1"); +producerConfig.put(AWSConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); +producerConfig.put(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +// Optional configs +producerConfig.put("AggregationMaxCount", "4294967295"); +producerConfig.put("CollectionMaxCount", "1000"); +producerConfig.put("RecordTtl", "30000"); +producerConfig.put("RequestTimeout", "6000"); FlinkKinesisProducer kinesis = new FlinkKinesisProducer<>(new SimpleStringSchema(), producerConfig); kinesis.setFailOnError(true); @@ -286,9 +292,15 @@ simpleStringStream.addSink(kinesis);
{% highlight scala %} val producerConfig = new Properties(); -producerConfig.put(ProducerConfigConstants.AWS_REGION, "us-east-1"); -producerConfig.put(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); -producerConfig.put(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +// Required configs +producerConfig.put(AWSConfigConstants.AWS_REGION, "us-east-1"); +producerConfig.put(AWSConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); +producerConfig.put(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +// Optional KPL configs +producerConfig.put("AggregationMaxCount", "4294967295"); +producerConfig.put("CollectionMaxCount", "1000"); +producerConfig.put("RecordTtl", "30000"); +producerConfig.put("RequestTimeout", "6000"); val kinesis = new FlinkKinesisProducer[String](new SimpleStringSchema, producerConfig); kinesis.setFailOnError(true); @@ -301,15 +313,14 @@ simpleStringStream.addSink(kinesis);
-The above is a simple example of using the producer. Configuration for the producer with the mandatory configuration values is supplied with a `java.util.Properties` -instance as described above for the consumer. The example demonstrates producing a single Kinesis stream in the AWS region "us-east-1". +The above is a simple example of using the producer. To initialize `FlinkKinesisProducer`, users are required to pass in `AWS_REGION`, `AWS_ACCESS_KEY_ID`, and `AWS_SECRET_ACCESS_KEY` via a `java.util.Properties` instance. Users can also pass in KPL's configurations as optional parameters to customize the KPL underlying `FlinkKinesisProducer`. The full list of KPL configs and explanations can be found [here](https://github.com/awslabs/amazon-kinesis-producer/blob/master/java/amazon-kinesis-producer-sample/default_config.properties). The example demonstrates producing a single Kinesis stream in the AWS region "us-east-1". + +If users don't specify any KPL configs and values, `FlinkKinesisProducer` will use default config values of KPL, except `RateLimit`. `RateLimit` limits the maximum allowed put rate for a shard, as a percentage of the backend limits. KPL's default value is 150 but it makes KPL throw `RateLimitExceededException` too frequently and breaks Flink sink as a result. Thus `FlinkKinesisProducer` overrides KPL's default value to 100. Instead of a `SerializationSchema`, it also supports a `KinesisSerializationSchema`. The `KinesisSerializationSchema` allows to send the data to multiple streams. This is done using the `KinesisSerializationSchema.getTargetStream(T element)` method. Returning `null` there will instruct the producer to write the element to the default stream. Otherwise, the returned stream name is used. -Other optional configuration keys for the producer can be found in `ProducerConfigConstants`. - ## Using Non-AWS Kinesis Endpoints for Testing @@ -317,29 +328,29 @@ It is sometimes desirable to have Flink operate as a consumer or producer agains [Kinesalite](https://github.com/mhart/kinesalite); this is especially useful when performing functional testing of a Flink application. The AWS endpoint that would normally be inferred by the AWS region set in the Flink configuration must be overridden via a configuration property. -To override the AWS endpoint, taking the producer for example, set the `ProducerConfigConstants.AWS_ENDPOINT` property in the -Flink configuration, in addition to the `ProducerConfigConstants.AWS_REGION` required by Flink. Although the region is +To override the AWS endpoint, taking the producer for example, set the `AWSConfigConstants.AWS_ENDPOINT` property in the +Flink configuration, in addition to the `AWSConfigConstants.AWS_REGION` required by Flink. Although the region is required, it will not be used to determine the AWS endpoint URL. -The following example shows how one might supply the `ProducerConfigConstants.AWS_ENDPOINT` configuration property: +The following example shows how one might supply the `AWSConfigConstants.AWS_ENDPOINT` configuration property:
{% highlight java %} Properties producerConfig = new Properties(); -producerConfig.put(ProducerConfigConstants.AWS_REGION, "us-east-1"); -producerConfig.put(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); -producerConfig.put(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); -producerConfig.put(ProducerConfigConstants.AWS_ENDPOINT, "http://localhost:4567"); +producerConfig.put(AWSConfigConstants.AWS_REGION, "us-east-1"); +producerConfig.put(AWSConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); +producerConfig.put(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +producerConfig.put(AWSConfigConstants.AWS_ENDPOINT, "http://localhost:4567"); {% endhighlight %}
{% highlight scala %} val producerConfig = new Properties(); -producerConfig.put(ProducerConfigConstants.AWS_REGION, "us-east-1"); -producerConfig.put(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); -producerConfig.put(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); -producerConfig.put(ProducerConfigConstants.AWS_ENDPOINT, "http://localhost:4567"); +producerConfig.put(AWSConfigConstants.AWS_REGION, "us-east-1"); +producerConfig.put(AWSConfigConstants.AWS_ACCESS_KEY_ID, "aws_access_key_id"); +producerConfig.put(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "aws_secret_access_key"); +producerConfig.put(AWSConfigConstants.AWS_ENDPOINT, "http://localhost:4567"); {% endhighlight %}
diff --git a/docs/dev/libs/cep.md b/docs/dev/libs/cep.md index fef19678bd6c6..492f95fb8bd0a 100644 --- a/docs/dev/libs/cep.md +++ b/docs/dev/libs/cep.md @@ -163,8 +163,9 @@ In FlinkCEP, looping patterns can be specified using these methods: `pattern.one more occurrences of a given event (e.g. the `b+` mentioned previously); and `pattern.times(#ofTimes)`, for patterns that expect a specific number of occurrences of a given type of event, e.g. 4 `a`'s; and `pattern.times(#fromTimes, #toTimes)`, for patterns that expect a specific minimum number of occurrences and maximum number of occurrences of a given type of event, -e.g. 2-4 `a`s. All patterns, looping or not, can be made optional using the `pattern.optional()` method. For a pattern -named `start`, the following are valid quantifiers: +e.g. 2-4 `a`s. Looping patterns can be made greedy using the `pattern.greedy()` method and group pattern cannot be made greedy +currently. All patterns, looping or not, can be made optional using the `pattern.optional()` method. +For a pattern named `start`, the following are valid quantifiers:
@@ -178,14 +179,35 @@ named `start`, the following are valid quantifiers: // expecting 2, 3 or 4 occurrences start.times(2, 4); + // expecting 2, 3 or 4 occurrences and repeating as many as possible + start.times(2, 4).greedy(); + // expecting 0, 2, 3 or 4 occurrences start.times(2, 4).optional(); + // expecting 0, 2, 3 or 4 occurrences and repeating as many as possible + start.times(2, 4).optional().greedy(); + // expecting 1 or more occurrences start.oneOrMore(); + // expecting 1 or more occurrences and repeating as many as possible + start.oneOrMore().greedy(); + // expecting 0 or more occurrences start.oneOrMore().optional(); + + // expecting 0 or more occurrences and repeating as many as possible + start.oneOrMore().optional().greedy(); + + // expecting 2 or more occurrences + start.timesOrMore(2); + + // expecting 2 or more occurrences and repeating as many as possible + start.timesOrMore(2).greedy(); + + // expecting 0, 2 or more occurrences and repeating as many as possible + start.timesOrMore(2).optional().greedy(); {% endhighlight %}
@@ -200,14 +222,38 @@ named `start`, the following are valid quantifiers: // expecting 2, 3 or 4 occurrences start.times(2, 4); + // expecting 2, 3 or 4 occurrences and repeating as many as possible + start.times(2, 4).greedy(); + // expecting 0, 2, 3 or 4 occurrences start.times(2, 4).optional(); + // expecting 0, 2, 3 or 4 occurrences and repeating as many as possible + start.times(2, 4).optional().greedy(); + // expecting 1 or more occurrences start.oneOrMore() + // expecting 1 or more occurrences and repeating as many as possible + start.oneOrMore().greedy(); + // expecting 0 or more occurrences start.oneOrMore().optional() + + // expecting 0 or more occurrences and repeating as many as possible + start.oneOrMore().optional().greedy(); + + // expecting 2 or more occurrences + start.timesOrMore(2); + + // expecting 2 or more occurrences and repeating as many as possible + start.timesOrMore(2).greedy(); + + // expecting 0, 2 or more occurrences + start.timesOrMore(2).optional(); + + // expecting 0, 2 or more occurrences and repeating as many as possible + start.timesOrMore(2).optional().greedy(); {% endhighlight %}
@@ -476,6 +522,18 @@ pattern.subtype(SubEvent.class); pattern.oneOrMore(); {% endhighlight %} + + + timesOrMore(#times) + +

Specifies that this pattern expects at least #times occurrences + of a matching event.

+

By default a relaxed internal contiguity (between subsequent events) is used. For more info on + internal contiguity see consecutive.

+{% highlight java %} +pattern.timesOrMore(2); +{% endhighlight %} + times(#ofTimes) @@ -507,6 +565,16 @@ pattern.times(2, 4); aforementioned quantifiers.

{% highlight java %} pattern.oneOrMore().optional(); +{% endhighlight %} + + + + greedy() + +

Specifies that this pattern is greedy, i.e. it will repeat as many as possible. This is only applicable + to quantifiers and it does not support group pattern currently.

+{% highlight java %} +pattern.oneOrMore().greedy(); {% endhighlight %} @@ -647,6 +715,18 @@ pattern.oneOrMore() {% endhighlight %} + + timesOrMore(#times) + +

Specifies that this pattern expects at least #times occurrences + of a matching event.

+

By default a relaxed internal contiguity (between subsequent events) is used. For more info on + internal contiguity see consecutive.

+{% highlight scala %} +pattern.timesOrMore(2) +{% endhighlight %} + + times(#ofTimes) @@ -677,6 +757,16 @@ pattern.times(2, 4); aforementioned quantifiers.

{% highlight scala %} pattern.oneOrMore().optional() +{% endhighlight %} + + + + greedy() + +

Specifies that this pattern is greedy, i.e. it will repeat as many as possible. This is only applicable + to quantifiers and it does not support group pattern currently.

+{% highlight scala %} +pattern.oneOrMore().greedy() {% endhighlight %} @@ -1160,6 +1250,105 @@ pattern.within(Time.seconds(10))
+### After Match Skip Strategy + +For a given pattern, same event may be assigned to multiple successful matches. In order to control to how many matches an event will be assigned, we need to specify the skip strategy called `AfterMatchSkipStrategy`. +There're four types of skip strategies, listed as follows: + +* *NO_SKIP*: Every possible match will be emitted. +* *SKIP_PAST_LAST_EVENT*: Discards every partial match that contains event of the match. +* *SKIP_TO_FIRST*: Discards every partial match that contains event of the match preceding the first of *PatternName*. +* *SKIP_TO_LAST*: Discards every partial match that contains event of the match preceding the last of *PatternName*. + +Notice that when using *SKIP_TO_FIRST* and *SKIP_TO_LAST* skip strategy, a valid *PatternName* should also be specified. + +Let's take an example: For a given pattern `a b{2}` and a data stream `ab1, ab2, ab3, ab4, ab5, ab6`, the differences between these four skip strategies can be listed as follows: + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Skip StrategyResult Description
NO_SKIP + ab1 ab2 ab3
+ ab2 ab3 ab4
+ ab3 ab4 ab5
+ ab4 ab5 ab6
+
After found matching ab1 ab2 ab3, the match process will not discard any result.
SKIP_PAST_LAST_EVENT + ab1 ab2 ab3
+ ab4 ab5 ab6
+
After found matching ab1 ab2 ab3, the match process will discard all started partial matches.
SKIP_TO_FIRST[b] + ab1 ab2 ab3
+ ab2 ab3 ab4
+ ab3 ab4 ab5
+ ab4 ab5 ab6
+
After found matching ab1 ab2 ab3, the match process will discard all partial matches containing ab1, which is the only event that comes before the first b.
SKIP_TO_LAST[b] + ab1 ab2 ab3
+ ab3 ab4 ab5
+
After found matching ab1 ab2 ab3, the match process will discard all partial matches containing ab1 and ab2, which are events that comes before the last b.
+ +To specify which skip strategy to use, just create an `AfterMatchSkipStrategy` by calling: + + + + + + + + + + + + + + + + + + + + + +
FunctionDescription
AfterMatchSkipStrategy.noSkip()Create a NO_SKIP skip strategy
AfterMatchSkipStrategy.skipPastLastEvent()Create a SKIP_PAST_LAST_EVENT skip strategy
AfterMatchSkipStrategy.skipToFirst(patternName)Create a SKIP_TO_FIRST skip strategy with the referenced pattern name patternName
AfterMatchSkipStrategy.skipToLast(patternName)Create a SKIP_TO_LAST skip strategy with the referenced pattern name patternName
+ +Then apply the skip strategy to a pattern by calling: + +
+
+{% highlight java %} +AfterMatchSkipStrategy skipStrategy = ... +Pattern.begin("patternName", skipStrategy); +{% endhighlight %} +
+
+{% highlight scala %} +val skipStrategy = ... +Pattern.begin("patternName", skipStrategy) +{% endhighlight %} +
+
+ ## Detecting Patterns After specifying the pattern sequence you are looking for, it is time to apply it to your input stream to detect @@ -1279,63 +1468,75 @@ and `flatSelect` API calls allow a timeout handler to be specified. This timeout partial event sequence. The timeout handler receives all the events that have been matched so far by the pattern, and the timestamp when the timeout was detected. +In order to treat partial patterns, the `select` and `flatSelect` API calls offer an overloaded version which takes as +parameters + + * `PatternTimeoutFunction`/`PatternFlatTimeoutFunction` + * [OutputTag]({{ site.baseurl }}/dev/stream/side_output.html) for the side output in which the timeouted matches will be returned + * and the known `PatternSelectFunction`/`PatternFlatSelectFunction`. +
-In order to treat partial patterns, the `select` and `flatSelect` API calls offer an overloaded version which takes as -the first parameter a `PatternTimeoutFunction`/`PatternFlatTimeoutFunction` and as second parameter the known -`PatternSelectFunction`/`PatternFlatSelectFunction`. The return type of the timeout function can be different from the -select function. The timeout event and the select event are wrapped in `Either.Left` and `Either.Right` respectively -so that the resulting data stream is of type `org.apache.flink.types.Either`. -{% highlight java %} +~~~java PatternStream patternStream = CEP.pattern(input, pattern); -DataStream> result = patternStream.select( +OutputTag outputTag = new OutputTag("side-output"){}; + +SingleOutputStreamOperator result = patternStream.select( new PatternTimeoutFunction() {...}, + outputTag, new PatternSelectFunction() {...} ); -DataStream> flatResult = patternStream.flatSelect( +DataStream timeoutResult = result.getSideOutput(outputTag); + +SingleOutputStreamOperator flatResult = patternStream.flatSelect( new PatternFlatTimeoutFunction() {...}, + outputTag, new PatternFlatSelectFunction() {...} ); -{% endhighlight %} + +DataStream timeoutFlatResult = flatResult.getSideOutput(outputTag); +~~~
-In order to treat partial patterns, the `select` API call offers an overloaded version which takes as the first parameter a timeout function and as second parameter a selection function. -The timeout function is called with a map of string-event pairs of the partial match which has timed out and a long indicating when the timeout occurred. -The string is defined by the name of the pattern to which the event has been matched. -The timeout function returns exactly one result per call. -The return type of the timeout function can be different from the select function. -The timeout event and the select event are wrapped in `Left` and `Right` respectively so that the resulting data stream is of type `Either`. -{% highlight scala %} +~~~scala val patternStream: PatternStream[Event] = CEP.pattern(input, pattern) -DataStream[Either[TimeoutEvent, ComplexEvent]] result = patternStream.select{ +val outputTag = OutputTag[String]("side-output") + +val result: SingleOutputStreamOperator[ComplexEvent] = patternStream.select(outputTag){ (pattern: Map[String, Iterable[Event]], timestamp: Long) => TimeoutEvent() } { pattern: Map[String, Iterable[Event]] => ComplexEvent() } -{% endhighlight %} + +val timeoutResult: DataStream = result.getSideOutput(outputTag); +~~~ The `flatSelect` API call offers the same overloaded version which takes as the first parameter a timeout function and as second parameter a selection function. In contrast to the `select` functions, the `flatSelect` functions are called with a `Collector`. The collector can be used to emit an arbitrary number of events. -{% highlight scala %} +~~~scala val patternStream: PatternStream[Event] = CEP.pattern(input, pattern) -DataStream[Either[TimeoutEvent, ComplexEvent]] result = patternStream.flatSelect{ +val outputTag = OutputTag[String]("side-output") + +val result: SingleOutputStreamOperator[ComplexEvent] = patternStream.flatSelect(outputTag){ (pattern: Map[String, Iterable[Event]], timestamp: Long, out: Collector[TimeoutEvent]) => out.collect(TimeoutEvent()) } { (pattern: mutable.Map[String, Iterable[Event]], out: Collector[ComplexEvent]) => out.collect(ComplexEvent()) } -{% endhighlight %} + +val timeoutResult: DataStream = result.getSideOutput(outputTag); +~~~
diff --git a/docs/dev/stream/operators/asyncio.md b/docs/dev/stream/operators/asyncio.md index 1ea0792a2f7be..c5bafa16cdc9e 100644 --- a/docs/dev/stream/operators/asyncio.md +++ b/docs/dev/stream/operators/asyncio.md @@ -74,7 +74,7 @@ Assuming one has an asynchronous client for the target database, three parts are with asynchronous I/O against the database: - An implementation of `AsyncFunction` that dispatches the requests - - A *callback* that takes the result of the operation and hands it to the `AsyncCollector` + - A *callback* that takes the result of the operation and hands it to the `ResultFuture` - Applying the async I/O operation on a DataStream as a transformation The following code example illustrates the basic pattern: @@ -104,16 +104,16 @@ class AsyncDatabaseRequest extends RichAsyncFunction> asyncCollector) throws Exception { + public void asyncInvoke(final String str, final ResultFuture> resultFuture) throws Exception { // issue the asynchronous request, receive a future for result Future resultFuture = client.query(str); // set the callback to be executed once the request by the client is complete - // the callback simply forwards the result to the collector + // the callback simply forwards the result to the result future resultFuture.thenAccept( (String result) -> { - asyncCollector.collect(Collections.singleton(new Tuple2<>(str, result))); + resultFuture.complete(Collections.singleton(new Tuple2<>(str, result))); }); } @@ -142,15 +142,15 @@ class AsyncDatabaseRequest extends AsyncFunction[String, (String, String)] { implicit lazy val executor: ExecutionContext = ExecutionContext.fromExecutor(Executors.directExecutor()) - override def asyncInvoke(str: String, asyncCollector: AsyncCollector[(String, String)]): Unit = { + override def asyncInvoke(str: String, resultFutre: ResultFuture[(String, String)]): Unit = { // issue the asynchronous request, receive a future for the result val resultFuture: Future[String] = client.query(str) // set the callback to be executed once the request by the client is complete - // the callback simply forwards the result to the collector + // the callback simply forwards the result to the result future resultFuture.onSuccess { - case result: String => asyncCollector.collect(Iterable((str, result))); + case result: String => resultFuture.complete(Iterable((str, result))); } } } @@ -166,8 +166,8 @@ val resultStream: DataStream[(String, String)] = -**Important note**: The `AsyncCollector` is completed with the first call of `AsyncCollector.collect`. -All subsequent `collect` calls will be ignored. +**Important note**: The `ResultFuture` is completed with the first call of `ResultFuture.complete`. +All subsequent `complete` calls will be ignored. The following two parameters control the asynchronous operations: @@ -229,7 +229,7 @@ asynchronous requests in checkpoints and restores/re-triggers the requests when For implementations with *Futures* that have an *Executor* (or *ExecutionContext* in Scala) for callbacks, we suggets to use a `DirectExecutor`, because the callback typically does minimal work, and a `DirectExecutor` avoids an additional thread-to-thread handover overhead. The callback typically only hands -the result to the `AsyncCollector`, which adds it to the output buffer. From there, the heavy logic that includes record emission and interaction +the result to the `ResultFuture`, which adds it to the output buffer. From there, the heavy logic that includes record emission and interaction with the checkpoint bookkeepting happens in a dedicated thread-pool anyways. A `DirectExecutor` can be obtained via `org.apache.flink.runtime.concurrent.Executors.directExecutor()` or diff --git a/docs/dev/stream/operators/windows.md b/docs/dev/stream/operators/windows.md index c2d557f444cfa..012d5313742bd 100644 --- a/docs/dev/stream/operators/windows.md +++ b/docs/dev/stream/operators/windows.md @@ -111,6 +111,11 @@ windows) assign elements to windows based on time, which can either be processin time. Please take a look at our section on [event time]({{ site.baseurl }}/dev/event_time.html) to learn about the difference between processing time and event time and how timestamps and watermarks are generated. +Time-based windows have a *start timestamp* (inclusive) and an *end timestamp* (exclusive) +that together describe the size of the window. In code, Flink uses `TimeWindow` when working with +time-based windows which has methods for querying the start- and end-timestamp and also an +additional method `maxTimestamp()` that returns the largest allowed timestamp for a given windows. + In the following, we show how Flink's pre-defined window assigners work and how they are used in a DataStream program. The following figures visualize the workings of each assigner. The purple circles represent elements of the stream, which are partitioned by some key (in this case *user 1*, *user 2* and *user 3*). @@ -460,118 +465,15 @@ The above example appends all input `Long` values to an initially empty `String` Attention `fold()` cannot be used with session windows or other mergeable windows. -### WindowFunction - The Generic Case - -A `WindowFunction` gets an `Iterable` containing all the elements of the window and provides -the most flexibility of all window functions. This comes -at the cost of performance and resource consumption, because elements cannot be incrementally -aggregated but instead need to be buffered internally until the window is considered ready for processing. - -The signature of a `WindowFunction` looks as follows: - -
-
-{% highlight java %} -public interface WindowFunction extends Function, Serializable { - - /** - * Evaluates the window and outputs none or several elements. - * - * @param key The key for which this window is evaluated. - * @param window The window that is being evaluated. - * @param input The elements in the window being evaluated. - * @param out A collector for emitting elements. - * - * @throws Exception The function may throw exceptions to fail the program and trigger recovery. - */ - void apply(KEY key, W window, Iterable input, Collector out) throws Exception; -} -{% endhighlight %} -
- -
-{% highlight scala %} -trait WindowFunction[IN, OUT, KEY, W <: Window] extends Function with Serializable { - - /** - * Evaluates the window and outputs none or several elements. - * - * @param key The key for which this window is evaluated. - * @param window The window that is being evaluated. - * @param input The elements in the window being evaluated. - * @param out A collector for emitting elements. - * @throws Exception The function may throw exceptions to fail the program and trigger recovery. - */ - def apply(key: KEY, window: W, input: Iterable[IN], out: Collector[OUT]) -} -{% endhighlight %} -
-
- -A `WindowFunction` can be defined and used like this: - -
-
-{% highlight java %} -DataStream> input = ...; - -input - .keyBy() - .window() - .apply(new MyWindowFunction()); - -/* ... */ - -public class MyWindowFunction implements WindowFunction, String, String, TimeWindow> { - - void apply(String key, TimeWindow window, Iterable> input, Collector out) { - long count = 0; - for (Tuple in: input) { - count++; - } - out.collect("Window: " + window + "count: " + count); - } -} - -{% endhighlight %} -
- -
-{% highlight scala %} -val input: DataStream[(String, Long)] = ... - -input - .keyBy() - .window() - .apply(new MyWindowFunction()) - -/* ... */ - -class MyWindowFunction extends WindowFunction[(String, Long), String, String, TimeWindow] { - - def apply(key: String, window: TimeWindow, input: Iterable[(String, Long)], out: Collector[String]): () = { - var count = 0L - for (in <- input) { - count = count + 1 - } - out.collect(s"Window $window count: $count") - } -} -{% endhighlight %} -
-
- -The example shows a `WindowFunction` to count the elements in a window. In addition, the window function adds information about the window to the output. - -Attention Note that using `WindowFunction` for simple aggregates such as count is quite inefficient. The next section shows how a `ReduceFunction` can be combined with a `WindowFunction` to get both incremental aggregation and the added information of a `WindowFunction`. - ### ProcessWindowFunction -In places where a `WindowFunction` can be used you can also use a `ProcessWindowFunction`. This -is very similar to `WindowFunction`, except that the interface allows to query more information -about the context in which the window evaluation happens. +A ProcessWindowFunction gets an Iterable containing all the elements of the window, and a Context +object with access to time and state information, which enables it to provide more flexibility than +other window functions. This comes at the cost of performance and resource consumption, because +elements cannot be incrementally aggregated but instead need to be buffered internally until the +window is considered ready for processing. -This is the `ProcessWindowFunction` interface: +The signature of `ProcessWindowFunction` looks as follows:
@@ -594,15 +496,35 @@ public abstract class ProcessWindowFunction impl Iterable elements, Collector out) throws Exception; - /** - * The context holding window metadata - */ - public abstract class Context { - /** - * @return The window that is being evaluated. - */ - public abstract W window(); - } + /** + * The context holding window metadata. + */ + public abstract class Context implements java.io.Serializable { + /** + * Returns the window that is being evaluated. + */ + public abstract W window(); + + /** Returns the current processing time. */ + public abstract long currentProcessingTime(); + + /** Returns the current event-time watermark. */ + public abstract long currentWatermark(); + + /** + * State accessor for per-key and per-window state. + * + *

NOTE:If you use per-window state you have to ensure that you clean it up + * by implementing {@link ProcessWindowFunction#clear(Context)}. + */ + public abstract KeyedStateStore windowState(); + + /** + * State accessor for per-key global state. + */ + public abstract KeyedStateStore globalState(); + } + } {% endhighlight %}

@@ -620,7 +542,6 @@ abstract class ProcessWindowFunction[IN, OUT, KEY, W <: Window] extends Function * @param out A collector for emitting elements. * @throws Exception The function may throw exceptions to fail the program and trigger recovery. */ - @throws[Exception] def process( key: KEY, context: Context, @@ -632,16 +553,42 @@ abstract class ProcessWindowFunction[IN, OUT, KEY, W <: Window] extends Function */ abstract class Context { /** - * @return The window that is being evaluated. + * Returns the window that is being evaluated. */ def window: W + + /** + * Returns the current processing time. + */ + def currentProcessingTime: Long + + /** + * Returns the current event-time watermark. + */ + def currentWatermark: Long + + /** + * State accessor for per-key and per-window state. + */ + def windowState: KeyedStateStore + + /** + * State accessor for per-key global state. + */ + def globalState: KeyedStateStore } + } {% endhighlight %}
-It can be used like this: +Note The `key` parameter is the key that is extracted +via the `KeySelector` that was specified for the `keyBy()` invocation. In case of tuple-index +keys or string-field references this key type is always `Tuple` and you have to manually cast +it to a tuple of the correct size to extract the key fields. + +A `ProcessWindowFunction` can be defined and used like this:
@@ -652,6 +599,20 @@ input .keyBy() .window() .process(new MyProcessWindowFunction()); + +/* ... */ + +public class MyProcessWindowFunction implements ProcessWindowFunction, String, String, TimeWindow> { + + void process(String key, Context context, Iterable> input, Collector out) { + long count = 0; + for (Tuple in: input) { + count++; + } + out.collect("Window: " + context.window() + "count: " + count); + } +} + {% endhighlight %}
@@ -663,25 +624,42 @@ input .keyBy() .window() .process(new MyProcessWindowFunction()) + +/* ... */ + +class MyWindowFunction extends ProcessWindowFunction[(String, Long), String, String, TimeWindow] { + + def apply(key: String, context: Context, input: Iterable[(String, Long)], out: Collector[String]): () = { + var count = 0L + for (in <- input) { + count = count + 1 + } + out.collect(s"Window ${context.window} count: $count") + } +} {% endhighlight %}
-### WindowFunction with Incremental Aggregation +The example shows a `ProcessWindowFunction` that counts the elements in a window. In addition, the window function adds information about the window to the output. + +Attention Note that using `ProcessWindowFunction` for simple aggregates such as count is quite inefficient. The next section shows how a `ReduceFunction` can be combined with a `ProcessWindowFunction` to get both incremental aggregation and the added information of a `ProcessWindowFunction`. + +### ProcessWindowFunction with Incremental Aggregation -A `WindowFunction` can be combined with either a `ReduceFunction` or a `FoldFunction` to +A `ProcessWindowFunction` can be combined with either a `ReduceFunction` or a `FoldFunction` to incrementally aggregate elements as they arrive in the window. -When the window is closed, the `WindowFunction` will be provided with the aggregated result. +When the window is closed, the `ProcessWindowFunction` will be provided with the aggregated result. This allows to incrementally compute windows while having access to the -additional window meta information of the `WindowFunction`. +additional window meta information of the `ProcessWindowFunction`. -Note You can also `ProcessWindowFunction` instead of -`WindowFunction` for incremental window aggregation. +Note You can also the legacy `WindowFunction` instead of +`ProcessWindowFunction` for incremental window aggregation. #### Incremental Window Aggregation with FoldFunction The following example shows how an incremental `FoldFunction` can be combined with -a `WindowFunction` to extract the number of events in the window and return also +a `ProcessWindowFunction` to extract the number of events in the window and return also the key and end time of the window.
@@ -692,7 +670,7 @@ DataStream input = ...; input .keyBy() .timeWindow() - .fold(new Tuple3("",0L, 0), new MyFoldFunction(), new MyWindowFunction()) + .fold(new Tuple3("",0L, 0), new MyFoldFunction(), new MyProcessWindowFunction()) // Function definitions @@ -706,15 +684,15 @@ private static class MyFoldFunction } } -private static class MyWindowFunction - implements WindowFunction, Tuple3, String, TimeWindow> { +private static class MyProcessWindowFunction + implements ProcessWindowFunction, Tuple3, String, TimeWindow> { - public void apply(String key, - TimeWindow window, + public void process(String key, + Context context, Iterable> counts, Collector> out) { Integer count = counts.iterator().next().getField(2); - out.collect(new Tuple3(key, window.getEnd(),count)); + out.collect(new Tuple3(key, context.window().getEnd(),count)); } } @@ -759,7 +737,7 @@ DataStream input = ...; input .keyBy() .timeWindow() - .reduce(new MyReduceFunction(), new MyWindowFunction()); + .reduce(new MyReduceFunction(), new MyProcessWindowFunction()); // Function definitions @@ -770,11 +748,11 @@ private static class MyReduceFunction implements ReduceFunction { } } -private static class MyWindowFunction - implements WindowFunction, String, TimeWindow> { +private static class MyProcessWindowFunction + implements ProcessWindowFunction, String, TimeWindow> { public void apply(String key, - TimeWindow window, + Context context, Iterable minReadings, Collector> out) { SensorReading min = minReadings.iterator().next(); @@ -808,6 +786,80 @@ input
+### WindowFunction (Legacy) + +In some places where a `ProcessWindowFunction` can be used you can also use a `WindowFunction`. This +is an older version of `ProcessWindowFunction` that provides less contextual information and does +not have some advances features, such as per-window keyed state. This interface will be deprecated +at some point. + +The signature of a `WindowFunction` looks as follows: + +
+
+{% highlight java %} +public interface WindowFunction extends Function, Serializable { + + /** + * Evaluates the window and outputs none or several elements. + * + * @param key The key for which this window is evaluated. + * @param window The window that is being evaluated. + * @param input The elements in the window being evaluated. + * @param out A collector for emitting elements. + * + * @throws Exception The function may throw exceptions to fail the program and trigger recovery. + */ + void apply(KEY key, W window, Iterable input, Collector out) throws Exception; +} +{% endhighlight %} +
+ +
+{% highlight scala %} +trait WindowFunction[IN, OUT, KEY, W <: Window] extends Function with Serializable { + + /** + * Evaluates the window and outputs none or several elements. + * + * @param key The key for which this window is evaluated. + * @param window The window that is being evaluated. + * @param input The elements in the window being evaluated. + * @param out A collector for emitting elements. + * @throws Exception The function may throw exceptions to fail the program and trigger recovery. + */ + def apply(key: KEY, window: W, input: Iterable[IN], out: Collector[OUT]) +} +{% endhighlight %} +
+
+ +It can be used like this: + +
+
+{% highlight java %} +DataStream> input = ...; + +input + .keyBy() + .window() + .apply(new MyWindowFunction()); +{% endhighlight %} +
+ +
+{% highlight scala %} +val input: DataStream[(String, Long)] = ... + +input + .keyBy() + .window() + .apply(new MyWindowFunction()) +{% endhighlight %} +
+
+ ## Triggers A `Trigger` determines when a window (as formed by the *window assigner*) is ready to be @@ -1028,6 +1080,80 @@ as they may "bridge" the gap between two pre-existing, unmerged windows. Attention You should be aware that the elements emitted by a late firing should be treated as updated results of a previous computation, i.e., your data stream will contain multiple results for the same computation. Depending on your application, you need to take these duplicated results into account or deduplicate them. +## Working with window results + +The result of a windowed operation is again a `DataStream`, no information about the windowed +operations is retained in the result elements so if you want to keep meta-information about the +window you have to manually encode that information in the result elements in your +`ProcessWindowFunction`. The only relevant information that is set on the result elements is the +element *timestamp*. This is set to the maximum allowed timestamp of the processed window, which +is *end timestamp - 1*, since the window-end timestamp is exclusive. Note that this is true for both +event-time windows and processing-time windows. i.e. after a windowed operations elements always +have a timestamp, but this can be an event-time timestamp or a processing-time timestamp. For +processing-time windows this has no special implications but for event-time windows this together +with how watermarks interact with windows enables +[consecutive windowed operations](#consecutive-windowed-operations) with the same window sizes. We +will cover this after taking a look how watermarks interact with windows. + +### Interaction of watermarks and windows + +Before continuing in this section you might want to take a look at our section about +[event time and watermarks]({{ site.baseurl }}/dev/event_time.html). + +When watermarks arrive at the window operator this triggers two things: + - the watermark triggers computation of all windows where the maximum timestamp (which is + *end-timestamp - 1*) is smaller than the new watermark + - the watermark is forwarded (as is) to downstream operations + +Intuitively, a watermark "flushes" out any windows that would be considered late in downstream +operations once they receive that watermark. + +### Consecutive windowed operations + +As mentioned before, the way the timestamp of windowed results is computed and how watermarks +interact with windows allows stringing together consecutive windowed operations. This can be useful +when you want to do two consecutive windowed operations where you want to use different keys but +still want elements from the same upstream window to end up in the same downstream window. Consider +this example: + +
+
+{% highlight java %} +DataStream input = ...; + +DataStream resultsPerKey = input + .keyBy() + .window(TumblingEventTimeWindows.of(Time.seconds(5))) + .reduce(new Summer()); + +DataStream globalResults = resultsPerKey + .windowAll(TumblingEventTimeWindows.of(Time.seconds(5))) + .process(new TopKWindowFunction()); + +{% endhighlight %} +
+ +
+{% highlight scala %} +val input: DataStream[Int] = ... + +val resultsPerKey = input + .keyBy() + .window(TumblingEventTimeWindows.of(Time.seconds(5))) + .reduce(new Summer()) + +val globalResults = resultsPerKey + .windowAll(TumblingEventTimeWindows.of(Time.seconds(5))) + .process(new TopKWindowFunction()) +{% endhighlight %} +
+
+ +In this example, the results for time window `[0, 5)` from the first operation will also end up in +time window `[0, 5)` in the subsequent windowed operation. This allows calculating a sum per key +and then calculating the top-k elements within the same window in the second operation. +and then calculating the top-k elements within the same window in the second operation. + ## Useful state size considerations Windows can be defined over long periods of time (such as days, weeks, or months) and therefore accumulate very large state. There are a couple of rules to keep in mind when estimating the storage requirements of your windowing computation: diff --git a/docs/dev/stream/testing.md b/docs/dev/stream/testing.md new file mode 100644 index 0000000000000..44f5cfd5d9a12 --- /dev/null +++ b/docs/dev/stream/testing.md @@ -0,0 +1,263 @@ +--- +title: "Testing" +nav-parent_id: streaming +nav-id: testing +nav-pos: 99 +--- + + +This page briefly discusses how to test a Flink application in your IDE or a local environment. + +* This will be replaced by the TOC +{:toc} + +## Unit testing + +Usually, one can assume that Flink produces correct results outside of a user-defined `Function`. Therefore, it is recommended to test `Function` classes that contain the main business logic with unit tests as much as possible. + +For example if one implements the following `ReduceFunction`: + +
+
+{% highlight java %} +public class SumReduce implements ReduceFunction { + + @Override + public Long reduce(Long value1, Long value2) throws Exception { + return value1 + value2; + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +class SumReduce extends ReduceFunction[Long] { + + override def reduce(value1: java.lang.Long, value2: java.lang.Long): java.lang.Long = { + value1 + value2 + } +} +{% endhighlight %} +
+
+ +It is very easy to unit test it with your favorite framework by passing suitable arguments and verify the output: + +
+
+{% highlight java %} +public class SumReduceTest { + + @Test + public void testSum() throws Exception { + // intiantiate your function + SumReduce sumReduce = new SumReduce(); + + // call the methods that you have implemented + assertEquals(42L, sumReduce.reduce(40L, 2L)); + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +class SumReduceTest extends FlatSpec with Matchers { + + "SumReduce" should "add values" in { + // intiantiate your function + val sumReduce: SumReduce = new SumReduce() + + // call the methods that you have implemented + sumReduce.reduce(40L, 2L) should be (42L) + } +} +{% endhighlight %} +
+
+ +## Integration testing + +In order to end-to-end test Flink streaming pipelines, you can also write integration tests that are executed against a local Flink mini cluster. + +In order to do so add the test dependency `flink-test-utils`: + +{% highlight xml %} + + org.apache.flink + flink-test-utils{{ site.scala_version_suffix }} + {{site.version }} + +{% endhighlight %} + +For example, if you want to test the following `MapFunction`: + +
+
+{% highlight java %} +public class MultiplyByTwo implements MapFunction { + + @Override + public Long map(Long value) throws Exception { + return value * 2; + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +class MultiplyByTwo extends MapFunction[Long, Long] { + + override def map(value: Long): Long = { + value * 2 + } +} +{% endhighlight %} +
+
+ +You could write the following integration test: + +
+
+{% highlight java %} +public class ExampleIntegrationTest extends StreamingMultipleProgramsTestBase { + + @Test + public void testMultiply() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + // configure your test environment + env.setParallelism(1); + + // values are collected in a static variable + CollectSink.values.clear(); + + // create a stream of custom elements and apply transformations + env.fromElements(1L, 21L, 22L) + .map(new MultiplyByTwo()) + .addSink(new CollectSink()); + + // execute + env.execute(); + + // verify your results + assertEquals(Lists.newArrayList(2L, 42L, 44L), CollectSink.values); + } + + // create a testing sink + private static class CollectSink implements SinkFunction { + + // must be static + public static final List values = new ArrayList<>(); + + @Override + public synchronized void invoke(Long value) throws Exception { + values.add(value); + } + } +} +{% endhighlight %} +
+ +
+{% highlight scala %} +class ExampleIntegrationTest extends StreamingMultipleProgramsTestBase { + + @Test + def testMultiply(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + + // configure your test environment + env.setParallelism(1) + + // values are collected in a static variable + CollectSink.values.clear() + + // create a stream of custom elements and apply transformations + env + .fromElements(1L, 21L, 22L) + .map(new MultiplyByTwo()) + .addSink(new CollectSink()) + + // execute + env.execute() + + // verify your results + assertEquals(Lists.newArrayList(2L, 42L, 44L), CollectSink.values) + } +} + +// create a testing sink +class CollectSink extends SinkFunction[Long] { + + override def invoke(value: java.lang.Long): Unit = { + synchronized { + values.add(value) + } + } +} + +object CollectSink { + + // must be static + val values: List[Long] = new ArrayList() +} +{% endhighlight %} +
+
+ +The static variable in `CollectSink` is used here because Flink serializes all operators before distributing them across a cluster. +Communicating with operators instantiated by a local Flink mini cluster via static variables is one way around this issue. +Alternatively, you could for example write the data to files in a temporary directory with your test sink. +You can also implement your own custom sources for emitting watermarks. + +## Testing checkpointing and state handling + +One way to test state handling is to enable checkpointing in integration tests. + +You can do that by configuring your `StreamExecutionEnvironment` in the test: + +
+
+{% highlight java %} +env.enableCheckpointing(500); +env.setRestartStrategy(RestartStrategies.fixedDelayRestart(3, 100)); +{% endhighlight %} +
+ +
+{% highlight scala %} +env.enableCheckpointing(500); +env.setRestartStrategy(RestartStrategies.fixedDelayRestart(3, 100)); +{% endhighlight %} +
+
+ +And for example adding to your Flink application an identity mapper operator that will throw an exception +once every `1000ms`. However writing such test could be tricky because of time dependencies between the actions. + +Another approach is to write a unit test using the Flink internal testing utility `AbstractStreamOperatorTestHarness` from the `flink-streaming-java` module. + +For an example of how to do that please have a look at the `org.apache.flink.streaming.runtime.operators.windowing.WindowOperatorTest` also in the `flink-streaming-java` module. + +Be aware that `AbstractStreamOperatorTestHarness` is currently not a part of public API and can be subject to change. diff --git a/docs/dev/table/udfs.md b/docs/dev/table/udfs.md index 55f58b6032c68..6c9bc1af76d37 100644 --- a/docs/dev/table/udfs.md +++ b/docs/dev/table/udfs.md @@ -24,15 +24,18 @@ under the License. User-defined functions are an important feature, because they significantly extend the expressiveness of queries. -**TODO** - * This will be replaced by the TOC {:toc} Register User-Defined Functions ------------------------------- +In most cases, a user-defined function must be registered before it can be used in an query. It is not necessary to register functions for the Scala Table API. + +Functions are registered at the `TableEnvironment` by calling a `registerFunction()` method. When a user-defined function is registered, it is inserted into the function catalog of the `TableEnvironment` such that the Table API or SQL parser can recognize and properly translate it. + +Please find detailed examples of how to register and how to call each type of user-defined function +(`ScalarFunction`, `TableFunction`, and `AggregateFunction`) in the following sub-sessions. -**TODO** {% top %} @@ -97,8 +100,6 @@ tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable"); By default the result type of an evaluation method is determined by Flink's type extraction facilities. This is sufficient for basic types or simple POJOs but might be wrong for more complex, custom, or composite types. In these cases `TypeInformation` of the result type can be manually defined by overriding `ScalarFunction#getResultType()`. -Internally, the Table API and SQL code generation works with primitive values as much as possible. If a user-defined scalar function should not introduce much overhead through object creation/casting during runtime, it is recommended to declare parameters and result types as primitive types instead of their boxed classes. `Types.DATE` and `Types.TIME` can also be represented as `int`. `Types.TIMESTAMP` can be represented as `long`. - The following example shows an advanced example which takes the internal timestamp representation and also returns the internal timestamp representation as a long value. By overriding `ScalarFunction#getResultType()` we define that the returned long value should be interpreted as a `Types.TIMESTAMP` by the code generation.
@@ -264,10 +265,405 @@ class CustomTypeSplit extends TableFunction[Row] { {% top %} + Aggregation Functions --------------------- -**TODO** +User-Defined Aggregate Functions (UDAGGs) aggregate a table (one ore more rows with one or more attributes) to a scalar value. + +
+UDAGG mechanism +
+ +The above figure shows an example of an aggregation. Assume you have a table that contains data about beverages. The table consists of three columns, `id`, `name` and `price` and 5 rows. Imagine you need to find the highest price of all beverages in the table, i.e., perform a `max()` aggregation. You would need to check each of the 5 rows and the result would be a single numeric value. + +User-defined aggregation functions are implemented by extending the `AggregateFunction` class. An `AggregateFunction` works as follows. First, it needs an `accumulator`, which is the data structure that holds the intermediate result of the aggregation. An empty accumulator is created by calling the `createAccumulator()` method of the `AggregateFunction`. Subsequently, the `accumulate()` method of the function is called for each input row to update the accumulator. Once all rows have been processed, the `getValue()` method of the function is called to compute and return the final result. + +**The following methods are mandatory for each `AggregateFunction`:** + +- `createAccumulator()` +- `accumulate()` +- `getValue()` + +Flink’s type extraction facilities can fail to identify complex data types, e.g., if they are not basic types or simple POJOs. So similar to `ScalarFunction` and `TableFunction`, `AggregateFunction` provides methods to specify the `TypeInformation` of the result type (through + `AggregateFunction#getResultType()`) and the type of the accumulator (through `AggregateFunction#getAccumulatorType()`). + +Besides the above methods, there are a few contracted methods that can be +optionally implemented. While some of these methods allow the system more efficient query execution, others are mandatory for certain use cases. For instance, the `merge()` method is mandatory if the aggregation function should be applied in the context of a session group window (the accumulators of two session windows need to be joined when a row is observed that "connects" them). + +**The following methods of `AggregateFunction` are required depending on the use case:** + +- `retract()` is required for aggregations on bounded `OVER` windows. +- `merge()` is required for many batch aggreagtions and session window aggregations. +- `resetAccumulator()` is required for many batch aggregations. + +All methods of `AggregateFunction` must be declared as `public`, not `static` and named exactly as the names mentioned above. The methods `createAccumulator`, `getValue`, `getResultType`, and `getAccumulatorType` are defined in the `AggregateFunction` abstract class, while others are contracted methods. In order to define a table function, one has to extend the base class `org.apache.flink.table.functions.AggregateFunction` and implement one (or more) `accumulate` methods. + +Detailed documentation for all methods of `AggregateFunction` is given below. + +
+
+{% highlight java %} +/** + * Base class for aggregation functions. + * + * @param the type of the aggregation result + * @param the type of the aggregation accumulator. The accumulator is used to keep the + * aggregated values which are needed to compute an aggregation result. + * AggregateFunction represents its state using accumulator, thereby the state of the + * AggregateFunction must be put into the accumulator. + */ +public abstract class AggregateFunction extends UserDefinedFunction { + + /** + * Creates and init the Accumulator for this [[AggregateFunction]]. + * + * @return the accumulator with the initial value + */ + public ACC createAccumulator(); // MANDATORY + + /** Processes the input values and update the provided accumulator instance. The method + * accumulate can be overloaded with different custom types and arguments. An AggregateFunction + * requires at least one accumulate() method. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + */ + public void accumulate(ACC accumulator, [user defined inputs]); // MANDATORY + + /** + * Retracts the input values from the accumulator instance. The current design assumes the + * inputs are the values that have been previously accumulated. The method retract can be + * overloaded with different custom types and arguments. This function must be implemented for + * datastream bounded over aggregate. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + */ + public void retract(ACC accumulator, [user defined inputs]); // OPTIONAL + + /** + * Merges a group of accumulator instances into one accumulator instance. This function must be + * implemented for datastream session window grouping aggregate and dataset grouping aggregate. + * + * @param accumulator the accumulator which will keep the merged aggregate results. It should + * be noted that the accumulator may contain the previous aggregated + * results. Therefore user should not replace or clean this instance in the + * custom merge method. + * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be + * merged. + */ + public void merge(ACC accumulator, java.lang.Iterable its); // OPTIONAL + + /** + * Called every time when an aggregation result should be materialized. + * The returned value could be either an early and incomplete result + * (periodically emitted as data arrive) or the final result of the + * aggregation. + * + * @param accumulator the accumulator which contains the current + * aggregated results + * @return the aggregation result + */ + public T getValue(ACC accumulator); // MANDATORY + + /** + * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for + * dataset grouping aggregate. + * + * @param accumulator the accumulator which needs to be reset + */ + public void resetAccumulator(ACC accumulator); // OPTIONAL + + /** + * Returns true if this AggregateFunction can only be applied in an OVER window. + * + * @return true if the AggregateFunction requires an OVER window, false otherwise. + */ + public Boolean requiresOver = false; // PRE-DEFINED + + /** + * Returns the TypeInformation of the AggregateFunction's result. + * + * @return The TypeInformation of the AggregateFunction's result or null if the result type + * should be automatically inferred. + */ + public TypeInformation getResultType = null; // PRE-DEFINED + + /** + * Returns the TypeInformation of the AggregateFunction's accumulator. + * + * @return The TypeInformation of the AggregateFunction's accumulator or null if the + * accumulator type should be automatically inferred. + */ + public TypeInformation getAccumulatorType = null; // PRE-DEFINED +} +{% endhighlight %} +
+ +
+{% highlight scala %} +/** + * Base class for aggregation functions. + * + * @tparam T the type of the aggregation result + * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the + * aggregated values which are needed to compute an aggregation result. + * AggregateFunction represents its state using accumulator, thereby the state of the + * AggregateFunction must be put into the accumulator. + */ +abstract class AggregateFunction[T, ACC] extends UserDefinedFunction { + /** + * Creates and init the Accumulator for this [[AggregateFunction]]. + * + * @return the accumulator with the initial value + */ + def createAccumulator(): ACC // MANDATORY + + /** + * Processes the input values and update the provided accumulator instance. The method + * accumulate can be overloaded with different custom types and arguments. An AggregateFunction + * requires at least one accumulate() method. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + */ + def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY + + /** + * Retracts the input values from the accumulator instance. The current design assumes the + * inputs are the values that have been previously accumulated. The method retract can be + * overloaded with different custom types and arguments. This function must be implemented for + * datastream bounded over aggregate. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + */ + def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL + + /** + * Merges a group of accumulator instances into one accumulator instance. This function must be + * implemented for datastream session window grouping aggregate and dataset grouping aggregate. + * + * @param accumulator the accumulator which will keep the merged aggregate results. It should + * be noted that the accumulator may contain the previous aggregated + * results. Therefore user should not replace or clean this instance in the + * custom merge method. + * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be + * merged. + */ + def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL + + /** + * Called every time when an aggregation result should be materialized. + * The returned value could be either an early and incomplete result + * (periodically emitted as data arrive) or the final result of the + * aggregation. + * + * @param accumulator the accumulator which contains the current + * aggregated results + * @return the aggregation result + */ + def getValue(accumulator: ACC): T // MANDATORY + + h/** + * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for + * dataset grouping aggregate. + * + * @param accumulator the accumulator which needs to be reset + */ + def resetAccumulator(accumulator: ACC): Unit // OPTIONAL + + /** + * Returns true if this AggregateFunction can only be applied in an OVER window. + * + * @return true if the AggregateFunction requires an OVER window, false otherwise. + */ + def requiresOver: Boolean = false // PRE-DEFINED + + /** + * Returns the TypeInformation of the AggregateFunction's result. + * + * @return The TypeInformation of the AggregateFunction's result or null if the result type + * should be automatically inferred. + */ + def getResultType: TypeInformation[T] = null // PRE-DEFINED + + /** + * Returns the TypeInformation of the AggregateFunction's accumulator. + * + * @return The TypeInformation of the AggregateFunction's accumulator or null if the + * accumulator type should be automatically inferred. + */ + def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED +} +{% endhighlight %} +
+
+ + +The following example shows how to + +- define an `AggregateFunction` that calculates the weighted average on a given column, +- register the function in the `TableEnvironment`, and +- use the function in a query. + +To calculate an weighted average value, the accumulator needs to store the weighted sum and count of all the data that has been accumulated. In our example we define a class `WeightedAvgAccum` to be the accumulator. Accumulators are automatically backup-ed by Flink's checkpointing mechanism and restored in case of a failure to ensure exactly-once semantics. + +The `accumulate()` method of our `WeightedAvg` `AggregateFunction` has three inputs. The first one is the `WeightedAvgAccum` accumulator, the other two are user-defined inputs: input value `ivalue` and weight of the input `iweight`. Although the `retract()`, `merge()`, and `resetAccumulator()` methods are not mandatory for most aggregation types, we provide them below as examples. Please note that we used Java primitive types and defined `getResultType()` and `getAccumulatorType()` methods in the Scala example because Flink type extraction does not work very well for Scala types. + +
+
+{% highlight java %} +/** + * Accumulator for WeightedAvg. + */ +public static class WeightedAvgAccum { + public long sum = 0; + public int count = 0; +} + +/** + * Weighted Average user-defined aggregate function. + */ +public static class WeightedAvg extends AggregateFunction { + + @Override + public WeightedAvgAccum createAccumulator() { + return new WeightedAvgAccum(); + } + + @Override + public long getValue(WeightedAvgAccum acc) { + if (acc.count == 0) { + return null; + } else { + return acc.sum / acc.count; + } + } + + public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) { + acc.sum += iValue * iWeight; + acc.count += iWeight; + } + + public void retract(WeightedAvgAccum acc, long iValue, int iWeight) { + acc.sum -= iValue * iWeight; + acc.count -= iWeight; + } + + public void merge(WeightedAvgAccum acc, Iterable it) { + Iterator iter = it.iterator(); + while (iter.hasNext()) { + WeightedAvgAccum a = iter.next(); + acc.count += a.count; + acc.sum += a.sum; + } + } + + public void resetAccumulator(WeightedAvgAccum acc) { + acc.count = 0; + acc.sum = 0L; + } +} + +// register function +StreamTableEnvironment tEnv = ... +tEnv.registerFunction("wAvg", new WeightedAvg()); + +// use function +tEnv.sql("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user"); + +{% endhighlight %} +
+ +
+{% highlight scala %} +import java.lang.{Long => JLong, Integer => JInteger} +import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.TupleTypeInfo +import org.apache.flink.table.functions.AggregateFunction + +/** + * Accumulator for WeightedAvg. + */ +class WeightedAvgAccum extends JTuple1[JLong, JInteger] { + sum = 0L + count = 0 +} + +/** + * Weighted Average user-defined aggregate function. + */ +class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] { + + override def createAccumulator(): WeightedAvgAccum = { + new WeightedAvgAccum + } + + override def getValue(acc: WeightedAvgAccum): JLong = { + if (acc.count == 0) { + null + } else { + acc.sum / acc.count + } + } + + def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { + acc.sum += iValue * iWeight + acc.count += iWeight + } + + def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { + acc.sum -= iValue * iWeight + acc.count -= iWeight + } + + def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = { + val iter = it.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.count += a.count + acc.sum += a.sum + } + } + + def resetAccumulator(acc: WeightedAvgAccum): Unit = { + acc.count = 0 + acc.sum = 0L + } + + override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = { + new TupleTypeInfo(classOf[WeightedAvgAccum], + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO) + } + + override def getResultType: TypeInformation[JLong] = + BasicTypeInfo.LONG_TYPE_INFO +} + +// register function +val tEnv: StreamTableEnvironment = ??? +tEnv.registerFunction("wAvg", new WeightedAvg()) + +// use function +tEnv.sql("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user") + +{% endhighlight %} +
+
+ + +{% top %} + +Best Practices for Implementing UDFs +------------------------------------ + +The Table API and SQL code generation internally tries to work with primitive values as much as possible. A user-defined function can introduce much overhead through object creation, casting, and (un)boxing. Therefore, it is highly recommended to declare parameters and result types as primitive types instead of their boxed classes. `Types.DATE` and `Types.TIME` can also be represented as `int`. `Types.TIMESTAMP` can be represented as `long`. + +We recommended that user-defined functions should be written by Java instead of Scala as Scala types pose a challenge for Flink's type extractor. {% top %} diff --git a/docs/fig/udagg-mechanism.png b/docs/fig/udagg-mechanism.png new file mode 100644 index 0000000000000..043196fc19cd6 Binary files /dev/null and b/docs/fig/udagg-mechanism.png differ diff --git a/docs/monitoring/debugging_event_time.md b/docs/monitoring/debugging_event_time.md index edc7dd0fcb4d1..8355b62e7677d 100644 --- a/docs/monitoring/debugging_event_time.md +++ b/docs/monitoring/debugging_event_time.md @@ -31,11 +31,7 @@ Flink's [event time]({{ site.baseurl }}/dev/event_time.html) and watermark suppo out-of-order events. However, it's harder to understand what exactly is going on because the progress of time is tracked within the system. -There are plans (see [FLINK-3427](https://issues.apache.org/jira/browse/FLINK-3427)) to show the current low watermark -for each operator in the Flink web interface. - -Until this feature is implemented the current low watermark for each task can be accessed through the -[metrics system]({{ site.baseurl }}/monitoring/metrics.html). +Low watermarks of each task can be accessed through Flink web interface or [metrics system]({{ site.baseurl }}/monitoring/metrics.html). Each Task in Flink exposes a metric called `currentLowWatermark` that represents the lowest watermark received by this task. This long value represents the "current event time". diff --git a/docs/monitoring/metrics.md b/docs/monitoring/metrics.md index b8f4acce4150d..b71f4cf3f07cd 100644 --- a/docs/monitoring/metrics.md +++ b/docs/monitoring/metrics.md @@ -229,7 +229,7 @@ public class MyMapper extends RichMapFunction { ## Scope Every metric is assigned an identifier under which it will be reported that is based on 3 components: the user-provided name when registering the metric, an optional user-defined scope and a system-provided scope. -For example, if `A.B` is the sytem scope, `C.D` the user scope and `E` the name, then the identifier for the metric will be `A.B.C.D.E`. +For example, if `A.B` is the system scope, `C.D` the user scope and `E` the name, then the identifier for the metric will be `A.B.C.D.E`. You can configure which delimiter to use for the identifier (default: `.`) by setting the `metrics.scope.delimiter` key in `conf/flink-conf.yaml`. diff --git a/docs/ops/config.md b/docs/ops/config.md index 4138b4d270ef3..e0b9d4db714f5 100644 --- a/docs/ops/config.md +++ b/docs/ops/config.md @@ -196,6 +196,13 @@ will be used under the directory specified by jobmanager.web.tmpdir. - `blob.storage.directory`: Directory for storing blobs (such as user JARs) on the TaskManagers. +- `blob.service.cleanup.interval`: Cleanup interval (in seconds) of the blob caches (DEFAULT: 1 hour). +Whenever a job is not referenced at the cache anymore, we set a TTL and let the periodic cleanup task +(executed every `blob.service.cleanup.interval` seconds) remove its blob files after this TTL has passed. +This means that a blob will be retained at most 2 * `blob.service.cleanup.interval` seconds after +not being referenced anymore. Therefore, a recovery still has the chance to use existing files rather +than to download them again. + - `blob.server.port`: Port definition for the blob server (serving user JARs) on the TaskManagers. By default the port is set to 0, which means that the operating system is picking an ephemeral port. Flink also accepts a list of ports ("50100,50101"), ranges ("50100-50200") or a combination of both. It is recommended to set a range of ports to avoid collisions when multiple JobManagers are running on the same machine. - `blob.service.ssl.enabled`: Flag to enable ssl for the blob client/server communication. This is applicable only when the global ssl flag security.ssl.enabled is set to true (DEFAULT: true). diff --git a/docs/quickstart/scala_api_quickstart.md b/docs/quickstart/scala_api_quickstart.md index abf6021a8f901..9e563ed25eb2c 100644 --- a/docs/quickstart/scala_api_quickstart.md +++ b/docs/quickstart/scala_api_quickstart.md @@ -41,25 +41,20 @@ These templates help you to set up the project structure and to create the initi ### Create Project +You can scafold a new project via either of the following two methods: +
-
- {% highlight bash %} - $ g8 tillrohrmann/flink-project - {% endhighlight %} - This will create a Flink project in the specified project directory from the flink-project template. - If you haven't installed giter8, then please follow this installation guide. -
-
+
{% highlight bash %} - $ git clone https://github.com/tillrohrmann/flink-project.git + $ sbt new tillrohrmann/flink-project.g8 {% endhighlight %} - This will create the Flink project in the directory flink-project. + This will will prompt you for a couple of parameters (project name, flink version...) and then create a Flink project from the flink-project template. + You need sbt >= 0.13.13 to execute this command. You can follow this installation guide to obtain it if necessary.
{% highlight bash %} diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java index 7bc26550037f2..c8a236e52dea6 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java @@ -382,7 +382,7 @@ else if (prog.isUsingInteractiveMode()) { // invoke main method prog.invokeInteractiveModeForExecution(); if (lastJobExecutionResult == null && factory.getLastEnvCreated() == null) { - throw new ProgramMissingJobException(); + throw new ProgramMissingJobException("The program didn't contain a Flink job."); } if (isDetached()) { // in detached mode, we execute the whole user code to extract the Flink job, afterwards we run it here diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/ProgramMissingJobException.java b/flink-clients/src/main/java/org/apache/flink/client/program/ProgramMissingJobException.java index 43d608b43aa2f..c2b57178eafbd 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/ProgramMissingJobException.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/ProgramMissingJobException.java @@ -18,12 +18,18 @@ package org.apache.flink.client.program; +import org.apache.flink.util.FlinkException; + /** * Exception used to indicate that no job was executed during the invocation of a Flink program. */ -public class ProgramMissingJobException extends Exception { +public class ProgramMissingJobException extends FlinkException { /** * Serial version UID for serialization interoperability. */ private static final long serialVersionUID = -1964276369605091101L; + + public ProgramMissingJobException(String message) { + super(message); + } } diff --git a/flink-connectors/flink-avro/src/main/java/org/apache/flink/api/java/io/AvroOutputFormat.java b/flink-connectors/flink-avro/src/main/java/org/apache/flink/api/java/io/AvroOutputFormat.java index aed40bf4d85f0..5da8f75dc9a8b 100644 --- a/flink-connectors/flink-avro/src/main/java/org/apache/flink/api/java/io/AvroOutputFormat.java +++ b/flink-connectors/flink-avro/src/main/java/org/apache/flink/api/java/io/AvroOutputFormat.java @@ -25,6 +25,7 @@ import org.apache.avro.Schema; import org.apache.avro.file.CodecFactory; import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.io.DatumWriter; import org.apache.avro.reflect.ReflectData; import org.apache.avro.reflect.ReflectDatumWriter; @@ -134,6 +135,12 @@ public void open(int taskNumber, int numTasks) throws IOException { } catch (InstantiationException | IllegalAccessException e) { throw new RuntimeException(e.getMessage()); } + } else if (org.apache.avro.generic.GenericRecord.class.isAssignableFrom(avroValueType)) { + if (userDefinedSchema == null) { + throw new IllegalStateException("Schema must be set when using Generic Record"); + } + datumWriter = new GenericDatumWriter(userDefinedSchema); + schema = userDefinedSchema; } else { datumWriter = new ReflectDatumWriter(avroValueType); schema = ReflectData.get().getSchema(avroValueType); diff --git a/flink-connectors/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroOutputFormatTest.java b/flink-connectors/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroOutputFormatTest.java index 87334a74ef97d..71ebd785bbde1 100644 --- a/flink-connectors/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroOutputFormatTest.java +++ b/flink-connectors/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroOutputFormatTest.java @@ -24,6 +24,10 @@ import org.apache.flink.core.fs.Path; import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; import org.junit.Test; import org.mockito.internal.util.reflection.Whitebox; @@ -35,6 +39,7 @@ import java.io.ObjectOutputStream; import static org.apache.flink.api.java.io.AvroOutputFormat.Codec; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -151,4 +156,42 @@ private void output(final AvroOutputFormat outputFormat) throws IOExceptio } outputFormat.close(); } + + @Test + public void testGenericRecord() throws IOException { + final Path outputPath = new Path(File.createTempFile("avro-output-file", "generic.avro").getAbsolutePath()); + final AvroOutputFormat outputFormat = new AvroOutputFormat<>(outputPath, GenericRecord.class); + Schema schema = new Schema.Parser().parse("{\"type\":\"record\", \"name\":\"user\", \"fields\": [{\"name\":\"user_name\", \"type\":\"string\"}, {\"name\":\"favorite_number\", \"type\":\"int\"}, {\"name\":\"favorite_color\", \"type\":\"string\"}]}"); + outputFormat.setWriteMode(FileSystem.WriteMode.OVERWRITE); + outputFormat.setSchema(schema); + output(outputFormat, schema); + + GenericDatumReader reader = new GenericDatumReader<>(schema); + DataFileReader dataFileReader = new DataFileReader<>(new File(outputPath.getPath()), reader); + + while (dataFileReader.hasNext()) { + GenericRecord record = dataFileReader.next(); + assertEquals(record.get("user_name").toString(), "testUser"); + assertEquals(record.get("favorite_number"), 1); + assertEquals(record.get("favorite_color").toString(), "blue"); + } + + //cleanup + FileSystem fs = FileSystem.getLocalFileSystem(); + fs.delete(outputPath, false); + + } + + private void output(final AvroOutputFormat outputFormat, Schema schema) throws IOException { + outputFormat.configure(new Configuration()); + outputFormat.open(1, 1); + for (int i = 0; i < 100; i++) { + GenericRecord record = new GenericData.Record(schema); + record.put("user_name", "testUser"); + record.put("favorite_number", 1); + record.put("favorite_color", "blue"); + outputFormat.writeRecord(record); + } + outputFormat.close(); + } } diff --git a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/batch/connectors/cassandra/example/BatchExample.java b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/batch/connectors/cassandra/example/BatchExample.java index af21f2d4a5eab..20a020802630a 100644 --- a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/batch/connectors/cassandra/example/BatchExample.java +++ b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/batch/connectors/cassandra/example/BatchExample.java @@ -34,8 +34,9 @@ /** * This is an example showing the to use the Cassandra Input-/OutputFormats in the Batch API. * - *

The example assumes that a table exists in a local cassandra database, according to the following query: - * CREATE TABLE test.batches (number int, strings text, PRIMARY KEY(number, strings)); + *

The example assumes that a table exists in a local cassandra database, according to the following queries: + * CREATE KEYSPACE IF NOT EXISTS test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': ‘1’}; + * CREATE TABLE IF NOT EXISTS test.batches (number int, strings text, PRIMARY KEY(number, strings)); */ public class BatchExample { private static final String INSERT_QUERY = "INSERT INTO test.batches (number, strings) VALUES (?,?);"; diff --git a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraPojoSinkExample.java b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraPojoSinkExample.java index a38b73b72c991..01cd6e8048f27 100644 --- a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraPojoSinkExample.java +++ b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraPojoSinkExample.java @@ -32,7 +32,8 @@ * *

Pojo's have to be annotated with datastax annotations to work with this sink. * - *

The example assumes that a table exists in a local cassandra database, according to the following query: + *

The example assumes that a table exists in a local cassandra database, according to the following queries: + * CREATE KEYSPACE IF NOT EXISTS test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': ‘1’}; * CREATE TABLE IF NOT EXISTS test.message(body txt PRIMARY KEY) */ public class CassandraPojoSinkExample { diff --git a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleSinkExample.java b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleSinkExample.java index ce2326f4e569c..72013d5141a3f 100644 --- a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleSinkExample.java +++ b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleSinkExample.java @@ -31,7 +31,8 @@ /** * This is an example showing the to use the Tuple Cassandra Sink in the Streaming API. * - *

The example assumes that a table exists in a local cassandra database, according to the following query: + *

The example assumes that a table exists in a local cassandra database, according to the following queries: + * CREATE KEYSPACE IF NOT EXISTS test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': ‘1’}; * CREATE TABLE IF NOT EXISTS test.writetuple(element1 text PRIMARY KEY, element2 int) */ public class CassandraTupleSinkExample { diff --git a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleWriteAheadSinkExample.java b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleWriteAheadSinkExample.java index 38618feaf7771..8cab311be7f90 100644 --- a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleWriteAheadSinkExample.java +++ b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/example/CassandraTupleWriteAheadSinkExample.java @@ -36,7 +36,8 @@ /** * This is an example showing the to use the Cassandra Sink (with write-ahead log) in the Streaming API. * - *

The example assumes that a table exists in a local cassandra database, according to the following query: + *

The example assumes that a table exists in a local cassandra database, according to the following queries: + * CREATE KEYSPACE IF NOT EXISTS example WITH replication = {'class': 'SimpleStrategy', 'replication_factor': ‘1’}; * CREATE TABLE example.values (id text, count int, PRIMARY KEY(id)); * *

Important things to note are that checkpointing is enabled, a StateBackend is set and the enableWriteAheadLog() call diff --git a/flink-connectors/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/streaming/connectors/elasticsearch/util/RetryRejectedExecutionFailureHandler.java b/flink-connectors/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/streaming/connectors/elasticsearch/util/RetryRejectedExecutionFailureHandler.java index 9380959934cbb..370625714e96a 100644 --- a/flink-connectors/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/streaming/connectors/elasticsearch/util/RetryRejectedExecutionFailureHandler.java +++ b/flink-connectors/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/streaming/connectors/elasticsearch/util/RetryRejectedExecutionFailureHandler.java @@ -36,7 +36,7 @@ public class RetryRejectedExecutionFailureHandler implements ActionRequestFailur @Override public void onFailure(ActionRequest action, Throwable failure, int restStatusCode, RequestIndexer indexer) throws Throwable { - if (ExceptionUtils.containsThrowable(failure, EsRejectedExecutionException.class)) { + if (ExceptionUtils.findThrowable(failure, EsRejectedExecutionException.class).isPresent()) { indexer.add(action); } else { // rethrow all other failures diff --git a/flink-connectors/flink-connector-elasticsearch2/pom.xml b/flink-connectors/flink-connector-elasticsearch2/pom.xml index 1f342bc2b2044..1dbe1144cfae2 100644 --- a/flink-connectors/flink-connector-elasticsearch2/pom.xml +++ b/flink-connectors/flink-connector-elasticsearch2/pom.xml @@ -91,4 +91,35 @@ under the License. + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + + + + com.google.guava:guava + + + + + com.google + org.apache.flink.elasticsearch.shaded.com.google + + com.google.protobuf.** + com.google.inject.** + + + + + + + + + + diff --git a/flink-connectors/flink-connector-eventhubs/pom.xml b/flink-connectors/flink-connector-eventhubs/pom.xml new file mode 100644 index 0000000000000..8ee8bb45ed427 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/pom.xml @@ -0,0 +1,158 @@ + + + + + + flink-connectors + org.apache.flink + 1.4-SNAPSHOT + + 4.0.0 + + flink-connector-azureeventhubs_${scala.binary.version} + flink-connector-azureeventhubs + + jar + + + + + + com.microsoft.azure + azure-eventhubs + 0.14.0 + provided + + + + org.apache.flink + flink-connector-kafka-base_${scala.binary.version} + ${project.version} + + + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${project.version} + provided + + + + org.apache.flink + flink-table_${scala.binary.version} + ${project.version} + provided + + true + + + + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${project.version} + test + test-jar + + + + org.apache.flink + flink-tests_${scala.binary.version} + ${project.version} + test-jar + test + + + + org.apache.flink + flink-test-utils_${scala.binary.version} + ${project.version} + test + + + + org.apache.flink + flink-runtime_${scala.binary.version} + ${project.version} + test-jar + test + + + + org.apache.flink + flink-metrics-jmx + ${project.version} + test + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + **/KafkaTestEnvironmentImpl* + + + + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-test-sources + + test-jar-no-fork + + + + **/KafkaTestEnvironmentImpl* + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + 1 + -Xms256m -Xmx1000m -Dlog4j.configuration=${log4j.configuration} -Dmvn.forkNumber=${surefire.forkNumber} -XX:-UseGCOverheadLimit + + + + + + + diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubConsumer.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubConsumer.java new file mode 100644 index 0000000000000..e5598597de73f --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubConsumer.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.ClosureCleaner; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Counter; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.connectors.eventhubs.internals.EventFetcher; +import org.apache.flink.streaming.connectors.eventhubs.internals.EventhubPartition; +import org.apache.flink.streaming.util.serialization.DeserializationSchema; +import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema; +import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchemaWrapper; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.SerializedValue; + +import com.microsoft.azure.eventhubs.PartitionReceiver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +/** + * Created by jozh on 5/22/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + * This class is used to create datastream from event hub + */ + +public class FlinkEventHubConsumer extends RichParallelSourceFunction implements + CheckpointedFunction, + ResultTypeQueryable { + private static final long serialVersionUID = -3247976470793561346L; + protected static final Logger LOGGER = LoggerFactory.getLogger(FlinkEventHubConsumer.class); + protected static final String DEFAULTOFFSETSTATENAME = "flink.eventhub.offset"; + + protected final KeyedDeserializationSchema deserializer; + protected final Properties eventhubsProps; + protected final int partitionCount; + protected List> subscribedPartitions; + protected final String defaultEventhubInitOffset; + + private Map subscribedPartitionsToStartOffsets; + private SerializedValue> periodicWatermarkAssigner; + private SerializedValue> punctuatedWatermarkAssigner; + private transient ListState> offsetsStateForCheckpoint; + private transient volatile EventFetcher eventhubFetcher; + private transient volatile HashMap restoreToOffset; + private volatile boolean running = true; + + private Counter receivedCount; + + public FlinkEventHubConsumer(Properties eventhubsProps, DeserializationSchema deserializer){ + this(eventhubsProps, new KeyedDeserializationSchemaWrapper(deserializer)); + } + + public FlinkEventHubConsumer(Properties eventhubsProps, KeyedDeserializationSchema deserializer){ + Preconditions.checkNotNull(eventhubsProps); + Preconditions.checkNotNull(deserializer); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.policyname")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.policykey")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.namespace")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.name")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.partition.count")); + + this.eventhubsProps = eventhubsProps; + this.partitionCount = Integer.parseInt(eventhubsProps.getProperty("eventhubs.partition.count")); + this.deserializer = deserializer; + + String userDefinedOffset = eventhubsProps.getProperty("eventhubs.auto.offset"); + if (userDefinedOffset != null && userDefinedOffset.toLowerCase().compareTo("lastest") == 0){ + this.defaultEventhubInitOffset = PartitionReceiver.END_OF_STREAM; + } + else { + this.defaultEventhubInitOffset = PartitionReceiver.START_OF_STREAM; + } + + if (this.partitionCount <= 0){ + throw new IllegalArgumentException("eventhubs.partition.count must greater than 0"); + } + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + receivedCount = getRuntimeContext().getMetricGroup().addGroup(this.getClass().getName()).counter("received_event_count"); + + List eventhubPartitions = this.getAllEventhubPartitions(); + this.subscribedPartitionsToStartOffsets = new HashMap<>(eventhubPartitions.size()); + + if (this.restoreToOffset != null){ + for (EventhubPartition partition : eventhubPartitions){ + if (this.restoreToOffset.containsKey(partition)){ + this.subscribedPartitionsToStartOffsets.put(partition, restoreToOffset.get(partition)); + } + } + + LOGGER.info("Consumer subtask {} will start reading {} partitions with offsets in restored state: {}", + getRuntimeContext().getIndexOfThisSubtask(), + this.subscribedPartitionsToStartOffsets.size(), + this.subscribedPartitionsToStartOffsets); + } + else { + //If there is no restored state. Then all partitions to read from start, the offset is "-1". In the + //future eventhub supports specify offset, we modify here + //We assign partition to each subTask in round robin mode + int numParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + int indexofThisSubtask = getRuntimeContext().getIndexOfThisSubtask(); + for (int i = 0; i < eventhubPartitions.size(); i++) { + if (i % numParallelSubtasks == indexofThisSubtask) { + this.subscribedPartitionsToStartOffsets.put(eventhubPartitions.get(i), defaultEventhubInitOffset); + } + } + + LOGGER.info("Consumer subtask {} will start reading {} partitions with offsets: {}", + getRuntimeContext().getIndexOfThisSubtask(), + this.subscribedPartitionsToStartOffsets.size(), + this.subscribedPartitionsToStartOffsets); + } + } + + @Override + public void run(SourceContext sourceContext) throws Exception { + if (this.subscribedPartitionsToStartOffsets == null || this.subscribedPartitionsToStartOffsets.size() == 0){ + throw new Exception("The partitions were not set for the consumer"); + } + + StreamingRuntimeContext runtimeContext = (StreamingRuntimeContext) getRuntimeContext(); + + if (!this.subscribedPartitionsToStartOffsets.isEmpty()){ + final EventFetcher fetcher = new EventFetcher(sourceContext, + subscribedPartitionsToStartOffsets, + deserializer, + periodicWatermarkAssigner, + punctuatedWatermarkAssigner, + runtimeContext.getProcessingTimeService(), + runtimeContext.getExecutionConfig().getAutoWatermarkInterval(), + runtimeContext.getUserCodeClassLoader(), + runtimeContext.getTaskNameWithSubtasks(), + eventhubsProps, + false, + receivedCount); + + this.eventhubFetcher = fetcher; + if (!this.running){ + return; + } + + this.eventhubFetcher.runFetchLoop(); + } + else { + sourceContext.emitWatermark(new Watermark(Long.MAX_VALUE)); + + final Object waitObj = new Object(); + while (this.running){ + try { + synchronized (waitObj){ + waitObj.wait(); + } + } + catch (InterruptedException ex){ + if (this.running){ + Thread.currentThread().interrupt(); + } + } + } + } + } + + @Override + public void close() throws Exception { + try { + this.cancel(); + } + finally { + super.close(); + } + } + + @Override + public void cancel() { + this.running = false; + + if (this.eventhubFetcher != null){ + this.eventhubFetcher.cancel(); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + if (!this.running){ + LOGGER.info("Consumer subtask {}: snapshotState() is called on the closed source", getRuntimeContext().getIndexOfThisSubtask()); + return; + } + + this.offsetsStateForCheckpoint.clear(); + final EventFetcher fetcher = this.eventhubFetcher; + if (fetcher == null){ + for (Map.Entry subscribedPartition : this.subscribedPartitionsToStartOffsets.entrySet()){ + this.offsetsStateForCheckpoint.add(Tuple2.of(subscribedPartition.getKey(), subscribedPartition.getValue())); + } + } + else { + HashMap currentOffsets = fetcher.snapshotCurrentState(); + for (Map.Entry subscribedPartition : currentOffsets.entrySet()){ + this.offsetsStateForCheckpoint.add(Tuple2.of(subscribedPartition.getKey(), subscribedPartition.getValue())); + } + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + LOGGER.info("Consumer subtask {}:Start init eventhub offset state", getRuntimeContext().getIndexOfThisSubtask()); + OperatorStateStore stateStore = context.getOperatorStateStore(); + /* this.offsetsStateForCheckpoint = stateStore + .getListState(new ListStateDescriptor>(DEFAULT_OFFSET_STATE_NAME, TypeInformation.of(new TypeHint>(){}))); +*/ + this.offsetsStateForCheckpoint = stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + if (context.isRestored()){ + if (this.restoreToOffset == null){ + this.restoreToOffset = new HashMap<>(); + for (Tuple2 offsetState : this.offsetsStateForCheckpoint.get()){ + this.restoreToOffset.put(offsetState.f0, offsetState.f1); + } + + LOGGER.info("Consumer subtask {}:Eventhub offset state is restored from checkpoint", getRuntimeContext().getIndexOfThisSubtask()); + } + else if (this.restoreToOffset.isEmpty()){ + this.restoreToOffset = null; + } + } + else { + LOGGER.info("Consumer subtask {}:No restore state for flink-eventhub-consumer", getRuntimeContext().getIndexOfThisSubtask()); + } + } + + //deprecated for CheckpointedRestoring + public void restoreState(HashMap eventhubPartitionOffsets) throws Exception { + LOGGER.info("{} (taskIdx={}) restoring offsets from an older version.", + getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask()); + + this.restoreToOffset = eventhubPartitionOffsets; + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("{} (taskIdx={}) restored offsets from an older Flink version: {}", + getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), eventhubPartitionOffsets); + } + } + + @Override + public TypeInformation getProducedType() { + return this.deserializer.getProducedType(); + } + + public FlinkEventHubConsumer assignTimestampsAndWatermarks(AssignerWithPunctuatedWatermarks assigner) { + Preconditions.checkNotNull(assigner); + + if (this.periodicWatermarkAssigner != null) { + throw new IllegalStateException("A periodic watermark emitter has already been set."); + } + try { + ClosureCleaner.clean(assigner, true); + this.punctuatedWatermarkAssigner = new SerializedValue<>(assigner); + return this; + } catch (Exception e) { + throw new IllegalArgumentException("The given assigner is not serializable", e); + } + } + + public FlinkEventHubConsumer assignTimestampsAndWatermarks(AssignerWithPeriodicWatermarks assigner) { + Preconditions.checkNotNull(assigner); + + if (this.punctuatedWatermarkAssigner != null) { + throw new IllegalStateException("A punctuated watermark emitter has already been set."); + } + try { + ClosureCleaner.clean(assigner, true); + this.periodicWatermarkAssigner = new SerializedValue<>(assigner); + return this; + } catch (Exception e) { + throw new IllegalArgumentException("The given assigner is not serializable", e); + } + } + + private List getAllEventhubPartitions() { + List partitions = new ArrayList<>(); + for (int i = 0; i < this.partitionCount; i++){ + partitions.add(new EventhubPartition(this.eventhubsProps, i)); + } + + LOGGER.info("Consumer subtask {}:Create {} eventhub partitions info", getRuntimeContext().getIndexOfThisSubtask(), this.partitionCount); + return partitions; + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubProducer.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubProducer.java new file mode 100644 index 0000000000000..344e1f1a5ad59 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/FlinkEventHubProducer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Counter; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.connectors.eventhubs.internals.EventhubProducerThread; +import org.apache.flink.streaming.connectors.eventhubs.internals.ProducerCache; +import org.apache.flink.streaming.util.serialization.SerializationSchema; +import org.apache.flink.util.Preconditions; + +import com.microsoft.azure.eventhubs.EventData; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** + * Created by jozh on 6/20/2017. + * Will support customize parttion in next version + */ +public class FlinkEventHubProducer extends RichSinkFunction implements CheckpointedFunction { + + private static final Logger logger = LoggerFactory.getLogger(FlinkEventHubProducer.class); + private static final long serialVersionUID = -7486455932880508035L; + private final SerializationSchema schema; + private final ProducerCache cache; + private final Properties eventhubsProps; + private EventhubProducerThread producerThread; + private Counter prepareSendCount; + private Counter commitSendCount; + + public FlinkEventHubProducer(SerializationSchema serializationSchema, Properties eventhubsProps){ + Preconditions.checkNotNull(serializationSchema); + Preconditions.checkNotNull(eventhubsProps); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.policyname")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.policykey")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.namespace")); + Preconditions.checkNotNull(eventhubsProps.getProperty("eventhubs.name")); + + this.schema = serializationSchema; + this.eventhubsProps = eventhubsProps; + + int capacity = eventhubsProps.getProperty("eventhubs.cache.capacity") == null + ? ProducerCache.DEFAULTCAPACITY : Integer.parseInt(eventhubsProps.getProperty("eventhubs.cache.capacity")); + + long timeout = eventhubsProps.getProperty("eventhubs.cache.timeout") == null + ? ProducerCache.DEFAULTTIMEOUTMILLISECOND : Long.parseLong(eventhubsProps.getProperty("eventhubs.cache.timeout")); + + this.cache = new ProducerCache(capacity, timeout); + + logger.info("Created eventhub producer for namespace: {}, name: {}", + eventhubsProps.getProperty("eventhubs.namespace"), + eventhubsProps.getProperty("eventhubs.name")); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + return; + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + return; + } + + @Override + public void invoke(OUT value) throws Exception { + cache.checkErr(); + EventData event = new EventData(this.schema.serialize(value)); + cache.put(event); + prepareSendCount.inc(); + logger.debug("Insert a event input output cache"); + cache.checkErr(); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + prepareSendCount = getRuntimeContext().getMetricGroup().addGroup(this.getClass().getName()).counter("prepare_send_event_count"); + commitSendCount = getRuntimeContext().getMetricGroup().addGroup(this.getClass().getName()).counter("commit_send_event_count"); + String threadName = getEventhubProducerName(); + + logger.info("Eventhub producer thread {} starting", threadName); + producerThread = new EventhubProducerThread( + logger, + threadName, + cache, + eventhubsProps, + commitSendCount); + producerThread.start(); + logger.info("Eventhub producer thread {} started", threadName); + cache.checkErr(); + } + + @Override + public void close() throws Exception { + super.close(); + logger.info("Eventhub producer thread close on demand"); + producerThread.shutdown(); + cache.close(); + cache.checkErr(); + } + + protected String getEventhubProducerName(){ + return "Eventhub producer " + getRuntimeContext().getTaskNameWithSubtasks(); + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventFetcher.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventFetcher.java new file mode 100644 index 0000000000000..3212702da1be1 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventFetcher.java @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.metrics.Counter; +import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema; +import org.apache.flink.util.SerializedValue; + +import com.microsoft.azure.eventhubs.EventData; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Created by jozh on 6/14/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector. + * A fetcher that fetches data from Eventhub via the EventhubUtil. + * Eventhub offset is stored at flink checkpoint backend + * @param The type of elements produced by the fetcher. + */ +public class EventFetcher { + protected static final int NO_TIMESTAMPS_WATERMARKS = 0; + protected static final int PERIODIC_WATERMARKS = 1; + protected static final int PUNCTUATED_WATERMARKS = 2; + private static final Logger logger = LoggerFactory.getLogger(EventFetcher.class); + private volatile boolean running = true; + + private final KeyedDeserializationSchema deserializer; + private final Handover handover; + private final Properties eventhubProps; + private final EventhubConsumerThread consumerThread; + private final String taskNameWithSubtasks; + + + protected final SourceFunction.SourceContext sourceContext; + protected final Object checkpointLock; + private final Map subscribedPartitionStates; + protected final int timestampWatermarkMode; + protected final boolean useMetrics; + private volatile long maxWatermarkSoFar = Long.MIN_VALUE; + private Counter receivedCount; + + public EventFetcher( + SourceFunction.SourceContext sourceContext, + Map assignedPartitionsWithInitialOffsets, + KeyedDeserializationSchema deserializer, + SerializedValue> watermarksPeriodic, + SerializedValue> watermarksPunctuated, + ProcessingTimeService processTimerProvider, + long autoWatermarkInterval, + ClassLoader userCodeClassLoader, + String taskNameWithSubtasks, + Properties eventhubProps, + boolean useMetrics, + Counter receivedCount) throws Exception { + + this.sourceContext = checkNotNull(sourceContext); + this.deserializer = checkNotNull(deserializer); + this.eventhubProps = eventhubProps; + this.checkpointLock = sourceContext.getCheckpointLock(); + this.useMetrics = useMetrics; + this.receivedCount = receivedCount; + this.taskNameWithSubtasks = taskNameWithSubtasks; + this.timestampWatermarkMode = getTimestampWatermarkMode(watermarksPeriodic, watermarksPunctuated); + + this.subscribedPartitionStates = initializeSubscribedPartitionStates( + assignedPartitionsWithInitialOffsets, + timestampWatermarkMode, + watermarksPeriodic, watermarksPunctuated, + userCodeClassLoader); + + this.handover = new Handover(); + this.consumerThread = new EventhubConsumerThread(logger, + handover, + this.eventhubProps, + getFetcherName() + " for " + taskNameWithSubtasks, + this.subscribedPartitionStates.values().toArray(new EventhubPartitionState[this.subscribedPartitionStates.size()])); + + if (this.timestampWatermarkMode == PERIODIC_WATERMARKS) { + PeriodicWatermarkEmitter periodicEmitter = + new PeriodicWatermarkEmitter(this.subscribedPartitionStates, sourceContext, processTimerProvider, autoWatermarkInterval); + periodicEmitter.start(); + } + + } + + public HashMap snapshotCurrentState() { + // this method assumes that the checkpoint lock is held + logger.debug("snapshot current offset state for subtask {}", taskNameWithSubtasks); + assert Thread.holdsLock(checkpointLock); + + HashMap state = new HashMap<>(subscribedPartitionStates.size()); + for (Map.Entry partition : subscribedPartitionStates.entrySet()){ + state.put(partition.getKey(), partition.getValue().getOffset()); + } + + return state; + } + + public void runFetchLoop() throws Exception{ + try { + final Handover handover = this.handover; + consumerThread.start(); + logger.info("Eventhub consumer thread started for substask {}", taskNameWithSubtasks); + + logger.info("Start fetcher loop to get data from eventhub and emit to flink for subtask {}", taskNameWithSubtasks); + while (running){ + final Tuple2> eventsTuple = handover.pollNext(); + for (EventData event : eventsTuple.f1){ + final T value = deserializer.deserialize(null, + event.getBytes(), + event.getSystemProperties().getPartitionKey(), + eventsTuple.f0.getParitionId(), + event.getSystemProperties().getSequenceNumber()); + + if (deserializer.isEndOfStream(value)){ + running = false; + break; + } + emitRecord(value, subscribedPartitionStates.get(eventsTuple.f0), event.getSystemProperties().getOffset()); + receivedCount.inc(); + } + } + } + finally { + logger.warn("Stopping eventhub consumer thread of subtask {}, because something wrong when deserializing received event " + , taskNameWithSubtasks); + consumerThread.shutdown(); + } + + try { + consumerThread.join(); + logger.warn("Waiting eventhub consumer thread of subtask {} stopped", taskNameWithSubtasks); + } + catch (InterruptedException ex){ + Thread.currentThread().interrupt(); + } + + logger.info("EventFetcher of subtask {} stopped", taskNameWithSubtasks); + } + + public void cancel(){ + logger.info("EventFetcher of subtask {} canceled on demand", taskNameWithSubtasks); + running = false; + handover.close(); + consumerThread.shutdown(); + } + + protected void emitRecord(T record, EventhubPartitionState partitionState, String offset) throws Exception{ + if (record == null){ + synchronized (this.checkpointLock){ + partitionState.setOffset(offset); + } + return; + } + + if (timestampWatermarkMode == NO_TIMESTAMPS_WATERMARKS){ + synchronized (this.checkpointLock){ + sourceContext.collect(record); + partitionState.setOffset(offset); + } + } + else if (timestampWatermarkMode == PERIODIC_WATERMARKS){ + emitRecordWithTimestampAndPeriodicWatermark(record, partitionState, offset, Long.MIN_VALUE); + } + else { + emitRecordWithTimestampAndPunctuatedWatermark(record, partitionState, offset, Long.MIN_VALUE); + } + } + + protected void emitRecordWithTimestampAndPunctuatedWatermark( + T record, + EventhubPartitionState partitionState, + String offset, + long eventTimestamp) { + + final EventhubPartitionStateWithPeriodicWatermarks withWatermarksState = + (EventhubPartitionStateWithPeriodicWatermarks) partitionState; + + final long timestamp; + synchronized (withWatermarksState) { + timestamp = withWatermarksState.getTimestampForRecord(record, eventTimestamp); + } + + synchronized (checkpointLock) { + sourceContext.collectWithTimestamp(record, timestamp); + partitionState.setOffset(offset); + } + } + + protected void emitRecordWithTimestampAndPeriodicWatermark( + T record, + EventhubPartitionState partitionState, + String offset, + long eventTimestamp) { + + final EventhubPartitionStateWithPunctuatedWatermarks withWatermarksState = + (EventhubPartitionStateWithPunctuatedWatermarks) partitionState; + + final long timestamp = withWatermarksState.getTimestampForRecord(record, eventTimestamp); + final Watermark newWatermark = withWatermarksState.checkAndGetNewWatermark(record, timestamp); + + synchronized (checkpointLock) { + sourceContext.collectWithTimestamp(record, timestamp); + partitionState.setOffset(offset); + } + + if (newWatermark != null) { + updateMinPunctuatedWatermark(newWatermark); + } + } + + protected String getFetcherName() { + return "Eventhubs Fetcher"; + } + + private int getTimestampWatermarkMode(SerializedValue> watermarksPeriodic, + SerializedValue> watermarksPunctuated) + throws IllegalArgumentException { + if (watermarksPeriodic == null){ + if (watermarksPunctuated == null){ + return NO_TIMESTAMPS_WATERMARKS; + } + else { + return PUNCTUATED_WATERMARKS; + } + } + else { + if (watermarksPunctuated == null){ + return PERIODIC_WATERMARKS; + } + else { + throw new IllegalArgumentException("Cannot have both periodic and punctuated watermarks"); + } + } + } + + private Map initializeSubscribedPartitionStates( + Map assignedPartitionsWithInitialOffsets, + int timestampWatermarkMode, + SerializedValue> watermarksPeriodic, + SerializedValue> watermarksPunctuated, + ClassLoader userCodeClassLoader) throws IOException, ClassNotFoundException { + + if (timestampWatermarkMode != NO_TIMESTAMPS_WATERMARKS + && timestampWatermarkMode != PERIODIC_WATERMARKS + && timestampWatermarkMode != PUNCTUATED_WATERMARKS) { + throw new RuntimeException(); + } + + Map partitionsState = new HashMap<>(assignedPartitionsWithInitialOffsets.size()); + for (Map.Entry partition : assignedPartitionsWithInitialOffsets.entrySet()){ + switch (timestampWatermarkMode){ + case NO_TIMESTAMPS_WATERMARKS:{ + partitionsState.put(partition.getKey(), new EventhubPartitionState(partition.getKey(), partition.getValue())); + logger.info("NO_TIMESTAMPS_WATERMARKS: Assigned partition {}, offset is {}", partition.getKey(), partition.getValue()); + break; + } + + case PERIODIC_WATERMARKS:{ + AssignerWithPeriodicWatermarks assignerInstance = + watermarksPeriodic.deserializeValue(userCodeClassLoader); + partitionsState.put(partition.getKey(), + new EventhubPartitionStateWithPeriodicWatermarks(partition.getKey(), partition.getValue(), assignerInstance)); + logger.info("PERIODIC_WATERMARKS: Assigned partition {}, offset is {}", partition.getKey(), partition.getValue()); + break; + } + + case PUNCTUATED_WATERMARKS: { + AssignerWithPunctuatedWatermarks assignerInstance = + watermarksPunctuated.deserializeValue(userCodeClassLoader); + partitionsState.put(partition.getKey(), + new EventhubPartitionStateWithPunctuatedWatermarks(partition.getKey(), partition.getValue(), assignerInstance)); + logger.info("PUNCTUATED_WATERMARKS: Assigned partition {}, offset is {}", partition.getKey(), partition.getValue()); + break; + } + } + } + return partitionsState; + } + + private void updateMinPunctuatedWatermark(Watermark nextWatermark) { + if (nextWatermark.getTimestamp() > maxWatermarkSoFar) { + long newMin = Long.MAX_VALUE; + + for (Map.Entry partition : subscribedPartitionStates.entrySet()){ + final EventhubPartitionStateWithPunctuatedWatermarks withWatermarksState = + (EventhubPartitionStateWithPunctuatedWatermarks) partition.getValue(); + + newMin = Math.min(newMin, withWatermarksState.getCurrentPartitionWatermark()); + } + + // double-check locking pattern + if (newMin > maxWatermarkSoFar) { + synchronized (checkpointLock) { + if (newMin > maxWatermarkSoFar) { + maxWatermarkSoFar = newMin; + sourceContext.emitWatermark(new Watermark(newMin)); + } + } + } + } + } + + private static class PeriodicWatermarkEmitter implements ProcessingTimeCallback { + + private final Map allPartitions; + + private final SourceFunction.SourceContext emitter; + + private final ProcessingTimeService timerService; + + private final long interval; + + private long lastWatermarkTimestamp; + + //------------------------------------------------- + + PeriodicWatermarkEmitter( + Map allPartitions, + SourceFunction.SourceContext emitter, + ProcessingTimeService timerService, + long autoWatermarkInterval) { + this.allPartitions = checkNotNull(allPartitions); + this.emitter = checkNotNull(emitter); + this.timerService = checkNotNull(timerService); + this.interval = autoWatermarkInterval; + this.lastWatermarkTimestamp = Long.MIN_VALUE; + } + + public void start() { + timerService.registerTimer(timerService.getCurrentProcessingTime() + interval, this); + } + + @Override + public void onProcessingTime(long timestamp) throws Exception { + + long minAcrossAll = Long.MAX_VALUE; + for (Map.Entry partition : allPartitions.entrySet()){ + final long curr; + EventhubPartitionStateWithPeriodicWatermarks state = + (EventhubPartitionStateWithPeriodicWatermarks) partition.getValue(); + + synchronized (state) { + curr = state.getCurrentWatermarkTimestamp(); + } + + minAcrossAll = Math.min(minAcrossAll, curr); + } + + if (minAcrossAll > lastWatermarkTimestamp) { + lastWatermarkTimestamp = minAcrossAll; + emitter.emitWatermark(new Watermark(minAcrossAll)); + } + + timerService.registerTimer(timerService.getCurrentProcessingTime() + interval, this); + } + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubClientWrapper.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubClientWrapper.java new file mode 100644 index 0000000000000..39f22a263a9e7 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubClientWrapper.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.util.Preconditions; + +import com.microsoft.azure.eventhubs.EventData; +import com.microsoft.azure.eventhubs.EventHubClient; +import com.microsoft.azure.eventhubs.PartitionReceiver; +import com.microsoft.azure.servicebus.ConnectionStringBuilder; +import com.microsoft.azure.servicebus.ServiceBusException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.ExecutionException; + +/** + * Created by jozh on 6/14/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ +public class EventhubClientWrapper implements Serializable { + private static final long serialVersionUID = -5319150387753930840L; + private static final Logger logger = LoggerFactory.getLogger(EventhubClientWrapper.class); + private EventHubClient eventHubClient; + private PartitionReceiver eventhubReceiver; + private ConnectionStringBuilder connectionString; + private String consumerGroup; + private Long receiverEpoch; + + private Duration receiverTimeout; + private EventhubOffsetType offsetType; + private String currentOffset; + private String partitionId; + + private final int minPrefetchCount = 10; + private int maxPrefetchCount = 999; + private int maxEventRate = 0; + private final Long defaultReceiverEpoch = -1L; + private final String defaultReceiverTimeout = "60000"; + + public void createReveiver(Properties eventhubParams, String partitionId) + throws IllegalArgumentException, URISyntaxException, IOException, ServiceBusException{ + int maxEventRate = Integer.parseInt(eventhubParams.getProperty("eventhubs.maxRate", "10")); + this.createReveiver(eventhubParams, partitionId, maxEventRate, PartitionReceiver.START_OF_STREAM); + } + + public void createReveiver(Properties eventhubParams, String partitionId, String offset) + throws IllegalArgumentException, URISyntaxException, IOException, ServiceBusException{ + int maxEventRate = Integer.parseInt(eventhubParams.getProperty("eventhubs.maxRate", "10")); + this.createReveiver(eventhubParams, partitionId, maxEventRate, offset); + } + + public void createReveiver(Properties eventhubParams, String partitionId, int maxEventRate) + throws IllegalArgumentException, URISyntaxException, IOException, ServiceBusException{ + this.createReveiver(eventhubParams, partitionId, maxEventRate, PartitionReceiver.START_OF_STREAM); + } + + /*Will not implement a standalone offset store here, will leverage flink state to save the offset of eventhub*/ + public void createReveiver(Properties eventhubParams, String partitionId, int maxEventRate, String offset) + throws IllegalArgumentException, URISyntaxException, IOException, ServiceBusException{ + if (eventhubParams.containsKey("eventhubs.uri") && eventhubParams.containsKey("eventhubs.namespace")) { + throw new IllegalArgumentException("Eventhubs URI and namespace cannot both be specified at the same time."); + } + + if (eventhubParams.containsKey("eventhubs.namespace")){ + this.connectionString = new ConnectionStringBuilder( + eventhubParams.getProperty("eventhubs.namespace"), + eventhubParams.getProperty("eventhubs.name"), + eventhubParams.getProperty("eventhubs.policyname"), + eventhubParams.getProperty("eventhubs.policykey")); + } + else if (eventhubParams.containsKey("eventhubs.uri")){ + this.connectionString = new ConnectionStringBuilder(new URI( + eventhubParams.getProperty("eventhubs.uri")), + eventhubParams.getProperty("eventhubs.name"), + eventhubParams.getProperty("eventhubs.policyname"), + eventhubParams.getProperty("eventhubs.policykey")); + } + else { + throw new IllegalArgumentException("Either Eventhubs URI or namespace nust be specified."); + } + + this.partitionId = Preconditions.checkNotNull(partitionId, "partitionId is no valid, cannot be null or empty"); + this.consumerGroup = eventhubParams.getProperty("eventhubs.consumergroup", EventHubClient.DEFAULT_CONSUMER_GROUP_NAME); + this.receiverEpoch = Long.parseLong(eventhubParams.getProperty("eventhubs.epoch", defaultReceiverEpoch.toString())); + this.receiverTimeout = Duration.ofMillis(Long.parseLong(eventhubParams.getProperty("eventhubs.receiver.timeout", defaultReceiverTimeout))); + this.offsetType = EventhubOffsetType.None; + this.currentOffset = PartitionReceiver.START_OF_STREAM; + + String previousOffset = offset; + + if (previousOffset.compareTo(PartitionReceiver.START_OF_STREAM) != 0 && previousOffset != null) { + + offsetType = EventhubOffsetType.PreviousCheckpoint; + currentOffset = previousOffset; + + } else if (eventhubParams.containsKey("eventhubs.filter.offset")) { + + offsetType = EventhubOffsetType.InputByteOffset; + currentOffset = eventhubParams.getProperty("eventhubs.filter.offset"); + + } else if (eventhubParams.containsKey("eventhubs.filter.enqueuetime")) { + + offsetType = EventhubOffsetType.InputTimeOffset; + currentOffset = eventhubParams.getProperty("eventhubs.filter.enqueuetime"); + } + + this.maxEventRate = maxEventRate; + + if (maxEventRate > 0 && maxEventRate < minPrefetchCount) { + maxPrefetchCount = minPrefetchCount; + } + else if (maxEventRate >= minPrefetchCount && maxEventRate < maxPrefetchCount) { + maxPrefetchCount = maxEventRate + 1; + } + else { + this.maxEventRate = maxPrefetchCount - 1; + } + + this.createReceiverInternal(); + } + + public Iterable receive () throws ExecutionException, InterruptedException { + return this.eventhubReceiver.receive(maxEventRate).get(); + } + + public void close(){ + logger.info("Close eventhub client on demand of partition {}", this.partitionId); + if (this.eventhubReceiver != null){ + try { + this.eventhubReceiver.closeSync(); + } + catch (ServiceBusException ex){ + logger.error("Close eventhub client of partition {} failed, reason: {}", this.partitionId, ex.getMessage()); + } + } + } + + private void createReceiverInternal() throws IOException, ServiceBusException{ + this.eventHubClient = EventHubClient.createFromConnectionStringSync(this.connectionString.toString()); + + switch (this.offsetType){ + case None: { + if (this.receiverEpoch > defaultReceiverEpoch){ + this.eventhubReceiver = this.eventHubClient.createEpochReceiverSync(consumerGroup, partitionId, currentOffset, receiverEpoch); + } + else { + this.eventhubReceiver = this.eventHubClient.createReceiverSync(consumerGroup, partitionId, currentOffset, false); + } + break; + } + case PreviousCheckpoint: { + if (this.receiverEpoch > defaultReceiverEpoch){ + this.eventhubReceiver = this.eventHubClient.createEpochReceiverSync(consumerGroup, partitionId, currentOffset, false, receiverEpoch); + } + else { + this.eventhubReceiver = this.eventHubClient.createReceiverSync(consumerGroup, partitionId, currentOffset, false); + } + break; + } + case InputByteOffset: { + if (this.receiverEpoch > defaultReceiverEpoch){ + this.eventhubReceiver = this.eventHubClient.createEpochReceiverSync(consumerGroup, partitionId, currentOffset, false, receiverEpoch); + } + else { + this.eventhubReceiver = this.eventHubClient.createReceiverSync(consumerGroup, partitionId, currentOffset, false); + } + break; + } + case InputTimeOffset: { + if (this.receiverEpoch > defaultReceiverEpoch){ + this.eventhubReceiver = this.eventHubClient.createEpochReceiverSync(consumerGroup, partitionId, Instant.ofEpochSecond(Long.parseLong(currentOffset)), receiverEpoch); + } + else { + this.eventhubReceiver = this.eventHubClient.createReceiverSync(consumerGroup, partitionId, Instant.ofEpochSecond(Long.parseLong(currentOffset))); + } + break; + } + } + + this.eventhubReceiver.setPrefetchCount(maxPrefetchCount); + this.eventhubReceiver.setReceiveTimeout(this.receiverTimeout); + logger.info("Successfully create eventhub receiver for partition {}, max_event_rate {}, max_prefetch_rate {}, receive_timeout {}, offset {}, ", + this.partitionId, + this.maxEventRate, + this.maxPrefetchCount, + this.receiverTimeout, + this.currentOffset); + } + + public Duration getReceiverTimeout() { + return receiverTimeout; + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubConsumerThread.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubConsumerThread.java new file mode 100644 index 0000000000000..a0f9f261d512e --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubConsumerThread.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.api.java.tuple.Tuple2; + +import com.microsoft.azure.eventhubs.EventData; +import org.slf4j.Logger; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +/** + * Created by jozh on 5/24/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + * Cause eventhub client can only access one partition at one time, so here we should have multiple eventhub clients + * In this worker thread, it will receive event from each partition in round robin mode, any partition failed to retrive + * events will lead thread exception, and leverage flink HA framework to start from begining again + */ +public class EventhubConsumerThread extends Thread { + private final Logger logger; + private final Handover handover; + private final Properties eventhubProps; + private final EventhubPartitionState[] subscribedPartitionStates; + private final Map clients; + private volatile boolean running; + + public EventhubConsumerThread( + Logger logger, + Handover handover, + Properties eventhubProps, + String threadName, + EventhubPartitionState[] subscribedPartitionStates) throws Exception{ + + super(threadName); + setDaemon(true); + + this.logger = logger; + this.handover = handover; + this.eventhubProps = eventhubProps; + this.subscribedPartitionStates = subscribedPartitionStates; + this.running = true; + + this.clients = new HashMap<>(this.subscribedPartitionStates.length); + for (int i = 0; i < this.subscribedPartitionStates.length; i++){ + EventhubClientWrapper client = new EventhubClientWrapper(); + this.clients.put(this.subscribedPartitionStates[i], client); + } + } + + public void shutdown(){ + logger.info("Shutdown eventhub consumer thread {} on demand", this.getName()); + running = false; + handover.wakeupProducer(); + } + + @Override + public void run() { + if (!running){ + logger.info("Eventhub consumer thread is set to STOP, thread {} exit", this.getName()); + return; + } + + try { + logger.info("Starting create {} eventhub clients on {}", this.subscribedPartitionStates.length, this.getName()); + for (Map.Entry client : clients.entrySet()){ + EventhubPartitionState state = client.getKey(); + client.getValue().createReveiver(this.eventhubProps, Integer.toString(state.getPartition().getParitionId()), state.getOffset()); + } + } + catch (Throwable t){ + logger.error("Create eventhub client of {}, error: {}", this.getName(), t); + handover.reportError(t); + clearReceiveClients(); + return; + } + + try { + int currentClientIndex = 0; + while (running){ + EventhubPartitionState partitionState = subscribedPartitionStates[currentClientIndex]; + EventhubClientWrapper client = clients.get(partitionState); + Iterable events = client.receive(); + if (events != null){ + handover.produce(Tuple2.of(partitionState.getPartition(), events)); + logger.debug("Received event from {} on {}", partitionState.getPartition().toString(), this.getName()); + } + else { + logger.warn("Receive events from {} timeout, timeout set to {}, thread {}", + partitionState.getPartition().toString(), + client.getReceiverTimeout(), + this.getName()); + } + + currentClientIndex++; + currentClientIndex = currentClientIndex % subscribedPartitionStates.length; + } + } + catch (Throwable t){ + logger.error("Receving events error, {}", t); + handover.reportError(t); + } + finally { + logger.info("Exit from eventhub consumer thread, {}", this.getName()); + handover.close(); + clearReceiveClients(); + } + + logger.info("EventhubConsumerThread {} quit", this.getName()); + } + + private void clearReceiveClients(){ + if (clients == null){ + return; + } + + for (Map.Entry client : clients.entrySet()){ + try { + client.getValue().close(); + logger.info("Eventhub client for partition {} closed", client.getKey().getPartition().getParitionId()); + } + catch (Throwable t){ + logger.warn("Error while close eventhub client for partition {}, error is {}", + client.getKey().getPartition().getParitionId(), + t.getMessage()); + } + } + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubOffsetType.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubOffsetType.java new file mode 100644 index 0000000000000..e745ded06afeb --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubOffsetType.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +/** + * Created by jozh on 5/22/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ +public enum EventhubOffsetType { + None, + PreviousCheckpoint, + InputByteOffset, + InputTimeOffset +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartition.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartition.java new file mode 100644 index 0000000000000..9ae168ccf664c --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartition.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Properties; + +/** + * Created by jozh on 5/23/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ + +public class EventhubPartition implements Serializable { + private static final long serialVersionUID = 134878919919793479L; + private final int cachedHash; + private final String policyName; + private final String policyKey; + private final String namespace; + private final String name; + + public int getParitionId() { + return paritionId; + } + + public String getPartitionName(){ + return namespace + "-" + name; + } + + private final int paritionId; + + public EventhubPartition(Properties props, int parition){ + this(props.getProperty("eventhubs.policyname"), + props.getProperty("eventhubs.policykey"), + props.getProperty("eventhubs.namespace"), + props.getProperty("eventhubs.name"), + parition); + } + + public EventhubPartition(String policyName, String policyKey, String namespace, String name, int paritionId){ + Preconditions.checkArgument(paritionId >= 0); + + this.policyName = Preconditions.checkNotNull(policyName); + this.policyKey = Preconditions.checkNotNull(policyKey); + this.name = Preconditions.checkNotNull(name); + this.namespace = Preconditions.checkNotNull(namespace); + this.paritionId = paritionId; + this.cachedHash = 31 * (this.namespace + this.name).hashCode() + paritionId; + } + + @Override + public String toString() { + return "EventhubPartition, namespace: " + this.namespace + + " name: " + this.name + + " partition: " + this.paritionId; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof EventhubPartition){ + return this.hashCode() == ((EventhubPartition) obj).hashCode(); + } + + return false; + } + + @Override + public int hashCode() { + return this.cachedHash; + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionState.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionState.java new file mode 100644 index 0000000000000..d1547a539bde6 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionState.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +/** + * Created by jozh on 5/23/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ + +public class EventhubPartitionState { + private final EventhubPartition partition; + private volatile String offset; + + public EventhubPartitionState(EventhubPartition partition, String offset){ + this.partition = partition; + this.offset = offset; + } + + public final String getOffset() { + return this.offset; + } + + public final void setOffset(String offset) { + this.offset = offset; + } + + public EventhubPartition getPartition() { + return this.partition; + } +} + diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPeriodicWatermarks.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPeriodicWatermarks.java new file mode 100644 index 0000000000000..1fab7ff074633 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPeriodicWatermarks.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; +import org.apache.flink.streaming.api.watermark.Watermark; + +/** + * Created by jozh on 6/16/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ + +public class EventhubPartitionStateWithPeriodicWatermarks extends EventhubPartitionState { + private final AssignerWithPeriodicWatermarks timestampsAndWatermarks; + private long partitionWatermark; + + public EventhubPartitionStateWithPeriodicWatermarks(EventhubPartition key, String value, AssignerWithPeriodicWatermarks timestampsAndWatermarks) { + super(key, value); + this.timestampsAndWatermarks = timestampsAndWatermarks; + this.partitionWatermark = Long.MIN_VALUE; + } + + public long getTimestampForRecord(T record, long kafkaEventTimestamp) { + return timestampsAndWatermarks.extractTimestamp(record, kafkaEventTimestamp); + } + + public long getCurrentWatermarkTimestamp() { + Watermark wm = timestampsAndWatermarks.getCurrentWatermark(); + if (wm != null) { + partitionWatermark = Math.max(partitionWatermark, wm.getTimestamp()); + } + return partitionWatermark; + } + + @Override + public String toString() { + return "EventhubPartitionStateWithPeriodicWatermarks: partition=" + getPartition() + + ", offset=" + getOffset() + ", watermark=" + partitionWatermark; + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPunctuatedWatermarks.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPunctuatedWatermarks.java new file mode 100644 index 0000000000000..ae3d07ea3cd91 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubPartitionStateWithPunctuatedWatermarks.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; +import org.apache.flink.streaming.api.watermark.Watermark; + +import javax.annotation.Nullable; + +/** + * Created by jozh on 6/16/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ + +public class EventhubPartitionStateWithPunctuatedWatermarks extends EventhubPartitionState { + private final AssignerWithPunctuatedWatermarks timestampsAndWatermarks; + private long partitionWatermark; + + public EventhubPartitionStateWithPunctuatedWatermarks(EventhubPartition key, String value, AssignerWithPunctuatedWatermarks timestampsAndWatermarks) { + super(key, value); + this.timestampsAndWatermarks = timestampsAndWatermarks; + this.partitionWatermark = Long.MIN_VALUE; + } + + public long getTimestampForRecord(T record, long kafkaEventTimestamp) { + return timestampsAndWatermarks.extractTimestamp(record, kafkaEventTimestamp); + } + + @Nullable + public Watermark checkAndGetNewWatermark(T record, long timestamp) { + Watermark mark = timestampsAndWatermarks.checkAndGetNextWatermark(record, timestamp); + if (mark != null && mark.getTimestamp() > partitionWatermark) { + partitionWatermark = mark.getTimestamp(); + return mark; + } + else { + return null; + } + } + + public long getCurrentPartitionWatermark() { + return partitionWatermark; + } + + @Override + public String toString() { + return "EventhubPartitionStateWithPunctuatedWatermarks: partition=" + getPartition() + + ", offset=" + getOffset() + ", watermark=" + partitionWatermark; + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubProducerThread.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubProducerThread.java new file mode 100644 index 0000000000000..b51236749386f --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/EventhubProducerThread.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.metrics.Counter; + +import com.microsoft.azure.eventhubs.EventData; +import com.microsoft.azure.eventhubs.EventHubClient; +import com.microsoft.azure.servicebus.ConnectionStringBuilder; +import com.microsoft.azure.servicebus.ServiceBusException; +import org.slf4j.Logger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Properties; +/** + * Created by jozh on 6/20/2017. + */ + +public class EventhubProducerThread extends Thread { + private final Logger logger; + private final ProducerCache producerCache; + private final Properties eventhubProps; + private final EventHubClient producer; + private volatile boolean running; + private Counter commitSendCount; + + public EventhubProducerThread( + Logger logger, + String threadName, + ProducerCache producerCache, + Properties eventhubProps, + Counter commitSendCount) throws IOException, ServiceBusException{ + + super(threadName); + setDaemon(true); + + this.logger = logger; + this.producerCache = producerCache; + this.eventhubProps = eventhubProps; + this.commitSendCount = commitSendCount; + + ConnectionStringBuilder connectionStringBuilder = new ConnectionStringBuilder( + eventhubProps.getProperty("eventhubs.namespace"), + eventhubProps.getProperty("eventhubs.name"), + eventhubProps.getProperty("eventhubs.policyname"), + eventhubProps.getProperty("eventhubs.policykey")); + this.producer = EventHubClient.createFromConnectionStringSync(connectionStringBuilder.toString()); + this.running = true; + } + + public void shutdown(){ + logger.info("Shutdown eventhub producer thread {} on demand", this.getName()); + running = false; + } + + @Override + public void run() { + if (!running){ + logger.info("Eventhub producer thread is set to STOP, thread {} exit", this.getName()); + return; + } + + try { + logger.info("Eventhub producer thread {} started", this.getName()); + while (running){ + final ArrayList events = producerCache.pollNextBatch(); + if (events != null && events.size() > 0){ + producer.sendSync(events); + commitSendCount.inc(events.size()); + logger.info("Eventhub producer thread send {} events success", events.size()); + } + else { + logger.debug("Eventhub producer thread received a null eventdata from producer cache"); + } + } + } + catch (Throwable t){ + logger.error("Sending events error, {}", t.toString()); + producerCache.reportError(t); + } + finally { + logger.info("Exit from eventhub producer thread, {}", this.getName()); + if (producer != null){ + try { + producer.closeSync(); + } + catch (Exception ex) { + logger.error("Close eventhubclient {} error {}", eventhubProps.getProperty("eventhubs.name"), ex.getMessage()); + producerCache.reportError(ex); + } + } + } + + logger.info("EventhubProducerThread {} quit", this.getName()); + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/Handover.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/Handover.java new file mode 100644 index 0000000000000..277b4bb7a1380 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/Handover.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.util.ExceptionUtils; + +import com.microsoft.azure.eventhubs.EventData; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.ThreadSafe; + +import java.io.Closeable; +import java.util.concurrent.ConcurrentLinkedQueue; + + +/** + * Created by jozh on 5/23/2017. + * Flink eventhub connnector has implemented with same design of flink kafka connector + */ +@ThreadSafe +public final class Handover implements Closeable { + private static final int MAX_EVENTS_BLOCK_IN_QUEUE = 1000; + private static final Logger logger = LoggerFactory.getLogger(Handover.class); + private ConcurrentLinkedQueue>> eventQueue = new ConcurrentLinkedQueue(); + + private volatile boolean allProducerWakeup = true; + + private Throwable error; + + public Tuple2> pollNext() throws Exception{ + logger.debug("###Begin to poll data from event cache queue"); + synchronized (eventQueue){ + while (eventQueue.isEmpty() && error == null){ + logger.debug("### No data in the msg queue, waiting... "); + eventQueue.wait(); + } + + logger.debug("### Get notified from consummer thread"); + Tuple2> events = eventQueue.poll(); + if (events != null && events.f0 != null && events.f1 != null){ + logger.debug("### Get event data from {}", events.f0.toString()); + int queueSize = eventQueue.size(); + if (queueSize < MAX_EVENTS_BLOCK_IN_QUEUE / 2){ + eventQueue.notifyAll(); + } + return events; + } + else { + ExceptionUtils.rethrowException(error, error.getMessage()); + return null; + } + } + } + + public void produce(final Tuple2> events) throws InterruptedException{ + if (events == null || events.f0 == null || events.f1 == null){ + logger.error("Received empty events from event producer"); + return; + } + + synchronized (eventQueue){ + while (eventQueue.size() > MAX_EVENTS_BLOCK_IN_QUEUE){ + logger.warn("Event queue is full, current size is {}", eventQueue.size()); + eventQueue.wait(); + } + + eventQueue.add(events); + eventQueue.notifyAll(); + logger.debug("Add received events into queue"); + } + } + + @Override + public void close() { + synchronized (eventQueue){ + logger.info("Close handover on demand"); + eventQueue.clear(); + if (error == null){ + error = new Throwable("Handover closed on command"); + } + + eventQueue.notifyAll(); + } + } + + public void reportError(Throwable t) { + if (t == null){ + return; + } + + synchronized (eventQueue){ + if (error == null){ + error = t; + } + eventQueue.clear(); + eventQueue.notifyAll(); + logger.info("Consumer thread report a error: {}", error.getMessage()); + } + } + + public void wakeupProducer() { + synchronized (eventQueue){ + logger.info("Wakeup producer on demand"); + eventQueue.clear(); + eventQueue.notifyAll(); + } + } +} diff --git a/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/ProducerCache.java b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/ProducerCache.java new file mode 100644 index 0000000000000..5b29c4ff95f71 --- /dev/null +++ b/flink-connectors/flink-connector-eventhubs/src/main/java/org/apache/flink/streaming/connectors/eventhubs/internals/ProducerCache.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.eventhubs.internals; + +import org.apache.flink.util.ExceptionUtils; + +import com.microsoft.azure.eventhubs.EventData; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Date; +import java.util.concurrent.ArrayBlockingQueue; + +/** + * Created by jozh on 6/20/2017. + */ +public final class ProducerCache implements Closeable, Serializable { + private static final Logger logger = LoggerFactory.getLogger(ProducerCache.class); + private static final long defaultCheckQueueStatusInterval = 50; + public static final int DEFAULTCAPACITY = 100; + public static final long DEFAULTTIMEOUTMILLISECOND = 100; + private final ArrayBlockingQueue cacheQueue; + private final int queueCapacity; + private final long pollTimeout; + private Date lastPollTime; + private Throwable error; + private volatile boolean closed; + + public ProducerCache(){ + this(DEFAULTCAPACITY, DEFAULTTIMEOUTMILLISECOND); + } + + public ProducerCache(int capacity){ + this(capacity, DEFAULTTIMEOUTMILLISECOND); + } + + public ProducerCache(int capacity, long timeout){ + this.queueCapacity = capacity; + this.pollTimeout = timeout; + this.cacheQueue = new ArrayBlockingQueue(this.queueCapacity); + this.lastPollTime = new Date(); + this.closed = false; + } + + public void put(EventData value) throws Exception{ + if (value == null){ + logger.error("Received empty events from event producer"); + return; + } + + synchronized (cacheQueue){ + while (cacheQueue.remainingCapacity() <= 0 && !closed){ + checkErr(); + logger.warn("Event queue is full, current size is {}", cacheQueue.size()); + cacheQueue.wait(defaultCheckQueueStatusInterval); + } + + if (closed){ + logger.info("Cache is closed, event is dropped."); + return; + } + + cacheQueue.add(value); + cacheQueue.notifyAll(); + + logger.debug("Add event into queue"); + } + } + + public ArrayList pollNextBatch() throws InterruptedException{ + logger.debug("###Begin to poll all data from event cache queue"); + + synchronized (cacheQueue){ + while (!isPollTimeout() && !closed && cacheQueue.remainingCapacity() > 0){ + cacheQueue.wait(defaultCheckQueueStatusInterval); + } + + final ArrayList result = new ArrayList<>(cacheQueue.size()); + for (EventData item : cacheQueue){ + result.add(item); + } + cacheQueue.clear(); + cacheQueue.notifyAll(); + + lastPollTime = new Date(); + return result; + } + } + + public void reportError(Throwable t) { + if (t == null){ + return; + } + + synchronized (cacheQueue){ + if (error == null){ + error = t; + } + logger.info("Producer thread report a error: {}", t.toString()); + } + } + + @Override + public void close() { + synchronized (cacheQueue){ + logger.info("Close cache on demand"); + closed = true; + cacheQueue.notifyAll(); + } + } + + public void checkErr() throws Exception { + synchronized (cacheQueue){ + if (error != null){ + ExceptionUtils.rethrowException(error, error.getMessage()); + } + } + } + + private boolean isPollTimeout(){ + long pollInterval = (new Date()).getTime() - lastPollTime.getTime(); + return pollInterval > pollTimeout; + } +} diff --git a/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java b/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java index 3d3ea05cf3093..e5758e8920fdb 100644 --- a/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java +++ b/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java @@ -29,7 +29,6 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.connectors.fs.bucketing.BucketingSink; import org.apache.flink.util.Preconditions; @@ -132,7 +131,7 @@ @Deprecated public class RollingSink extends RichSinkFunction implements InputTypeConfigurable, CheckpointedFunction, - CheckpointListener, CheckpointedRestoring { + CheckpointListener { private static final long serialVersionUID = 1L; @@ -758,25 +757,6 @@ private void handleRestoredBucketState(BucketState bucketState) { } } - // -------------------------------------------------------------------------------------------- - // Backwards compatibility with Flink 1.1 - // -------------------------------------------------------------------------------------------- - - @Override - public void restoreState(BucketState state) throws Exception { - LOG.info("{} (taskIdx={}) restored bucket state from an older Flink version: {}", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), state); - - try { - initFileSystem(); - } catch (IOException e) { - LOG.error("Error while creating FileSystem when restoring the state of the RollingSink.", e); - throw new RuntimeException("Error while creating FileSystem when restoring the state of the RollingSink.", e); - } - - handleRestoredBucketState(state); - } - // -------------------------------------------------------------------------------------------- // Setters for User configuration values // -------------------------------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSink.java b/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSink.java index 70168b55164bb..cc924a4a2056f 100644 --- a/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSink.java +++ b/flink-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSink.java @@ -30,7 +30,6 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.connectors.fs.Clock; @@ -154,8 +153,7 @@ */ public class BucketingSink extends RichSinkFunction - implements InputTypeConfigurable, CheckpointedFunction, CheckpointListener, - CheckpointedRestoring, ProcessingTimeCallback { + implements InputTypeConfigurable, CheckpointedFunction, CheckpointListener, ProcessingTimeCallback { private static final long serialVersionUID = 1L; @@ -872,25 +870,6 @@ private void handlePendingFilesForPreviousCheckpoints(Map> pe } } - // -------------------------------------------------------------------------------------------- - // Backwards compatibility with Flink 1.1 - // -------------------------------------------------------------------------------------------- - - @Override - public void restoreState(RollingSink.BucketState state) throws Exception { - LOG.info("{} (taskIdx={}) restored bucket state from the RollingSink an older Flink version: {}", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), state); - - try { - initFileSystem(); - } catch (IOException e) { - LOG.error("Error while creating FileSystem when restoring the state of the BucketingSink.", e); - throw new RuntimeException("Error while creating FileSystem when restoring the state of the BucketingSink.", e); - } - - handleRestoredRollingSinkState(state); - } - // -------------------------------------------------------------------------------------------- // Setters for User configuration values // -------------------------------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingSinkMigrationTest.java b/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingSinkMigrationTest.java deleted file mode 100644 index e0413795b2aa7..0000000000000 --- a/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingSinkMigrationTest.java +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.connectors.fs.bucketing; - -import org.apache.flink.streaming.api.operators.StreamSink; -import org.apache.flink.streaming.connectors.fs.RollingSink; -import org.apache.flink.streaming.connectors.fs.StringWriter; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.util.OperatingSystem; - -import org.apache.commons.io.FileUtils; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.io.File; -import java.io.IOException; -import java.net.URL; -import java.util.List; -import java.util.Map; - -/** - * Tests the migration from 1.1 snapshots. - */ -@Deprecated -public class RollingSinkMigrationTest { - - @ClassRule - public static TemporaryFolder tempFolder = new TemporaryFolder(); - - private static final String PART_PREFIX = "part"; - private static final String PENDING_SUFFIX = ".pending"; - private static final String IN_PROGRESS_SUFFIX = ".in-progress"; - private static final String VALID_LENGTH_SUFFIX = ".valid"; - - @BeforeClass - public static void verifyOS() { - Assume.assumeTrue("HDFS cluster cannot be started on Windows without extensions.", !OperatingSystem.isWindows()); - } - - @Test - public void testMigration() throws Exception { - - /* - * Code ran to get the snapshot: - * - * final File outDir = tempFolder.newFolder(); - - RollingSink sink = new RollingSink(outDir.getAbsolutePath()) - .setWriter(new StringWriter()) - .setBatchSize(5) - .setPartPrefix(PART_PREFIX) - .setInProgressPrefix("") - .setPendingPrefix("") - .setValidLengthPrefix("") - .setInProgressSuffix(IN_PROGRESS_SUFFIX) - .setPendingSuffix(PENDING_SUFFIX) - .setValidLengthSuffix(VALID_LENGTH_SUFFIX); - - OneInputStreamOperatorTestHarness testHarness1 = - new OneInputStreamOperatorTestHarness<>(new StreamSink<>(sink)); - - testHarness1.setup(); - testHarness1.open(); - - testHarness1.processElement(new StreamRecord<>("test1", 0L)); - testHarness1.processElement(new StreamRecord<>("test2", 0L)); - - checkFs(outDir, 1, 1, 0, 0); - - testHarness1.processElement(new StreamRecord<>("test3", 0L)); - testHarness1.processElement(new StreamRecord<>("test4", 0L)); - testHarness1.processElement(new StreamRecord<>("test5", 0L)); - - checkFs(outDir, 1, 4, 0, 0); - - StreamTaskState taskState = testHarness1.snapshot(0, 0); - testHarness1.snaphotToFile(taskState, "src/test/resources/rolling-sink-migration-test-flink1.1-snapshot"); - testHarness1.close(); - * */ - - final File outDir = tempFolder.newFolder(); - - RollingSink sink = new ValidatingRollingSink(outDir.getAbsolutePath()) - .setWriter(new StringWriter()) - .setBatchSize(5) - .setPartPrefix(PART_PREFIX) - .setInProgressPrefix("") - .setPendingPrefix("") - .setValidLengthPrefix("") - .setInProgressSuffix(IN_PROGRESS_SUFFIX) - .setPendingSuffix(PENDING_SUFFIX) - .setValidLengthSuffix(VALID_LENGTH_SUFFIX); - - OneInputStreamOperatorTestHarness testHarness1 = new OneInputStreamOperatorTestHarness<>( - new StreamSink<>(sink), 10, 1, 0); - testHarness1.setup(); - testHarness1.initializeStateFromLegacyCheckpoint(getResourceFilename("rolling-sink-migration-test-flink1.1-snapshot")); - testHarness1.open(); - - testHarness1.processElement(new StreamRecord<>("test1", 0L)); - testHarness1.processElement(new StreamRecord<>("test2", 0L)); - - checkFs(outDir, 1, 1, 0, 0); - - testHarness1.close(); - } - - private void checkFs(File outDir, int inprogress, int pending, int completed, int valid) throws IOException { - int inProg = 0; - int pend = 0; - int compl = 0; - int val = 0; - - for (File file: FileUtils.listFiles(outDir, null, true)) { - if (file.getAbsolutePath().endsWith("crc")) { - continue; - } - String path = file.getPath(); - if (path.endsWith(IN_PROGRESS_SUFFIX)) { - inProg++; - } else if (path.endsWith(PENDING_SUFFIX)) { - pend++; - } else if (path.endsWith(VALID_LENGTH_SUFFIX)) { - val++; - } else if (path.contains(PART_PREFIX)) { - compl++; - } - } - - Assert.assertEquals(inprogress, inProg); - Assert.assertEquals(pending, pend); - Assert.assertEquals(completed, compl); - Assert.assertEquals(valid, val); - } - - private static String getResourceFilename(String filename) { - ClassLoader cl = RollingSinkMigrationTest.class.getClassLoader(); - URL resource = cl.getResource(filename); - return resource.getFile(); - } - - static class ValidatingRollingSink extends RollingSink { - - private static final long serialVersionUID = -4263974081712009141L; - - ValidatingRollingSink(String basePath) { - super(basePath); - } - - @Override - public void restoreState(BucketState state) throws Exception { - - /** - * this validates that we read the state that was checkpointed by the previous version. We expect it to be: - * In-progress=/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-4 - * validLength=6 - * pendingForNextCheckpoint=[] - * pendingForPrevCheckpoints={0=[ /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-0, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-1, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-2, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-3]} - * */ - - String current = state.currentFile; - long validLength = state.currentFileValidLength; - - Assert.assertEquals("/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-4", current); - Assert.assertEquals(6, validLength); - - List pendingFiles = state.pendingFiles; - Assert.assertTrue(pendingFiles.isEmpty()); - - final Map> pendingFilesPerCheckpoint = state.pendingFilesPerCheckpoint; - Assert.assertEquals(1, pendingFilesPerCheckpoint.size()); - - for (Map.Entry> entry: pendingFilesPerCheckpoint.entrySet()) { - long checkpoint = entry.getKey(); - List files = entry.getValue(); - - Assert.assertEquals(0L, checkpoint); - Assert.assertEquals(4, files.size()); - - for (int i = 0; i < 4; i++) { - Assert.assertEquals( - "/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-" + i, - files.get(i)); - } - } - super.restoreState(state); - } - } -} diff --git a/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingToBucketingMigrationTest.java b/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingToBucketingMigrationTest.java deleted file mode 100644 index 8a8dbd6bc9301..0000000000000 --- a/flink-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/RollingToBucketingMigrationTest.java +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.connectors.fs.bucketing; - -import org.apache.flink.streaming.api.operators.StreamSink; -import org.apache.flink.streaming.connectors.fs.RollingSink; -import org.apache.flink.streaming.connectors.fs.StringWriter; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.util.OperatingSystem; - -import org.apache.commons.io.FileUtils; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.io.File; -import java.io.IOException; -import java.net.URL; -import java.util.List; -import java.util.Map; - -/** - * Tests the migration from {@link RollingSink} to {@link BucketingSink}. - */ -public class RollingToBucketingMigrationTest { - - @ClassRule - public static TemporaryFolder tempFolder = new TemporaryFolder(); - - private static final String PART_PREFIX = "part"; - private static final String PENDING_SUFFIX = ".pending"; - private static final String IN_PROGRESS_SUFFIX = ".in-progress"; - private static final String VALID_LENGTH_SUFFIX = ".valid"; - - @BeforeClass - public static void verifyOS() { - Assume.assumeTrue("HDFS cluster cannot be started on Windows without extensions.", !OperatingSystem.isWindows()); - } - - @Test - public void testMigration() throws Exception { - final File outDir = tempFolder.newFolder(); - - BucketingSink sink = new ValidatingBucketingSink(outDir.getAbsolutePath()) - .setWriter(new StringWriter()) - .setBatchSize(5) - .setPartPrefix(PART_PREFIX) - .setInProgressPrefix("") - .setPendingPrefix("") - .setValidLengthPrefix("") - .setInProgressSuffix(IN_PROGRESS_SUFFIX) - .setPendingSuffix(PENDING_SUFFIX) - .setValidLengthSuffix(VALID_LENGTH_SUFFIX); - - OneInputStreamOperatorTestHarness testHarness1 = new OneInputStreamOperatorTestHarness<>( - new StreamSink<>(sink), 10, 1, 0); - testHarness1.setup(); - testHarness1.initializeStateFromLegacyCheckpoint(getResourceFilename("rolling-sink-migration-test-flink1.1-snapshot")); - testHarness1.open(); - - testHarness1.processElement(new StreamRecord<>("test1", 0L)); - testHarness1.processElement(new StreamRecord<>("test2", 0L)); - - checkFs(outDir, 1, 1, 0, 0); - - testHarness1.close(); - } - - private static String getResourceFilename(String filename) { - ClassLoader cl = RollingToBucketingMigrationTest.class.getClassLoader(); - URL resource = cl.getResource(filename); - return resource.getFile(); - } - - private void checkFs(File outDir, int inprogress, int pending, int completed, int valid) throws IOException { - int inProg = 0; - int pend = 0; - int compl = 0; - int val = 0; - - for (File file: FileUtils.listFiles(outDir, null, true)) { - if (file.getAbsolutePath().endsWith("crc")) { - continue; - } - String path = file.getPath(); - if (path.endsWith(IN_PROGRESS_SUFFIX)) { - inProg++; - } else if (path.endsWith(PENDING_SUFFIX)) { - pend++; - } else if (path.endsWith(VALID_LENGTH_SUFFIX)) { - val++; - } else if (path.contains(PART_PREFIX)) { - compl++; - } - } - - Assert.assertEquals(inprogress, inProg); - Assert.assertEquals(pending, pend); - Assert.assertEquals(completed, compl); - Assert.assertEquals(valid, val); - } - - static class ValidatingBucketingSink extends BucketingSink { - - private static final long serialVersionUID = -4263974081712009141L; - - ValidatingBucketingSink(String basePath) { - super(basePath); - } - - @Override - public void restoreState(RollingSink.BucketState state) throws Exception { - - /** - * this validates that we read the state that was checkpointed by the previous version. We expect it to be: - * In-progress=/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-4 - * validLength=6 - * pendingForNextCheckpoint=[] - * pendingForPrevCheckpoints={0=[ /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-0, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-1, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-2, - * /var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-3]} - * */ - - String current = state.currentFile; - long validLength = state.currentFileValidLength; - - Assert.assertEquals("/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-4", current); - Assert.assertEquals(6, validLength); - - List pendingFiles = state.pendingFiles; - Assert.assertTrue(pendingFiles.isEmpty()); - - final Map> pendingFilesPerCheckpoint = state.pendingFilesPerCheckpoint; - Assert.assertEquals(1, pendingFilesPerCheckpoint.size()); - - for (Map.Entry> entry: pendingFilesPerCheckpoint.entrySet()) { - long checkpoint = entry.getKey(); - List files = entry.getValue(); - - Assert.assertEquals(0L, checkpoint); - Assert.assertEquals(4, files.size()); - - for (int i = 0; i < 4; i++) { - Assert.assertEquals( - "/var/folders/z5/fxvg1j6s6mn94nsf8b1yc8s80000gn/T/junit2927527303216950257/junit5645682027227039270/2017-01-09--18/part-0-" + i, - files.get(i)); - } - } - - super.restoreState(state); - } - } -} diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java index f3c9e5e342044..3088b1552a65d 100644 --- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java +++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java @@ -33,7 +33,6 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; @@ -78,8 +77,7 @@ public abstract class FlinkKafkaConsumerBase extends RichParallelSourceFunction implements CheckpointListener, ResultTypeQueryable, - CheckpointedFunction, - CheckpointedRestoring> { + CheckpointedFunction { private static final long serialVersionUID = -6272159445203409112L; @@ -766,22 +764,6 @@ public final void snapshotState(FunctionSnapshotContext context) throws Exceptio } } - @Override - public final void restoreState(HashMap restoredOffsets) { - LOG.info("{} (taskIdx={}) restoring offsets from an older version: {}", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), restoredOffsets); - - restoredFromOldState = true; - - if (restoredOffsets.size() > 0 && discoveryIntervalMillis != PARTITION_DISCOVERY_DISABLED) { - throw new IllegalArgumentException( - "Topic / partition discovery cannot be enabled if the job is restored from a savepoint from Flink 1.1.x."); - } - - restoredState = new TreeMap<>(new KafkaTopicPartition.Comparator()); - restoredState.putAll(restoredOffsets); - } - @Override public final void notifyCheckpointComplete(long checkpointId) throws Exception { if (!running) { diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscoverer.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscoverer.java index 39645be129f3d..b336fdc57061b 100644 --- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscoverer.java +++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscoverer.java @@ -17,10 +17,10 @@ package org.apache.flink.streaming.connectors.kafka.internals; -import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Map; +import java.util.Set; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -68,7 +68,7 @@ public abstract class AbstractPartitionDiscoverer { * to keep track of only the largest partition id because Kafka partition numbers are only * allowed to be increased and has incremental ids. */ - private Map topicsToLargestDiscoveredPartitionId; + private Set discoveredPartitions; public AbstractPartitionDiscoverer( KafkaTopicsDescriptor topicsDescriptor, @@ -78,7 +78,7 @@ public AbstractPartitionDiscoverer( this.topicsDescriptor = checkNotNull(topicsDescriptor); this.indexOfThisSubtask = indexOfThisSubtask; this.numParallelSubtasks = numParallelSubtasks; - this.topicsToLargestDiscoveredPartitionId = new HashMap<>(); + this.discoveredPartitions = new HashSet<>(); } /** @@ -149,10 +149,6 @@ public List discoverPartitions() throws WakeupException, Cl if (newDiscoveredPartitions == null || newDiscoveredPartitions.isEmpty()) { throw new RuntimeException("Unable to retrieve any partitions with KafkaTopicsDescriptor: " + topicsDescriptor); } else { - // sort so that we make sure the topicsToLargestDiscoveredPartitionId state is updated - // with incremental partition ids of the same topics (otherwise some partition ids may be skipped) - KafkaTopicPartition.sort(newDiscoveredPartitions); - Iterator iter = newDiscoveredPartitions.iterator(); KafkaTopicPartition nextPartition; while (iter.hasNext()) { @@ -196,7 +192,7 @@ public List discoverPartitions() throws WakeupException, Cl */ public boolean setAndCheckDiscoveredPartition(KafkaTopicPartition partition) { if (isUndiscoveredPartition(partition)) { - topicsToLargestDiscoveredPartitionId.put(partition.getTopic(), partition.getPartition()); + discoveredPartitions.add(partition); return KafkaTopicPartitionAssigner.assign(partition, numParallelSubtasks) == indexOfThisSubtask; } @@ -246,11 +242,6 @@ public static final class ClosedException extends Exception { } private boolean isUndiscoveredPartition(KafkaTopicPartition partition) { - return !topicsToLargestDiscoveredPartitionId.containsKey(partition.getTopic()) - || partition.getPartition() > topicsToLargestDiscoveredPartitionId.get(partition.getTopic()); - } - - public static boolean shouldAssignToThisSubtask(KafkaTopicPartition partition, int indexOfThisSubtask, int numParallelSubtasks) { - return Math.abs(partition.hashCode() % numParallelSubtasks) == indexOfThisSubtask; + return !discoveredPartitions.contains(partition); } } diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartition.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartition.java index 3500cd81edc3f..d35d5856f15ef 100644 --- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartition.java +++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/KafkaTopicPartition.java @@ -19,7 +19,6 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -131,8 +130,4 @@ public int compare(KafkaTopicPartition p1, KafkaTopicPartition p2) { } } } - - public static void sort(List partitions) { - Collections.sort(partitions, new Comparator()); - } } diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseMigrationTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseMigrationTest.java index e3f337ec671dd..84f0e388e1969 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseMigrationTest.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseMigrationTest.java @@ -92,7 +92,7 @@ public class FlinkKafkaConsumerBaseMigrationTest { @Parameterized.Parameters(name = "Migration Savepoint: {0}") public static Collection parameters () { - return Arrays.asList(MigrationVersion.v1_1, MigrationVersion.v1_2, MigrationVersion.v1_3); + return Arrays.asList(MigrationVersion.v1_2, MigrationVersion.v1_3); } public FlinkKafkaConsumerBaseMigrationTest(MigrationVersion testMigrateVersion) { @@ -322,7 +322,7 @@ public void testRestore() throws Exception { */ @Test public void testRestoreFailsWithNonEmptyPreFlink13StatesIfDiscoveryEnabled() throws Exception { - assumeTrue(testMigrateVersion == MigrationVersion.v1_1 || testMigrateVersion == MigrationVersion.v1_2); + assumeTrue(testMigrateVersion == MigrationVersion.v1_3 || testMigrateVersion == MigrationVersion.v1_2); final List partitions = new ArrayList<>(PARTITION_STATE.keySet()); diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaProducerTestBase.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaProducerTestBase.java index 4a611039af4bf..000de5268fe22 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaProducerTestBase.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaProducerTestBase.java @@ -45,13 +45,13 @@ import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.Preconditions; -import com.google.common.collect.ImmutableSet; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.junit.Test; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -285,7 +285,7 @@ public int partition(Integer record, byte[] key, byte[] value, String targetTopi properties, topic, partition, - ImmutableSet.copyOf(getIntegersSequence(BrokerRestartingMapper.numElementsBeforeSnapshot)), + Collections.unmodifiableSet(new HashSet<>(getIntegersSequence(BrokerRestartingMapper.numElementsBeforeSnapshot))), 30000L); deleteTestTopic(topic); diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java index 2633b951ae1f8..e9f1537ed1be1 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java @@ -394,6 +394,41 @@ public void testDeterministicAssignmentWithDifferentFetchedPartitionOrdering() t } } + @Test + public void testNonContiguousPartitionIdDiscovery() throws Exception { + List mockGetAllPartitionsForTopicsReturn1 = Arrays.asList( + new KafkaTopicPartition("test-topic", 1), + new KafkaTopicPartition("test-topic", 4)); + + List mockGetAllPartitionsForTopicsReturn2 = Arrays.asList( + new KafkaTopicPartition("test-topic", 0), + new KafkaTopicPartition("test-topic", 1), + new KafkaTopicPartition("test-topic", 2), + new KafkaTopicPartition("test-topic", 3), + new KafkaTopicPartition("test-topic", 4)); + + TestPartitionDiscoverer partitionDiscoverer = new TestPartitionDiscoverer( + topicsDescriptor, + 0, + 1, + TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList("test-topic")), + // first metadata fetch has missing partitions that appears only in the second fetch; + // need to create new modifiable lists for each fetch, since internally Iterable.remove() is used. + Arrays.asList(new ArrayList<>(mockGetAllPartitionsForTopicsReturn1), new ArrayList<>(mockGetAllPartitionsForTopicsReturn2))); + partitionDiscoverer.open(); + + List discoveredPartitions1 = partitionDiscoverer.discoverPartitions(); + assertEquals(2, discoveredPartitions1.size()); + assertTrue(discoveredPartitions1.contains(new KafkaTopicPartition("test-topic", 1))); + assertTrue(discoveredPartitions1.contains(new KafkaTopicPartition("test-topic", 4))); + + List discoveredPartitions2 = partitionDiscoverer.discoverPartitions(); + assertEquals(3, discoveredPartitions2.size()); + assertTrue(discoveredPartitions2.contains(new KafkaTopicPartition("test-topic", 0))); + assertTrue(discoveredPartitions2.contains(new KafkaTopicPartition("test-topic", 2))); + assertTrue(discoveredPartitions2.contains(new KafkaTopicPartition("test-topic", 3))); + } + private boolean contains(List partitions, int partition) { for (KafkaTopicPartition ktp : partitions) { if (ktp.getPartition() == partition) { diff --git a/flink-connectors/flink-connector-kinesis/pom.xml b/flink-connectors/flink-connector-kinesis/pom.xml index 41daaa7f6ca95..83934f64608be 100644 --- a/flink-connectors/flink-connector-kinesis/pom.xml +++ b/flink-connectors/flink-connector-kinesis/pom.xml @@ -33,9 +33,9 @@ under the License. flink-connector-kinesis_${scala.binary.version} flink-connector-kinesis - 1.10.71 - 1.6.2 - 0.10.2 + 1.11.171 + 1.8.1 + 0.12.5 jar diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java index d127f2b6fb710..a3681eca52fee 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java @@ -28,13 +28,11 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; @@ -44,6 +42,7 @@ import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; import org.apache.flink.streaming.util.serialization.DeserializationSchema; +import org.apache.flink.util.InstantiationUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,8 +71,7 @@ */ public class FlinkKinesisConsumer extends RichParallelSourceFunction implements ResultTypeQueryable, - CheckpointedFunction, - CheckpointedRestoring> { + CheckpointedFunction { private static final long serialVersionUID = 4724006128720664870L; @@ -176,7 +174,12 @@ public FlinkKinesisConsumer(List streams, KinesisDeserializationSchema entry : lastStateSnapshot.entrySet()) { @@ -362,23 +365,6 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { } } - @Override - public void restoreState(HashMap restoredState) throws Exception { - LOG.info("Subtask {} restoring offsets from an older Flink version: {}", - getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore); - - if (restoredState.isEmpty()) { - sequenceNumsToRestore = null; - } else { - sequenceNumsToRestore = new HashMap<>(); - for (Map.Entry stateEntry : restoredState.entrySet()) { - sequenceNumsToRestore.put( - KinesisStreamShard.convertToStreamShardMetadata(stateEntry.getKey()), - stateEntry.getValue()); - } - } - } - /** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */ protected KinesisDataFetcher createFetcher( List streams, diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducer.java index 04d7055ca8230..1f5e64c1fade3 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducer.java @@ -17,15 +17,13 @@ package org.apache.flink.streaming.connectors.kinesis; -import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; -import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisSerializationSchema; import org.apache.flink.streaming.connectors.kinesis.util.AWSUtil; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; import org.apache.flink.streaming.util.serialization.SerializationSchema; -import org.apache.flink.util.PropertiesUtil; +import org.apache.flink.util.InstantiationUtil; import com.amazonaws.services.kinesis.producer.Attempt; import com.amazonaws.services.kinesis.producer.KinesisProducer; @@ -35,14 +33,15 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; import java.util.List; -import java.util.Objects; import java.util.Properties; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -90,7 +89,7 @@ public class FlinkKinesisProducer extends RichSinkFunction { * This is a constructor supporting Flink's {@see SerializationSchema}. * * @param schema Serialization schema for the data type - * @param configProps The properties used to configure AWS credentials and AWS region + * @param configProps The properties used to configure KinesisProducer, including AWS credentials and AWS region */ public FlinkKinesisProducer(final SerializationSchema schema, Properties configProps) { @@ -115,15 +114,17 @@ public String getTargetStream(OUT element) { * This is a constructor supporting {@see KinesisSerializationSchema}. * * @param schema Kinesis serialization schema for the data type - * @param configProps The properties used to configure AWS credentials and AWS region + * @param configProps The properties used to configure KinesisProducer, including AWS credentials and AWS region */ public FlinkKinesisProducer(KinesisSerializationSchema schema, Properties configProps) { - this.configProps = checkNotNull(configProps, "configProps can not be null"); - - // check the configuration properties for any conflicting settings - KinesisConfigUtil.validateProducerConfiguration(this.configProps); - - ClosureCleaner.ensureSerializable(Objects.requireNonNull(schema)); + checkNotNull(configProps, "configProps can not be null"); + this.configProps = KinesisConfigUtil.replaceDeprecatedProducerKeys(configProps); + + checkNotNull(schema, "serialization schema cannot be null"); + checkArgument( + InstantiationUtil.isSerializable(schema), + "The provided serialization schema is not serializable: " + schema.getClass().getName() + ". " + + "Please check that it does not contain references to non-serializable instances."); this.schema = schema; } @@ -154,8 +155,12 @@ public void setDefaultPartition(String defaultPartition) { } public void setCustomPartitioner(KinesisPartitioner partitioner) { - Objects.requireNonNull(partitioner); - ClosureCleaner.ensureSerializable(partitioner); + checkNotNull(partitioner, "partitioner cannot be null"); + checkArgument( + InstantiationUtil.isSerializable(partitioner), + "The provided custom partitioner is not serializable: " + partitioner.getClass().getName() + ". " + + "Please check that it does not contain references to non-serializable instances."); + this.customPartitioner = partitioner; } @@ -165,18 +170,9 @@ public void setCustomPartitioner(KinesisPartitioner partitioner) { public void open(Configuration parameters) throws Exception { super.open(parameters); - KinesisProducerConfiguration producerConfig = new KinesisProducerConfiguration(); - - producerConfig.setRegion(configProps.getProperty(ProducerConfigConstants.AWS_REGION)); + // check and pass the configuration properties + KinesisProducerConfiguration producerConfig = KinesisConfigUtil.validateProducerConfiguration(configProps); producerConfig.setCredentialsProvider(AWSUtil.getCredentialsProvider(configProps)); - if (configProps.containsKey(ProducerConfigConstants.COLLECTION_MAX_COUNT)) { - producerConfig.setCollectionMaxCount(PropertiesUtil.getLong(configProps, - ProducerConfigConstants.COLLECTION_MAX_COUNT, producerConfig.getCollectionMaxCount(), LOG)); - } - if (configProps.containsKey(ProducerConfigConstants.AGGREGATION_MAX_COUNT)) { - producerConfig.setAggregationMaxCount(PropertiesUtil.getLong(configProps, - ProducerConfigConstants.AGGREGATION_MAX_COUNT, producerConfig.getAggregationMaxCount(), LOG)); - } producer = new KinesisProducer(producerConfig); callback = new FutureCallback() { diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ProducerConfigConstants.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ProducerConfigConstants.java index d131150b1697b..d66bb90f9c8a1 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ProducerConfigConstants.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ProducerConfigConstants.java @@ -21,13 +21,30 @@ /** * Optional producer specific configuration keys for {@link FlinkKinesisProducer}. + * + * @deprecated This class is deprecated in favor of the official AWS Kinesis producer configuration keys. + * See + * here for the full list of available configs. + * For configuring the region and credentials, please use the keys in {@link AWSConfigConstants}. */ +@Deprecated public class ProducerConfigConstants extends AWSConfigConstants { - /** Maximum number of items to pack into an PutRecords request. **/ + /** + * Deprecated key. + * + * @deprecated This is deprecated in favor of the official AWS Kinesis producer configuration keys. + * Please use {@code CollectionMaxCount} instead. + **/ + @Deprecated public static final String COLLECTION_MAX_COUNT = "aws.producer.collectionMaxCount"; - /** Maximum number of items to pack into an aggregated record. **/ + /** + * Deprecated key. + * + * @deprecated This is deprecated in favor of the official AWS Kinesis producer configuration keys. + * Please use {@code AggregationMaxCount} instead. + **/ + @Deprecated public static final String AGGREGATION_MAX_COUNT = "aws.producer.aggregationMaxCount"; - } diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/examples/ProduceIntoKinesis.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/examples/ProduceIntoKinesis.java index ee031eb80b1b9..8d21c2caa1f91 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/examples/ProduceIntoKinesis.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/examples/ProduceIntoKinesis.java @@ -22,7 +22,7 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisProducer; -import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; import org.apache.flink.streaming.util.serialization.SimpleStringSchema; import org.apache.commons.lang3.RandomStringUtils; @@ -43,9 +43,9 @@ public static void main(String[] args) throws Exception { DataStream simpleStringStream = see.addSource(new EventsGenerator()); Properties kinesisProducerConfig = new Properties(); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_REGION, pt.getRequired("region")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_REGION, pt.getRequired("region")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); FlinkKinesisProducer kinesis = new FlinkKinesisProducer<>( new SimpleStringSchema(), kinesisProducerConfig); diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtil.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtil.java index 42f1af055ad82..997191c464f68 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtil.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtil.java @@ -26,6 +26,7 @@ import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; import com.amazonaws.regions.Regions; +import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -38,6 +39,22 @@ * Utilities for Flink Kinesis connector configuration. */ public class KinesisConfigUtil { + + /** Maximum number of items to pack into an PutRecords request. **/ + protected static final String COLLECTION_MAX_COUNT = "CollectionMaxCount"; + + /** Maximum number of items to pack into an aggregated record. **/ + protected static final String AGGREGATION_MAX_COUNT = "AggregationMaxCount"; + + /** Limits the maximum allowed put rate for a shard, as a percentage of the backend limits. + * The default value is set as 100% in Flink. KPL's default value is 150% but it makes KPL throw + * RateLimitExceededException too frequently and breaks Flink sink as a result. + **/ + private static final String RATE_LIMIT = "RateLimit"; + + /** Default values for RateLimit. **/ + private static final String DEFAULT_RATE_LIMIT = "100"; + /** * Validate configuration properties for {@link FlinkKinesisConsumer}. */ @@ -126,19 +143,40 @@ public static void validateConsumerConfiguration(Properties config) { } } + /** + * Replace deprecated configuration properties for {@link FlinkKinesisProducer}. + * This should be remove along with deprecated keys + */ + public static Properties replaceDeprecatedProducerKeys(Properties configProps) { + // Replace deprecated key + if (configProps.containsKey(ProducerConfigConstants.COLLECTION_MAX_COUNT)) { + configProps.setProperty(COLLECTION_MAX_COUNT, + configProps.getProperty(ProducerConfigConstants.COLLECTION_MAX_COUNT)); + configProps.remove(ProducerConfigConstants.COLLECTION_MAX_COUNT); + } + // Replace deprecated key + if (configProps.containsKey(ProducerConfigConstants.AGGREGATION_MAX_COUNT)) { + configProps.setProperty(AGGREGATION_MAX_COUNT, + configProps.getProperty(ProducerConfigConstants.AGGREGATION_MAX_COUNT)); + configProps.remove(ProducerConfigConstants.AGGREGATION_MAX_COUNT); + } + return configProps; + } + /** * Validate configuration properties for {@link FlinkKinesisProducer}. */ - public static void validateProducerConfiguration(Properties config) { + public static KinesisProducerConfiguration validateProducerConfiguration(Properties config) { checkNotNull(config, "config can not be null"); validateAwsConfiguration(config); - validateOptionalPositiveLongProperty(config, ProducerConfigConstants.COLLECTION_MAX_COUNT, - "Invalid value given for maximum number of items to pack into a PutRecords request. Must be a valid non-negative long value."); + // Override KPL default value if it's not specified by user + if (!config.containsKey(RATE_LIMIT)) { + config.setProperty(RATE_LIMIT, DEFAULT_RATE_LIMIT); + } - validateOptionalPositiveLongProperty(config, ProducerConfigConstants.AGGREGATION_MAX_COUNT, - "Invalid value given for maximum number of items to pack into an aggregated record. Must be a valid non-negative long value."); + return KinesisProducerConfiguration.fromProperties(config); } /** diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java index af84420e6e3eb..364560c40ddd3 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java @@ -18,100 +18,133 @@ package org.apache.flink.streaming.connectors.kinesis; import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.StreamSource; -import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.StreamShardMetadata; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OperatorSnapshotUtil; +import org.apache.flink.streaming.util.migration.MigrationTestUtil; +import org.apache.flink.streaming.util.migration.MigrationVersion; -import com.amazonaws.services.kinesis.model.Shard; +import org.junit.Ignore; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; -import java.net.URL; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Properties; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were - * done using the Flink 1.1 {@code FlinkKinesisConsumer}. + * done using an older {@code FlinkKinesisConsumer}. + * + *

For regenerating the binary snapshot files run {@link #writeSnapshot()} on the corresponding + * Flink release-* branch. */ +@RunWith(Parameterized.class) public class FlinkKinesisConsumerMigrationTest { + /** + * TODO change this to the corresponding savepoint version to be written (e.g. {@link MigrationVersion#v1_3} for 1.3) + * TODO and remove all @Ignore annotations on the writeSnapshot() method to generate savepoints + */ + private final MigrationVersion flinkGenerateSavepointVersion = null; + + private static final HashMap TEST_STATE = new HashMap<>(); + static { + StreamShardMetadata shardMetadata = new StreamShardMetadata(); + shardMetadata.setStreamName("fakeStream1"); + shardMetadata.setShardId(KinesisShardIdGenerator.generateFromShardOrder(0)); + + TEST_STATE.put(shardMetadata, new SequenceNumber("987654321")); + } + + private final MigrationVersion testMigrateVersion; + + @Parameterized.Parameters(name = "Migration Savepoint: {0}") + public static Collection parameters () { + return Arrays.asList(MigrationVersion.v1_3); + } + + public FlinkKinesisConsumerMigrationTest(MigrationVersion testMigrateVersion) { + this.testMigrateVersion = testMigrateVersion; + } + + /** + * Manually run this to write binary snapshot data. + */ + @Ignore @Test - public void testRestoreFromFlink11WithEmptyState() throws Exception { - Properties testConfig = new Properties(); - testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1"); - testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC"); - testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); - testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + public void writeSnapshot() throws Exception { + writeSnapshot("src/test/resources/kinesis-consumer-migration-test-flink" + flinkGenerateSavepointVersion + "-snapshot", TEST_STATE); - final DummyFlinkKafkaConsumer consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig); + // write empty state snapshot + writeSnapshot("src/test/resources/kinesis-consumer-migration-test-flink" + flinkGenerateSavepointVersion + "-empty-snapshot", new HashMap<>()); + } + + @Test + public void testRestoreWithEmptyState() throws Exception { + final DummyFlinkKinesisConsumer consumerFunction = new DummyFlinkKinesisConsumer<>(mock(KinesisDataFetcher.class)); - StreamSource> consumerOperator = new StreamSource<>(consumerFunction); + StreamSource> consumerOperator = new StreamSource<>(consumerFunction); final AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0); - testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); - testHarness.setup(); - // restore state from binary snapshot file using legacy method - testHarness.initializeStateFromLegacyCheckpoint( - getResourceFilename("kinesis-consumer-migration-test-flink1.1-empty-snapshot")); + MigrationTestUtil.restoreFromSnapshot( + testHarness, + "src/test/resources/kinesis-consumer-migration-test-flink" + testMigrateVersion + "-empty-snapshot", testMigrateVersion); testHarness.open(); // assert that no state was restored - assertEquals(null, consumerFunction.getRestoredState()); + assertTrue(consumerFunction.getRestoredState().isEmpty()); consumerOperator.close(); consumerOperator.cancel(); } @Test - public void testRestoreFromFlink11() throws Exception { - Properties testConfig = new Properties(); - testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1"); - testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC"); - testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); - testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + public void testRestore() throws Exception { + final DummyFlinkKinesisConsumer consumerFunction = new DummyFlinkKinesisConsumer<>(mock(KinesisDataFetcher.class)); - final DummyFlinkKafkaConsumer consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig); - - StreamSource> consumerOperator = + StreamSource> consumerOperator = new StreamSource<>(consumerFunction); final AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0); - testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); - testHarness.setup(); - // restore state from binary snapshot file using legacy method - testHarness.initializeStateFromLegacyCheckpoint( - getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot")); + MigrationTestUtil.restoreFromSnapshot( + testHarness, + "src/test/resources/kinesis-consumer-migration-test-flink" + testMigrateVersion + "-snapshot", testMigrateVersion); testHarness.open(); - // the expected state in "kafka-consumer-migration-test-flink1.1-snapshot" - final HashMap expectedState = new HashMap<>(); - expectedState.put(KinesisStreamShard.convertToStreamShardMetadata(new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)))), - new SequenceNumber("987654321")); - - // assert that state is correctly restored from legacy checkpoint + // assert that state is correctly restored assertNotEquals(null, consumerFunction.getRestoredState()); assertEquals(1, consumerFunction.getRestoredState().size()); - assertEquals(expectedState, consumerFunction.getRestoredState()); + assertEquals(TEST_STATE, consumerFunction.getRestoredState()); consumerOperator.close(); consumerOperator.cancel(); @@ -119,31 +152,87 @@ public void testRestoreFromFlink11() throws Exception { // ------------------------------------------------------------------------ - private static String getResourceFilename(String filename) { - ClassLoader cl = FlinkKinesisConsumerMigrationTest.class.getClassLoader(); - URL resource = cl.getResource(filename); - if (resource == null) { - throw new NullPointerException("Missing snapshot resource."); + @SuppressWarnings("unchecked") + private void writeSnapshot(String path, HashMap state) throws Exception { + final OneShotLatch latch = new OneShotLatch(); + + final KinesisDataFetcher fetcher = mock(KinesisDataFetcher.class); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + latch.trigger(); + return null; + } + }).when(fetcher).runFetcher(); + when(fetcher.snapshotState()).thenReturn(state); + + final DummyFlinkKinesisConsumer consumer = new DummyFlinkKinesisConsumer<>(fetcher); + + StreamSource> consumerOperator = new StreamSource<>(consumer); + + final AbstractStreamOperatorTestHarness testHarness = + new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0); + + testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); + + testHarness.setup(); + testHarness.open(); + + final AtomicReference error = new AtomicReference<>(); + + // run the source asynchronously + Thread runner = new Thread() { + @Override + public void run() { + try { + consumer.run(mock(SourceFunction.SourceContext.class)); + } catch (Throwable t) { + t.printStackTrace(); + error.set(t); + } + } + }; + runner.start(); + + if (!latch.isTriggered()) { + latch.await(); + } + + final OperatorStateHandles snapshot; + synchronized (testHarness.getCheckpointLock()) { + snapshot = testHarness.snapshot(0L, 0L); } - return resource.getFile(); + + OperatorSnapshotUtil.writeStateHandle(snapshot, path); + + consumerOperator.close(); + runner.join(); } - private static class DummyFlinkKafkaConsumer extends FlinkKinesisConsumer { - private static final long serialVersionUID = 1L; + private static class DummyFlinkKinesisConsumer extends FlinkKinesisConsumer { + + private KinesisDataFetcher mockFetcher; + + private static Properties dummyConfig = new Properties(); + static { + dummyConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + dummyConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + dummyConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + } - @SuppressWarnings("unchecked") - DummyFlinkKafkaConsumer(Properties properties) { - super("test", mock(KinesisDeserializationSchema.class), properties); + DummyFlinkKinesisConsumer(KinesisDataFetcher mockFetcher) { + super("dummy-topic", mock(KinesisDeserializationSchema.class), dummyConfig); + this.mockFetcher = mockFetcher; } @Override protected KinesisDataFetcher createFetcher( List streams, - SourceFunction.SourceContext sourceContext, + SourceContext sourceContext, RuntimeContext runtimeContext, Properties configProps, - KinesisDeserializationSchema deserializationSchema) { - return mock(KinesisDataFetcher.class); + KinesisDeserializationSchema deserializer) { + return mockFetcher; } } } diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java index a26e758e72687..4a007d5b2c009 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.PojoSerializer; @@ -32,7 +33,6 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; @@ -40,6 +40,7 @@ import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.StreamShardMetadata; +import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator; import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; @@ -48,6 +49,7 @@ import com.amazonaws.services.kinesis.model.HashKeyRange; import com.amazonaws.services.kinesis.model.SequenceNumberRange; import com.amazonaws.services.kinesis.model.Shard; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -60,6 +62,7 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -501,38 +504,6 @@ public void testUnparsableLongForShardDiscoveryIntervalMillisInConfig() { KinesisConfigUtil.validateConsumerConfiguration(testConfig); } - // ---------------------------------------------------------------------- - // FlinkKinesisConsumer.validateProducerConfiguration() tests - // ---------------------------------------------------------------------- - - @Test - public void testUnparsableLongForCollectionMaxCountInConfig() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("Invalid value given for maximum number of items to pack into a PutRecords request"); - - Properties testConfig = new Properties(); - testConfig.setProperty(ProducerConfigConstants.AWS_REGION, "us-east-1"); - testConfig.setProperty(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); - testConfig.setProperty(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); - testConfig.setProperty(ProducerConfigConstants.COLLECTION_MAX_COUNT, "unparsableLong"); - - KinesisConfigUtil.validateProducerConfiguration(testConfig); - } - - @Test - public void testUnparsableLongForAggregationMaxCountInConfig() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("Invalid value given for maximum number of items to pack into an aggregated record"); - - Properties testConfig = new Properties(); - testConfig.setProperty(ProducerConfigConstants.AWS_REGION, "us-east-1"); - testConfig.setProperty(ProducerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); - testConfig.setProperty(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); - testConfig.setProperty(ProducerConfigConstants.AGGREGATION_MAX_COUNT, "unparsableLong"); - - KinesisConfigUtil.validateProducerConfiguration(testConfig); - } - // ---------------------------------------------------------------------- // Tests related to state initialization // ---------------------------------------------------------------------- @@ -710,38 +681,6 @@ public void testFetcherShouldNotBeRestoringFromFailureIfNotRestoringFromCheckpoi consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); } - @Test - @SuppressWarnings("unchecked") - public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws Exception { - HashMap fakeRestoredState = getFakeRestoredStore("all"); - HashMap legacyFakeRestoredState = new HashMap<>(); - for (Map.Entry kv : fakeRestoredState.entrySet()) { - legacyFakeRestoredState.put(new KinesisStreamShard(kv.getKey().getStreamName(), kv.getKey().getShard()), kv.getValue()); - } - - KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); - List shards = new ArrayList<>(); - shards.addAll(fakeRestoredState.keySet()); - when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); - PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); - - // assume the given config is correct - PowerMockito.mockStatic(KinesisConfigUtil.class); - PowerMockito.doNothing().when(KinesisConfigUtil.class); - - TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( - "fakeStream", new Properties(), 10, 2); - consumer.restoreState(legacyFakeRestoredState); - consumer.open(new Configuration()); - consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - - for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { - Mockito.verify(mockedFetcher).registerNewSubscribedShardState( - new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(restoredShard.getKey()), - restoredShard.getKey(), restoredShard.getValue())); - } - } - @Test @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception { @@ -1062,4 +1001,35 @@ private HashMap getFakeRestoredStore(String s return fakeRestoredState; } + + /** + * A non-serializable {@link KinesisDeserializationSchema} (because it is a nested class with reference + * to the enclosing class, which is not serializable) used for testing. + */ + private final class NonSerializableDeserializationSchema implements KinesisDeserializationSchema { + @Override + public String deserialize(byte[] recordValue, String partitionKey, String seqNum, long approxArrivalTimestamp, String stream, String shardId) throws IOException { + return new String(recordValue); + } + + @Override + public TypeInformation getProducedType() { + return BasicTypeInfo.STRING_TYPE_INFO; + } + } + + /** + * A static, serializable {@link KinesisDeserializationSchema}. + */ + private static final class SerializableDeserializationSchema implements KinesisDeserializationSchema { + @Override + public String deserialize(byte[] recordValue, String partitionKey, String seqNum, long approxArrivalTimestamp, String stream, String shardId) throws IOException { + return new String(recordValue); + } + + @Override + public TypeInformation getProducedType() { + return BasicTypeInfo.STRING_TYPE_INFO; + } + } } diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducerTest.java new file mode 100644 index 0000000000000..ac03cfed0c898 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisProducerTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis; + +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisSerializationSchema; +import org.apache.flink.streaming.util.serialization.SimpleStringSchema; +import org.apache.flink.util.InstantiationUtil; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.nio.ByteBuffer; +import java.util.Properties; + +import static org.junit.Assert.assertTrue; + +/** + * Suite of {@link FlinkKinesisProducer} tests. + */ +public class FlinkKinesisProducerTest { + + @Rule + public ExpectedException exception = ExpectedException.none(); + + // ---------------------------------------------------------------------- + // Tests to verify serializability + // ---------------------------------------------------------------------- + + @Test + public void testCreateWithNonSerializableDeserializerFails() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("The provided serialization schema is not serializable"); + + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + new FlinkKinesisProducer<>(new NonSerializableSerializationSchema(), testConfig); + } + + @Test + public void testCreateWithSerializableDeserializer() { + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + new FlinkKinesisProducer<>(new SerializableSerializationSchema(), testConfig); + } + + @Test + public void testConfigureWithNonSerializableCustomPartitionerFails() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("The provided custom partitioner is not serializable"); + + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + new FlinkKinesisProducer<>(new SimpleStringSchema(), testConfig) + .setCustomPartitioner(new NonSerializableCustomPartitioner()); + } + + @Test + public void testConfigureWithSerializableCustomPartitioner() { + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + new FlinkKinesisProducer<>(new SimpleStringSchema(), testConfig) + .setCustomPartitioner(new SerializableCustomPartitioner()); + } + + @Test + public void testConsumerIsSerializable() { + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + FlinkKinesisProducer consumer = new FlinkKinesisProducer<>(new SimpleStringSchema(), testConfig); + assertTrue(InstantiationUtil.isSerializable(consumer)); + } + + // ---------------------------------------------------------------------- + // Utility test classes + // ---------------------------------------------------------------------- + + /** + * A non-serializable {@link KinesisSerializationSchema} (because it is a nested class with reference + * to the enclosing class, which is not serializable) used for testing. + */ + private final class NonSerializableSerializationSchema implements KinesisSerializationSchema { + @Override + public ByteBuffer serialize(String element) { + return ByteBuffer.wrap(element.getBytes()); + } + + @Override + public String getTargetStream(String element) { + return "test-stream"; + } + } + + /** + * A static, serializable {@link KinesisSerializationSchema}. + */ + private static final class SerializableSerializationSchema implements KinesisSerializationSchema { + @Override + public ByteBuffer serialize(String element) { + return ByteBuffer.wrap(element.getBytes()); + } + + @Override + public String getTargetStream(String element) { + return "test-stream"; + } + } + + /** + * A non-serializable {@link KinesisPartitioner} (because it is a nested class with reference + * to the enclosing class, which is not serializable) used for testing. + */ + private final class NonSerializableCustomPartitioner extends KinesisPartitioner { + @Override + public String getPartitionId(String element) { + return "test-partition"; + } + } + + /** + * A static, serializable {@link KinesisPartitioner}. + */ + private static final class SerializableCustomPartitioner extends KinesisPartitioner { + @Override + public String getPartitionId(String element) { + return "test-partition"; + } + } +} diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualConsumerProducerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualConsumerProducerTest.java index 2915e2f6da1b3..a7470dc166e2f 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualConsumerProducerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualConsumerProducerTest.java @@ -25,8 +25,8 @@ import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisProducer; import org.apache.flink.streaming.connectors.kinesis.KinesisPartitioner; +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.examples.ProduceIntoKinesis; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisSerializationSchema; import org.apache.flink.streaming.util.serialization.SimpleStringSchema; @@ -56,9 +56,9 @@ public static void main(String[] args) throws Exception { DataStream simpleStringStream = see.addSource(new ProduceIntoKinesis.EventsGenerator()); Properties kinesisProducerConfig = new Properties(); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_REGION, pt.getRequired("region")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_REGION, pt.getRequired("region")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); FlinkKinesisProducer kinesis = new FlinkKinesisProducer<>( new KinesisSerializationSchema() { diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualProducerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualProducerTest.java index 8abf4bb2ef655..fb49169bc8110 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualProducerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/manualtests/ManualProducerTest.java @@ -23,7 +23,7 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisProducer; import org.apache.flink.streaming.connectors.kinesis.KinesisPartitioner; -import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; import org.apache.flink.streaming.connectors.kinesis.examples.ProduceIntoKinesis; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisSerializationSchema; @@ -53,9 +53,9 @@ public static void main(String[] args) throws Exception { DataStream simpleStringStream = see.addSource(new ProduceIntoKinesis.EventsGenerator()); Properties kinesisProducerConfig = new Properties(); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_REGION, pt.getRequired("region")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); - kinesisProducerConfig.setProperty(ProducerConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_REGION, pt.getRequired("region")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, pt.getRequired("accessKey")); + kinesisProducerConfig.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, pt.getRequired("secretKey")); FlinkKinesisProducer kinesis = new FlinkKinesisProducer<>( new KinesisSerializationSchema() { diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/ExactlyOnceValidatingConsumerThread.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/ExactlyOnceValidatingConsumerThread.java index 75356efca8428..1336652226f13 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/ExactlyOnceValidatingConsumerThread.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/ExactlyOnceValidatingConsumerThread.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; @@ -29,11 +29,14 @@ import org.apache.flink.streaming.util.serialization.SimpleStringSchema; import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.BitSet; +import java.util.Collections; +import java.util.List; import java.util.Properties; import java.util.concurrent.atomic.AtomicReference; @@ -95,7 +98,7 @@ public void run() { return new Thread(exactlyOnceValidationConsumer); } - private static class ExactlyOnceValidatingMapper implements FlatMapFunction, Checkpointed { + private static class ExactlyOnceValidatingMapper implements FlatMapFunction, ListCheckpointed { private static final Logger LOG = LoggerFactory.getLogger(ExactlyOnceValidatingMapper.class); @@ -126,13 +129,18 @@ public void flatMap(String value, Collector out) throws Exception { } @Override - public BitSet snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return validator; + public List snapshotState(long checkpointId, long timestamp) throws Exception { + return Collections.singletonList(validator); } @Override - public void restoreState(BitSet state) throws Exception { - this.validator = state; + public void restoreState(List state) throws Exception { + // we expect either 1 or 0 elements + if (state.size() == 1) { + validator = state.get(0); + } else { + Preconditions.checkState(state.isEmpty()); + } } } diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtilTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtilTest.java new file mode 100644 index 0000000000000..3b000588f4b91 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/KinesisConfigUtilTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; +import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for KinesisConfigUtil. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({FlinkKinesisConsumer.class, KinesisConfigUtil.class}) +public class KinesisConfigUtilTest { + @Rule + private ExpectedException exception = ExpectedException.none(); + + @Test + public void testUnparsableLongForProducerConfiguration() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Error trying to set field RateLimit with the value 'unparsableLong'"); + + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty("RateLimit", "unparsableLong"); + + KinesisConfigUtil.validateProducerConfiguration(testConfig); + } + + @Test + public void testReplaceDeprecatedKeys() { + Properties testConfig = new Properties(); + testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); + // these deprecated keys should be replaced + testConfig.setProperty(ProducerConfigConstants.AGGREGATION_MAX_COUNT, "1"); + testConfig.setProperty(ProducerConfigConstants.COLLECTION_MAX_COUNT, "2"); + Properties replacedConfig = KinesisConfigUtil.replaceDeprecatedProducerKeys(testConfig); + + assertEquals("1", replacedConfig.getProperty(KinesisConfigUtil.AGGREGATION_MAX_COUNT)); + assertEquals("2", replacedConfig.getProperty(KinesisConfigUtil.COLLECTION_MAX_COUNT)); + } +} diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-empty-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-empty-snapshot deleted file mode 100644 index f4dd96d211342..0000000000000 Binary files a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-empty-snapshot and /dev/null differ diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot deleted file mode 100644 index b60402e848327..0000000000000 Binary files a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot and /dev/null differ diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-empty-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-empty-snapshot new file mode 100644 index 0000000000000..aa981c0a61b83 Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-empty-snapshot differ diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-snapshot new file mode 100644 index 0000000000000..ddf8a4d8d0834 Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.3-snapshot differ diff --git a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java index 05ae8108b5a68..f180e786c2f28 100644 --- a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java +++ b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java @@ -51,8 +51,8 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.List; import java.util.Random; +import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -180,12 +180,12 @@ public void testCheckpointing() throws Exception { testHarnessCopy.initializeState(data); testHarnessCopy.open(); - ArrayDeque>> deque = sourceCopy.getRestoredState(); - List messageIds = deque.getLast().f1; + ArrayDeque>> deque = sourceCopy.getRestoredState(); + Set messageIds = deque.getLast().f1; assertEquals(numIds, messageIds.size()); if (messageIds.size() > 0) { - assertEquals(lastSnapshotId, (long) Long.valueOf(messageIds.get(messageIds.size() - 1))); + assertTrue(messageIds.contains(Long.toString(lastSnapshotId))); } // check if the messages are being acknowledged and the transaction committed @@ -339,7 +339,7 @@ public TypeInformation getProducedType() { private class RMQTestSource extends RMQSource { - private ArrayDeque>> restoredState; + private ArrayDeque>> restoredState; public RMQTestSource() { super(new RMQConnectionConfig.Builder().setHost("hostTest") @@ -353,7 +353,7 @@ public void initializeState(FunctionInitializationContext context) throws Except this.restoredState = this.pendingCheckpoints; } - public ArrayDeque>> getRestoredState() { + public ArrayDeque>> getRestoredState() { return this.restoredState; } diff --git a/flink-connectors/flink-connector-twitter/pom.xml b/flink-connectors/flink-connector-twitter/pom.xml index 0f1e44a5d5dcc..2d8d62a39593c 100644 --- a/flink-connectors/flink-connector-twitter/pom.xml +++ b/flink-connectors/flink-connector-twitter/pom.xml @@ -77,12 +77,23 @@ under the License. + com.google.guava:guava com.twitter:hbc-core com.twitter:joauth org.apache.httpcomponents:httpclient org.apache.httpcomponents:httpcore + + + com.google + org.apache.flink.twitter.shaded.com.google + + com.google.protobuf.** + com.google.inject.** + + + diff --git a/flink-connectors/flink-hbase/pom.xml b/flink-connectors/flink-hbase/pom.xml index b900d23ded2b3..e18fe39280c8b 100644 --- a/flink-connectors/flink-hbase/pom.xml +++ b/flink-connectors/flink-hbase/pom.xml @@ -127,14 +127,6 @@ under the License. flink-streaming-java_${scala.binary.version} ${project.version} provided - - - - - com.google.guava - guava - - diff --git a/flink-connectors/flink-hcatalog/pom.xml b/flink-connectors/flink-hcatalog/pom.xml index 10ca36d8acaf2..1e77d7df443f4 100644 --- a/flink-connectors/flink-hcatalog/pom.xml +++ b/flink-connectors/flink-hcatalog/pom.xml @@ -70,6 +70,32 @@ under the License. + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + + + + com.google.guava:guava + + + + + com.google + org.apache.flink.hcatalog.shaded.com.google + + com.google.protobuf.** + com.google.inject.** + + + + + + + net.alchim31.maven diff --git a/flink-connectors/flink-jdbc/src/main/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormat.java b/flink-connectors/flink-jdbc/src/main/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormat.java index b7ac7446a410d..7d088147ae829 100644 --- a/flink-connectors/flink-jdbc/src/main/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormat.java +++ b/flink-connectors/flink-jdbc/src/main/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormat.java @@ -144,7 +144,7 @@ public void openInputFormat() { dbConn = DriverManager.getConnection(dbURL, username, password); } statement = dbConn.prepareStatement(queryTemplate, resultSetType, resultSetConcurrency); - if (fetchSize > 0) { + if (fetchSize == Integer.MIN_VALUE || fetchSize > 0) { statement.setFetchSize(fetchSize); } } catch (SQLException se) { @@ -390,7 +390,8 @@ public JDBCInputFormatBuilder setRowTypeInfo(RowTypeInfo rowTypeInfo) { } public JDBCInputFormatBuilder setFetchSize(int fetchSize) { - Preconditions.checkArgument(fetchSize > 0, "Illegal value %s for fetchSize, has to be positive.", fetchSize); + Preconditions.checkArgument(fetchSize == Integer.MIN_VALUE || fetchSize > 0, + "Illegal value %s for fetchSize, has to be positive or Integer.MIN_VALUE.", fetchSize); format.fetchSize = fetchSize; return this; } diff --git a/flink-connectors/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java b/flink-connectors/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java index f7a86e5afd222..10e8c66a7ddfa 100644 --- a/flink-connectors/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java +++ b/flink-connectors/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java @@ -113,6 +113,17 @@ public void testInvalidFetchSize() { .finish(); } + @Test + public void testValidFetchSizeIntegerMin() { + jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() + .setDrivername(DRIVER_CLASS) + .setDBUrl(DB_URL) + .setQuery(SELECT_ALL_BOOKS) + .setRowTypeInfo(ROW_TYPE_INFO) + .setFetchSize(Integer.MIN_VALUE) + .finish(); + } + @Test public void testDefaultFetchSizeIsUsedIfNotConfiguredOtherwise() throws SQLException, ClassNotFoundException { jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() diff --git a/flink-connectors/pom.xml b/flink-connectors/pom.xml index bc3f82f686c44..2ed3b7974098b 100644 --- a/flink-connectors/pom.xml +++ b/flink-connectors/pom.xml @@ -54,6 +54,7 @@ under the License. flink-connector-nifi flink-connector-cassandra flink-connector-filesystem + flink-connector-eventhubs + + org.apache.flink + flink-shaded-guava + + org.apache.flink flink-test-utils_${scala.binary.version} diff --git a/flink-contrib/flink-storm-examples/src/test/java/org/apache/flink/storm/join/SingleJoinITCase.java b/flink-contrib/flink-storm-examples/src/test/java/org/apache/flink/storm/join/SingleJoinITCase.java index 83531bab38a2e..5d406dba72a7f 100644 --- a/flink-contrib/flink-storm-examples/src/test/java/org/apache/flink/storm/join/SingleJoinITCase.java +++ b/flink-contrib/flink-storm-examples/src/test/java/org/apache/flink/storm/join/SingleJoinITCase.java @@ -20,7 +20,7 @@ import org.apache.flink.streaming.util.StreamingProgramTestBase; -import com.google.common.base.Joiner; +import org.apache.flink.shaded.guava18.com.google.common.base.Joiner; /** * Test for the SingleJoin example. diff --git a/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java index e27c29f2e0b46..019580a6705dd 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java @@ -22,7 +22,7 @@ import static org.apache.flink.configuration.ConfigOptions.key; /** - * Configuration options for the BlobServer. + * Configuration options for the BlobServer and BlobCache. */ @PublicEvolving public class BlobServerOptions { @@ -73,4 +73,18 @@ public class BlobServerOptions { public static final ConfigOption SSL_ENABLED = key("blob.service.ssl.enabled") .defaultValue(true); + + /** + * Cleanup interval of the blob caches at the task managers (in seconds). + * + *

Whenever a job is not referenced at the cache anymore, we set a TTL and let the periodic + * cleanup task (executed every CLEANUP_INTERVAL seconds) remove its blob files after this TTL + * has passed. This means that a blob will be retained at most 2 * CLEANUP_INTERVAL + * seconds after not being referenced anymore. Therefore, a recovery still has the chance to use + * existing files rather than to download them again. + */ + public static final ConfigOption CLEANUP_INTERVAL = + key("blob.service.cleanup.interval") + .defaultValue(3_600L) // once per hour + .withDeprecatedKeys("library-cache-manager.cleanup.interval"); } diff --git a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java index 4c6c62a51c8ca..4153e456e3160 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java @@ -178,7 +178,10 @@ public final class ConfigConstants { /** * The config parameter defining the cleanup interval of the library cache manager. + * + * @deprecated use {@link BlobServerOptions#CLEANUP_INTERVAL} instead */ + @Deprecated public static final String LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL = "library-cache-manager.cleanup.interval"; /** @@ -1253,8 +1256,12 @@ public final class ConfigConstants { /** * The default library cache manager cleanup interval in seconds + * + * @deprecated use {@link BlobServerOptions#CLEANUP_INTERVAL} instead */ - public static final long DEFAULT_LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL = 3600; + @Deprecated + public static final long DEFAULT_LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL = + BlobServerOptions.CLEANUP_INTERVAL.defaultValue(); /** * The default network port to connect to for communication with the job manager. diff --git a/flink-core/src/main/java/org/apache/flink/configuration/Configuration.java b/flink-core/src/main/java/org/apache/flink/configuration/Configuration.java index d6f1decf3f69f..dfcd04fb97eae 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/Configuration.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/Configuration.java @@ -79,7 +79,7 @@ public Configuration(Configuration other) { } // -------------------------------------------------------------------------------------------- - + /** * Returns the class associated with the given key as a string. * diff --git a/flink-core/src/main/java/org/apache/flink/configuration/GlobalConfiguration.java b/flink-core/src/main/java/org/apache/flink/configuration/GlobalConfiguration.java index ea9f8bfc97e0c..4569ebe0acf0d 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/GlobalConfiguration.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/GlobalConfiguration.java @@ -28,6 +28,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + /** * Global configuration object for Flink. Similar to Java properties configuration * objects it includes key-value pairs which represent the framework's configuration. @@ -46,24 +48,6 @@ private GlobalConfiguration() {} // -------------------------------------------------------------------------------------------- - private static Configuration dynamicProperties = null; - - /** - * Set the process-wide dynamic properties to be merged with the loaded configuration. - */ - public static void setDynamicProperties(Configuration dynamicProperties) { - GlobalConfiguration.dynamicProperties = new Configuration(dynamicProperties); - } - - /** - * Get the dynamic properties. - */ - public static Configuration getDynamicProperties() { - return GlobalConfiguration.dynamicProperties; - } - - // -------------------------------------------------------------------------------------------- - /** * Loads the global configuration from the environment. Fails if an error occurs during loading. Returns an * empty configuration object if the environment variable is not set. In production this variable is set but @@ -76,18 +60,30 @@ public static Configuration loadConfiguration() { if (configDir == null) { return new Configuration(); } - return loadConfiguration(configDir); + return loadConfiguration(configDir, null); } /** * Loads the configuration files from the specified directory. *

* YAML files are supported as configuration files. - * + * * @param configDir * the directory which contains the configuration files */ public static Configuration loadConfiguration(final String configDir) { + return loadConfiguration(configDir, null); + } + + /** + * Loads the configuration files from the specified directory. If the dynamic properties + * configuration is not null, then it is added to the loaded configuration. + * + * @param configDir directory to load the configuration from + * @param dynamicProperties configuration file containing the dynamic properties. Null if none. + * @return The configuration loaded from the given configuration directory + */ + public static Configuration loadConfiguration(final String configDir, @Nullable final Configuration dynamicProperties) { if (configDir == null) { throw new IllegalArgumentException("Given configuration directory is null, cannot load configuration"); @@ -109,13 +105,29 @@ public static Configuration loadConfiguration(final String configDir) { "' (" + confDirFile.getAbsolutePath() + ") does not exist."); } - Configuration conf = loadYAMLResource(yamlConfigFile); + Configuration configuration = loadYAMLResource(yamlConfigFile); - if(dynamicProperties != null) { - conf.addAll(dynamicProperties); + if (dynamicProperties != null) { + configuration.addAll(dynamicProperties); + } + + return configuration; + } + + /** + * Loads the global configuration and adds the given dynamic properties + * configuration. + * + * @param dynamicProperties The given dynamic properties + * @return Returns the loaded global configuration with dynamic properties + */ + public static Configuration loadConfigurationWithDynamicProperties(Configuration dynamicProperties) { + final String configDir = System.getenv(ConfigConstants.ENV_FLINK_CONF_DIR); + if (configDir == null) { + return new Configuration(dynamicProperties); } - return conf; + return loadConfiguration(configDir, dynamicProperties); } /** diff --git a/flink-core/src/main/java/org/apache/flink/configuration/RestOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/RestOptions.java new file mode 100644 index 0000000000000..a2a20136ea40e --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/configuration/RestOptions.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.configuration; + +import org.apache.flink.annotation.Internal; + +import static org.apache.flink.configuration.ConfigOptions.key; + +/** + * Configuration parameters for REST communication. + */ +@Internal +public class RestOptions { + /** + * The address that the server binds itself to / the client connects to. + */ + public static final ConfigOption REST_ADDRESS = + key("rest.address") + .defaultValue("localhost"); + + /** + * The port that the server listens on / the client connects to. + */ + public static final ConfigOption REST_PORT = + key("rest.port") + .defaultValue(9067); +} diff --git a/flink-core/src/main/java/org/apache/flink/migration/util/MigrationInstantiationUtil.java b/flink-core/src/main/java/org/apache/flink/migration/util/MigrationInstantiationUtil.java deleted file mode 100644 index 69e4e6daa668e..0000000000000 --- a/flink-core/src/main/java/org/apache/flink/migration/util/MigrationInstantiationUtil.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.util; - -import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.util.InstantiationUtil; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.ObjectInputStream; -import java.io.ObjectStreamClass; - -/** - * Utility class to deserialize legacy classes for migration. - */ -@PublicEvolving -public final class MigrationInstantiationUtil { - - public static class ClassLoaderObjectInputStream extends InstantiationUtil.ClassLoaderObjectInputStream { - - private static final String ARRAY_PREFIX = "[L"; - private static final String FLINK_BASE_PACKAGE = "org.apache.flink."; - private static final String FLINK_MIGRATION_PACKAGE = "org.apache.flink.migration."; - - public ClassLoaderObjectInputStream(InputStream in, ClassLoader classLoader) throws IOException { - super(in, classLoader); - } - - @Override - protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - final String className = desc.getName(); - - // the flink package may be at position 0 (regular class) or position 2 (array) - final int flinkPackagePos; - if ((flinkPackagePos = className.indexOf(FLINK_BASE_PACKAGE)) == 0 || - (flinkPackagePos == 2 && className.startsWith(ARRAY_PREFIX))) - { - final String modClassName = flinkPackagePos == 0 ? - FLINK_MIGRATION_PACKAGE + className.substring(FLINK_BASE_PACKAGE.length()) : - ARRAY_PREFIX + FLINK_MIGRATION_PACKAGE + className.substring(2 + FLINK_BASE_PACKAGE.length()); - - try { - return classLoader != null ? - Class.forName(modClassName, false, classLoader) : - Class.forName(modClassName); - } - catch (ClassNotFoundException ignored) {} - } - - // either a non-Flink class, or not located in the migration package - return super.resolveClass(desc); - } - } - - public static T deserializeObject(byte[] bytes, ClassLoader cl) throws IOException, ClassNotFoundException { - return deserializeObject(new ByteArrayInputStream(bytes), cl); - } - - @SuppressWarnings("unchecked") - public static T deserializeObject(InputStream in, ClassLoader cl) throws IOException, ClassNotFoundException { - final ClassLoader old = Thread.currentThread().getContextClassLoader(); - try (ObjectInputStream oois = new ClassLoaderObjectInputStream(in, cl)) { - Thread.currentThread().setContextClassLoader(cl); - return (T) oois.readObject(); - } finally { - Thread.currentThread().setContextClassLoader(old); - } - } - - // -------------------------------------------------------------------------------------------- - - /** - * Private constructor to prevent instantiation. - */ - private MigrationInstantiationUtil() { - throw new IllegalAccessError(); - } - -} diff --git a/flink-core/src/main/java/org/apache/flink/migration/util/SerializedValue.java b/flink-core/src/main/java/org/apache/flink/migration/util/SerializedValue.java deleted file mode 100644 index 6fa29d3c554c7..0000000000000 --- a/flink-core/src/main/java/org/apache/flink/migration/util/SerializedValue.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.util; - -import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.util.InstantiationUtil; - -import java.io.IOException; -import java.util.Arrays; - -/** - * This class is used to transfer (via serialization) objects whose classes are not available - * in the system class loader. When those objects are deserialized without access to their - * special class loader, the deserialization fails with a {@code ClassNotFoundException}. - * - * To work around that issue, the SerializedValue serialized data immediately into a byte array. - * When send through RPC or another service that uses serialization, only the byte array is - * transferred. The object is deserialized later (upon access) and requires the accessor to - * provide the corresponding class loader. - * - * @param The type of the value held. - * @deprecated Only used internally when migrating from previous savepoint versions. - */ -@Deprecated -@PublicEvolving -public class SerializedValue implements java.io.Serializable { - - private static final long serialVersionUID = -3564011643393683761L; - - /** The serialized data */ - private final byte[] serializedData; - - private SerializedValue(byte[] serializedData) { - this.serializedData = serializedData; - } - - public SerializedValue(T value) throws IOException { - this.serializedData = value == null ? null : InstantiationUtil.serializeObject(value); - } - - @SuppressWarnings("unchecked") - public T deserializeValue(ClassLoader loader) throws IOException, ClassNotFoundException { - return serializedData == null ? null : (T) MigrationInstantiationUtil.deserializeObject(serializedData, loader); - } - - /** - * Returns the serialized value or null if no value is set. - * - * @return Serialized data. - */ - public byte[] getByteArray() { - return serializedData; - } - - public static SerializedValue fromBytes(byte[] serializedData) { - return new SerializedValue(serializedData); - } - - // -------------------------------------------------------------------------------------------- - - @Override - public int hashCode() { - return serializedData == null ? 0 : Arrays.hashCode(serializedData); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof SerializedValue) { - SerializedValue other = (SerializedValue) obj; - return this.serializedData == null ? other.serializedData == null : - (other.serializedData != null && Arrays.equals(this.serializedData, other.serializedData)); - } - else { - return false; - } - } - - @Override - public String toString() { - return "SerializedValue"; - } -} diff --git a/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java b/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java index 9c8907ba9b8fa..d141ecb1dae34 100644 --- a/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java +++ b/flink-core/src/main/java/org/apache/flink/util/ExceptionUtils.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; +import java.util.Optional; import java.util.concurrent.ExecutionException; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -276,27 +277,27 @@ else if (t instanceof Error) { } /** - * Checks whether a throwable chain contains a specific type of exception. + * Checks whether a throwable chain contains a specific type of exception and returns it. * * @param throwable the throwable chain to check. * @param searchType the type of exception to search for in the chain. - * @return True, if the searched type is nested in the throwable, false otherwise. + * @return Optional throwable of the requested type if available, otherwise empty */ - public static boolean containsThrowable(Throwable throwable, Class searchType) { + public static Optional findThrowable(Throwable throwable, Class searchType) { if (throwable == null || searchType == null) { - return false; + return Optional.empty(); } Throwable t = throwable; while (t != null) { if (searchType.isAssignableFrom(t.getClass())) { - return true; + return Optional.of(t); } else { t = t.getCause(); } } - return false; + return Optional.empty(); } /** diff --git a/flink-core/src/main/java/org/apache/flink/util/LambdaUtil.java b/flink-core/src/main/java/org/apache/flink/util/LambdaUtil.java new file mode 100644 index 0000000000000..8ac0f0e597e00 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/util/LambdaUtil.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.util; + +/** + * This class offers utility functions for Java's lambda features. + */ +public final class LambdaUtil { + + private LambdaUtil() { + throw new AssertionError(); + } + + /** + * This method supplies all elements from the input to the consumer. Exceptions that happen on elements are + * suppressed until all elements are processed. If exceptions happened for one or more of the inputs, they are + * reported in a combining suppressed exception. + * + * @param inputs iterator for all inputs to the throwingConsumer. + * @param throwingConsumer this consumer will be called for all elements delivered by the input iterator. + * @param the type of input. + * @throws Exception collected exceptions that happened during the invocation of the consumer on the input elements. + */ + public static void applyToAllWhileSuppressingExceptions( + Iterable inputs, + ThrowingConsumer throwingConsumer) throws Exception { + + if (inputs != null && throwingConsumer != null) { + Exception exception = null; + + for (T input : inputs) { + + if (input != null) { + try { + throwingConsumer.accept(input); + } catch (Exception ex) { + exception = ExceptionUtils.firstOrSuppressed(ex, exception); + } + } + } + + if (exception != null) { + throw exception; + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MigrationRestoreSnapshot.java b/flink-core/src/main/java/org/apache/flink/util/ThrowingConsumer.java similarity index 57% rename from flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MigrationRestoreSnapshot.java rename to flink-core/src/main/java/org/apache/flink/util/ThrowingConsumer.java index 4277b56fc5be1..a180a1290a036 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MigrationRestoreSnapshot.java +++ b/flink-core/src/main/java/org/apache/flink/util/ThrowingConsumer.java @@ -16,20 +16,22 @@ * limitations under the License. */ -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; -import org.apache.flink.runtime.state.heap.StateTable; -import org.apache.flink.util.Migration; - -import java.io.IOException; +package org.apache.flink.util; /** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. + * This interface is basically Java's {@link java.util.function.Consumer} interface enhanced with the ability to throw + * an exception. + * + * @param type of the consumed elements. */ -@Deprecated -@Internal -public interface MigrationRestoreSnapshot extends Migration { - StateTable deserialize(String stateName, HeapKeyedStateBackend stateBackend) throws IOException; +@FunctionalInterface +public interface ThrowingConsumer { + + /** + * Performs this operation on the given argument. + * + * @param t the input argument + * @throws Exception on errors during consumption + */ + void accept(T t) throws Exception; } diff --git a/flink-dist/src/main/flink-bin/bin/config.sh b/flink-dist/src/main/flink-bin/bin/config.sh index 69f15622b74af..3055999314e8b 100755 --- a/flink-dist/src/main/flink-bin/bin/config.sh +++ b/flink-dist/src/main/flink-bin/bin/config.sh @@ -351,8 +351,22 @@ if [ -z "$HADOOP_CONF_DIR" ]; then fi fi +# try and set HADOOP_CONF_DIR to some common default if it's not set +if [ -z "$HADOOP_CONF_DIR" ]; then + if [ -d "/etc/hadoop/conf" ]; then + echo "Setting HADOOP_CONF_DIR=/etc/hadoop/conf because no HADOOP_CONF_DIR was set." + HADOOP_CONF_DIR="/etc/hadoop/conf" + fi +fi + INTERNAL_HADOOP_CLASSPATHS="${HADOOP_CLASSPATH}:${HADOOP_CONF_DIR}:${YARN_CONF_DIR}" +# check if the "hadoop" binary is available, if yes, use that to augment the CLASSPATH +if command -v hadoop >/dev/null 2>&1; then + echo "Using the result of 'hadoop classpath' to augment the Hadoop classpath: `hadoop classpath`" + INTERNAL_HADOOP_CLASSPATHS="${INTERNAL_HADOOP_CLASSPATHS}:`hadoop classpath`" +fi + if [ -n "${HBASE_CONF_DIR}" ]; then # Look for hbase command in HBASE_HOME or search PATH. if [ -n "${HBASE_HOME}" ]; then diff --git a/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-job.sh b/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-job.sh new file mode 100755 index 0000000000000..b21670a6452ea --- /dev/null +++ b/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-job.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +# get Flink config +. "$bin"/config.sh + +if [ "$FLINK_IDENT_STRING" = "" ]; then + FLINK_IDENT_STRING="$USER" +fi + +CC_CLASSPATH=`manglePathList $(constructFlinkClassPath):$INTERNAL_HADOOP_CLASSPATHS` + +log="${FLINK_LOG_DIR}/flink-${FLINK_IDENT_STRING}-mesos-appmaster-${HOSTNAME}.log" +log_setting="-Dlog.file="$log" -Dlog4j.configuration=file:"$FLINK_CONF_DIR"/log4j.properties -Dlogback.configurationFile=file:"$FLINK_CONF_DIR"/logback.xml" + +export FLINK_CONF_DIR +export FLINK_BIN_DIR +export FLINK_LIB_DIR + +exec $JAVA_RUN $JVM_ARGS -classpath "$CC_CLASSPATH" $log_setting org.apache.flink.mesos.entrypoint.MesosJobClusterEntrypoint "$@" + +rc=$? + +if [[ $rc -ne 0 ]]; then + echo "Error while starting the mesos application master. Please check ${log} for more details." +fi + +exit $rc diff --git a/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-session.sh b/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-session.sh new file mode 100755 index 0000000000000..b9e0f5375c640 --- /dev/null +++ b/flink-dist/src/main/flink-bin/mesos-bin/mesos-appmaster-flip6-session.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +# get Flink config +. "$bin"/config.sh + +if [ "$FLINK_IDENT_STRING" = "" ]; then + FLINK_IDENT_STRING="$USER" +fi + +CC_CLASSPATH=`manglePathList $(constructFlinkClassPath):$INTERNAL_HADOOP_CLASSPATHS` + +log="${FLINK_LOG_DIR}/flink-${FLINK_IDENT_STRING}-mesos-appmaster-${HOSTNAME}.log" +log_setting="-Dlog.file="$log" -Dlog4j.configuration=file:"$FLINK_CONF_DIR"/log4j.properties -Dlogback.configurationFile=file:"$FLINK_CONF_DIR"/logback.xml" + +export FLINK_CONF_DIR +export FLINK_BIN_DIR +export FLINK_LIB_DIR + +exec $JAVA_RUN $JVM_ARGS -classpath "$CC_CLASSPATH" $log_setting org.apache.flink.mesos.entrypoint.MesosSessionClusterEntrypoint "$@" + +rc=$? + +if [[ $rc -ne 0 ]]; then + echo "Error while starting the mesos application master. Please check ${log} for more details." +fi + +exit $rc diff --git a/flink-dist/src/main/flink-bin/mesos-bin/mesos-taskmanager-flip6.sh b/flink-dist/src/main/flink-bin/mesos-bin/mesos-taskmanager-flip6.sh new file mode 100755 index 0000000000000..f2514429fe7e8 --- /dev/null +++ b/flink-dist/src/main/flink-bin/mesos-bin/mesos-taskmanager-flip6.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +bin=`dirname "$0"` +bin=`cd "$bin"; pwd` + +# get Flink config +. "$bin"/config.sh + +CC_CLASSPATH=`manglePathList $(constructFlinkClassPath):$INTERNAL_HADOOP_CLASSPATHS` + +log=flink-taskmanager.log +log_setting="-Dlog.file="$log" -Dlog4j.configuration=file:"$FLINK_CONF_DIR"/log4j.properties -Dlogback.configurationFile=file:"$FLINK_CONF_DIR"/logback.xml" + +# Add precomputed memory JVM options +if [ -z "${FLINK_ENV_JAVA_OPTS_MEM}" ]; then + FLINK_ENV_JAVA_OPTS_MEM="" +fi +export FLINK_ENV_JAVA_OPTS="${FLINK_ENV_JAVA_OPTS} ${FLINK_ENV_JAVA_OPTS_MEM}" + +# Add TaskManager-specific JVM options +export FLINK_ENV_JAVA_OPTS="${FLINK_ENV_JAVA_OPTS} ${FLINK_ENV_JAVA_OPTS_TM}" + +export FLINK_CONF_DIR +export FLINK_BIN_DIR +export FLINK_LIB_DIR + +exec $JAVA_RUN $JVM_ARGS ${FLINK_ENV_JAVA_OPTS} -classpath "$CC_CLASSPATH" $log_setting org.apache.flink.mesos.entrypoint.MesosTaskExecutorRunner "$@" + diff --git a/flink-examples/flink-examples-streaming/pom.xml b/flink-examples/flink-examples-streaming/pom.xml index eba81d3851d43..dd32a1d2d3f90 100644 --- a/flink-examples/flink-examples-streaming/pom.xml +++ b/flink-examples/flink-examples-streaming/pom.xml @@ -78,6 +78,11 @@ under the License. test test-jar + + org.apache.flink + flink-connector-azureeventhubs_${scala.binary.version} + ${project.version} + diff --git a/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/async/AsyncIOExample.java b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/async/AsyncIOExample.java index 748cb82fe4c86..95379e39d8d8e 100644 --- a/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/async/AsyncIOExample.java +++ b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/async/AsyncIOExample.java @@ -29,8 +29,8 @@ import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.util.Collector; @@ -178,7 +178,7 @@ public void close() throws Exception { } @Override - public void asyncInvoke(final Integer input, final AsyncCollector collector) throws Exception { + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) throws Exception { this.executorService.submit(new Runnable() { @Override public void run() { @@ -188,13 +188,13 @@ public void run() { Thread.sleep(sleep); if (random.nextFloat() < failRatio) { - collector.collect(new Exception("wahahahaha...")); + resultFuture.completeExceptionally(new Exception("wahahahaha...")); } else { - collector.collect( + resultFuture.complete( Collections.singletonList("key-" + (input % 10))); } } catch (InterruptedException e) { - collector.collect(new ArrayList(0)); + resultFuture.complete(new ArrayList(0)); } } }); diff --git a/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/ReadFromEventhub.java b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/ReadFromEventhub.java new file mode 100644 index 0000000000000..d04e921669b5b --- /dev/null +++ b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/ReadFromEventhub.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.examples.eventhub; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.utils.ParameterTool; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.connectors.eventhubs.FlinkEventHubConsumer; +import org.apache.flink.streaming.util.serialization.SimpleStringSchema; + +/** + * Created by jozh on 6/21/2017. + */ +public class ReadFromEventhub { + public static void main(String[] args) throws Exception { + // parse input arguments + final ParameterTool parameterTool = ParameterTool.fromArgs(args); + + if (parameterTool.getNumberOfParameters() < 4) { + System.out.println("Missing parameters!\nUsage: ReadFromEventhub --eventhubs.policykey " + + "--eventhubs.namespace --eventhubs.name --eventhubs.partition.count "); + return; + } + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().disableSysoutLogging(); + env.getConfig().setRestartStrategy(RestartStrategies.fixedDelayRestart(4, 10000)); + env.enableCheckpointing(5000); // create a checkpoint every 5 seconds + env.getConfig().setGlobalJobParameters(parameterTool); // make parameters available in the web interface + + DataStream messageStream = env + .addSource(new FlinkEventHubConsumer( + parameterTool.getProperties(), + new SimpleStringSchema())); + + // write kafka stream to standard out. + messageStream.print(); + + env.execute("Read from eventhub example"); + } +} diff --git a/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/WriteToEventhub.java b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/WriteToEventhub.java new file mode 100644 index 0000000000000..397daf9f452ac --- /dev/null +++ b/flink-examples/flink-examples-streaming/src/main/java/org/apache/flink/streaming/examples/eventhub/WriteToEventhub.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.examples.eventhub; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.utils.ParameterTool; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.connectors.eventhubs.FlinkEventHubProducer; +import org.apache.flink.streaming.util.serialization.SimpleStringSchema; + +/** + * Created by jozh on 6/21/2017. + */ +public class WriteToEventhub { + public static void main(String[] args) throws Exception { + ParameterTool parameterTool = ParameterTool.fromArgs(args); + if (parameterTool.getNumberOfParameters() < 4) { + System.out.println("Missing parameters!"); + System.out.println("Usage: WriteToEventhub --eventhubs.saskeyname " + + "--eventhubs.saskey --eventhubs.namespace --eventhubs.name "); + return; + } + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().disableSysoutLogging(); + env.getConfig().setRestartStrategy(RestartStrategies.fixedDelayRestart(4, 10000)); + + // very simple data generator + DataStream messageStream = env.addSource(new SourceFunction() { + private static final long serialVersionUID = 6369260445318862378L; + public boolean running = true; + + @Override + public void run(SourceContext ctx) throws Exception { + long i = 0; + while (this.running) { + ctx.collect("Element - " + i++); + Thread.sleep(500); + } + } + + @Override + public void cancel() { + running = false; + } + }); + + // write data into Kafka + messageStream.addSink(new FlinkEventHubProducer(new SimpleStringSchema(), parameterTool.getProperties())); + + env.execute("Write into eventhub example"); + } +} diff --git a/flink-examples/flink-examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/async/AsyncIOExample.scala b/flink-examples/flink-examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/async/AsyncIOExample.scala index 69c4c0a295e05..5808aaaf8d7ad 100644 --- a/flink-examples/flink-examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/async/AsyncIOExample.scala +++ b/flink-examples/flink-examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/async/AsyncIOExample.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext import org.apache.flink.streaming.api.scala._ -import org.apache.flink.streaming.api.scala.async.AsyncCollector +import org.apache.flink.streaming.api.scala.async.ResultFuture import scala.concurrent.{ExecutionContext, Future} @@ -38,9 +38,9 @@ object AsyncIOExample { val input = env.addSource(new SimpleSource()) val asyncMapped = AsyncDataStream.orderedWait(input, timeout, TimeUnit.MILLISECONDS, 10) { - (input, collector: AsyncCollector[Int]) => + (input, collector: ResultFuture[Int]) => Future { - collector.collect(Seq(input)) + collector.complete(Seq(input)) } (ExecutionContext.global) } diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingMigrationTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingMigrationTest.java index 78c57edabf953..602ad3e1bb0ff 100644 --- a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingMigrationTest.java +++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingMigrationTest.java @@ -76,7 +76,6 @@ public class ContinuousFileProcessingMigrationTest { @Parameterized.Parameters(name = "Migration Savepoint / Mod Time: {0}") public static Collection> parameters () { return Arrays.asList( - Tuple2.of(MigrationVersion.v1_1, 1482144479339L), Tuple2.of(MigrationVersion.v1_2, 1493116191000L), Tuple2.of(MigrationVersion.v1_3, 1496532000000L)); } diff --git a/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/PatternStream.scala b/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/PatternStream.scala index d270ef7c65f4f..d1b07b3519e76 100644 --- a/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/PatternStream.scala +++ b/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/PatternStream.scala @@ -17,20 +17,17 @@ */ package org.apache.flink.cep.scala -import java.util.{List => JList, Map => JMap} +import java.util.{UUID, List => JList, Map => JMap} import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.cep.{EventComparator, PatternFlatSelectFunction, PatternFlatTimeoutFunction, PatternSelectFunction, PatternTimeoutFunction, PatternStream => JPatternStream} import org.apache.flink.cep.pattern.{Pattern => JPattern} +import org.apache.flink.cep.scala.pattern.Pattern +import org.apache.flink.cep.{EventComparator, PatternFlatSelectFunction, PatternFlatTimeoutFunction, PatternSelectFunction, PatternTimeoutFunction, PatternStream => JPatternStream} import org.apache.flink.streaming.api.scala.{asScalaStream, _} import org.apache.flink.util.Collector -import org.apache.flink.types.{Either => FEither} -import org.apache.flink.api.java.tuple.{Tuple2 => FTuple2} -import java.lang.{Long => JLong} import org.apache.flink.cep.operator.CEPOperatorUtils import org.apache.flink.cep.scala.pattern.Pattern - import scala.collection.Map /** @@ -84,37 +81,54 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { * pattern sequence. * @tparam L Type of the resulting timeout event * @tparam R Type of the resulting event + * @deprecated Use the version that returns timeouted events as a side-output * @return Data stream of either type which contains the resulting events and resulting timeout * events. */ + @deprecated def select[L: TypeInformation, R: TypeInformation]( patternTimeoutFunction: PatternTimeoutFunction[T, L], patternSelectFunction: PatternSelectFunction[T, R]) : DataStream[Either[L, R]] = { + val outputTag = OutputTag[L](UUID.randomUUID().toString) + val mainStream = select(outputTag, patternTimeoutFunction, patternSelectFunction) + mainStream.connect(mainStream.getSideOutput[L](outputTag)).map(r => Right(r), l => Left(l)) + } - val patternStream = CEPOperatorUtils.createTimeoutPatternStream( - jPatternStream.getInputStream, - jPatternStream.getPattern, - jPatternStream.getComparator) - + /** + * Applies a select function to the detected pattern sequence. For each pattern sequence the + * provided [[PatternSelectFunction]] is called. The pattern select function can produce + * exactly one resulting element. + * + * Additionally a timeout function is applied to partial event patterns which have timed out. For + * each partial pattern sequence the provided [[PatternTimeoutFunction]] is called. The pattern + * timeout function has to produce exactly one resulting timeout event. + * + * You can get the stream of timeouted matches using [[DataStream.getSideOutput()]] on the + * [[DataStream]] resulting from the windowed operation with the same [[OutputTag]]. + * + * @param outputTag [[OutputTag]] that identifies side output with timeouted patterns + * @param patternTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param patternSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @tparam L Type of the resulting timeout event + * @tparam R Type of the resulting event + * @return Data stream which contains the resulting elements with the resulting timeout elements + * in a side output. + */ + def select[L: TypeInformation, R: TypeInformation]( + outputTag: OutputTag[L], + patternTimeoutFunction: PatternTimeoutFunction[T, L], + patternSelectFunction: PatternSelectFunction[T, R]) + : DataStream[R] = { val cleanedSelect = cleanClosure(patternSelectFunction) val cleanedTimeout = cleanClosure(patternTimeoutFunction) - implicit val eitherTypeInfo = createTypeInformation[Either[L, R]] - - asScalaStream(patternStream).map[Either[L, R]] { - input: FEither[FTuple2[JMap[String, JList[T]], JLong], JMap[String, JList[T]]] => - if (input.isLeft) { - val timeout = input.left() - val timeoutEvent = cleanedTimeout.timeout(timeout.f0, timeout.f1) - val t = Left[L, R](timeoutEvent) - t - } else { - val event = cleanedSelect.select(input.right()) - val t = Right[L, R](event) - t - } - } + asScalaStream( + jPatternStream + .select(outputTag, cleanedTimeout, implicitly[TypeInformation[R]], cleanedSelect) + ) } /** @@ -151,44 +165,58 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { * detected pattern sequence. * @tparam L Type of the resulting timeout event * @tparam R Type of the resulting event + * @deprecated Use the version that returns timeouted events as a side-output * @return Data stream of either type which contains the resulting events and the resulting * timeout events wrapped in a [[Either]] type. */ + @deprecated def flatSelect[L: TypeInformation, R: TypeInformation]( patternFlatTimeoutFunction: PatternFlatTimeoutFunction[T, L], patternFlatSelectFunction: PatternFlatSelectFunction[T, R]) : DataStream[Either[L, R]] = { - val patternStream = CEPOperatorUtils.createTimeoutPatternStream( - jPatternStream.getInputStream, - jPatternStream.getPattern, - jPatternStream.getComparator - ) - - val cleanedSelect = cleanClosure(patternFlatSelectFunction) - val cleanedTimeout = cleanClosure(patternFlatTimeoutFunction) - - implicit val eitherTypeInfo = createTypeInformation[Either[L, R]] - asScalaStream(patternStream).flatMap[Either[L, R]] { - (input: FEither[FTuple2[JMap[String, JList[T]], JLong], JMap[String, JList[T]]], - collector: Collector[Either[L, R]]) => - - if (input.isLeft()) { - val timeout = input.left() + val outputTag = OutputTag[L]("dummy-timeouted") + val mainStream = flatSelect(outputTag, patternFlatTimeoutFunction, patternFlatSelectFunction) + mainStream.connect(mainStream.getSideOutput[L](outputTag)).map(r => Right(r), l => Left(l)) + } - cleanedTimeout.timeout(timeout.f0, timeout.f1, new Collector[L]() { - override def collect(record: L): Unit = collector.collect(Left(record)) + /** + * Applies a flat select function to the detected pattern sequence. For each pattern sequence + * the provided [[PatternFlatSelectFunction]] is called. The pattern flat select function can + * produce an arbitrary number of resulting elements. + * + * Additionally a timeout function is applied to partial event patterns which have timed out. For + * each partial pattern sequence the provided [[PatternFlatTimeoutFunction]] is called. The + * pattern timeout function can produce an arbitrary number of resulting timeout events. + * + * You can get the stream of timeouted matches using [[DataStream.getSideOutput()]] on the + * [[DataStream]] resulting from the windowed operation with the same [[OutputTag]]. + * + * @param outputTag [[OutputTag]] that identifies side output with timeouted patterns + * @param patternFlatTimeoutFunction The pattern flat timeout function which is called for each + * partially matched pattern sequence which has timed out. + * @param patternFlatSelectFunction The pattern flat select function which is called for each + * detected pattern sequence. + * @tparam L Type of the resulting timeout event + * @tparam R Type of the resulting event + * @return Data stream which contains the resulting elements with the resulting timeout elements + * in a side output. + */ + def flatSelect[L: TypeInformation, R: TypeInformation]( + outputTag: OutputTag[L], + patternFlatTimeoutFunction: PatternFlatTimeoutFunction[T, L], + patternFlatSelectFunction: PatternFlatSelectFunction[T, R]) + : DataStream[R] = { - override def close(): Unit = collector.close() - }) - } else { - cleanedSelect.flatSelect(input.right, new Collector[R]() { - override def collect(record: R): Unit = collector.collect(Right(record)) + val cleanedSelect = cleanClosure(patternFlatSelectFunction) + val cleanedTimeout = cleanClosure(patternFlatTimeoutFunction) - override def close(): Unit = collector.close() - }) - } - } + asScalaStream( + jPatternStream.flatSelect( + outputTag, + cleanedTimeout, + implicitly[TypeInformation[R]], + cleanedSelect)) } /** @@ -228,9 +256,11 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { * pattern sequence. * @tparam L Type of the resulting timeout event * @tparam R Type of the resulting event + * @deprecated Use the version that returns timeouted events as a side-output * @return Data stream of either type which contain the resulting events and resulting timeout * events. */ + @deprecated def select[L: TypeInformation, R: TypeInformation]( patternTimeoutFunction: (Map[String, Iterable[T]], Long) => L) ( patternSelectFunction: Map[String, Iterable[T]] => R) @@ -251,6 +281,48 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { select(patternTimeoutFun, patternSelectFun) } + /** + * Applies a select function to the detected pattern sequence. For each pattern sequence the + * provided [[PatternSelectFunction]] is called. The pattern select function can produce + * exactly one resulting element. + * + * Additionally a timeout function is applied to partial event patterns which have timed out. For + * each partial pattern sequence the provided [[PatternTimeoutFunction]] is called. The pattern + * timeout function has to produce exactly one resulting element. + * + * You can get the stream of timeouted matches using [[DataStream.getSideOutput()]] on the + * [[DataStream]] resulting from the windowed operation with the same [[OutputTag]]. + * + * @param outputTag [[OutputTag]] that identifies side output with timeouted patterns + * @param patternTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param patternSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @tparam L Type of the resulting timeout event + * @tparam R Type of the resulting event + * @return Data stream of either type which contain the resulting events and resulting timeout + * events. + */ + def select[L: TypeInformation, R: TypeInformation](outputTag: OutputTag[L])( + patternTimeoutFunction: (Map[String, Iterable[T]], Long) => L) ( + patternSelectFunction: Map[String, Iterable[T]] => R) + : DataStream[R] = { + + val cleanSelectFun = cleanClosure(patternSelectFunction) + val cleanTimeoutFun = cleanClosure(patternTimeoutFunction) + + val patternSelectFun = new PatternSelectFunction[T, R] { + override def select(pattern: JMap[String, JList[T]]): R = + cleanSelectFun(mapToScala(pattern)) + } + val patternTimeoutFun = new PatternTimeoutFunction[T, L] { + override def timeout(pattern: JMap[String, JList[T]], timeoutTimestamp: Long): L = + cleanTimeoutFun(mapToScala(pattern), timeoutTimestamp) + } + + select(outputTag, patternTimeoutFun, patternSelectFun) + } + /** * Applies a flat select function to the detected pattern sequence. For each pattern sequence * the provided [[PatternFlatSelectFunction]] is called. The pattern flat select function @@ -292,9 +364,11 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { * detected pattern sequence. * @tparam L Type of the resulting timeout event * @tparam R Type of the resulting event + * @deprecated Use the version that returns timeouted events as a side-output * @return Data stream of either type which contains the resulting events and the resulting * timeout events wrapped in a [[Either]] type. */ + @deprecated def flatSelect[L: TypeInformation, R: TypeInformation]( patternFlatTimeoutFunction: (Map[String, Iterable[T]], Long, Collector[L]) => Unit) ( patternFlatSelectFunction: (Map[String, Iterable[T]], Collector[R]) => Unit) @@ -319,6 +393,53 @@ class PatternStream[T](jPatternStream: JPatternStream[T]) { flatSelect(patternFlatTimeoutFun, patternFlatSelectFun) } + + /** + * Applies a flat select function to the detected pattern sequence. For each pattern sequence + * the provided [[PatternFlatSelectFunction]] is called. The pattern flat select function can + * produce an arbitrary number of resulting elements. + * + * Additionally a timeout function is applied to partial event patterns which have timed out. For + * each partial pattern sequence the provided [[PatternFlatTimeoutFunction]] is called. The + * pattern timeout function can produce an arbitrary number of resulting timeout events. + * + * You can get the stream of timeouted matches using [[DataStream.getSideOutput()]] on the + * [[DataStream]] resulting from the windowed operation with the same [[OutputTag]]. + * + * @param outputTag [[OutputTag]] that identifies side output with timeouted patterns + * @param patternFlatTimeoutFunction The pattern flat timeout function which is called for each + * partially matched pattern sequence which has timed out. + * @param patternFlatSelectFunction The pattern flat select function which is called for each + * detected pattern sequence. + * @tparam L Type of the resulting timeout event + * @tparam R Type of the resulting event + * @return Data stream of either type which contains the resulting events and the resulting + * timeout events wrapped in a [[Either]] type. + */ + def flatSelect[L: TypeInformation, R: TypeInformation](outputTag: OutputTag[L])( + patternFlatTimeoutFunction: (Map[String, Iterable[T]], Long, Collector[L]) => Unit) ( + patternFlatSelectFunction: (Map[String, Iterable[T]], Collector[R]) => Unit) + : DataStream[R] = { + + val cleanSelectFun = cleanClosure(patternFlatSelectFunction) + val cleanTimeoutFun = cleanClosure(patternFlatTimeoutFunction) + + val patternFlatSelectFun = new PatternFlatSelectFunction[T, R] { + override def flatSelect(pattern: JMap[String, JList[T]], out: Collector[R]): Unit = + cleanSelectFun(mapToScala(pattern), out) + } + + val patternFlatTimeoutFun = new PatternFlatTimeoutFunction[T, L] { + override def timeout( + pattern: JMap[String, JList[T]], + timeoutTimestamp: Long, out: Collector[L]) + : Unit = { + cleanTimeoutFun(mapToScala(pattern), timeoutTimestamp, out) + } + } + + flatSelect(outputTag, patternFlatTimeoutFun, patternFlatSelectFun) + } } object PatternStream { @@ -328,7 +449,7 @@ object PatternStream { * @tparam T Type of the events * @return A new pattern stream wrapping the pattern stream from Java APU */ - def apply[T](jPatternStream: JPatternStream[T]) = { + def apply[T](jPatternStream: JPatternStream[T]): PatternStream[T] = { new PatternStream[T](jPatternStream) } } diff --git a/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/pattern/Pattern.scala b/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/pattern/Pattern.scala index 5daebe0483567..42a95e82c286e 100644 --- a/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/pattern/Pattern.scala +++ b/flink-libraries/flink-cep-scala/src/main/scala/org/apache/flink/cep/scala/pattern/Pattern.scala @@ -18,6 +18,7 @@ package org.apache.flink.cep.scala.pattern import org.apache.flink.cep +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy import org.apache.flink.cep.pattern.conditions.IterativeCondition.{Context => JContext} import org.apache.flink.cep.pattern.conditions.{IterativeCondition, SimpleCondition} import org.apache.flink.cep.pattern.{MalformedPatternException, Quantifier, Pattern => JPattern} @@ -344,7 +345,7 @@ class Pattern[T , F <: T](jPattern: JPattern[T, F]) { * {{{A1 A2 B}}} appears, this will generate patterns: * {{{A1 B}}} and {{{A1 A2 B}}}. See also {{{allowCombinations()}}}. * - * @return The same pattern with a [[Quantifier.oneOrMore()]] quantifier applied. + * @return The same pattern with a [[Quantifier.looping()]] quantifier applied. * @throws MalformedPatternException if the quantifier is not applicable to this pattern. */ def oneOrMore: Pattern[T, F] = { @@ -352,6 +353,18 @@ class Pattern[T , F <: T](jPattern: JPattern[T, F]) { this } + /** + * Specifies that this pattern is greedy. + * This means as many events as possible will be matched to this pattern. + * + * @return The same pattern with { @link Quantifier#greedy} set to true. + * @throws MalformedPatternException if the quantifier is not applicable to this pattern. + */ + def greedy: Pattern[T, F] = { + jPattern.greedy() + this + } + /** * Specifies exact number of times that this pattern should be matched. * @@ -378,7 +391,21 @@ class Pattern[T , F <: T](jPattern: JPattern[T, F]) { } /** - * Applicable only to [[Quantifier.oneOrMore()]] and [[Quantifier.times()]] patterns, + * Specifies that this pattern can occur the specified times at least. + * This means at least the specified times and at most infinite number of events can + * be matched to this pattern. + * + * @return The same pattern with a { @link Quantifier#looping(ConsumingStrategy)} quantifier + * applied. + * @throws MalformedPatternException if the quantifier is not applicable to this pattern. + */ + def timesOrMore(times: Int): Pattern[T, F] = { + jPattern.timesOrMore(times) + this + } + + /** + * Applicable only to [[Quantifier.looping()]] and [[Quantifier.times()]] patterns, * this option allows more flexibility to the matching events. * * If {{{allowCombinations()}}} is not applied for a @@ -452,6 +479,12 @@ class Pattern[T , F <: T](jPattern: JPattern[T, F]) { def next(pattern: Pattern[T, F]): GroupPattern[T, F] = GroupPattern[T, F](jPattern.next(pattern.wrappedPattern)) + /** + * Get after match skip strategy. + * @return current after match skip strategy + */ + def getAfterMatchSkipStrategy: AfterMatchSkipStrategy = + jPattern.getAfterMatchSkipStrategy } object Pattern { @@ -476,6 +509,18 @@ object Pattern { */ def begin[X](name: String): Pattern[X, X] = Pattern(JPattern.begin(name)) + /** + * Starts a new pattern sequence. The provided name is the one of the initial pattern + * of the new sequence. Furthermore, the base type of the event sequence is set. + * + * @param name The name of starting pattern of the new pattern sequence + * @param afterMatchSkipStrategy The skip strategy to use after each match + * @tparam X Base type of the event pattern + * @return The first pattern of a pattern sequence + */ + def begin[X](name: String, afterMatchSkipStrategy: AfterMatchSkipStrategy): Pattern[X, X] = + Pattern(JPattern.begin(name, afterMatchSkipStrategy)) + /** * Starts a new pattern sequence. The provided pattern is the initial pattern * of the new sequence. @@ -485,4 +530,17 @@ object Pattern { */ def begin[T, F <: T](pattern: Pattern[T, F]): GroupPattern[T, F] = GroupPattern[T, F](JPattern.begin(pattern.wrappedPattern)) + + /** + * Starts a new pattern sequence. The provided pattern is the initial pattern + * of the new sequence. + * + * @param pattern the pattern to begin with + * @param afterMatchSkipStrategy The skip strategy to use after each match + * @return The first pattern of a pattern sequence + */ + def begin[T, F <: T](pattern: Pattern[T, F], + afterMatchSkipStrategy: AfterMatchSkipStrategy): GroupPattern[T, F] = + GroupPattern(JPattern.begin(pattern.wrappedPattern, afterMatchSkipStrategy)) + } diff --git a/flink-libraries/flink-cep-scala/src/test/scala/org/apache/flink/cep/scala/PatternStreamScalaJavaAPIInteroperabilityTest.scala b/flink-libraries/flink-cep-scala/src/test/scala/org/apache/flink/cep/scala/PatternStreamScalaJavaAPIInteroperabilityTest.scala index e2161a02249a9..f3371c86c8419 100644 --- a/flink-libraries/flink-cep-scala/src/test/scala/org/apache/flink/cep/scala/PatternStreamScalaJavaAPIInteroperabilityTest.scala +++ b/flink-libraries/flink-cep-scala/src/test/scala/org/apache/flink/cep/scala/PatternStreamScalaJavaAPIInteroperabilityTest.scala @@ -21,15 +21,17 @@ import org.apache.flink.api.common.functions.util.ListCollector import org.apache.flink.cep.scala.pattern.Pattern import org.apache.flink.streaming.api.operators.{StreamFlatMap, StreamMap} import org.apache.flink.streaming.api.scala._ -import org.apache.flink.streaming.api.transformations.OneInputTransformation +import org.apache.flink.streaming.api.transformations.{OneInputTransformation, TwoInputTransformation} import org.apache.flink.util.{Collector, TestLogger} import org.apache.flink.types.{Either => FEither} import org.apache.flink.api.java.tuple.{Tuple2 => FTuple2} - import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{List => JList} +import org.apache.flink.cep.operator.{FlatSelectCepOperator, FlatSelectTimeoutCepOperator, SelectCepOperator} +import org.apache.flink.streaming.api.functions.co.CoMapFunction + import scala.collection.JavaConverters._ import scala.collection.Map import org.junit.Assert._ @@ -51,8 +53,8 @@ class PatternStreamScalaJavaAPIInteroperabilityTest extends TestLogger { assertEquals(param, pattern) param.get("begin").get(0) }) - val out = extractUserFunction[StreamMap[JMap[String, JList[(Int, Int)]], (Int, Int)]](result) - .getUserFunction.map(param.mapValues(_.asJava).asJava) + val out = extractUserFunction[SelectCepOperator[(Int, Int), Byte, (Int, Int)]](result) + .getUserFunction.select(param.mapValues(_.asJava).asJava) //verifies output parameter forwarding assertEquals(param.get("begin").get(0), out) } @@ -77,8 +79,8 @@ class PatternStreamScalaJavaAPIInteroperabilityTest extends TestLogger { out.collect(pattern.get("begin").get.head) }) - extractUserFunction[StreamFlatMap[java.util.Map[String, JList[List[Int]]], List[Int]]](result). - getUserFunction.flatMap(inParam.mapValues(_.asJava).asJava, outParam) + extractUserFunction[FlatSelectCepOperator[List[Int], Byte, List[Int]]](result). + getUserFunction.flatSelect(inParam.mapValues(_.asJava).asJava, outParam) //verify output parameter forwarding and that flatMap function was actually called assertEquals(inList, outList.get(0)) } @@ -96,28 +98,26 @@ class PatternStreamScalaJavaAPIInteroperabilityTest extends TestLogger { val expectedOutput = List(Right("match"), Right("barfoo"), Left("timeout"), Left("barfoo")) .asJava - val result: DataStream[Either[String, String]] = pStream.flatSelect { - (pattern: Map[String, Iterable[String]], timestamp: Long, out: Collector[String]) => - out.collect("timeout") - out.collect(pattern("begin").head) + val outputTag = OutputTag[Either[String, String]]("timeouted") + val result: DataStream[Either[String, String]] = pStream.flatSelect(outputTag) { + (pattern: Map[String, Iterable[String]], timestamp: Long, + out: Collector[Either[String, String]]) => + out.collect(Left("timeout")) + out.collect(Left(pattern("begin").head)) } { - (pattern: Map[String, Iterable[String]], out: Collector[String]) => + (pattern: Map[String, Iterable[String]], out: Collector[Either[String, String]]) => //verifies input parameter forwarding assertEquals(inParam, pattern) - out.collect("match") - out.collect(pattern("begin").head) + out.collect(Right("match")) + out.collect(Right(pattern("begin").head)) } - val fun = extractUserFunction[ - StreamFlatMap[ - FEither[ - FTuple2[JMap[String, JList[String]], JLong], - JMap[String, JList[String]]], - Either[String, String]]](result) + val fun = extractUserFunction[FlatSelectTimeoutCepOperator[String, Either[String, String], + Either[String, String], Byte]]( + result).getUserFunction - fun.getUserFunction.flatMap(FEither.Right(inParam.mapValues(_.asJava).asJava), output) - fun.getUserFunction.flatMap(FEither.Left(FTuple2.of(inParam.mapValues(_.asJava).asJava, 42L)), - output) + fun.getFlatSelectFunction.flatSelect(inParam.mapValues(_.asJava).asJava, output) + fun.getFlatTimeoutFunction.timeout(inParam.mapValues(_.asJava).asJava, 42L, output) assertEquals(expectedOutput, outList) } @@ -129,4 +129,5 @@ class PatternStreamScalaJavaAPIInteroperabilityTest extends TestLogger { .getOperator .asInstanceOf[T] } + } diff --git a/flink-libraries/flink-cep/pom.xml b/flink-libraries/flink-cep/pom.xml index 35045c03d9c19..23978b2f72012 100644 --- a/flink-libraries/flink-cep/pom.xml +++ b/flink-libraries/flink-cep/pom.xml @@ -52,11 +52,10 @@ under the License. provided - - com.google.guava - guava - ${guava.version} - + + org.apache.flink + flink-shaded-guava + diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/PatternStream.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/PatternStream.java index 555d270bd4f11..79ca736de412f 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/PatternStream.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/PatternStream.java @@ -18,21 +18,20 @@ package org.apache.flink.cep; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.EitherTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.cep.operator.CEPOperatorUtils; import org.apache.flink.cep.pattern.Pattern; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.types.Either; -import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; -import java.util.List; -import java.util.Map; +import java.util.UUID; /** * Stream abstraction for CEP pattern detection. A pattern stream is a stream which emits detected @@ -108,6 +107,16 @@ public SingleOutputStreamOperator select(final PatternSelectFunction F clean(F f) { + return inputStream.getExecutionEnvironment().clean(f); + } + /** * Applies a select function to the detected pattern sequence. For each pattern sequence the * provided {@link PatternSelectFunction} is called. The pattern select function can produce @@ -121,13 +130,94 @@ public SingleOutputStreamOperator select(final PatternSelectFunction SingleOutputStreamOperator select(final PatternSelectFunction patternSelectFunction, TypeInformation outTypeInfo) { - SingleOutputStreamOperator>> patternStream = - CEPOperatorUtils.createPatternStream(inputStream, pattern, comparator); + return CEPOperatorUtils.createPatternStream(inputStream, pattern, comparator, clean(patternSelectFunction), outTypeInfo); + } - return patternStream.map( - new PatternSelectMapper<>( - patternStream.getExecutionEnvironment().clean(patternSelectFunction))) - .returns(outTypeInfo); + /** + * Applies a select function to the detected pattern sequence. For each pattern sequence the + * provided {@link PatternSelectFunction} is called. The pattern select function can produce + * exactly one resulting element. + * + *

Applies a timeout function to a partial pattern sequence which has timed out. For each + * partial pattern sequence the provided {@link PatternTimeoutFunction} is called. The pattern + * timeout function can produce exactly one resulting element. + * + *

You can get the stream of timed-out data resulting from the + * {@link SingleOutputStreamOperator#getSideOutput(OutputTag)} on the + * {@link SingleOutputStreamOperator} resulting from the select operation + * with the same {@link OutputTag}. + * + * @param timeoutOutputTag {@link OutputTag} that identifies side output with timed out patterns + * @param patternTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param patternSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @param Type of the resulting timeout elements + * @param Type of the resulting elements + * @return {@link DataStream} which contains the resulting elements with the resulting timeout + * elements in a side output. + */ + public SingleOutputStreamOperator select( + final OutputTag timeoutOutputTag, + final PatternTimeoutFunction patternTimeoutFunction, + final PatternSelectFunction patternSelectFunction) { + + TypeInformation rightTypeInfo = TypeExtractor.getUnaryOperatorReturnType( + patternSelectFunction, + PatternSelectFunction.class, + 0, + 1, + new int[]{0, 1, 0}, + new int[]{}, + inputStream.getType(), + null, + false); + + return select( + timeoutOutputTag, + patternTimeoutFunction, + rightTypeInfo, + patternSelectFunction); + } + + /** + * Applies a select function to the detected pattern sequence. For each pattern sequence the + * provided {@link PatternSelectFunction} is called. The pattern select function can produce + * exactly one resulting element. + * + *

Applies a timeout function to a partial pattern sequence which has timed out. For each + * partial pattern sequence the provided {@link PatternTimeoutFunction} is called. The pattern + * timeout function can produce exactly one resulting element. + * + *

You can get the stream of timed-out data resulting from the + * {@link SingleOutputStreamOperator#getSideOutput(OutputTag)} on the + * {@link SingleOutputStreamOperator} resulting from the select operation + * with the same {@link OutputTag}. + * + * @param timeoutOutputTag {@link OutputTag} that identifies side output with timed out patterns + * @param patternTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param outTypeInfo Explicit specification of output type. + * @param patternSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @param Type of the resulting timeout elements + * @param Type of the resulting elements + * @return {@link DataStream} which contains the resulting elements with the resulting timeout + * elements in a side output. + */ + public SingleOutputStreamOperator select( + final OutputTag timeoutOutputTag, + final PatternTimeoutFunction patternTimeoutFunction, + final TypeInformation outTypeInfo, + final PatternSelectFunction patternSelectFunction) { + return CEPOperatorUtils.createTimeoutPatternStream( + inputStream, + pattern, + comparator, + clean(patternSelectFunction), + outTypeInfo, + timeoutOutputTag, + clean(patternTimeoutFunction)); } /** @@ -145,19 +235,21 @@ public SingleOutputStreamOperator select(final PatternSelectFunction Type of the resulting timeout elements * @param Type of the resulting elements + * + * @deprecated Use {@link PatternStream#select(OutputTag, PatternTimeoutFunction, PatternSelectFunction)} + * that returns timed out events as a side-output + * * @return {@link DataStream} which contains the resulting elements or the resulting timeout * elements wrapped in an {@link Either} type. */ + @Deprecated public SingleOutputStreamOperator> select( final PatternTimeoutFunction patternTimeoutFunction, final PatternSelectFunction patternSelectFunction) { - SingleOutputStreamOperator>, Long>, Map>>> patternStream = - CEPOperatorUtils.createTimeoutPatternStream(inputStream, pattern, comparator); - - TypeInformation leftTypeInfo = TypeExtractor.getUnaryOperatorReturnType( - patternTimeoutFunction, - PatternTimeoutFunction.class, + TypeInformation rightTypeInfo = TypeExtractor.getUnaryOperatorReturnType( + patternSelectFunction, + PatternSelectFunction.class, 0, 1, new int[]{0, 1, 0}, @@ -166,9 +258,9 @@ public SingleOutputStreamOperator> select( null, false); - TypeInformation rightTypeInfo = TypeExtractor.getUnaryOperatorReturnType( - patternSelectFunction, - PatternSelectFunction.class, + TypeInformation leftTypeInfo = TypeExtractor.getUnaryOperatorReturnType( + patternTimeoutFunction, + PatternTimeoutFunction.class, 0, 1, new int[]{0, 1, 0}, @@ -177,14 +269,22 @@ public SingleOutputStreamOperator> select( null, false); + final OutputTag outputTag = new OutputTag(UUID.randomUUID().toString(), leftTypeInfo); + + final SingleOutputStreamOperator mainStream = CEPOperatorUtils.createTimeoutPatternStream( + inputStream, + pattern, + comparator, + clean(patternSelectFunction), + rightTypeInfo, + outputTag, + clean(patternTimeoutFunction)); + + final DataStream timedOutStream = mainStream.getSideOutput(outputTag); + TypeInformation> outTypeInfo = new EitherTypeInfo<>(leftTypeInfo, rightTypeInfo); - return patternStream.map( - new PatternSelectTimeoutMapper<>( - patternStream.getExecutionEnvironment().clean(patternSelectFunction), - patternStream.getExecutionEnvironment().clean(patternTimeoutFunction) - ) - ).returns(outTypeInfo); + return mainStream.connect(timedOutStream).map(new CoMapTimeout<>()).returns(outTypeInfo); } /** @@ -227,14 +327,99 @@ public SingleOutputStreamOperator flatSelect(final PatternFlatSelectFunct * @return {@link DataStream} which contains the resulting elements from the pattern flat select * function. */ - public SingleOutputStreamOperator flatSelect(final PatternFlatSelectFunction patternFlatSelectFunction, TypeInformation outTypeInfo) { - SingleOutputStreamOperator>> patternStream = - CEPOperatorUtils.createPatternStream(inputStream, pattern, comparator); - - return patternStream.flatMap( - new PatternFlatSelectMapper<>( - patternStream.getExecutionEnvironment().clean(patternFlatSelectFunction) - )).returns(outTypeInfo); + public SingleOutputStreamOperator flatSelect( + final PatternFlatSelectFunction patternFlatSelectFunction, + final TypeInformation outTypeInfo) { + return CEPOperatorUtils.createPatternStream( + inputStream, + pattern, + comparator, + clean(patternFlatSelectFunction), + outTypeInfo); + } + + /** + * Applies a flat select function to the detected pattern sequence. For each pattern sequence the + * provided {@link PatternFlatSelectFunction} is called. The pattern select function can produce + * exactly one resulting element. + * + *

Applies a timeout function to a partial pattern sequence which has timed out. For each + * partial pattern sequence the provided {@link PatternFlatTimeoutFunction} is called. The pattern + * timeout function can produce exactly one resulting element. + * + *

You can get the stream of timed-out data resulting from the + * {@link SingleOutputStreamOperator#getSideOutput(OutputTag)} on the + * {@link SingleOutputStreamOperator} resulting from the select operation + * with the same {@link OutputTag}. + * + * @param timeoutOutputTag {@link OutputTag} that identifies side output with timed out patterns + * @param patternFlatTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param patternFlatSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @param Type of the resulting timeout elements + * @param Type of the resulting elements + * @return {@link DataStream} which contains the resulting elements with the resulting timeout + * elements in a side output. + */ + public SingleOutputStreamOperator flatSelect( + final OutputTag timeoutOutputTag, + final PatternFlatTimeoutFunction patternFlatTimeoutFunction, + final PatternFlatSelectFunction patternFlatSelectFunction) { + + TypeInformation rightTypeInfo = TypeExtractor.getUnaryOperatorReturnType( + patternFlatSelectFunction, + PatternFlatSelectFunction.class, + 0, + 1, + new int[]{0, 1, 0}, + new int[]{1, 0}, + inputStream.getType(), + null, + false); + + return flatSelect(timeoutOutputTag, patternFlatTimeoutFunction, rightTypeInfo, patternFlatSelectFunction); + } + + /** + * Applies a flat select function to the detected pattern sequence. For each pattern sequence the + * provided {@link PatternFlatSelectFunction} is called. The pattern select function can produce + * exactly one resulting element. + * + *

Applies a timeout function to a partial pattern sequence which has timed out. For each + * partial pattern sequence the provided {@link PatternFlatTimeoutFunction} is called. The pattern + * timeout function can produce exactly one resulting element. + * + *

You can get the stream of timed-out data resulting from the + * {@link SingleOutputStreamOperator#getSideOutput(OutputTag)} on the + * {@link SingleOutputStreamOperator} resulting from the select operation + * with the same {@link OutputTag}. + * + * @param timeoutOutputTag {@link OutputTag} that identifies side output with timed out patterns + * @param patternFlatTimeoutFunction The pattern timeout function which is called for each partial + * pattern sequence which has timed out. + * @param patternFlatSelectFunction The pattern select function which is called for each detected + * pattern sequence. + * @param outTypeInfo Explicit specification of output type. + * @param Type of the resulting timeout elements + * @param Type of the resulting elements + * @return {@link DataStream} which contains the resulting elements with the resulting timeout + * elements in a side output. + */ + public SingleOutputStreamOperator flatSelect( + final OutputTag timeoutOutputTag, + final PatternFlatTimeoutFunction patternFlatTimeoutFunction, + final TypeInformation outTypeInfo, + final PatternFlatSelectFunction patternFlatSelectFunction) { + + return CEPOperatorUtils.createTimeoutPatternStream( + inputStream, + pattern, + comparator, + clean(patternFlatSelectFunction), + outTypeInfo, + timeoutOutputTag, + clean(patternFlatTimeoutFunction)); } /** @@ -252,17 +437,19 @@ public SingleOutputStreamOperator flatSelect(final PatternFlatSelectFunct * detected pattern sequence. * @param Type of the resulting timeout events * @param Type of the resulting events + * + * @deprecated Use {@link PatternStream#flatSelect(OutputTag, PatternFlatTimeoutFunction, PatternFlatSelectFunction)} + * that returns timed out events as a side-output + * * @return {@link DataStream} which contains the resulting events from the pattern flat select * function or the resulting timeout events from the pattern flat timeout function wrapped in an * {@link Either} type. */ + @Deprecated public SingleOutputStreamOperator> flatSelect( final PatternFlatTimeoutFunction patternFlatTimeoutFunction, final PatternFlatSelectFunction patternFlatSelectFunction) { - SingleOutputStreamOperator>, Long>, Map>>> patternStream = - CEPOperatorUtils.createTimeoutPatternStream(inputStream, pattern, comparator); - TypeInformation leftTypeInfo = TypeExtractor.getUnaryOperatorReturnType( patternFlatTimeoutFunction, PatternFlatTimeoutFunction.class, @@ -285,147 +472,40 @@ public SingleOutputStreamOperator> flatSelect( null, false); - TypeInformation> outTypeInfo = new EitherTypeInfo<>(leftTypeInfo, rightTypeInfo); + final OutputTag outputTag = new OutputTag(UUID.randomUUID().toString(), leftTypeInfo); - return patternStream.flatMap( - new PatternFlatSelectTimeoutWrapper<>( - patternStream.getExecutionEnvironment().clean(patternFlatSelectFunction), - patternStream.getExecutionEnvironment().clean(patternFlatTimeoutFunction) - ) - ).returns(outTypeInfo); - } + final SingleOutputStreamOperator mainStream = CEPOperatorUtils.createTimeoutPatternStream( + inputStream, + pattern, + comparator, + clean(patternFlatSelectFunction), + rightTypeInfo, + outputTag, + clean(patternFlatTimeoutFunction)); - /** - * Wrapper for a {@link PatternSelectFunction}. - * - * @param Type of the input elements - * @param Type of the resulting elements - */ - private static class PatternSelectMapper implements MapFunction>, R> { - private static final long serialVersionUID = 2273300432692943064L; - - private final PatternSelectFunction patternSelectFunction; - - public PatternSelectMapper(PatternSelectFunction patternSelectFunction) { - this.patternSelectFunction = patternSelectFunction; - } - - @Override - public R map(Map> value) throws Exception { - return patternSelectFunction.select(value); - } - } - - private static class PatternSelectTimeoutMapper implements MapFunction>, Long>, Map>>, Either> { - - private static final long serialVersionUID = 8259477556738887724L; - - private final PatternSelectFunction patternSelectFunction; - private final PatternTimeoutFunction patternTimeoutFunction; - - public PatternSelectTimeoutMapper( - PatternSelectFunction patternSelectFunction, - PatternTimeoutFunction patternTimeoutFunction) { - - this.patternSelectFunction = patternSelectFunction; - this.patternTimeoutFunction = patternTimeoutFunction; - } - - @Override - public Either map(Either>, Long>, Map>> value) throws Exception { - if (value.isLeft()) { - Tuple2>, Long> timeout = value.left(); - - return Either.Left(patternTimeoutFunction.timeout(timeout.f0, timeout.f1)); - } else { - return Either.Right(patternSelectFunction.select(value.right())); - } - } - } - - private static class PatternFlatSelectTimeoutWrapper implements FlatMapFunction>, Long>, Map>>, Either> { - - private static final long serialVersionUID = 7483674669662261667L; + final DataStream timedOutStream = mainStream.getSideOutput(outputTag); - private final PatternFlatSelectFunction patternFlatSelectFunction; - private final PatternFlatTimeoutFunction patternFlatTimeoutFunction; - - public PatternFlatSelectTimeoutWrapper( - PatternFlatSelectFunction patternFlatSelectFunction, - PatternFlatTimeoutFunction patternFlatTimeoutFunction) { - this.patternFlatSelectFunction = patternFlatSelectFunction; - this.patternFlatTimeoutFunction = patternFlatTimeoutFunction; - } - - @Override - public void flatMap(Either>, Long>, Map>> value, Collector> out) throws Exception { - if (value.isLeft()) { - Tuple2>, Long> timeout = value.left(); - - patternFlatTimeoutFunction.timeout(timeout.f0, timeout.f1, new LeftCollector<>(out)); - } else { - patternFlatSelectFunction.flatSelect(value.right(), new RightCollector(out)); - } - } - - private static class LeftCollector implements Collector { - - private final Collector> out; - - private LeftCollector(Collector> out) { - this.out = out; - } - - @Override - public void collect(L record) { - out.collect(Either.Left(record)); - } - - @Override - public void close() { - out.close(); - } - } - - private static class RightCollector implements Collector { - - private final Collector> out; - - private RightCollector(Collector> out) { - this.out = out; - } - - @Override - public void collect(R record) { - out.collect(Either.Right(record)); - } + TypeInformation> outTypeInfo = new EitherTypeInfo<>(leftTypeInfo, rightTypeInfo); - @Override - public void close() { - out.close(); - } - } + return mainStream.connect(timedOutStream).map(new CoMapTimeout<>()).returns(outTypeInfo); } /** - * Wrapper for a {@link PatternFlatSelectFunction}. - * - * @param Type of the input elements - * @param Type of the resulting elements + * Used for joining results from timeout side-output for API backward compatibility. */ - private static class PatternFlatSelectMapper implements FlatMapFunction>, R> { + @Internal + public static class CoMapTimeout implements CoMapFunction> { - private static final long serialVersionUID = -8610796233077989108L; + private static final long serialVersionUID = 2059391566945212552L; - private final PatternFlatSelectFunction patternFlatSelectFunction; - - public PatternFlatSelectMapper(PatternFlatSelectFunction patternFlatSelectFunction) { - this.patternFlatSelectFunction = patternFlatSelectFunction; + @Override + public Either map1(R value) throws Exception { + return Either.Right(value); } @Override - public void flatMap(Map> value, Collector out) throws Exception { - patternFlatSelectFunction.flatSelect(value, out); + public Either map2(L value) throws Exception { + return Either.Left(value); } } } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/AfterMatchSkipStrategy.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/AfterMatchSkipStrategy.java new file mode 100644 index 0000000000000..dcda441a45dc9 --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/AfterMatchSkipStrategy.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.nfa; + +import java.io.Serializable; + + +/** + * Indicate the skip strategy after a match process. + * + *

For more info on possible skip strategies see {@link SkipStrategy}. + */ +public class AfterMatchSkipStrategy implements Serializable { + + private static final long serialVersionUID = -4048930333619068531L; + // default strategy + private SkipStrategy strategy = SkipStrategy.NO_SKIP; + + // pattern name to skip to + private String patternName = null; + + /** + * Discards every partial match that contains event of the match preceding the first of *PatternName*. + * @param patternName the pattern name to skip to + * @return the created AfterMatchSkipStrategy + */ + public static AfterMatchSkipStrategy skipToFirst(String patternName) { + return new AfterMatchSkipStrategy(SkipStrategy.SKIP_TO_FIRST, patternName); + } + + /** + * Discards every partial match that contains event of the match preceding the last of *PatternName*. + * @param patternName the pattern name to skip to + * @return the created AfterMatchSkipStrategy + */ + public static AfterMatchSkipStrategy skipToLast(String patternName) { + return new AfterMatchSkipStrategy(SkipStrategy.SKIP_TO_LAST, patternName); + } + + /** + * Discards every partial match that contains event of the match. + * @return the created AfterMatchSkipStrategy + */ + public static AfterMatchSkipStrategy skipPastLastEvent() { + return new AfterMatchSkipStrategy(SkipStrategy.SKIP_PAST_LAST_EVENT); + } + + /** + * Every possible match will be emitted. + * @return the created AfterMatchSkipStrategy + */ + public static AfterMatchSkipStrategy noSkip() { + return new AfterMatchSkipStrategy(SkipStrategy.NO_SKIP); + } + + private AfterMatchSkipStrategy(SkipStrategy strategy) { + this(strategy, null); + } + + private AfterMatchSkipStrategy(SkipStrategy strategy, String patternName) { + if (patternName == null && (strategy == SkipStrategy.SKIP_TO_FIRST || strategy == SkipStrategy.SKIP_TO_LAST)) { + throw new IllegalArgumentException("The patternName field can not be empty when SkipStrategy is " + strategy); + } + this.strategy = strategy; + this.patternName = patternName; + } + + /** + * Get the {@link SkipStrategy} enum. + * @return the skip strategy + */ + public SkipStrategy getStrategy() { + return strategy; + } + + /** + * Get the referenced pattern name of this strategy. + * @return the referenced pattern name. + */ + public String getPatternName() { + return patternName; + } + + @Override + public String toString() { + switch (strategy) { + case NO_SKIP: + case SKIP_PAST_LAST_EVENT: + return "AfterMatchStrategy{" + + strategy + + "}"; + case SKIP_TO_FIRST: + case SKIP_TO_LAST: + return "AfterMatchStrategy{" + + strategy + "[" + + patternName + "]" + + "}"; + } + return super.toString(); + } + + /** + * Skip Strategy Enum. + */ + public enum SkipStrategy{ + /** + * Every possible match will be emitted. + */ + NO_SKIP, + /** + * Discards every partial match that contains event of the match. + */ + SKIP_PAST_LAST_EVENT, + /** + * Discards every partial match that contains event of the match preceding the first of *PatternName*. + */ + SKIP_TO_FIRST, + /** + * Discards every partial match that contains event of the match preceding the last of *PatternName*. + */ + SKIP_TO_LAST + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java index 2f6f02eb9bfb7..ff4967fc55bec 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java @@ -34,6 +34,7 @@ import org.apache.flink.cep.NonDuplicatingTypeSerializer; import org.apache.flink.cep.nfa.compiler.NFACompiler; import org.apache.flink.cep.nfa.compiler.NFAStateNameHandler; +import org.apache.flink.cep.operator.AbstractKeyedCEPPatternOperator; import org.apache.flink.cep.pattern.conditions.IterativeCondition; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; @@ -42,11 +43,6 @@ import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.util.Preconditions; -import com.google.common.base.Predicate; -import com.google.common.collect.Iterators; - -import javax.annotation.Nullable; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -54,12 +50,12 @@ import java.io.ObjectOutputStream; import java.io.OptionalDataException; import java.io.Serializable; -import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -71,7 +67,7 @@ /** * Non-deterministic finite automaton implementation. * - *

The {@link org.apache.flink.cep.operator.AbstractKeyedCEPPatternOperator CEP operator} + *

The {@link AbstractKeyedCEPPatternOperator CEP operator} * keeps one NFA per key, for keyed input streams, and a single global NFA for non-keyed ones. * When an event gets processed, it updates the NFA's internal state machine. * @@ -232,6 +228,26 @@ public void resetNFAChanged() { * activated) */ public Tuple2>>, Collection>, Long>>> process(final T event, final long timestamp) { + return process(event, timestamp, AfterMatchSkipStrategy.noSkip()); + } + + /** + * Processes the next input event. If some of the computations reach a final state then the + * resulting event sequences are returned. If computations time out and timeout handling is + * activated, then the timed out event patterns are returned. + * + *

If computations reach a stop state, the path forward is discarded and currently constructed path is returned + * with the element that resulted in the stop state. + * + * @param event The current event to be processed or null if only pruning shall be done + * @param timestamp The timestamp of the current event + * @param afterMatchSkipStrategy The skip strategy to use after per match + * @return Tuple of the collection of matched patterns (e.g. the result of computations which have + * reached a final state) and the collection of timed out patterns (if timeout handling is + * activated) + */ + public Tuple2>>, Collection>, Long>>> process(final T event, + final long timestamp, AfterMatchSkipStrategy afterMatchSkipStrategy) { final int numberComputationStates = computationStates.size(); final Collection>> result = new ArrayList<>(); final Collection>, Long>> timeoutResult = new ArrayList<>(); @@ -248,8 +264,8 @@ public Tuple2>>, Collection> timedoutPattern = extractCurrentMatches(computationState); - timeoutResult.add(Tuple2.of(timedoutPattern, timestamp)); + Map> timedOutPattern = extractCurrentMatches(computationState); + timeoutResult.add(Tuple2.of(timedOutPattern, timestamp)); } eventSharedBuffer.release( @@ -322,6 +338,8 @@ public Tuple2>>, Collection 0L) { long pruningTimestamp = timestamp - windowTime; @@ -340,6 +358,66 @@ public Tuple2>>, Collection> computationStates, + Collection>> matchedResult, AfterMatchSkipStrategy afterMatchSkipStrategy) { + Set discardEvents = new HashSet<>(); + switch(afterMatchSkipStrategy.getStrategy()) { + case SKIP_TO_LAST: + for (Map> resultMap: matchedResult) { + for (Map.Entry> keyMatches : resultMap.entrySet()) { + if (keyMatches.getKey().equals(afterMatchSkipStrategy.getPatternName())) { + discardEvents.addAll(keyMatches.getValue().subList(0, keyMatches.getValue().size() - 1)); + break; + } else { + discardEvents.addAll(keyMatches.getValue()); + } + } + } + break; + case SKIP_TO_FIRST: + for (Map> resultMap: matchedResult) { + for (Map.Entry> keyMatches : resultMap.entrySet()) { + if (keyMatches.getKey().equals(afterMatchSkipStrategy.getPatternName())) { + break; + } else { + discardEvents.addAll(keyMatches.getValue()); + } + } + } + break; + case SKIP_PAST_LAST_EVENT: + for (Map> resultMap: matchedResult) { + for (List eventList: resultMap.values()) { + discardEvents.addAll(eventList); + } + } + break; + } + if (!discardEvents.isEmpty()) { + List> discardStates = new ArrayList<>(); + for (ComputationState computationState : computationStates) { + Map> partialMatch = extractCurrentMatches(computationState); + for (List list: partialMatch.values()) { + for (T e: list) { + if (discardEvents.contains(e)) { + // discard the computation state. + eventSharedBuffer.release( + NFAStateNameHandler.getOriginalNameFromInternal( + computationState.getState().getName()), + computationState.getEvent(), + computationState.getTimestamp(), + computationState.getCounter() + ); + discardStates.add(computationState); + break; + } + } + } + } + computationStates.removeAll(discardStates); + } + } + @Override public boolean equals(Object obj) { if (obj instanceof NFA) { @@ -696,7 +774,7 @@ Map> extractCurrentMatches(final ComputationState computation // for a given computation state, we cannot have more than one matching patterns. Preconditions.checkState(paths.size() == 1); - Map> result = new HashMap<>(); + Map> result = new LinkedHashMap<>(); Map> path = paths.get(0); for (String key: path.keySet()) { List events = path.get(key); @@ -715,9 +793,7 @@ Map> extractCurrentMatches(final ComputationState computation return result; } - ////////////////////// Fault-Tolerance / Migration ////////////////////// - - private static final String BEGINNING_STATE_NAME = "$beginningState$"; + ////////////////////// Fault-Tolerance ////////////////////// private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { ois.defaultReadObject(); @@ -728,103 +804,15 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound final List> readComputationStates = new ArrayList<>(numberComputationStates); - boolean afterMigration = false; for (int i = 0; i < numberComputationStates; i++) { ComputationState computationState = readComputationState(ois); - if (computationState.getState().getName().equals(BEGINNING_STATE_NAME)) { - afterMigration = true; - } - readComputationStates.add(computationState); } - if (afterMigration && !readComputationStates.isEmpty()) { - try { - //Backwards compatibility - this.computationStates.addAll(migrateNFA(readComputationStates)); - final Field newSharedBufferField = NFA.class.getDeclaredField("eventSharedBuffer"); - final Field sharedBufferField = NFA.class.getDeclaredField("sharedBuffer"); - sharedBufferField.setAccessible(true); - newSharedBufferField.setAccessible(true); - newSharedBufferField.set(this, SharedBuffer.migrateSharedBuffer(this.sharedBuffer)); - sharedBufferField.set(this, null); - sharedBufferField.setAccessible(false); - newSharedBufferField.setAccessible(false); - } catch (Exception e) { - throw new IllegalStateException("Could not migrate from earlier version", e); - } - } else { - this.computationStates.addAll(readComputationStates); - } - + this.computationStates.addAll(readComputationStates); nonDuplicatingTypeSerializer.clearReferences(); } - /** - * Needed for backward compatibility. First migrates the {@link State} graph see {@link NFACompiler#migrateGraph(State)}. - * Than recreates the {@link ComputationState}s with the new {@link State} graph. - * @param readStates computation states read from snapshot - * @return collection of migrated computation states - */ - private Collection> migrateNFA(Collection> readStates) { - final ArrayList> computationStates = new ArrayList<>(); - - final State startState = Iterators.find( - readStates.iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable ComputationState input) { - return input != null && input.getState().getName().equals(BEGINNING_STATE_NAME); - } - }).getState(); - - final Map> convertedStates = NFACompiler.migrateGraph(startState); - - for (ComputationState readState : readStates) { - if (!readState.isStartState()) { - final String previousName = readState.getState().getName(); - final String currentName = Iterators.find( - readState.getState().getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.TAKE; - } - }).getTargetState().getName(); - - final State previousState = convertedStates.get(previousName); - - computationStates.add(ComputationState.createState( - this, - convertedStates.get(currentName), - previousState, - readState.getEvent(), - 0, - readState.getTimestamp(), - readState.getVersion(), - readState.getStartTimestamp() - )); - } - } - - final String startName = Iterators.find(convertedStates.values().iterator(), new Predicate>() { - @Override - public boolean apply(@Nullable State input) { - return input != null && input.isStart(); - } - }).getName(); - - computationStates.add(ComputationState.createStartState( - this, - convertedStates.get(startName), - new DeweyNumber(this.startEventCounter))); - - this.states.clear(); - this.states.addAll(convertedStates.values()); - - return computationStates; - } - @SuppressWarnings("unchecked") private ComputationState readComputationState(ObjectInputStream ois) throws IOException, ClassNotFoundException { final State state = (State) ois.readObject(); @@ -1130,11 +1118,11 @@ public CompatibilityResult> ensureCompatibility(TypeSerializerConfigSnaps return CompatibilityResult.compatible(); } else { if (eventCompatResult.getConvertDeserializer() != null && - sharedBufCompatResult.getConvertDeserializer() != null) { + sharedBufCompatResult.getConvertDeserializer() != null) { return CompatibilityResult.requiresMigration( - new NFASerializer<>( - new TypeDeserializerAdapter<>(eventCompatResult.getConvertDeserializer()), - new TypeDeserializerAdapter<>(sharedBufCompatResult.getConvertDeserializer()))); + new NFASerializer<>( + new TypeDeserializerAdapter<>(eventCompatResult.getConvertDeserializer()), + new TypeDeserializerAdapter<>(sharedBufCompatResult.getConvertDeserializer()))); } } } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/SharedBuffer.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/SharedBuffer.java index c6f69b9f10995..6bc5091eb8bd6 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/SharedBuffer.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/SharedBuffer.java @@ -18,7 +18,6 @@ package org.apache.flink.cep.nfa; -import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.CompatibilityResult; import org.apache.flink.api.common.typeutils.CompatibilityUtil; import org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot; @@ -47,6 +46,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -248,7 +248,7 @@ public List>> extractPatterns( // termination criterion if (currentEntry == null) { - final Map> completePath = new HashMap<>(); + final Map> completePath = new LinkedHashMap<>(); while (!currentPath.isEmpty()) { final SharedBufferEntry currentPathEntry = currentPath.pop(); @@ -335,47 +335,6 @@ private SharedBuffer( this.pages = pages; } - /** - * For backward compatibility only. Previously the key in {@link SharedBuffer} was {@link State}. - * Now it is {@link String}. - */ - @Internal - static SharedBuffer migrateSharedBuffer(SharedBuffer, T> buffer) { - - final Map> pageMap = new HashMap<>(); - final Map, T>, SharedBufferEntry> entries = new HashMap<>(); - - for (Map.Entry, SharedBufferPage, T>> page : buffer.pages.entrySet()) { - final SharedBufferPage newPage = new SharedBufferPage<>(page.getKey().getName()); - pageMap.put(newPage.getKey(), newPage); - - for (Map.Entry, SharedBufferEntry, T>> pageEntry : page.getValue().entries.entrySet()) { - final SharedBufferEntry newSharedBufferEntry = new SharedBufferEntry<>( - pageEntry.getKey(), - newPage); - newSharedBufferEntry.referenceCounter = pageEntry.getValue().referenceCounter; - entries.put(pageEntry.getValue(), newSharedBufferEntry); - newPage.entries.put(pageEntry.getKey(), newSharedBufferEntry); - } - } - - for (Map.Entry, SharedBufferPage, T>> page : buffer.pages.entrySet()) { - for (Map.Entry, SharedBufferEntry, T>> pageEntry : page.getValue().entries.entrySet()) { - final SharedBufferEntry newEntry = entries.get(pageEntry.getValue()); - for (SharedBufferEdge, T> edge : pageEntry.getValue().edges) { - final SharedBufferEntry targetNewEntry = entries.get(edge.getTarget()); - - final SharedBufferEdge newEdge = new SharedBufferEdge<>( - targetNewEntry, - edge.getVersion()); - newEntry.edges.add(newEdge); - } - } - } - - return new SharedBuffer<>(buffer.valueSerializer, pageMap); - } - private SharedBufferEntry get( final K key, final V value, @@ -1177,76 +1136,4 @@ public CompatibilityResult> ensureCompatibility(TypeSerialize return CompatibilityResult.requiresMigration(); } } - - ////////////////// Java Serialization methods for backwards compatibility ////////////////// - - private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { - DataInputViewStreamWrapper source = new DataInputViewStreamWrapper(ois); - ArrayList> entryList = new ArrayList<>(); - ois.defaultReadObject(); - - this.pages = new HashMap<>(); - - int numberPages = ois.readInt(); - - for (int i = 0; i < numberPages; i++) { - // key of the page - @SuppressWarnings("unchecked") - K key = (K) ois.readObject(); - - SharedBufferPage page = new SharedBufferPage<>(key); - - pages.put(key, page); - - int numberEntries = ois.readInt(); - - for (int j = 0; j < numberEntries; j++) { - // restore the SharedBufferEntries for the given page - V value = valueSerializer.deserialize(source); - long timestamp = ois.readLong(); - - ValueTimeWrapper valueTimeWrapper = new ValueTimeWrapper<>(value, timestamp, 0); - SharedBufferEntry sharedBufferEntry = new SharedBufferEntry(valueTimeWrapper, page); - - sharedBufferEntry.referenceCounter = ois.readInt(); - - page.entries.put(valueTimeWrapper, sharedBufferEntry); - - entryList.add(sharedBufferEntry); - } - } - - // read the edges of the shared buffer entries - int numberEdges = ois.readInt(); - - for (int j = 0; j < numberEdges; j++) { - int sourceIndex = ois.readInt(); - int targetIndex = ois.readInt(); - - if (sourceIndex >= entryList.size() || sourceIndex < 0) { - throw new RuntimeException("Could not find source entry with index " + sourceIndex + - ". This indicates a corrupted state."); - } else { - // We've already deserialized the shared buffer entry. Simply read its ID and - // retrieve the buffer entry from the list of entries - SharedBufferEntry sourceEntry = entryList.get(sourceIndex); - - final DeweyNumber version = (DeweyNumber) ois.readObject(); - final SharedBufferEntry target; - - if (targetIndex >= 0) { - if (targetIndex >= entryList.size()) { - throw new RuntimeException("Could not find target entry with index " + targetIndex + - ". This indicates a corrupted state."); - } else { - target = entryList.get(targetIndex); - } - } else { - target = null; - } - - sourceEntry.edges.add(new SharedBufferEdge(target, version)); - } - } - } } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/StateTransition.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/StateTransition.java index bb61e091e7bad..e2fd900d936fb 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/StateTransition.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/StateTransition.java @@ -76,6 +76,10 @@ public IterativeCondition getCondition() { return newCondition; } + public void setCondition(IterativeCondition condition) { + this.newCondition = condition; + } + @Override public boolean equals(Object obj) { if (obj instanceof StateTransition) { diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/compiler/NFACompiler.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/compiler/NFACompiler.java index 62464d1d4b7ae..39e8d34acef4d 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/compiler/NFACompiler.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/compiler/NFACompiler.java @@ -18,9 +18,9 @@ package org.apache.flink.cep.nfa.compiler; -import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.nfa.NFA; import org.apache.flink.cep.nfa.State; import org.apache.flink.cep.nfa.StateTransition; @@ -36,11 +36,6 @@ import org.apache.flink.cep.pattern.conditions.NotCondition; import org.apache.flink.streaming.api.windowing.time.Time; -import com.google.common.base.Predicate; -import com.google.common.collect.Iterators; - -import javax.annotation.Nullable; - import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -117,9 +112,12 @@ static class NFAFactoryCompiler { private Map, Boolean> firstOfLoopMap = new HashMap<>(); private Pattern currentPattern; private Pattern followingPattern; + private final AfterMatchSkipStrategy afterMatchSkipStrategy; + private Map> originalStateMap = new HashMap<>(); NFAFactoryCompiler(final Pattern pattern) { this.currentPattern = pattern; + afterMatchSkipStrategy = pattern.getAfterMatchSkipStrategy(); } /** @@ -133,6 +131,8 @@ void compileFactory() { checkPatternNameUniqueness(); + checkPatternSkipStrategy(); + // we're traversing the pattern from the end to the beginning --> the first state is the final state State sinkState = createEndingState(); // add all the normal states @@ -141,6 +141,10 @@ void compileFactory() { createStartState(sinkState); } + AfterMatchSkipStrategy getAfterMatchSkipStrategy(){ + return afterMatchSkipStrategy; + } + List> getStates() { return states; } @@ -149,6 +153,25 @@ long getWindowTime() { return windowTime; } + /** + * Check pattern after match skip strategy. + */ + private void checkPatternSkipStrategy() { + if (afterMatchSkipStrategy.getStrategy() == AfterMatchSkipStrategy.SkipStrategy.SKIP_TO_FIRST || + afterMatchSkipStrategy.getStrategy() == AfterMatchSkipStrategy.SkipStrategy.SKIP_TO_LAST) { + Pattern pattern = currentPattern; + while (pattern.getPrevious() != null && !pattern.getName().equals(afterMatchSkipStrategy.getPatternName())) { + pattern = pattern.getPrevious(); + } + + // pattern name match check. + if (!pattern.getName().equals(afterMatchSkipStrategy.getPatternName())) { + throw new MalformedPatternException("The pattern name specified in AfterMatchSkipStrategy " + + "can not be found in the given Pattern"); + } + } + } + /** * Check if there are duplicate pattern names. If yes, it * throws a {@link MalformedPatternException}. @@ -291,13 +314,9 @@ private State convertPattern(final State sinkState) { final State looping = createLooping(sink); setCurrentGroupPatternFirstOfLoop(true); - if (!quantifier.hasProperty(Quantifier.QuantifierProperty.OPTIONAL)) { - lastSink = createInitMandatoryStateOfOneOrMore(looping); - } else { - lastSink = createInitOptionalStateOfZeroOrMore(looping, sinkState); - } + lastSink = createTimesState(looping, sinkState, currentPattern.getTimes()); } else if (quantifier.hasProperty(Quantifier.QuantifierProperty.TIMES)) { - lastSink = createTimesState(sinkState, currentPattern.getTimes()); + lastSink = createTimesState(sinkState, sinkState, currentPattern.getTimes()); } else { lastSink = createSingletonState(sinkState); } @@ -386,6 +405,21 @@ private State copyWithoutTransitiveNots(final State sinkState) { return copyOfSink; } + private State copy(final State state) { + final State copyOfState = createState( + NFAStateNameHandler.getOriginalNameFromInternal(state.getName()), + state.getStateType()); + for (StateTransition tStateTransition : state.getStateTransitions()) { + copyOfState.addStateTransition( + tStateTransition.getAction(), + tStateTransition.getTargetState().equals(tStateTransition.getSourceState()) + ? copyOfState + : tStateTransition.getTargetState(), + tStateTransition.getCondition()); + } + return copyOfState; + } + private void addStopStates(final State state) { for (Tuple2, String> notCondition: getCurrentNotCondition()) { final State stopState = createStopState(notCondition.f0, notCondition.f1); @@ -407,16 +441,35 @@ private void addStopStateToLooping(final State loopingState) { * same {@link IterativeCondition}. * * @param sinkState the state that the created state should point to + * @param proceedState state that the state being converted should proceed to * @param times number of times the state should be copied * @return the first state of the "complex" state, next state should point to it */ - private State createTimesState(final State sinkState, Times times) { + @SuppressWarnings("unchecked") + private State createTimesState(final State sinkState, final State proceedState, Times times) { State lastSink = sinkState; setCurrentGroupPatternFirstOfLoop(false); - final IterativeCondition takeCondition = getTakeCondition(currentPattern); - final IterativeCondition innerIgnoreCondition = getInnerIgnoreCondition(currentPattern); + final IterativeCondition untilCondition = (IterativeCondition) currentPattern.getUntilCondition(); + final IterativeCondition innerIgnoreCondition = extendWithUntilCondition( + getInnerIgnoreCondition(currentPattern), + untilCondition, + false); + final IterativeCondition takeCondition = extendWithUntilCondition( + getTakeCondition(currentPattern), + untilCondition, + true); + + if (currentPattern.getQuantifier().hasProperty(Quantifier.QuantifierProperty.GREEDY) && + times.getFrom() != times.getTo()) { + if (untilCondition != null) { + State sinkStateCopy = copy(sinkState); + originalStateMap.put(sinkState.getName(), sinkStateCopy); + } + updateWithGreedyCondition(sinkState, takeCondition); + } + for (int i = times.getFrom(); i < times.getTo(); i++) { - lastSink = createSingletonState(lastSink, sinkState, takeCondition, innerIgnoreCondition, true); + lastSink = createSingletonState(lastSink, proceedState, takeCondition, innerIgnoreCondition, true); addStopStateToLooping(lastSink); } for (int i = 0; i < times.getFrom() - 1; i++) { @@ -427,7 +480,7 @@ private State createTimesState(final State sinkState, Times times) { setCurrentGroupPatternFirstOfLoop(true); return createSingletonState( lastSink, - sinkState, + proceedState, takeCondition, getIgnoreCondition(currentPattern), currentPattern.getQuantifier().hasProperty(Quantifier.QuantifierProperty.OPTIONAL)); @@ -520,18 +573,32 @@ private State createSingletonState(final State sinkState, return createGroupPatternState((GroupPattern) currentPattern, sinkState, proceedState, isOptional); } - final IterativeCondition trueFunction = getTrueFunction(); - final State singletonState = createState(currentPattern.getName(), State.StateType.Normal); // if event is accepted then all notPatterns previous to the optional states are no longer valid final State sink = copyWithoutTransitiveNots(sinkState); singletonState.addTake(sink, takeCondition); + // if no element accepted the previous nots are still valid. + final IterativeCondition proceedCondition = getTrueFunction(); + // for the first state of a group pattern, its PROCEED edge should point to the following state of // that group pattern and the edge will be added at the end of creating the NFA for that group pattern if (isOptional && !headOfGroup(currentPattern)) { - // if no element accepted the previous nots are still valid. - singletonState.addProceed(proceedState, trueFunction); + if (currentPattern.getQuantifier().hasProperty(Quantifier.QuantifierProperty.GREEDY)) { + final IterativeCondition untilCondition = + (IterativeCondition) currentPattern.getUntilCondition(); + if (untilCondition != null) { + singletonState.addProceed( + originalStateMap.get(proceedState.getName()), + new AndCondition<>(proceedCondition, untilCondition)); + } + singletonState.addProceed(proceedState, + untilCondition != null + ? new AndCondition<>(proceedCondition, new NotCondition<>(untilCondition)) + : proceedCondition); + } else { + singletonState.addProceed(proceedState, proceedCondition); + } } if (ignoreCondition != null) { @@ -563,11 +630,12 @@ private State createGroupPatternState( final State sinkState, final State proceedState, final boolean isOptional) { - final IterativeCondition trueFunction = getTrueFunction(); + final IterativeCondition proceedCondition = getTrueFunction(); Pattern oldCurrentPattern = currentPattern; Pattern oldFollowingPattern = followingPattern; GroupPattern oldGroupPattern = currentGroupPattern; + State lastSink = sinkState; currentGroupPattern = groupPattern; currentPattern = groupPattern.getRawPattern(); @@ -576,7 +644,7 @@ private State createGroupPatternState( if (isOptional) { // for the first state of a group pattern, its PROCEED edge should point to // the following state of that group pattern - lastSink.addProceed(proceedState, trueFunction); + lastSink.addProceed(proceedState, proceedCondition); } currentPattern = oldCurrentPattern; followingPattern = oldFollowingPattern; @@ -594,19 +662,20 @@ private State createGroupPatternState( private State createLoopingGroupPatternState( final GroupPattern groupPattern, final State sinkState) { - final IterativeCondition trueFunction = getTrueFunction(); + final IterativeCondition proceedCondition = getTrueFunction(); Pattern oldCurrentPattern = currentPattern; Pattern oldFollowingPattern = followingPattern; GroupPattern oldGroupPattern = currentGroupPattern; + final State dummyState = createState(currentPattern.getName(), State.StateType.Normal); State lastSink = dummyState; currentGroupPattern = groupPattern; currentPattern = groupPattern.getRawPattern(); lastSink = createMiddleStates(lastSink); lastSink = convertPattern(lastSink); - lastSink.addProceed(sinkState, trueFunction); - dummyState.addProceed(lastSink, trueFunction); + lastSink.addProceed(sinkState, proceedCondition); + dummyState.addProceed(lastSink, proceedCondition); currentPattern = oldCurrentPattern; followingPattern = oldFollowingPattern; currentGroupPattern = oldGroupPattern; @@ -637,9 +706,23 @@ private State createLooping(final State sinkState) { untilCondition, true); - final IterativeCondition proceedCondition = getTrueFunction(); + IterativeCondition proceedCondition = getTrueFunction(); final State loopingState = createState(currentPattern.getName(), State.StateType.Normal); - loopingState.addProceed(sinkState, proceedCondition); + + if (currentPattern.getQuantifier().hasProperty(Quantifier.QuantifierProperty.GREEDY)) { + if (untilCondition != null) { + State sinkStateCopy = copy(sinkState); + loopingState.addProceed(sinkStateCopy, new AndCondition<>(proceedCondition, untilCondition)); + originalStateMap.put(sinkState.getName(), sinkStateCopy); + } + loopingState.addProceed(sinkState, + untilCondition != null + ? new AndCondition<>(proceedCondition, new NotCondition<>(untilCondition)) + : proceedCondition); + updateWithGreedyCondition(sinkState, getTakeCondition(currentPattern)); + } else { + loopingState.addProceed(sinkState, proceedCondition); + } loopingState.addTake(takeCondition); addStopStateToLooping(loopingState); @@ -655,46 +738,6 @@ private State createLooping(final State sinkState) { return loopingState; } - /** - * Patterns with quantifiers AT_LEAST_ONE_* are created as a pair of states: a singleton state and - * looping state. This method creates the first of the two. - * - * @param sinkState the state the newly created state should point to, it should be a looping state - * @return the newly created state - */ - @SuppressWarnings("unchecked") - private State createInitMandatoryStateOfOneOrMore(final State sinkState) { - final IterativeCondition takeCondition = extendWithUntilCondition( - getTakeCondition(currentPattern), - (IterativeCondition) currentPattern.getUntilCondition(), - true - ); - - final IterativeCondition ignoreCondition = getIgnoreCondition(currentPattern); - - return createSingletonState(sinkState, null, takeCondition, ignoreCondition, false); - } - - /** - * Creates a pair of states that enables relaxed strictness before a zeroOrMore looping state. - * - * @param loopingState the first state of zeroOrMore complex state - * @param lastSink the state that the looping one points to - * @return the newly created state - */ - @SuppressWarnings("unchecked") - private State createInitOptionalStateOfZeroOrMore(final State loopingState, final State lastSink) { - final IterativeCondition takeCondition = extendWithUntilCondition( - getTakeCondition(currentPattern), - (IterativeCondition) currentPattern.getUntilCondition(), - true - ); - - final IterativeCondition ignoreFunction = getIgnoreCondition(currentPattern); - - return createSingletonState(loopingState, lastSink, takeCondition, ignoreFunction, true); - } - /** * This method extends the given condition with stop(until) condition if necessary. * The until condition needs to be applied only if both of the given conditions are not null. @@ -825,114 +868,15 @@ private IterativeCondition getTrueFunction() { } return trueCondition; } - } - - /** - * Used for migrating CEP graphs prior to 1.3. It removes the dummy start, adds the dummy end, and translates all - * states to consuming ones by moving all TAKEs and IGNOREs to the next state. This method assumes each state - * has at most one TAKE and one IGNORE and name of each state is unique. No PROCEED transition is allowed! - * - * @param oldStartState dummy start state of old graph - * @param type of events - * @return map of new states, where key is the name of a state and value is the state itself - */ - @Internal - public static Map> migrateGraph(State oldStartState) { - State oldFirst = oldStartState; - State oldSecond = oldStartState.getStateTransitions().iterator().next().getTargetState(); - - StateTransition oldFirstToSecondTake = Iterators.find( - oldFirst.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.TAKE; - } - - }); - - StateTransition oldFirstIgnore = Iterators.find( - oldFirst.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.IGNORE; - } - - }, null); - - StateTransition oldSecondToThirdTake = Iterators.find( - oldSecond.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.TAKE; - } - - }, null); - - final Map> convertedStates = new HashMap<>(); - State newSecond; - State newFirst = new State<>(oldSecond.getName(), State.StateType.Start); - convertedStates.put(newFirst.getName(), newFirst); - while (oldSecondToThirdTake != null) { - - newSecond = new State(oldSecondToThirdTake.getTargetState().getName(), State.StateType.Normal); - convertedStates.put(newSecond.getName(), newSecond); - newFirst.addTake(newSecond, oldFirstToSecondTake.getCondition()); - if (oldFirstIgnore != null) { - newFirst.addIgnore(oldFirstIgnore.getCondition()); + private void updateWithGreedyCondition( + State state, + IterativeCondition takeCondition) { + for (StateTransition stateTransition : state.getStateTransitions()) { + stateTransition.setCondition( + new AndCondition<>(stateTransition.getCondition(), new NotCondition<>(takeCondition))); } - - oldFirst = oldSecond; - - oldFirstToSecondTake = Iterators.find( - oldFirst.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.TAKE; - } - - }); - - oldFirstIgnore = Iterators.find( - oldFirst.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.IGNORE; - } - - }, null); - - oldSecond = oldSecondToThirdTake.getTargetState(); - - oldSecondToThirdTake = Iterators.find( - oldSecond.getStateTransitions().iterator(), - new Predicate>() { - @Override - public boolean apply(@Nullable StateTransition input) { - return input != null && input.getAction() == StateTransitionAction.TAKE; - } - - }, null); - - newFirst = newSecond; } - - final State endingState = new State<>(ENDING_STATE_NAME, State.StateType.Final); - - newFirst.addTake(endingState, oldFirstToSecondTake.getCondition()); - - if (oldFirstIgnore != null) { - newFirst.addIgnore(oldFirstIgnore.getCondition()); - } - - convertedStates.put(endingState.getName(), endingState); - - return convertedStates; } /** diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java index 66663d2dae2f4..ae2d7e4c9ad38 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java @@ -19,51 +19,35 @@ package org.apache.flink.cep.operator; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.CompatibilityResult; -import org.apache.flink.api.common.typeutils.CompatibilityUtil; -import org.apache.flink.api.common.typeutils.TypeDeserializerAdapter; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer; -import org.apache.flink.api.common.typeutils.base.CollectionSerializerConfigSnapshot; import org.apache.flink.api.common.typeutils.base.ListSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.nfa.NFA; import org.apache.flink.cep.nfa.compiler.NFACompiler; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.migration.streaming.runtime.streamrecord.MultiplexingStreamRecordSerializer; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.CheckpointedRestoringOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Triggerable; -import org.apache.flink.streaming.runtime.streamrecord.StreamElement; -import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.util.Migration; import org.apache.flink.util.Preconditions; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.Serializable; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.PriorityQueue; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -79,21 +63,16 @@ * @param Type of the key on which the input stream is keyed * @param Type of the output elements */ -public abstract class AbstractKeyedCEPPatternOperator - extends AbstractStreamOperator - implements OneInputStreamOperator, Triggerable, CheckpointedRestoringOperator { +public abstract class AbstractKeyedCEPPatternOperator + extends AbstractUdfStreamOperator + implements OneInputStreamOperator, Triggerable { private static final long serialVersionUID = -4166778210774160757L; - private static final int INITIAL_PRIORITY_QUEUE_CAPACITY = 11; - private final boolean isProcessingTime; private final TypeSerializer inputSerializer; - // necessary to serialize the set of seen keys - private final TypeSerializer keySerializer; - /////////////// State ////////////// private static final String NFA_OPERATOR_STATE_NAME = "nfaOperatorStateName"; @@ -112,29 +91,29 @@ public abstract class AbstractKeyedCEPPatternOperator */ private long lastWatermark; - /** - * A flag used in the case of migration that indicates if - * we are restoring from an old keyed or non-keyed operator. - */ - private final boolean migratingFromOldKeyedOperator; - private final EventComparator comparator; + protected final AfterMatchSkipStrategy afterMatchSkipStrategy; + public AbstractKeyedCEPPatternOperator( final TypeSerializer inputSerializer, final boolean isProcessingTime, - final TypeSerializer keySerializer, final NFACompiler.NFAFactory nfaFactory, - final boolean migratingFromOldKeyedOperator, - final EventComparator comparator) { + final EventComparator comparator, + final AfterMatchSkipStrategy afterMatchSkipStrategy, + final F function) { + super(function); this.inputSerializer = Preconditions.checkNotNull(inputSerializer); this.isProcessingTime = Preconditions.checkNotNull(isProcessingTime); - this.keySerializer = Preconditions.checkNotNull(keySerializer); this.nfaFactory = Preconditions.checkNotNull(nfaFactory); - - this.migratingFromOldKeyedOperator = migratingFromOldKeyedOperator; this.comparator = comparator; + + if (afterMatchSkipStrategy == null) { + this.afterMatchSkipStrategy = AfterMatchSkipStrategy.noSkip(); + } else { + this.afterMatchSkipStrategy = afterMatchSkipStrategy; + } } @Override @@ -348,7 +327,18 @@ private PriorityQueue getSortedTimestamps() throws Exception { * @param event The current event to be processed * @param timestamp The timestamp of the event */ - protected abstract void processEvent(NFA nfa, IN event, long timestamp); + private void processEvent(NFA nfa, IN event, long timestamp) { + Tuple2>>, Collection>, Long>>> patterns = + nfa.process(event, timestamp); + + try { + processMatchedSequences(patterns.f0, timestamp); + processTimedOutSequences(patterns.f1, timestamp); + } catch (Exception e) { + //rethrow as Runtime, to be able to use processEvent in Stream. + throw new RuntimeException(e); + } + } /** * Advances the time for the given NFA to the given timestamp. This can lead to pruning and @@ -357,295 +347,15 @@ private PriorityQueue getSortedTimestamps() throws Exception { * @param nfa to advance the time for * @param timestamp to advance the time to */ - protected abstract void advanceTime(NFA nfa, long timestamp); - - ////////////////////// Backwards Compatibility ////////////////////// - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - if (in instanceof Migration) { - // absorb the introduced byte from the migration stream - int hasUdfState = in.read(); - if (hasUdfState == 1) { - throw new Exception("Found UDF state but CEPOperator is not an UDF operator."); - } - } - - DataInputViewStreamWrapper inputView = new DataInputViewStreamWrapper(in); - timerService = getInternalTimerService( - "watermark-callbacks", - VoidNamespaceSerializer.INSTANCE, - this); - - // this is with the old serializer so that we can read the state. - ValueState> oldNfaOperatorState = getRuntimeContext().getState( - new ValueStateDescriptor<>("nfaOperatorState", new NFA.Serializer())); - - ValueState>> oldPriorityQueueOperatorState = - getRuntimeContext().getState( - new ValueStateDescriptor<>( - "priorityQueueStateName", - new PriorityQueueSerializer<>( - ((TypeSerializer) new StreamElementSerializer<>(inputSerializer)), - new PriorityQueueStreamRecordFactory() - ) - ) - ); - - if (migratingFromOldKeyedOperator) { - int numberEntries = inputView.readInt(); - for (int i = 0; i < numberEntries; i++) { - KEY key = keySerializer.deserialize(inputView); - setCurrentKey(key); - saveRegisterWatermarkTimer(); - - NFA nfa = oldNfaOperatorState.value(); - oldNfaOperatorState.clear(); - nfaOperatorState.update(nfa); - - PriorityQueue> priorityQueue = oldPriorityQueueOperatorState.value(); - if (priorityQueue != null && !priorityQueue.isEmpty()) { - Map> elementMap = new HashMap<>(); - for (StreamRecord record: priorityQueue) { - long timestamp = record.getTimestamp(); - IN element = record.getValue(); - - List elements = elementMap.get(timestamp); - if (elements == null) { - elements = new ArrayList<>(); - elementMap.put(timestamp, elements); - } - elements.add(element); - } - - // write the old state into the new one. - for (Map.Entry> entry: elementMap.entrySet()) { - elementQueueState.put(entry.getKey(), entry.getValue()); - } - - // clear the old state - oldPriorityQueueOperatorState.clear(); - } - } - } else { - - final ObjectInputStream ois = new ObjectInputStream(in); - - // retrieve the NFA - @SuppressWarnings("unchecked") - NFA nfa = (NFA) ois.readObject(); - - // retrieve the elements that were pending in the priority queue - MultiplexingStreamRecordSerializer recordSerializer = new MultiplexingStreamRecordSerializer<>(inputSerializer); - - Map> elementMap = new HashMap<>(); - int entries = ois.readInt(); - for (int i = 0; i < entries; i++) { - StreamElement streamElement = recordSerializer.deserialize(inputView); - StreamRecord record = streamElement.asRecord(); - - long timestamp = record.getTimestamp(); - IN element = record.getValue(); - - List elements = elementMap.get(timestamp); - if (elements == null) { - elements = new ArrayList<>(); - elementMap.put(timestamp, elements); - } - elements.add(element); - } - - // finally register the retrieved state with the new keyed state. - setCurrentKey((byte) 0); - nfaOperatorState.update(nfa); - - // write the priority queue to the new map state. - for (Map.Entry> entry: elementMap.entrySet()) { - elementQueueState.put(entry.getKey(), entry.getValue()); - } - - if (!isProcessingTime) { - // this is relevant only for event/ingestion time - setCurrentKey((byte) 0); - saveRegisterWatermarkTimer(); - } - ois.close(); - } + private void advanceTime(NFA nfa, long timestamp) throws Exception { + processEvent(nfa, null, timestamp); } - ////////////////////// Utility Classes ////////////////////// - - /** - * Custom type serializer implementation to serialize priority queues. - * - * @param Type of the priority queue's elements - */ - private static class PriorityQueueSerializer extends TypeSerializer> { - - private static final long serialVersionUID = -231980397616187715L; - - private final TypeSerializer elementSerializer; - private final PriorityQueueFactory factory; - - PriorityQueueSerializer(final TypeSerializer elementSerializer, final PriorityQueueFactory factory) { - this.elementSerializer = elementSerializer; - this.factory = factory; - } + protected abstract void processMatchedSequences(Iterable>> matchingSequences, long timestamp) throws Exception; - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public TypeSerializer> duplicate() { - return new PriorityQueueSerializer<>(elementSerializer.duplicate(), factory); - } - - @Override - public PriorityQueue createInstance() { - return factory.createPriorityQueue(); - } - - @Override - public PriorityQueue copy(PriorityQueue from) { - PriorityQueue result = factory.createPriorityQueue(); - - for (T element: from) { - result.offer(elementSerializer.copy(element)); - } - - return result; - } - - @Override - public PriorityQueue copy(PriorityQueue from, PriorityQueue reuse) { - reuse.clear(); - - for (T element: from) { - reuse.offer(elementSerializer.copy(element)); - } - - return reuse; - } - - @Override - public int getLength() { - return 0; - } - - @Override - public void serialize(PriorityQueue record, DataOutputView target) throws IOException { - target.writeInt(record.size()); - - for (T element: record) { - elementSerializer.serialize(element, target); - } - } - - @Override - public PriorityQueue deserialize(DataInputView source) throws IOException { - PriorityQueue result = factory.createPriorityQueue(); - - return deserialize(result, source); - } - - @Override - public PriorityQueue deserialize(PriorityQueue reuse, DataInputView source) throws IOException { - reuse.clear(); - - int numberEntries = source.readInt(); - - for (int i = 0; i < numberEntries; i++) { - reuse.offer(elementSerializer.deserialize(source)); - } - - return reuse; - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof PriorityQueueSerializer) { - @SuppressWarnings("unchecked") - PriorityQueueSerializer other = (PriorityQueueSerializer) obj; - - return factory.equals(other.factory) && elementSerializer.equals(other.elementSerializer); - } else { - return false; - } - } - - @Override - public boolean canEqual(Object obj) { - return obj instanceof PriorityQueueSerializer; - } - - @Override - public int hashCode() { - return Objects.hash(factory, elementSerializer); - } - - // -------------------------------------------------------------------------------------------- - // Serializer configuration snapshotting & compatibility - // -------------------------------------------------------------------------------------------- - - @Override - public TypeSerializerConfigSnapshot snapshotConfiguration() { - return new CollectionSerializerConfigSnapshot<>(elementSerializer); - } - - @Override - public CompatibilityResult> ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { - if (configSnapshot instanceof CollectionSerializerConfigSnapshot) { - Tuple2, TypeSerializerConfigSnapshot> previousElemSerializerAndConfig = - ((CollectionSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); - - CompatibilityResult compatResult = CompatibilityUtil.resolveCompatibilityResult( - previousElemSerializerAndConfig.f0, - UnloadableDummyTypeSerializer.class, - previousElemSerializerAndConfig.f1, - elementSerializer); - - if (!compatResult.isRequiresMigration()) { - return CompatibilityResult.compatible(); - } else if (compatResult.getConvertDeserializer() != null) { - return CompatibilityResult.requiresMigration( - new PriorityQueueSerializer<>( - new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer()), factory)); - } - } - - return CompatibilityResult.requiresMigration(); - } - } - - private interface PriorityQueueFactory extends Serializable { - PriorityQueue createPriorityQueue(); - } - - private static class PriorityQueueStreamRecordFactory implements PriorityQueueFactory> { - - private static final long serialVersionUID = 1254766984454616593L; - - @Override - public PriorityQueue> createPriorityQueue() { - return new PriorityQueue>(INITIAL_PRIORITY_QUEUE_CAPACITY, new StreamRecordComparator()); - } - - @Override - public boolean equals(Object obj) { - return obj instanceof PriorityQueueStreamRecordFactory; - } - - @Override - public int hashCode() { - return getClass().hashCode(); - } + protected void processTimedOutSequences( + Iterable>, Long>> timedOutSequences, + long timestamp) throws Exception { } ////////////////////// Testing Methods ////////////////////// diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java index de2d8f8ce0c2c..7c0c55d4b2b89 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java @@ -18,28 +18,25 @@ package org.apache.flink.cep.operator; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.ByteSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.NullByteKeySelector; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.typeutils.EitherTypeInfo; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternFlatSelectFunction; +import org.apache.flink.cep.PatternFlatTimeoutFunction; +import org.apache.flink.cep.PatternSelectFunction; import org.apache.flink.cep.PatternStream; +import org.apache.flink.cep.PatternTimeoutFunction; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.nfa.compiler.NFACompiler; import org.apache.flink.cep.pattern.Pattern; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.KeyedStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; -import org.apache.flink.types.Either; - -import java.util.List; -import java.util.Map; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.util.OutputTag; /** * Utility methods for creating {@link PatternStream}. @@ -47,123 +44,268 @@ public class CEPOperatorUtils { /** - * Creates a data stream containing the fully matching event patterns of the NFA computation. + * Creates a data stream containing results of {@link PatternSelectFunction} to fully matching event patterns. * - * @param Type of the key - * @return Data stream containing fully matched event sequences stored in a {@link Map}. The - * events are indexed by their associated names of the pattern. + * @param inputStream stream of input events + * @param pattern pattern to be search for in the stream + * @param selectFunction function to be applied to matching event sequences + * @param outTypeInfo output TypeInformation of selectFunction + * @param type of input events + * @param type of output events + * @return Data stream containing fully matched event sequence with applied {@link PatternSelectFunction} */ - public static SingleOutputStreamOperator>> createPatternStream( - DataStream inputStream, - Pattern pattern, - EventComparator comparator) { - final TypeSerializer inputSerializer = inputStream.getType().createSerializer(inputStream.getExecutionConfig()); - - // check whether we use processing time - final boolean isProcessingTime = inputStream.getExecutionEnvironment().getStreamTimeCharacteristic() == TimeCharacteristic.ProcessingTime; - - // compile our pattern into a NFAFactory to instantiate NFAs later on - final NFACompiler.NFAFactory nfaFactory = NFACompiler.compileFactory(pattern, inputSerializer, false); - - final SingleOutputStreamOperator>> patternStream; + public static SingleOutputStreamOperator createPatternStream( + final DataStream inputStream, + final Pattern pattern, + final EventComparator comparator, + final PatternSelectFunction selectFunction, + final TypeInformation outTypeInfo) { + return createPatternStream(inputStream, pattern, outTypeInfo, false, comparator, new OperatorBuilder() { + @Override + public OneInputStreamOperator build( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy) { + return new SelectCepOperator<>( + inputSerializer, + isProcessingTime, + nfaFactory, + comparator, + skipStrategy, + selectFunction + ); + } - if (inputStream instanceof KeyedStream) { - // We have to use the KeyedCEPPatternOperator which can deal with keyed input streams - KeyedStream keyedStream = (KeyedStream) inputStream; + @Override + public String getKeyedOperatorName() { + return "SelectCepOperator"; + } - TypeSerializer keySerializer = keyedStream.getKeyType().createSerializer(keyedStream.getExecutionConfig()); + @Override + public String getOperatorName() { + return "SelectCepOperator"; + } + }); + } - patternStream = keyedStream.transform( - "KeyedCEPPatternOperator", - (TypeInformation>>) (TypeInformation) TypeExtractor.getForClass(Map.class), - new KeyedCEPPatternOperator<>( + /** + * Creates a data stream containing results of {@link PatternFlatSelectFunction} to fully matching event patterns. + * + * @param inputStream stream of input events + * @param pattern pattern to be search for in the stream + * @param selectFunction function to be applied to matching event sequences + * @param outTypeInfo output TypeInformation of selectFunction + * @param type of input events + * @param type of output events + * @return Data stream containing fully matched event sequence with applied {@link PatternFlatSelectFunction} + */ + public static SingleOutputStreamOperator createPatternStream( + final DataStream inputStream, + final Pattern pattern, + final EventComparator comparator, + final PatternFlatSelectFunction selectFunction, + final TypeInformation outTypeInfo) { + return createPatternStream(inputStream, pattern, outTypeInfo, false, comparator, new OperatorBuilder() { + @Override + public OneInputStreamOperator build( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy) { + return new FlatSelectCepOperator<>( inputSerializer, isProcessingTime, - keySerializer, nfaFactory, - true, - comparator)); - } else { + comparator, + skipStrategy, + selectFunction + ); + } - KeySelector keySelector = new NullByteKeySelector<>(); - TypeSerializer keySerializer = ByteSerializer.INSTANCE; + @Override + public String getKeyedOperatorName() { + return "FlatSelectCepOperator"; + } - patternStream = inputStream.keyBy(keySelector).transform( - "CEPPatternOperator", - (TypeInformation>>) (TypeInformation) TypeExtractor.getForClass(Map.class), - new KeyedCEPPatternOperator<>( + @Override + public String getOperatorName() { + return "FlatSelectCepOperator"; + } + }); + } + + /** + * Creates a data stream containing results of {@link PatternFlatSelectFunction} to fully matching event patterns and + * also timed out partially matched with applied {@link PatternFlatTimeoutFunction} as a sideoutput. + * + * @param inputStream stream of input events + * @param pattern pattern to be search for in the stream + * @param selectFunction function to be applied to matching event sequences + * @param outTypeInfo output TypeInformation of selectFunction + * @param outputTag {@link OutputTag} for a side-output with timed out matches + * @param timeoutFunction function to be applied to timed out event sequences + * @param type of input events + * @param type of fully matched events + * @param type of timed out events + * @return Data stream containing fully matched event sequence with applied {@link PatternFlatSelectFunction} that + * contains timed out patterns with applied {@link PatternFlatTimeoutFunction} as side-output + */ + public static SingleOutputStreamOperator createTimeoutPatternStream( + final DataStream inputStream, + final Pattern pattern, + final EventComparator comparator, + final PatternFlatSelectFunction selectFunction, + final TypeInformation outTypeInfo, + final OutputTag outputTag, + final PatternFlatTimeoutFunction timeoutFunction) { + return createPatternStream(inputStream, pattern, outTypeInfo, true, comparator, new OperatorBuilder() { + @Override + public OneInputStreamOperator build( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy) { + return new FlatSelectTimeoutCepOperator<>( inputSerializer, isProcessingTime, - keySerializer, nfaFactory, - false, - comparator - )).forceNonParallel(); - } + comparator, + skipStrategy, + selectFunction, + timeoutFunction, + outputTag + ); + } - return patternStream; + @Override + public String getKeyedOperatorName() { + return "FlatSelectTimeoutCepOperator"; + } + + @Override + public String getOperatorName() { + return "FlatSelectTimeoutCepOperator"; + } + }); } /** - * Creates a data stream containing fully matching event patterns or partially matching event - * patterns which have timed out. The former are wrapped in a Either.Right and the latter in a - * Either.Left type. + * Creates a data stream containing results of {@link PatternSelectFunction} to fully matching event patterns and + * also timed out partially matched with applied {@link PatternTimeoutFunction} as a sideoutput. * - * @param Type of the key - * @return Data stream containing fully matched and partially matched event sequences wrapped in - * a {@link Either} instance. + * @param inputStream stream of input events + * @param pattern pattern to be search for in the stream + * @param selectFunction function to be applied to matching event sequences + * @param outTypeInfo output TypeInformation of selectFunction + * @param outputTag {@link OutputTag} for a side-output with timed out matches + * @param timeoutFunction function to be applied to timed out event sequences + * @param type of input events + * @param type of fully matched events + * @param type of timed out events + * @return Data stream containing fully matched event sequence with applied {@link PatternSelectFunction} that + * contains timed out patterns with applied {@link PatternTimeoutFunction} as side-output */ - public static SingleOutputStreamOperator>, Long>, Map>>> createTimeoutPatternStream( - DataStream inputStream, Pattern pattern, EventComparator comparator) { + public static SingleOutputStreamOperator createTimeoutPatternStream( + final DataStream inputStream, + final Pattern pattern, + final EventComparator comparator, + final PatternSelectFunction selectFunction, + final TypeInformation outTypeInfo, + final OutputTag outputTag, + final PatternTimeoutFunction timeoutFunction) { + return createPatternStream(inputStream, pattern, outTypeInfo, true, comparator, new OperatorBuilder() { + @Override + public OneInputStreamOperator build( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy) { + return new SelectTimeoutCepOperator<>( + inputSerializer, + isProcessingTime, + nfaFactory, + comparator, + skipStrategy, + selectFunction, + timeoutFunction, + outputTag + ); + } - final TypeSerializer inputSerializer = inputStream.getType().createSerializer(inputStream.getExecutionConfig()); + @Override + public String getKeyedOperatorName() { + return "SelectTimeoutCepOperator"; + } + + @Override + public String getOperatorName() { + return "SelectTimeoutCepOperator"; + } + }); + } + + private static SingleOutputStreamOperator createPatternStream( + final DataStream inputStream, + final Pattern pattern, + final TypeInformation outTypeInfo, + final boolean timeoutHandling, + final EventComparator comparator, + final OperatorBuilder operatorBuilder) { + final TypeSerializer inputSerializer = inputStream.getType().createSerializer(inputStream.getExecutionConfig()); // check whether we use processing time final boolean isProcessingTime = inputStream.getExecutionEnvironment().getStreamTimeCharacteristic() == TimeCharacteristic.ProcessingTime; // compile our pattern into a NFAFactory to instantiate NFAs later on - final NFACompiler.NFAFactory nfaFactory = NFACompiler.compileFactory(pattern, inputSerializer, true); + final NFACompiler.NFAFactory nfaFactory = NFACompiler.compileFactory(pattern, inputSerializer, timeoutHandling); - final SingleOutputStreamOperator>, Long>, Map>>> patternStream; - - final TypeInformation>> rightTypeInfo = (TypeInformation>>) (TypeInformation) TypeExtractor.getForClass(Map.class); - final TypeInformation>, Long>> leftTypeInfo = new TupleTypeInfo<>(rightTypeInfo, BasicTypeInfo.LONG_TYPE_INFO); - final TypeInformation>, Long>, Map>>> eitherTypeInformation = new EitherTypeInfo<>(leftTypeInfo, rightTypeInfo); + final SingleOutputStreamOperator patternStream; if (inputStream instanceof KeyedStream) { - // We have to use the KeyedCEPPatternOperator which can deal with keyed input streams - KeyedStream keyedStream = (KeyedStream) inputStream; - - TypeSerializer keySerializer = keyedStream.getKeyType().createSerializer(keyedStream.getExecutionConfig()); + KeyedStream keyedStream = (KeyedStream) inputStream; patternStream = keyedStream.transform( - "TimeoutKeyedCEPPatternOperator", - eitherTypeInformation, - new TimeoutKeyedCEPPatternOperator<>( + operatorBuilder.getKeyedOperatorName(), + outTypeInfo, + operatorBuilder.build( inputSerializer, isProcessingTime, - keySerializer, nfaFactory, - true, - comparator)); + comparator, + pattern.getAfterMatchSkipStrategy())); } else { - - KeySelector keySelector = new NullByteKeySelector<>(); - TypeSerializer keySerializer = ByteSerializer.INSTANCE; + KeySelector keySelector = new NullByteKeySelector<>(); patternStream = inputStream.keyBy(keySelector).transform( - "TimeoutCEPPatternOperator", - eitherTypeInformation, - new TimeoutKeyedCEPPatternOperator<>( + operatorBuilder.getOperatorName(), + outTypeInfo, + operatorBuilder.build( inputSerializer, isProcessingTime, - keySerializer, nfaFactory, - false, - comparator + comparator, + pattern.getAfterMatchSkipStrategy() )).forceNonParallel(); } return patternStream; } + + private interface OperatorBuilder { + OneInputStreamOperator build( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy); + + String getKeyedOperatorName(); + + String getOperatorName(); + } } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectCepOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectCepOperator.java new file mode 100644 index 0000000000000..5e493728299a7 --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectCepOperator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternFlatSelectFunction; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.streaming.api.operators.TimestampedCollector; + +import java.util.List; +import java.util.Map; + +/** + * Version of {@link AbstractKeyedCEPPatternOperator} that applies given {@link PatternFlatSelectFunction} to fully matched event patterns. + * + * @param Type of the input elements + * @param Type of the key on which the input stream is keyed + * @param Type of the output elements + */ +public class FlatSelectCepOperator + extends AbstractKeyedCEPPatternOperator> { + private static final long serialVersionUID = 5845993459551561518L; + + public FlatSelectCepOperator( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy, + PatternFlatSelectFunction function) { + super(inputSerializer, isProcessingTime, nfaFactory, comparator, skipStrategy, function); + } + + private transient TimestampedCollector collector; + + @Override + public void open() throws Exception { + super.open(); + collector = new TimestampedCollector<>(output); + } + + @Override + protected void processMatchedSequences(Iterable>> matchingSequences, long timestamp) throws Exception { + for (Map> match : matchingSequences) { + collector.setAbsoluteTimestamp(timestamp); + getUserFunction().flatSelect(match, collector); + } + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectTimeoutCepOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectTimeoutCepOperator.java new file mode 100644 index 0000000000000..4423bb1dd40bd --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/FlatSelectTimeoutCepOperator.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternFlatSelectFunction; +import org.apache.flink.cep.PatternFlatTimeoutFunction; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.util.OutputTag; + +import java.util.List; +import java.util.Map; + +/** + * Version of {@link AbstractKeyedCEPPatternOperator} that applies given {@link PatternFlatSelectFunction} to fully + * matched event patterns and {@link PatternFlatTimeoutFunction} to timed out ones. The timed out elements are returned + * as a side-output. + * + * @param Type of the input elements + * @param Type of the key on which the input stream is keyed + * @param Type of the output elements + * @param Type of the timed out output elements + */ +public class FlatSelectTimeoutCepOperator extends + AbstractKeyedCEPPatternOperator> { + + private transient TimestampedCollector collector; + + private transient TimestampedSideOutputCollector sideOutputCollector; + + private OutputTag timedOutOutputTag; + + public FlatSelectTimeoutCepOperator( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy, + PatternFlatSelectFunction flatSelectFunction, + PatternFlatTimeoutFunction flatTimeoutFunction, + OutputTag outputTag) { + super( + inputSerializer, + isProcessingTime, + nfaFactory, + comparator, + skipStrategy, + new FlatSelectWrapper<>(flatSelectFunction, flatTimeoutFunction)); + this.timedOutOutputTag = outputTag; + } + + @Override + public void open() throws Exception { + super.open(); + collector = new TimestampedCollector<>(output); + sideOutputCollector = new TimestampedSideOutputCollector<>(timedOutOutputTag, output); + } + + @Override + protected void processMatchedSequences( + Iterable>> matchingSequences, + long timestamp) throws Exception { + for (Map> match : matchingSequences) { + getUserFunction().getFlatSelectFunction().flatSelect(match, collector); + } + } + + @Override + protected void processTimedOutSequences( + Iterable>, Long>> timedOutSequences, long timestamp) throws Exception { + for (Tuple2>, Long> match : timedOutSequences) { + sideOutputCollector.setAbsoluteTimestamp(timestamp); + getUserFunction().getFlatTimeoutFunction().timeout(match.f0, match.f1, sideOutputCollector); + } + } + + /** + * Wrapper that enables storing {@link PatternFlatSelectFunction} and {@link PatternFlatTimeoutFunction} functions + * in one udf. + */ + @Internal + public static class FlatSelectWrapper implements Function { + + private static final long serialVersionUID = -8320546120157150202L; + + private PatternFlatSelectFunction flatSelectFunction; + private PatternFlatTimeoutFunction flatTimeoutFunction; + + @VisibleForTesting + public PatternFlatSelectFunction getFlatSelectFunction() { + return flatSelectFunction; + } + + @VisibleForTesting + public PatternFlatTimeoutFunction getFlatTimeoutFunction() { + return flatTimeoutFunction; + } + + public FlatSelectWrapper( + PatternFlatSelectFunction flatSelectFunction, + PatternFlatTimeoutFunction flatTimeoutFunction) { + this.flatSelectFunction = flatSelectFunction; + this.flatTimeoutFunction = flatTimeoutFunction; + } + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/KeyedCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/KeyedCEPPatternOperator.java deleted file mode 100644 index 22f9c14a33d48..0000000000000 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/KeyedCEPPatternOperator.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cep.operator; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.cep.EventComparator; -import org.apache.flink.cep.nfa.NFA; -import org.apache.flink.cep.nfa.compiler.NFACompiler; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; - -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Map; - -/** - * CEP pattern operator which only returns fully matched event patterns stored in a {@link Map}. The - * events are indexed by the event names associated in the pattern specification. The operator works - * on keyed input data. - * - * @param Type of the input events - * @param Type of the key - */ -public class KeyedCEPPatternOperator extends AbstractKeyedCEPPatternOperator>> { - private static final long serialVersionUID = 5328573789532074581L; - - public KeyedCEPPatternOperator( - TypeSerializer inputSerializer, - boolean isProcessingTime, - TypeSerializer keySerializer, - NFACompiler.NFAFactory nfaFactory, - boolean migratingFromOldKeyedOperator, - EventComparator comparator) { - - super(inputSerializer, isProcessingTime, keySerializer, nfaFactory, migratingFromOldKeyedOperator, comparator); - } - - @Override - protected void processEvent(NFA nfa, IN event, long timestamp) { - Tuple2>>, Collection>, Long>>> patterns = - nfa.process(event, timestamp); - - emitMatchedSequences(patterns.f0, timestamp); - } - - @Override - protected void advanceTime(NFA nfa, long timestamp) { - Tuple2>>, Collection>, Long>>> patterns = - nfa.process(null, timestamp); - - emitMatchedSequences(patterns.f0, timestamp); - } - - private void emitMatchedSequences(Iterable>> matchedSequences, long timestamp) { - Iterator>> iterator = matchedSequences.iterator(); - - if (iterator.hasNext()) { - StreamRecord>> streamRecord = new StreamRecord<>(null, timestamp); - - do { - streamRecord.replace(iterator.next()); - output.collect(streamRecord); - } while (iterator.hasNext()); - } - } -} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectCepOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectCepOperator.java new file mode 100644 index 0000000000000..cbb49e676f984 --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectCepOperator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternSelectFunction; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import java.util.List; +import java.util.Map; + +/** + * Version of {@link AbstractKeyedCEPPatternOperator} that applies given {@link PatternSelectFunction} to fully matched event patterns. + * + * @param Type of the input elements + * @param Type of the key on which the input stream is keyed + * @param Type of the output elements + */ +public class SelectCepOperator + extends AbstractKeyedCEPPatternOperator> { + public SelectCepOperator( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator, + AfterMatchSkipStrategy skipStrategy, + PatternSelectFunction function) { + super(inputSerializer, isProcessingTime, nfaFactory, comparator, skipStrategy, function); + } + + @Override + protected void processMatchedSequences(Iterable>> matchingSequences, long timestamp) throws Exception { + for (Map> match : matchingSequences) { + output.collect(new StreamRecord<>(getUserFunction().select(match), timestamp)); + } + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectTimeoutCepOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectTimeoutCepOperator.java new file mode 100644 index 0000000000000..cb233a486ec5e --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/SelectTimeoutCepOperator.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternSelectFunction; +import org.apache.flink.cep.PatternTimeoutFunction; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; + +import java.util.List; +import java.util.Map; + +/** + * Version of {@link AbstractKeyedCEPPatternOperator} that applies given {@link PatternSelectFunction} to fully + * matched event patterns and {@link PatternTimeoutFunction} to timed out ones. The timed out elements are returned + * as a side-output. + * + * @param Type of the input elements + * @param Type of the key on which the input stream is keyed + * @param Type of the output elements + * @param Type of the timed out output elements + */ +public class SelectTimeoutCepOperator + extends AbstractKeyedCEPPatternOperator> { + + private OutputTag timedOutOutputTag; + + public SelectTimeoutCepOperator( + TypeSerializer inputSerializer, + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + final EventComparator comparator, + AfterMatchSkipStrategy skipStrategy, + PatternSelectFunction flatSelectFunction, + PatternTimeoutFunction flatTimeoutFunction, + OutputTag outputTag) { + super( + inputSerializer, + isProcessingTime, + nfaFactory, + comparator, + skipStrategy, + new SelectWrapper<>(flatSelectFunction, flatTimeoutFunction)); + this.timedOutOutputTag = outputTag; + } + + @Override + protected void processMatchedSequences(Iterable>> matchingSequences, long timestamp) throws Exception { + for (Map> match : matchingSequences) { + output.collect(new StreamRecord<>(getUserFunction().getFlatSelectFunction().select(match), timestamp)); + } + } + + @Override + protected void processTimedOutSequences( + Iterable>, Long>> timedOutSequences, long timestamp) throws Exception { + for (Tuple2>, Long> match : timedOutSequences) { + output.collect(timedOutOutputTag, + new StreamRecord<>( + getUserFunction().getFlatTimeoutFunction().timeout(match.f0, match.f1), + timestamp)); + } + } + + /** + * Wrapper that enables storing {@link PatternSelectFunction} and {@link PatternTimeoutFunction} in one udf. + * + * @param Type of the input elements + * @param Type of the output elements + * @param Type of the timed out output elements + */ + @Internal + public static class SelectWrapper implements Function { + + private static final long serialVersionUID = -8320546120157150202L; + + private PatternSelectFunction flatSelectFunction; + private PatternTimeoutFunction flatTimeoutFunction; + + PatternSelectFunction getFlatSelectFunction() { + return flatSelectFunction; + } + + PatternTimeoutFunction getFlatTimeoutFunction() { + return flatTimeoutFunction; + } + + public SelectWrapper( + PatternSelectFunction flatSelectFunction, + PatternTimeoutFunction flatTimeoutFunction) { + this.flatSelectFunction = flatSelectFunction; + this.flatTimeoutFunction = flatTimeoutFunction; + } + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimeoutKeyedCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimeoutKeyedCEPPatternOperator.java deleted file mode 100644 index ca58955fb2c2f..0000000000000 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimeoutKeyedCEPPatternOperator.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cep.operator; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.cep.EventComparator; -import org.apache.flink.cep.nfa.NFA; -import org.apache.flink.cep.nfa.compiler.NFACompiler; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.types.Either; - -import java.util.Collection; -import java.util.List; -import java.util.Map; - -/** - * CEP pattern operator which returns fully and partially matched (timed-out) event patterns stored in a - * {@link Map}. The events are indexed by the event names associated in the pattern specification. The - * operator works on keyed input data. - * - * @param Type of the input events - * @param Type of the key - */ -public class TimeoutKeyedCEPPatternOperator extends AbstractKeyedCEPPatternOperator>, Long>, Map>>> { - private static final long serialVersionUID = 3570542177814518158L; - - public TimeoutKeyedCEPPatternOperator( - TypeSerializer inputSerializer, - boolean isProcessingTime, - TypeSerializer keySerializer, - NFACompiler.NFAFactory nfaFactory, - boolean migratingFromOldKeyedOperator, - EventComparator comparator) { - - super(inputSerializer, isProcessingTime, keySerializer, nfaFactory, migratingFromOldKeyedOperator, comparator); - } - - @Override - protected void processEvent(NFA nfa, IN event, long timestamp) { - Tuple2>>, Collection>, Long>>> patterns = - nfa.process(event, timestamp); - - emitMatchedSequences(patterns.f0, timestamp); - emitTimedOutSequences(patterns.f1, timestamp); - } - - @Override - protected void advanceTime(NFA nfa, long timestamp) { - Tuple2>>, Collection>, Long>>> patterns = - nfa.process(null, timestamp); - - emitMatchedSequences(patterns.f0, timestamp); - emitTimedOutSequences(patterns.f1, timestamp); - } - - private void emitTimedOutSequences(Iterable>, Long>> timedOutSequences, long timestamp) { - StreamRecord>, Long>, Map>>> streamRecord = - new StreamRecord<>(null, timestamp); - - for (Tuple2>, Long> partialPattern: timedOutSequences) { - streamRecord.replace(Either.Left(partialPattern)); - output.collect(streamRecord); - } - } - - protected void emitMatchedSequences(Iterable>> matchedSequences, long timestamp) { - StreamRecord>, Long>, Map>>> streamRecord = - new StreamRecord<>(null, timestamp); - - for (Map> matchedPattern : matchedSequences) { - streamRecord.replace(Either.Right(matchedPattern)); - output.collect(streamRecord); - } - } -} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimestampedSideOutputCollector.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimestampedSideOutputCollector.java new file mode 100644 index 0000000000000..533654390125e --- /dev/null +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/TimestampedSideOutputCollector.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +/** + * Wrapper around an {@link Output} for user functions that expect a {@link Collector}. + * Before giving the {@link TimestampedSideOutputCollector} to a user function you must set + * the timestamp that should be attached to emitted elements. Most operators + * would set the timestamp of the incoming + * {@link org.apache.flink.streaming.runtime.streamrecord.StreamRecord} here. + * + *

This version emits results into a SideOutput specified by given {@link OutputTag} + * + * @param The type of the elements that can be emitted. + */ +@Internal +public class TimestampedSideOutputCollector implements Collector { + + private final Output output; + + private final StreamRecord reuse; + + private final OutputTag outputTag; + + /** + * Creates a new {@link TimestampedSideOutputCollector} that wraps the given {@link Output} and collects + * results into sideoutput corresponding to {@link OutputTag}. + */ + public TimestampedSideOutputCollector(OutputTag outputTag, Output output) { + this.output = output; + this.outputTag = outputTag; + this.reuse = new StreamRecord(null); + } + + @Override + public void collect(T record) { + output.collect(outputTag, reuse.replace(record)); + } + + public void setTimestamp(StreamRecord timestampBase) { + if (timestampBase.hasTimestamp()) { + reuse.setTimestamp(timestampBase.getTimestamp()); + } else { + reuse.eraseTimestamp(); + } + } + + public void setAbsoluteTimestamp(long timestamp) { + reuse.setTimestamp(timestamp); + } + + public void eraseTimestamp() { + reuse.eraseTimestamp(); + } + + @Override + public void close() { + output.close(); + } +} diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/GroupPattern.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/GroupPattern.java index a20d37795c46e..fce408ce6e9ef 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/GroupPattern.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/GroupPattern.java @@ -18,6 +18,7 @@ package org.apache.flink.cep.pattern; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.pattern.conditions.IterativeCondition; /** @@ -31,16 +32,12 @@ public class GroupPattern extends Pattern { /** Group pattern representing the pattern definition of this group. */ private final Pattern groupPattern; - GroupPattern(final Pattern previous, final Pattern groupPattern) { - super("GroupPattern", previous); - this.groupPattern = groupPattern; - } - GroupPattern( final Pattern previous, final Pattern groupPattern, - final Quantifier.ConsumingStrategy consumingStrategy) { - super("GroupPattern", previous, consumingStrategy); + final Quantifier.ConsumingStrategy consumingStrategy, + final AfterMatchSkipStrategy afterMatchSkipStrategy) { + super("GroupPattern", previous, consumingStrategy, afterMatchSkipStrategy); this.groupPattern = groupPattern; } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Pattern.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Pattern.java index 2ffbc41a600eb..a276d9a5d3a46 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Pattern.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Pattern.java @@ -19,6 +19,7 @@ package org.apache.flink.cep.pattern; import org.apache.flink.api.java.ClosureCleaner; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.nfa.NFA; import org.apache.flink.cep.pattern.Quantifier.ConsumingStrategy; import org.apache.flink.cep.pattern.Quantifier.Times; @@ -70,18 +71,17 @@ public class Pattern { */ private Times times; - protected Pattern(final String name, final Pattern previous) { - this.name = name; - this.previous = previous; - } + private final AfterMatchSkipStrategy afterMatchSkipStrategy; protected Pattern( - final String name, - final Pattern previous, - final ConsumingStrategy consumingStrategy) { + final String name, + final Pattern previous, + final ConsumingStrategy consumingStrategy, + final AfterMatchSkipStrategy afterMatchSkipStrategy) { this.name = name; this.previous = previous; this.quantifier = Quantifier.one(consumingStrategy); + this.afterMatchSkipStrategy = afterMatchSkipStrategy; } public Pattern getPrevious() { @@ -121,7 +121,20 @@ public IterativeCondition getUntilCondition() { * @return The first pattern of a pattern sequence */ public static Pattern begin(final String name) { - return new Pattern(name, null); + return new Pattern<>(name, null, ConsumingStrategy.STRICT, AfterMatchSkipStrategy.noSkip()); + } + + /** + * Starts a new pattern sequence. The provided name is the one of the initial pattern + * of the new sequence. Furthermore, the base type of the event sequence is set. + * + * @param name The name of starting pattern of the new pattern sequence + * @param afterMatchSkipStrategy the {@link AfterMatchSkipStrategy.SkipStrategy} to use after each match. + * @param Base type of the event pattern + * @return The first pattern of a pattern sequence + */ + public static Pattern begin(final String name, final AfterMatchSkipStrategy afterMatchSkipStrategy) { + return new Pattern(name, null, ConsumingStrategy.STRICT, afterMatchSkipStrategy); } /** @@ -241,7 +254,7 @@ public Pattern within(Time windowTime) { * @return A new pattern which is appended to this one */ public Pattern next(final String name) { - return new Pattern<>(name, this, ConsumingStrategy.STRICT); + return new Pattern<>(name, this, ConsumingStrategy.STRICT, afterMatchSkipStrategy); } /** @@ -258,7 +271,7 @@ public Pattern notNext(final String name) { "You can simulate such pattern with two independent patterns, one with and the other without " + "the optional part."); } - return new Pattern<>(name, this, ConsumingStrategy.NOT_NEXT); + return new Pattern<>(name, this, ConsumingStrategy.NOT_NEXT, afterMatchSkipStrategy); } /** @@ -270,7 +283,7 @@ public Pattern notNext(final String name) { * @return A new pattern which is appended to this one */ public Pattern followedBy(final String name) { - return new Pattern<>(name, this, ConsumingStrategy.SKIP_TILL_NEXT); + return new Pattern<>(name, this, ConsumingStrategy.SKIP_TILL_NEXT, afterMatchSkipStrategy); } /** @@ -289,7 +302,7 @@ public Pattern notFollowedBy(final String name) { "You can simulate such pattern with two independent patterns, one with and the other without " + "the optional part."); } - return new Pattern<>(name, this, ConsumingStrategy.NOT_FOLLOW); + return new Pattern<>(name, this, ConsumingStrategy.NOT_FOLLOW, afterMatchSkipStrategy); } /** @@ -301,7 +314,7 @@ public Pattern notFollowedBy(final String name) { * @return A new pattern which is appended to this one */ public Pattern followedByAny(final String name) { - return new Pattern<>(name, this, ConsumingStrategy.SKIP_TILL_ANY); + return new Pattern<>(name, this, ConsumingStrategy.SKIP_TILL_ANY, afterMatchSkipStrategy); } /** @@ -312,6 +325,7 @@ public Pattern followedByAny(final String name) { * @throws MalformedPatternException if the quantifier is not applicable to this pattern. */ public Pattern optional() { + checkIfPreviousPatternGreedy(); quantifier.optional(); return this; } @@ -326,13 +340,28 @@ public Pattern optional() { * {@code A1 A2 B} appears, this will generate patterns: * {@code A1 B} and {@code A1 A2 B}. See also {@link #allowCombinations()}. * - * @return The same pattern with a {@link Quantifier#oneOrMore(ConsumingStrategy)} quantifier applied. + * @return The same pattern with a {@link Quantifier#looping(ConsumingStrategy)} quantifier applied. * @throws MalformedPatternException if the quantifier is not applicable to this pattern. */ public Pattern oneOrMore() { checkIfNoNotPattern(); checkIfQuantifierApplied(); - this.quantifier = Quantifier.oneOrMore(quantifier.getConsumingStrategy()); + this.quantifier = Quantifier.looping(quantifier.getConsumingStrategy()); + this.times = Times.of(1); + return this; + } + + /** + * Specifies that this pattern is greedy. + * This means as many events as possible will be matched to this pattern. + * + * @return The same pattern with {@link Quantifier#greedy} set to true. + * @throws MalformedPatternException if the quantifier is not applicable to this pattern. + */ + public Pattern greedy() { + checkIfNoNotPattern(); + checkIfNoGroupPattern(); + this.quantifier.greedy(); return this; } @@ -375,7 +404,23 @@ public Pattern times(int from, int to) { } /** - * Applicable only to {@link Quantifier#oneOrMore(ConsumingStrategy)} and + * Specifies that this pattern can occur the specified times at least. + * This means at least the specified times and at most infinite number of events can + * be matched to this pattern. + * + * @return The same pattern with a {@link Quantifier#looping(ConsumingStrategy)} quantifier applied. + * @throws MalformedPatternException if the quantifier is not applicable to this pattern. + */ + public Pattern timesOrMore(int times) { + checkIfNoNotPattern(); + checkIfQuantifierApplied(); + this.quantifier = Quantifier.looping(quantifier.getConsumingStrategy()); + this.times = Times.of(times); + return this; + } + + /** + * Applicable only to {@link Quantifier#looping(ConsumingStrategy)} and * {@link Quantifier#times(ConsumingStrategy)} patterns, this option allows more flexibility to the matching events. * *

If {@code allowCombinations()} is not applied for a @@ -431,6 +476,19 @@ public Pattern consecutive() { return this; } + /** + * Starts a new pattern sequence. The provided pattern is the initial pattern + * of the new sequence. + * + * + * @param group the pattern to begin with + * @param afterMatchSkipStrategy the {@link AfterMatchSkipStrategy.SkipStrategy} to use after each match. + * @return The first pattern of a pattern sequence + */ + public static GroupPattern begin(final Pattern group, final AfterMatchSkipStrategy afterMatchSkipStrategy) { + return new GroupPattern<>(null, group, ConsumingStrategy.STRICT, afterMatchSkipStrategy); + } + /** * Starts a new pattern sequence. The provided pattern is the initial pattern * of the new sequence. @@ -439,7 +497,7 @@ public Pattern consecutive() { * @return the first pattern of a pattern sequence */ public static GroupPattern begin(Pattern group) { - return new GroupPattern<>(null, group); + return new GroupPattern<>(null, group, ConsumingStrategy.STRICT, AfterMatchSkipStrategy.noSkip()); } /** @@ -451,7 +509,7 @@ public static GroupPattern begin(Pattern group) { * @return A new pattern which is appended to this one */ public GroupPattern followedBy(Pattern group) { - return new GroupPattern<>(this, group, ConsumingStrategy.SKIP_TILL_NEXT); + return new GroupPattern<>(this, group, ConsumingStrategy.SKIP_TILL_NEXT, afterMatchSkipStrategy); } /** @@ -463,7 +521,7 @@ public GroupPattern followedBy(Pattern group) { * @return A new pattern which is appended to this one */ public GroupPattern followedByAny(Pattern group) { - return new GroupPattern<>(this, group, ConsumingStrategy.SKIP_TILL_ANY); + return new GroupPattern<>(this, group, ConsumingStrategy.SKIP_TILL_ANY, afterMatchSkipStrategy); } /** @@ -476,7 +534,7 @@ public GroupPattern followedByAny(Pattern group) { * @return A new pattern which is appended to this one */ public GroupPattern next(Pattern group) { - return new GroupPattern<>(this, group, ConsumingStrategy.STRICT); + return new GroupPattern<>(this, group, ConsumingStrategy.STRICT, afterMatchSkipStrategy); } private void checkIfNoNotPattern() { @@ -492,4 +550,23 @@ private void checkIfQuantifierApplied() { "Current quantifier is: " + quantifier); } } + + /** + * @return the pattern's {@link AfterMatchSkipStrategy.SkipStrategy} after match. + */ + public AfterMatchSkipStrategy getAfterMatchSkipStrategy() { + return afterMatchSkipStrategy; + } + + private void checkIfNoGroupPattern() { + if (this instanceof GroupPattern) { + throw new MalformedPatternException("Option not applicable to group pattern"); + } + } + + private void checkIfPreviousPatternGreedy() { + if (previous != null && previous.getQuantifier().hasProperty(Quantifier.QuantifierProperty.GREEDY)) { + throw new MalformedPatternException("Optional pattern cannot be preceded by greedy pattern"); + } + } } diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Quantifier.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Quantifier.java index 9192a133dfbae..b55051d0609ad 100644 --- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Quantifier.java +++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/pattern/Quantifier.java @@ -55,7 +55,7 @@ public static Quantifier one(final ConsumingStrategy consumingStrategy) { return new Quantifier(consumingStrategy, QuantifierProperty.SINGLE); } - public static Quantifier oneOrMore(final ConsumingStrategy consumingStrategy) { + public static Quantifier looping(final ConsumingStrategy consumingStrategy) { return new Quantifier(consumingStrategy, QuantifierProperty.LOOPING); } @@ -105,6 +105,15 @@ public void optional() { properties.add(Quantifier.QuantifierProperty.OPTIONAL); } + public void greedy() { + checkPattern(!(innerConsumingStrategy == ConsumingStrategy.SKIP_TILL_ANY), + "Option not applicable to FollowedByAny pattern"); + checkPattern(!hasProperty(Quantifier.QuantifierProperty.SINGLE), + "Option not applicable to singleton quantifier"); + + properties.add(QuantifierProperty.GREEDY); + } + @Override public boolean equals(Object o) { if (this == o) { @@ -130,7 +139,8 @@ public enum QuantifierProperty { SINGLE, LOOPING, TIMES, - OPTIONAL + OPTIONAL, + GREEDY } /** diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/AfterMatchSkipITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/AfterMatchSkipITCase.java new file mode 100644 index 0000000000000..f767d92459ebb --- /dev/null +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/AfterMatchSkipITCase.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.nfa; + +import org.apache.flink.cep.Event; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.cep.pattern.Pattern; +import org.apache.flink.cep.pattern.conditions.SimpleCondition; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.cep.nfa.NFATestUtilities.compareMaps; +import static org.apache.flink.cep.nfa.NFATestUtilities.feedNFA; + +/** + * IT tests covering {@link AfterMatchSkipStrategy}. + */ +public class AfterMatchSkipITCase extends TestLogger{ + + @Test + public void testSkipToNext() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a", 0.0); + Event a2 = new Event(2, "a", 0.0); + Event a3 = new Event(3, "a", 0.0); + Event a4 = new Event(4, "a", 0.0); + Event a5 = new Event(5, "a", 0.0); + Event a6 = new Event(6, "a", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(a3)); + streamEvents.add(new StreamRecord(a4)); + streamEvents.add(new StreamRecord(a5)); + streamEvents.add(new StreamRecord(a6)); + + Pattern pattern = Pattern.begin("start", AfterMatchSkipStrategy.noSkip()) + .where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).times(3); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(a1, a2, a3), + Lists.newArrayList(a2, a3, a4), + Lists.newArrayList(a3, a4, a5), + Lists.newArrayList(a4, a5, a6) + )); + } + + @Test + public void testSkipPastLast() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a", 0.0); + Event a2 = new Event(2, "a", 0.0); + Event a3 = new Event(3, "a", 0.0); + Event a4 = new Event(4, "a", 0.0); + Event a5 = new Event(5, "a", 0.0); + Event a6 = new Event(6, "a", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(a3)); + streamEvents.add(new StreamRecord(a4)); + streamEvents.add(new StreamRecord(a5)); + streamEvents.add(new StreamRecord(a6)); + + Pattern pattern = Pattern.begin("start", AfterMatchSkipStrategy.skipPastLastEvent()) + .where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).times(3); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(a1, a2, a3), + Lists.newArrayList(a4, a5, a6) + )); + } + + @Test + public void testSkipToFirst() { + List> streamEvents = new ArrayList<>(); + + Event ab1 = new Event(1, "ab", 0.0); + Event ab2 = new Event(2, "ab", 0.0); + Event ab3 = new Event(3, "ab", 0.0); + Event ab4 = new Event(4, "ab", 0.0); + Event ab5 = new Event(5, "ab", 0.0); + Event ab6 = new Event(6, "ab", 0.0); + + streamEvents.add(new StreamRecord(ab1)); + streamEvents.add(new StreamRecord(ab2)); + streamEvents.add(new StreamRecord(ab3)); + streamEvents.add(new StreamRecord(ab4)); + streamEvents.add(new StreamRecord(ab5)); + streamEvents.add(new StreamRecord(ab6)); + + Pattern pattern = Pattern.begin("start", + AfterMatchSkipStrategy.skipToFirst("end")) + .where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + }).times(2).next("end").where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + }).times(2); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(ab1, ab2, ab3, ab4), + Lists.newArrayList(ab3, ab4, ab5, ab6) + )); + } + + @Test + public void testSkipToLast() { + List> streamEvents = new ArrayList<>(); + + Event ab1 = new Event(1, "ab", 0.0); + Event ab2 = new Event(2, "ab", 0.0); + Event ab3 = new Event(3, "ab", 0.0); + Event ab4 = new Event(4, "ab", 0.0); + Event ab5 = new Event(5, "ab", 0.0); + Event ab6 = new Event(6, "ab", 0.0); + Event ab7 = new Event(7, "ab", 0.0); + + streamEvents.add(new StreamRecord(ab1)); + streamEvents.add(new StreamRecord(ab2)); + streamEvents.add(new StreamRecord(ab3)); + streamEvents.add(new StreamRecord(ab4)); + streamEvents.add(new StreamRecord(ab5)); + streamEvents.add(new StreamRecord(ab6)); + streamEvents.add(new StreamRecord(ab7)); + + Pattern pattern = Pattern.begin("start", AfterMatchSkipStrategy.skipToLast("end")).where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + }).times(2).next("end").where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + }).times(2); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(ab1, ab2, ab3, ab4), + Lists.newArrayList(ab4, ab5, ab6, ab7) + )); + } + + @Test + public void testSkipPastLast2() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a1", 0.0); + Event a2 = new Event(2, "a2", 0.0); + Event b1 = new Event(3, "b1", 0.0); + Event b2 = new Event(4, "b2", 0.0); + Event c1 = new Event(5, "c1", 0.0); + Event c2 = new Event(6, "c2", 0.0); + Event d1 = new Event(7, "d1", 0.0); + Event d2 = new Event(7, "d2", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(b1)); + streamEvents.add(new StreamRecord(b2)); + streamEvents.add(new StreamRecord(c1)); + streamEvents.add(new StreamRecord(c2)); + streamEvents.add(new StreamRecord(d1)); + streamEvents.add(new StreamRecord(d2)); + + Pattern pattern = Pattern.begin("a", AfterMatchSkipStrategy.skipPastLastEvent()).where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + }).followedByAny("b").where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + } + ).followedByAny("c").where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("c"); + } + }).followedByAny("d").where(new SimpleCondition() { + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("d"); + } + }); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(a1, b1, c1, d1), + Lists.newArrayList(a1, b1, c2, d1), + Lists.newArrayList(a1, b2, c1, d1), + Lists.newArrayList(a1, b2, c2, d1), + Lists.newArrayList(a2, b1, c1, d1), + Lists.newArrayList(a2, b1, c2, d1), + Lists.newArrayList(a2, b2, c1, d1), + Lists.newArrayList(a2, b2, c2, d1) + )); + } + + @Test + public void testSkipPastLast3() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a1", 0.0); + Event c = new Event(2, "c", 0.0); + Event a2 = new Event(3, "a2", 0.0); + Event b2 = new Event(4, "b2", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(c)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(b2)); + + Pattern pattern = Pattern.begin("a", AfterMatchSkipStrategy.skipPastLastEvent() + ).where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + }).next("b").where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + } + ); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(a2, b2) + )); + } + + @Test + public void testSkipToFirstWithOptionalMatch() { + List> streamEvents = new ArrayList<>(); + + Event ab1 = new Event(1, "ab1", 0.0); + Event c1 = new Event(2, "c1", 0.0); + Event ab2 = new Event(3, "ab2", 0.0); + Event c2 = new Event(4, "c2", 0.0); + + streamEvents.add(new StreamRecord(ab1)); + streamEvents.add(new StreamRecord(c1)); + streamEvents.add(new StreamRecord(ab2)); + streamEvents.add(new StreamRecord(c2)); + + Pattern pattern = Pattern.begin("x", AfterMatchSkipStrategy.skipToFirst("b") + ).where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("x"); + } + }).oneOrMore().optional().next("b").where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + } + ).next("c").where(new SimpleCondition() { + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("c"); + } + }); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(ab1, c1), + Lists.newArrayList(ab2, c2) + )); + } + + @Test + public void testSkipToFirstAtStartPosition() { + List> streamEvents = new ArrayList<>(); + + Event ab1 = new Event(1, "ab1", 0.0); + Event c1 = new Event(2, "c1", 0.0); + Event ab2 = new Event(3, "ab2", 0.0); + Event c2 = new Event(4, "c2", 0.0); + + streamEvents.add(new StreamRecord(ab1)); + streamEvents.add(new StreamRecord(c1)); + streamEvents.add(new StreamRecord(ab2)); + streamEvents.add(new StreamRecord(c2)); + + Pattern pattern = Pattern.begin("b", AfterMatchSkipStrategy.skipToFirst("b") + ).where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + } + ).next("c").where(new SimpleCondition() { + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("c"); + } + }); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(ab1, c1), + Lists.newArrayList(ab2, c2) + )); + } + + @Test + public void testSkipToFirstWithOneOrMore() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a1", 0.0); + Event b1 = new Event(2, "b1", 0.0); + Event a2 = new Event(3, "a2", 0.0); + Event b2 = new Event(4, "b2", 0.0); + Event b3 = new Event(5, "b3", 0.0); + Event a3 = new Event(3, "a3", 0.0); + Event b4 = new Event(4, "b4", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(b1)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(b2)); + streamEvents.add(new StreamRecord(b3)); + streamEvents.add(new StreamRecord(a3)); + streamEvents.add(new StreamRecord(b4)); + + Pattern pattern = Pattern.begin("a", AfterMatchSkipStrategy.skipToFirst("b") + ).where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + } + ).next("b").where(new SimpleCondition() { + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + }).oneOrMore().consecutive(); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(a1, b1), + Lists.newArrayList(a2, b2), + Lists.newArrayList(a3, b4) + )); + } + + @Test + public void testSkipToLastWithOneOrMore() { + List> streamEvents = new ArrayList<>(); + + Event a1 = new Event(1, "a1", 0.0); + Event b1 = new Event(2, "b1", 0.0); + Event a2 = new Event(3, "a2", 0.0); + Event b2 = new Event(4, "b2", 0.0); + Event b3 = new Event(5, "b3", 0.0); + Event a3 = new Event(3, "a3", 0.0); + Event b4 = new Event(4, "b4", 0.0); + + streamEvents.add(new StreamRecord(a1)); + streamEvents.add(new StreamRecord(b1)); + streamEvents.add(new StreamRecord(a2)); + streamEvents.add(new StreamRecord(b2)); + streamEvents.add(new StreamRecord(b3)); + streamEvents.add(new StreamRecord(a3)); + streamEvents.add(new StreamRecord(b4)); + + Pattern pattern = Pattern.begin("a", AfterMatchSkipStrategy.skipToLast("b") + ).where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + } + ).next("b").where(new SimpleCondition() { + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("b"); + } + }).oneOrMore().consecutive(); + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(streamEvents, nfa, pattern.getAfterMatchSkipStrategy()); + + compareMaps(resultingPatterns, Lists.newArrayList( + Lists.newArrayList(a1, b1), + Lists.newArrayList(a2, b2), + Lists.newArrayList(a3, b4) + )); + } +} diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GreedyITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GreedyITCase.java new file mode 100644 index 0000000000000..2c7f23c024f65 --- /dev/null +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GreedyITCase.java @@ -0,0 +1,907 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.nfa; + +import org.apache.flink.cep.Event; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.cep.pattern.Pattern; +import org.apache.flink.cep.pattern.conditions.SimpleCondition; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.cep.nfa.NFATestUtilities.compareMaps; +import static org.apache.flink.cep.nfa.NFATestUtilities.feedNFA; + +/** + * IT tests covering {@link Pattern#greedy()}. + */ +public class GreedyITCase extends TestLogger { + + @Test + public void testGreedyZeroOrMore() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyZeroOrMoreInBetween() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 2)); + inputEvents.add(new StreamRecord<>(a1, 3)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 4)); + inputEvents.add(new StreamRecord<>(a2, 5)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 6)); + inputEvents.add(new StreamRecord<>(a3, 7)); + inputEvents.add(new StreamRecord<>(d, 8)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyZeroOrMoreWithDummyEventsAfterQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(new Event(43, "dummy", 2.0), 4)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, d) + )); + } + + @Test + public void testGreedyZeroOrMoreWithDummyEventsBeforeQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(43, "dummy", 2.0), 2)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, d) + )); + } + + @Test + public void testGreedyUntilZeroOrMoreWithDummyEventsAfterQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 3.0); + Event a3 = new Event(43, "a", 3.0); + Event d = new Event(45, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(new Event(44, "a", 4.0), 5)); + inputEvents.add(new StreamRecord<>(d, 6)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().until(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getPrice() > 3.0; + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyUntilWithDummyEventsBeforeQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 3.0); + Event a3 = new Event(43, "a", 3.0); + Event d = new Event(45, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(44, "a", 4.0), 2)); + inputEvents.add(new StreamRecord<>(a1, 3)); + inputEvents.add(new StreamRecord<>(a2, 4)); + inputEvents.add(new StreamRecord<>(a3, 5)); + inputEvents.add(new StreamRecord<>(d, 6)); + + // c a* d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().until(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getPrice() > 3.0; + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, d) + )); + } + + @Test + public void testGreedyOneOrMore() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyOneOrMoreInBetween() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 2)); + inputEvents.add(new StreamRecord<>(a1, 3)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 4)); + inputEvents.add(new StreamRecord<>(a2, 5)); + inputEvents.add(new StreamRecord<>(new Event(1, "dummy", 1111), 6)); + inputEvents.add(new StreamRecord<>(a3, 7)); + inputEvents.add(new StreamRecord<>(d, 8)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyOneOrMoreWithDummyEventsAfterQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(new Event(43, "dummy", 2.0), 4)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, d) + )); + } + + @Test + public void testGreedyOneOrMoreWithDummyEventsBeforeQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event d = new Event(44, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(43, "dummy", 2.0), 2)); + inputEvents.add(new StreamRecord<>(d, 5)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList()); + } + + @Test + public void testGreedyUntilOneOrMoreWithDummyEventsAfterQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 3.0); + Event a3 = new Event(43, "a", 3.0); + Event d = new Event(45, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(new Event(44, "a", 4.0), 5)); + inputEvents.add(new StreamRecord<>(d, 6)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().until(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getPrice() > 3.0; + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d) + )); + } + + @Test + public void testGreedyUntilOneOrMoreWithDummyEventsBeforeQuantifier() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 3.0); + Event a3 = new Event(43, "a", 3.0); + Event d = new Event(45, "d", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(new Event(44, "a", 4.0), 2)); + inputEvents.add(new StreamRecord<>(a1, 3)); + inputEvents.add(new StreamRecord<>(a2, 4)); + inputEvents.add(new StreamRecord<>(a3, 5)); + inputEvents.add(new StreamRecord<>(d, 6)); + + // c a+ d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().greedy().until(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getPrice() > 3.0; + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList()); + } + + @Test + public void testGreedyZeroOrMoreBeforeGroupPattern() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(40, "a", 1.0); + Event a2 = new Event(40, "a", 1.0); + Event a3 = new Event(40, "a", 1.0); + Event d1 = new Event(40, "d", 1.0); + Event e1 = new Event(40, "e", 1.0); + Event d2 = new Event(40, "d", 1.0); + Event e2 = new Event(40, "e", 1.0); + Event f = new Event(44, "f", 3.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(new Event(43, "dummy", 2.0), 4)); + inputEvents.add(new StreamRecord<>(a3, 5)); + inputEvents.add(new StreamRecord<>(d1, 6)); + inputEvents.add(new StreamRecord<>(e1, 7)); + inputEvents.add(new StreamRecord<>(d2, 8)); + inputEvents.add(new StreamRecord<>(e2, 9)); + inputEvents.add(new StreamRecord<>(f, 10)); + + // c a* (d e){2} f + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy().followedBy(Pattern.begin("middle1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }).followedBy("middle2").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("e"); + } + })).times(2).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("f"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, d1, e1, d2, e2, f) + )); + } + + @Test + public void testEndWithZeroOrMoreGreedy() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(new Event(44, "dummy", 2.0), 4)); + inputEvents.add(new StreamRecord<>(a3, 5)); + + // c a* + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().greedy(); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c), + Lists.newArrayList(c, a1), + Lists.newArrayList(c, a1, a2), + Lists.newArrayList(c, a1, a2, a3) + )); + } + + @Test + public void testEndWithZeroOrMoreConsecutiveGreedy() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(new Event(44, "dummy", 2.0), 4)); + inputEvents.add(new StreamRecord<>(a3, 5)); + + // c a* + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).oneOrMore().optional().consecutive().greedy(); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c), + Lists.newArrayList(c, a1), + Lists.newArrayList(c, a1, a2) + )); + } + + @Test + public void testEndWithGreedyTimesRange() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event a4 = new Event(44, "a", 2.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(a4, 5)); + inputEvents.add(new StreamRecord<>(new Event(44, "dummy", 2.0), 6)); + + // c a{2, 5} + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).times(2, 5).greedy(); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2), + Lists.newArrayList(c, a1, a2, a3), + Lists.newArrayList(c, a1, a2, a3, a4) + )); + } + + @Test + public void testGreedyTimesRange() { + List> inputEvents = new ArrayList<>(); + + Event c = new Event(40, "c", 1.0); + Event a1 = new Event(41, "a", 2.0); + Event a2 = new Event(42, "a", 2.0); + Event a3 = new Event(43, "a", 2.0); + Event a4 = new Event(44, "a", 2.0); + Event d = new Event(45, "d", 2.0); + + inputEvents.add(new StreamRecord<>(c, 1)); + inputEvents.add(new StreamRecord<>(a1, 2)); + inputEvents.add(new StreamRecord<>(a2, 3)); + inputEvents.add(new StreamRecord<>(a3, 4)); + inputEvents.add(new StreamRecord<>(a4, 5)); + inputEvents.add(new StreamRecord<>(d, 6)); + + // c a{2, 5} d + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).times(2, 5).greedy().followedBy("end").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("d"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(c, a1, a2, a3, a4, d) + )); + } +} diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GroupITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GroupITCase.java index 226a0916ddf21..c2c7cdacf3eb8 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GroupITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/GroupITCase.java @@ -26,7 +26,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/IterativeConditionsITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/IterativeConditionsITCase.java index 910907fa42f9b..80754b7680b64 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/IterativeConditionsITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/IterativeConditionsITCase.java @@ -27,7 +27,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java index a83eb12874279..84278b15fd462 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java @@ -29,7 +29,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Assert; import org.junit.Test; @@ -2721,4 +2722,76 @@ public boolean filter(Event s) throws Exception { match.get("middle").toArray(), Lists.newArrayList(endEvent1, endEvent2, endEvent3).toArray()); } + + @Test + public void testNFAResultKeyOrdering() { + List> inputEvents = new ArrayList<>(); + + Event a1 = new Event(41, "a", 2.0); + Event b1 = new Event(41, "b", 3.0); + Event aa1 = new Event(41, "aa", 4.0); + Event bb1 = new Event(41, "bb", 5.0); + Event ab1 = new Event(41, "ab", 6.0); + + inputEvents.add(new StreamRecord<>(a1, 1)); + inputEvents.add(new StreamRecord<>(b1, 3)); + inputEvents.add(new StreamRecord<>(aa1, 4)); + inputEvents.add(new StreamRecord<>(bb1, 5)); + inputEvents.add(new StreamRecord<>(ab1, 6)); + + Pattern pattern = Pattern + .begin("a") + .where(new SimpleCondition() { + private static final long serialVersionUID = 6452194090480345053L; + + @Override + public boolean filter(Event s) throws Exception { + return s.getName().equals("a"); + } + }).next("b").where(new SimpleCondition() { + @Override + public boolean filter(Event s) throws Exception { + return s.getName().equals("b"); + } + }).next("aa").where(new SimpleCondition() { + @Override + public boolean filter(Event s) throws Exception { + return s.getName().equals("aa"); + } + }).next("bb").where(new SimpleCondition() { + @Override + public boolean filter(Event s) throws Exception { + return s.getName().equals("bb"); + } + }).next("ab").where(new SimpleCondition() { + @Override + public boolean filter(Event s) throws Exception { + return s.getName().equals("ab"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List>> resultingPatterns = new ArrayList<>(); + + for (StreamRecord inputEvent : inputEvents) { + Collection>> patterns = nfa.process( + inputEvent.getValue(), + inputEvent.getTimestamp()).f0; + + resultingPatterns.addAll(patterns); + } + + Assert.assertEquals(1L, resultingPatterns.size()); + + Map> match = resultingPatterns.get(0); + + List expectedOrder = Lists.newArrayList("a", "b", "aa", "bb", "ab"); + List resultOrder = new ArrayList<>(); + for (String key: match.keySet()) { + resultOrder.add(key); + } + + Assert.assertEquals(expectedOrder, resultOrder); + } } diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFATestUtilities.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFATestUtilities.java index 7bf0767755418..a9e17955ac8ba 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFATestUtilities.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFATestUtilities.java @@ -21,7 +21,6 @@ import org.apache.flink.cep.Event; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.primitives.Doubles; import org.junit.Assert; import java.util.ArrayList; @@ -37,12 +36,18 @@ public class NFATestUtilities { public static List> feedNFA(List> inputEvents, NFA nfa) { + return feedNFA(inputEvents, nfa, AfterMatchSkipStrategy.noSkip()); + } + + public static List> feedNFA(List> inputEvents, NFA nfa, + AfterMatchSkipStrategy afterMatchSkipStrategy) { List> resultingPatterns = new ArrayList<>(); for (StreamRecord inputEvent : inputEvents) { Collection>> patterns = nfa.process( inputEvent.getValue(), - inputEvent.getTimestamp()).f0; + inputEvent.getTimestamp(), + afterMatchSkipStrategy).f0; for (Map> p: patterns) { List res = new ArrayList<>(); @@ -96,7 +101,7 @@ private static class EventComparator implements Comparator { @Override public int compare(Event o1, Event o2) { int nameComp = o1.getName().compareTo(o2.getName()); - int priceComp = Doubles.compare(o1.getPrice(), o2.getPrice()); + int priceComp = Double.compare(o1.getPrice(), o2.getPrice()); int idComp = Integer.compare(o1.getId(), o2.getId()); if (nameComp == 0) { if (priceComp == 0) { diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NotPatternITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NotPatternITCase.java index 3b95eb4b6064b..9198ff8f07578 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NotPatternITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NotPatternITCase.java @@ -25,7 +25,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SameElementITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SameElementITCase.java index 183cb6d13d714..357107fce9cc8 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SameElementITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SameElementITCase.java @@ -26,8 +26,9 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterators; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SharedBufferTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SharedBufferTest.java index 3621bade5884e..dfbfa5fcd3f72 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SharedBufferTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/SharedBufferTest.java @@ -33,6 +33,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -190,4 +191,49 @@ public void testClearingSharedBufferWithMultipleEdgesBetweenEntries() { //There should be still events[1] and events[2] in the buffer assertFalse(sharedBuffer.isEmpty()); } + + @Test + public void testSharedBufferExtractOrder() { + SharedBuffer sharedBuffer = new SharedBuffer<>(Event.createTypeSerializer()); + int numberEvents = 10; + Event[] events = new Event[numberEvents]; + final long timestamp = 1L; + + for (int i = 0; i < numberEvents; i++) { + events[i] = new Event(i + 1, "e" + (i + 1), i); + } + + Map> expectedResult = new LinkedHashMap<>(); + expectedResult.put("a", new ArrayList<>()); + expectedResult.get("a").add(events[1]); + expectedResult.put("b", new ArrayList<>()); + expectedResult.get("b").add(events[2]); + expectedResult.put("aa", new ArrayList<>()); + expectedResult.get("aa").add(events[3]); + expectedResult.put("bb", new ArrayList<>()); + expectedResult.get("bb").add(events[4]); + expectedResult.put("c", new ArrayList<>()); + expectedResult.get("c").add(events[5]); + + sharedBuffer.put("a", events[1], timestamp, DeweyNumber.fromString("1")); + sharedBuffer.put("b", events[2], timestamp, "a", events[1], timestamp, 0, DeweyNumber.fromString("1.0")); + sharedBuffer.put("aa", events[3], timestamp, "b", events[2], timestamp, 1, DeweyNumber.fromString("1.0.0")); + sharedBuffer.put("bb", events[4], timestamp, "aa", events[3], timestamp, 2, DeweyNumber.fromString("1.0.0.0")); + sharedBuffer.put("c", events[5], timestamp, "bb", events[4], timestamp, 3, DeweyNumber.fromString("1.0.0.0.0")); + + Collection>> patternsResult = sharedBuffer.extractPatterns("c", events[5], timestamp, 4, DeweyNumber.fromString("1.0.0.0.0")); + + List expectedOrder = new ArrayList<>(); + expectedOrder.add("a"); + expectedOrder.add("b"); + expectedOrder.add("aa"); + expectedOrder.add("bb"); + expectedOrder.add("c"); + + List resultOrder = new ArrayList<>(); + for (String key: patternsResult.iterator().next().keySet()){ + resultOrder.add(key); + } + assertEquals(expectedOrder, resultOrder); + } } diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesOrMoreITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesOrMoreITCase.java new file mode 100644 index 0000000000000..4e540ddb1151f --- /dev/null +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesOrMoreITCase.java @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.nfa; + +import org.apache.flink.cep.Event; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.cep.pattern.Pattern; +import org.apache.flink.cep.pattern.conditions.SimpleCondition; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.cep.nfa.NFATestUtilities.compareMaps; +import static org.apache.flink.cep.nfa.NFATestUtilities.feedNFA; + +/** + * Tests for {@link Pattern#timesOrMore(int)}. + */ +public class TimesOrMoreITCase extends TestLogger { + @Test + public void testTimesOrMore() { + List> inputEvents = new ArrayList<>(); + + Event startEvent = new Event(40, "c", 1.0); + Event middleEvent1 = new Event(41, "a", 2.0); + Event middleEvent2 = new Event(42, "a", 3.0); + Event middleEvent3 = new Event(43, "a", 4.0); + Event end1 = new Event(44, "b", 5.0); + + inputEvents.add(new StreamRecord<>(startEvent, 1)); + inputEvents.add(new StreamRecord<>(middleEvent1, 2)); + inputEvents.add(new StreamRecord<>(middleEvent2, 3)); + inputEvents.add(new StreamRecord<>(middleEvent3, 4)); + inputEvents.add(new StreamRecord<>(end1, 6)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).allowCombinations().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + final List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(startEvent, middleEvent1, middleEvent2, middleEvent3, end1), + Lists.newArrayList(startEvent, middleEvent1, middleEvent2, end1), + Lists.newArrayList(startEvent, middleEvent1, middleEvent3, end1) + )); + } + + @Test + public void testTimesOrMoreNonStrict() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).allowCombinations().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreStrict() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).times(2).consecutive().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreStrictOptional() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).consecutive().optional().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreStrictOptional2() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,}, b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).next("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).consecutive().optional().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNonStrictOptional() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).optional().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNonStrictOptional2() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).allowCombinations().optional().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNonStrictOptional3() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 3)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).optional().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNonStrictWithNext() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 2)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 3)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 4)); + inputEvents.add(new StreamRecord<>(new Event(23, "f", 1.0), 5)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).next("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).allowCombinations().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent3, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNotStrictWithFollowedBy() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedBy("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end) + )); + } + + @Test + public void testTimesOrMoreNotStrictWithFollowedByAny() { + List> inputEvents = new ArrayList<>(); + + inputEvents.add(new StreamRecord<>(ConsecutiveData.startEvent, 1)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent1, 2)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent2, 4)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.middleEvent3, 6)); + inputEvents.add(new StreamRecord<>(ConsecutiveData.end, 7)); + + // c a{2,} b + Pattern pattern = Pattern.begin("start").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("c"); + } + }).followedByAny("middle").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("a"); + } + }).timesOrMore(2).allowCombinations().followedBy("end1").where(new SimpleCondition() { + private static final long serialVersionUID = 5726188262756267490L; + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().equals("b"); + } + }); + + NFA nfa = NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + + List> resultingPatterns = feedNFA(inputEvents, nfa); + + compareMaps(resultingPatterns, Lists.>newArrayList( + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent2, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent2, ConsecutiveData.middleEvent3, ConsecutiveData.end), + Lists.newArrayList(ConsecutiveData.startEvent, ConsecutiveData.middleEvent1, ConsecutiveData.middleEvent3, ConsecutiveData.end) + )); + } + + private static class ConsecutiveData { + private static final Event startEvent = new Event(40, "c", 1.0); + private static final Event middleEvent1 = new Event(41, "a", 2.0); + private static final Event middleEvent2 = new Event(42, "a", 3.0); + private static final Event middleEvent3 = new Event(43, "a", 4.0); + private static final Event end = new Event(44, "b", 5.0); + + private ConsecutiveData() { + } + } +} diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesRangeITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesRangeITCase.java index 37a953407cfd6..76ed26ab5965e 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesRangeITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/TimesRangeITCase.java @@ -25,7 +25,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/UntilConditionITCase.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/UntilConditionITCase.java index 639541d1da7ff..f88e5b21ab401 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/UntilConditionITCase.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/UntilConditionITCase.java @@ -25,7 +25,8 @@ import org.apache.flink.cep.pattern.conditions.SimpleCondition; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.ArrayList; diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/compiler/NFACompilerTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/compiler/NFACompilerTest.java index 6d4329a873a3a..ec2cf47578cf6 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/compiler/NFACompilerTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/compiler/NFACompilerTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.cep.Event; import org.apache.flink.cep.SubEvent; +import org.apache.flink.cep.nfa.AfterMatchSkipStrategy; import org.apache.flink.cep.nfa.NFA; import org.apache.flink.cep.nfa.State; import org.apache.flink.cep.nfa.StateTransition; @@ -33,6 +34,8 @@ import org.apache.flink.cep.pattern.conditions.SimpleCondition; import org.apache.flink.util.TestLogger; +import org.apache.flink.shaded.guava18.com.google.common.collect.Sets; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -42,7 +45,6 @@ import java.util.Map; import java.util.Set; -import static com.google.common.collect.Sets.newHashSet; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -142,14 +144,14 @@ public void testNFACompilerWithSimplePattern() { State startState = stateMap.get("start"); assertTrue(startState.isStart()); final Set> startTransitions = unfoldTransitions(startState); - assertEquals(newHashSet( + assertEquals(Sets.newHashSet( Tuple2.of("middle", StateTransitionAction.TAKE) ), startTransitions); assertTrue(stateMap.containsKey("middle")); State middleState = stateMap.get("middle"); final Set> middleTransitions = unfoldTransitions(middleState); - assertEquals(newHashSet( + assertEquals(Sets.newHashSet( Tuple2.of("middle", StateTransitionAction.IGNORE), Tuple2.of("end", StateTransitionAction.TAKE) ), middleTransitions); @@ -157,7 +159,7 @@ public void testNFACompilerWithSimplePattern() { assertTrue(stateMap.containsKey("end")); State endState = stateMap.get("end"); final Set> endTransitions = unfoldTransitions(endState); - assertEquals(newHashSet( + assertEquals(Sets.newHashSet( Tuple2.of(NFACompiler.ENDING_STATE_NAME, StateTransitionAction.TAKE) ), endTransitions); @@ -187,6 +189,37 @@ public void testNoUnnecessaryStateCopiesCreated() { assertEquals(1, endStateCount); } + @Test + public void testSkipToNotExistsMatchingPattern() { + expectedException.expect(MalformedPatternException.class); + expectedException.expectMessage("The pattern name specified in AfterMatchSkipStrategy can not be found in the given Pattern"); + + Pattern invalidPattern = Pattern.begin("start", + AfterMatchSkipStrategy.skipToLast("midd")).where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("a"); + } + }).next("middle").where( + new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("d"); + } + } + ).oneOrMore().optional().next("end").where(new SimpleCondition() { + + @Override + public boolean filter(Event value) throws Exception { + return value.getName().contains("c"); + } + }); + + NFACompiler.compile(invalidPattern, Event.createTypeSerializer(), false); + } + private Set> unfoldTransitions(final State state) { final Set> transitions = new HashSet<>(); for (StateTransition transition : state.getStateTransitions()) { diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration11to13Test.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration11to13Test.java deleted file mode 100644 index 95987c2c225be..0000000000000 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration11to13Test.java +++ /dev/null @@ -1,385 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.cep.operator; - -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.ByteSerializer; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.functions.NullByteKeySelector; -import org.apache.flink.cep.Event; -import org.apache.flink.cep.SubEvent; -import org.apache.flink.cep.nfa.NFA; -import org.apache.flink.cep.nfa.compiler.NFACompiler; -import org.apache.flink.cep.pattern.Pattern; -import org.apache.flink.cep.pattern.conditions.SimpleCondition; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.api.windowing.time.Time; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; -import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; - -import org.junit.Test; - -import java.net.URL; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentLinkedQueue; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -/** - * Tests for migration from 1.1.x to 1.3.x. - */ -public class CEPMigration11to13Test { - - private static String getResourceFilename(String filename) { - ClassLoader cl = CEPMigration11to13Test.class.getClassLoader(); - URL resource = cl.getResource(filename); - if (resource == null) { - throw new NullPointerException("Missing snapshot resource."); - } - return resource.getFile(); - } - - @Test - public void testKeyedCEPOperatorMigratation() throws Exception { - - KeySelector keySelector = new KeySelector() { - private static final long serialVersionUID = -4873366487571254798L; - - @Override - public Integer getKey(Event value) throws Exception { - return value.getId(); - } - }; - - final Event startEvent = new Event(42, "start", 1.0); - final SubEvent middleEvent = new SubEvent(42, "foo", 1.0, 10.0); - final Event endEvent = new Event(42, "end", 1.0); - - // uncomment these lines for regenerating the snapshot on Flink 1.1 - /* - OneInputStreamOperatorTestHarness> harness = new OneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - keySelector, - IntSerializer.INSTANCE, - new NFAFactory())); - harness.configureForKeyedStream(keySelector, BasicTypeInfo.INT_TYPE_INFO); - harness.open(); - harness.processElement(new StreamRecord(startEvent, 1)); - harness.processElement(new StreamRecord(new Event(42, "foobar", 1.0), 2)); - harness.processElement(new StreamRecord(new SubEvent(42, "barfoo", 1.0, 5.0), 3)); - harness.processWatermark(new Watermark(2)); - - harness.processElement(new StreamRecord(middleEvent, 3)); - - // simulate snapshot/restore with empty element queue but NFA state - StreamTaskState snapshot = harness.snapshot(1, 1); - FileOutputStream out = new FileOutputStream( - "src/test/resources/cep-keyed-1_1-snapshot"); - ObjectOutputStream oos = new ObjectOutputStream(out); - oos.writeObject(snapshot); - out.close(); - harness.close(); - */ - - OneInputStreamOperatorTestHarness>> harness = - new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), - keySelector, - BasicTypeInfo.INT_TYPE_INFO); - - try { - harness.setup(); - harness - .initializeStateFromLegacyCheckpoint(getResourceFilename("cep-keyed-1_1-snapshot")); - harness.open(); - - harness.processElement(new StreamRecord<>(new Event(42, "start", 1.0), 4)); - harness.processElement(new StreamRecord<>(endEvent, 5)); - - harness.processWatermark(new Watermark(20)); - - ConcurrentLinkedQueue result = harness.getOutput(); - - // watermark and the result - assertEquals(2, result.size()); - - Object resultObject = result.poll(); - assertTrue(resultObject instanceof StreamRecord); - StreamRecord resultRecord = (StreamRecord) resultObject; - assertTrue(resultRecord.getValue() instanceof Map); - - @SuppressWarnings("unchecked") - Map> patternMap = - (Map>) resultRecord.getValue(); - - assertEquals(startEvent, patternMap.get("start").get(0)); - assertEquals(middleEvent, patternMap.get("middle").get(0)); - assertEquals(endEvent, patternMap.get("end").get(0)); - - // and now go for a checkpoint with the new serializers - - final Event startEvent1 = new Event(42, "start", 2.0); - final SubEvent middleEvent1 = new SubEvent(42, "foo", 1.0, 11.0); - final Event endEvent1 = new Event(42, "end", 2.0); - - harness.processElement(new StreamRecord(startEvent1, 21)); - harness.processElement(new StreamRecord(middleEvent1, 23)); - - // simulate snapshot/restore with some elements in internal sorting queue - OperatorStateHandles snapshot = harness.snapshot(1L, 1L); - harness.close(); - - harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), - keySelector, - BasicTypeInfo.INT_TYPE_INFO); - - harness.setup(); - harness.initializeState(snapshot); - harness.open(); - - harness.processElement(new StreamRecord<>(endEvent1, 25)); - - harness.processWatermark(new Watermark(50)); - - result = harness.getOutput(); - - // watermark and the result - assertEquals(2, result.size()); - - Object resultObject1 = result.poll(); - assertTrue(resultObject1 instanceof StreamRecord); - StreamRecord resultRecord1 = (StreamRecord) resultObject1; - assertTrue(resultRecord1.getValue() instanceof Map); - - @SuppressWarnings("unchecked") - Map> patternMap1 = - (Map>) resultRecord1.getValue(); - - assertEquals(startEvent1, patternMap1.get("start").get(0)); - assertEquals(middleEvent1, patternMap1.get("middle").get(0)); - assertEquals(endEvent1, patternMap1.get("end").get(0)); - } finally { - harness.close(); - } - } - - @Test - public void testNonKeyedCEPFunctionMigration() throws Exception { - - final Event startEvent = new Event(42, "start", 1.0); - final SubEvent middleEvent = new SubEvent(42, "foo", 1.0, 10.0); - final Event endEvent = new Event(42, "end", 1.0); - - // uncomment these lines for regenerating the snapshot on Flink 1.1 - /* - OneInputStreamOperatorTestHarness> harness = new OneInputStreamOperatorTestHarness<>( - new CEPPatternOperator<>( - Event.createTypeSerializer(), - false, - new NFAFactory())); - harness.open(); - harness.processElement(new StreamRecord(startEvent, 1)); - harness.processElement(new StreamRecord(new Event(42, "foobar", 1.0), 2)); - harness.processElement(new StreamRecord(new SubEvent(42, "barfoo", 1.0, 5.0), 3)); - harness.processWatermark(new Watermark(2)); - - harness.processElement(new StreamRecord(middleEvent, 3)); - - // simulate snapshot/restore with empty element queue but NFA state - StreamTaskState snapshot = harness.snapshot(1, 1); - FileOutputStream out = new FileOutputStream( - "src/test/resources/cep-non-keyed-1.1-snapshot"); - ObjectOutputStream oos = new ObjectOutputStream(out); - oos.writeObject(snapshot); - out.close(); - harness.close(); - */ - - NullByteKeySelector keySelector = new NullByteKeySelector(); - - OneInputStreamOperatorTestHarness>> harness = - new KeyedOneInputStreamOperatorTestHarness>>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - ByteSerializer.INSTANCE, - new NFAFactory(), - false, - null), - keySelector, - BasicTypeInfo.BYTE_TYPE_INFO); - - try { - harness.setup(); - harness.initializeStateFromLegacyCheckpoint( - getResourceFilename("cep-non-keyed-1.1-snapshot")); - harness.open(); - - harness.processElement(new StreamRecord<>(new Event(42, "start", 1.0), 4)); - harness.processElement(new StreamRecord<>(endEvent, 5)); - - harness.processWatermark(new Watermark(20)); - - ConcurrentLinkedQueue result = harness.getOutput(); - - // watermark and the result - assertEquals(2, result.size()); - - Object resultObject = result.poll(); - assertTrue(resultObject instanceof StreamRecord); - StreamRecord resultRecord = (StreamRecord) resultObject; - assertTrue(resultRecord.getValue() instanceof Map); - - @SuppressWarnings("unchecked") - Map> patternMap = - (Map>) resultRecord.getValue(); - - assertEquals(startEvent, patternMap.get("start").get(0)); - assertEquals(middleEvent, patternMap.get("middle").get(0)); - assertEquals(endEvent, patternMap.get("end").get(0)); - - // and now go for a checkpoint with the new serializers - - final Event startEvent1 = new Event(42, "start", 2.0); - final SubEvent middleEvent1 = new SubEvent(42, "foo", 1.0, 11.0); - final Event endEvent1 = new Event(42, "end", 2.0); - - harness.processElement(new StreamRecord(startEvent1, 21)); - harness.processElement(new StreamRecord(middleEvent1, 23)); - - // simulate snapshot/restore with some elements in internal sorting queue - OperatorStateHandles snapshot = harness.snapshot(1L, 1L); - harness.close(); - - harness = new KeyedOneInputStreamOperatorTestHarness>>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - ByteSerializer.INSTANCE, - new NFAFactory(), - false, - null), - keySelector, - BasicTypeInfo.BYTE_TYPE_INFO); - - harness.setup(); - harness.initializeState(snapshot); - harness.open(); - - harness.processElement(new StreamRecord<>(endEvent1, 25)); - - harness.processWatermark(new Watermark(50)); - - result = harness.getOutput(); - - // watermark and the result - assertEquals(2, result.size()); - - Object resultObject1 = result.poll(); - assertTrue(resultObject1 instanceof StreamRecord); - StreamRecord resultRecord1 = (StreamRecord) resultObject1; - assertTrue(resultRecord1.getValue() instanceof Map); - - @SuppressWarnings("unchecked") - Map> patternMap1 = - (Map>) resultRecord1.getValue(); - - assertEquals(startEvent1, patternMap1.get("start").get(0)); - assertEquals(middleEvent1, patternMap1.get("middle").get(0)); - assertEquals(endEvent1, patternMap1.get("end").get(0)); - } finally { - harness.close(); - } - } - - private static class NFAFactory implements NFACompiler.NFAFactory { - - private static final long serialVersionUID = 1173020762472766713L; - - private final boolean handleTimeout; - - private NFAFactory() { - this(false); - } - - private NFAFactory(boolean handleTimeout) { - this.handleTimeout = handleTimeout; - } - - @Override - public NFA createNFA() { - - Pattern pattern = Pattern.begin("start").where(new StartFilter()) - .followedBy("middle").subtype(SubEvent.class).where(new MiddleFilter()) - .followedBy("end").where(new EndFilter()) - // add a window timeout to test whether timestamps of elements in the - // priority queue in CEP operator are correctly checkpointed/restored - .within(Time.milliseconds(10L)); - - return NFACompiler.compile(pattern, Event.createTypeSerializer(), handleTimeout); - } - } - - private static class StartFilter extends SimpleCondition { - private static final long serialVersionUID = 5726188262756267490L; - - @Override - public boolean filter(Event value) throws Exception { - return value.getName().equals("start"); - } - } - - private static class MiddleFilter extends SimpleCondition { - private static final long serialVersionUID = 6215754202506583964L; - - @Override - public boolean filter(SubEvent value) throws Exception { - return value.getVolume() > 5.0; - } - } - - private static class EndFilter extends SimpleCondition { - private static final long serialVersionUID = 7056763917392056548L; - - @Override - public boolean filter(Event value) throws Exception { - return value.getName().equals("end"); - } - } -} diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigrationTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigrationTest.java index 0eeff09cbee9c..ed28f254a43f6 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigrationTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigrationTest.java @@ -19,7 +19,6 @@ package org.apache.flink.cep.operator; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.cep.Event; import org.apache.flink.cep.SubEvent; @@ -48,6 +47,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentLinkedQueue; +import static org.apache.flink.cep.operator.CepOperatorTestUtilities.getKeyedCepOpearator; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -71,7 +71,7 @@ public class CEPMigrationTest { @Parameterized.Parameters(name = "Migration Savepoint: {0}") public static Collection parameters () { - return Arrays.asList(MigrationVersion.v1_2, MigrationVersion.v1_3); + return Arrays.asList(MigrationVersion.v1_3); } public CEPMigrationTest(MigrationVersion migrateVersion) { @@ -100,13 +100,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -151,13 +145,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -221,13 +209,7 @@ public Integer getKey(Event value) throws Exception { harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -282,13 +264,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -331,13 +307,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -415,13 +385,7 @@ public Integer getKey(Event value) throws Exception { harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null), + getKeyedCepOpearator(false, new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -475,13 +439,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new SinglePatternNFAFactory(), - true, - null), + getKeyedCepOpearator(false, new SinglePatternNFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); @@ -515,13 +473,7 @@ public Integer getKey(Event value) throws Exception { OneInputStreamOperatorTestHarness>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new SinglePatternNFAFactory(), - true, - null), + getKeyedCepOpearator(false, new SinglePatternNFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO); diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java index 46ad7bed2b46b..9eb60de6f0981 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java @@ -21,11 +21,12 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.cep.Event; +import org.apache.flink.cep.PatternSelectFunction; +import org.apache.flink.cep.PatternTimeoutFunction; import org.apache.flink.cep.SubEvent; import org.apache.flink.cep.nfa.NFA; import org.apache.flink.cep.nfa.compiler.NFACompiler; @@ -40,10 +41,11 @@ import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.types.Either; +import org.apache.flink.util.OutputTag; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.After; import org.junit.Assert; import org.junit.Rule; @@ -65,7 +67,7 @@ import static org.mockito.Mockito.validateMockitoUsage; /** - * Tests for {@link KeyedCEPPatternOperator} and {@link TimeoutKeyedCEPPatternOperator}. + * Tests for {@link AbstractKeyedCEPPatternOperator}. */ public class CEPOperatorTest extends TestLogger { @@ -239,8 +241,6 @@ public void testKeyedCEPOperatorCheckpointingWithRocksDB() throws Exception { */ @Test public void testKeyedAdvancingTimeWithoutElements() throws Exception { - final KeySelector keySelector = new TestKeySelector(); - final Event startEvent = new Event(42, "start", 1.0); final long watermarkTimestamp1 = 5L; final long watermarkTimestamp2 = 13L; @@ -248,21 +248,42 @@ public void testKeyedAdvancingTimeWithoutElements() throws Exception { final Map> expectedSequence = new HashMap<>(2); expectedSequence.put("start", Collections.singletonList(startEvent)); - OneInputStreamOperatorTestHarness>, Long>, Map>>> harness = new KeyedOneInputStreamOperatorTestHarness<>( - new TimeoutKeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFAFactory(true), - true, - null), - keySelector, - BasicTypeInfo.INT_TYPE_INFO); + final OutputTag>, Long>> timedOut = + new OutputTag>, Long>>("timedOut") {}; + final KeyedOneInputStreamOperatorTestHarness>> harness = + new KeyedOneInputStreamOperatorTestHarness<>( + new SelectTimeoutCepOperator<>( + Event.createTypeSerializer(), + false, + new NFAFactory(true), + null, + null, + new PatternSelectFunction>>() { + @Override + public Map> select(Map> pattern) throws Exception { + return pattern; + } + }, + new PatternTimeoutFunction>, Long>>() { + @Override + public Tuple2>, Long> timeout( + Map> pattern, + long timeoutTimestamp) throws Exception { + return Tuple2.of(pattern, timeoutTimestamp); + } + }, + timedOut + ), new KeySelector() { + @Override + public Integer getKey(Event value) throws Exception { + return value.getId(); + } + }, BasicTypeInfo.INT_TYPE_INFO); try { harness.setup( new KryoSerializer<>( - (Class>, Long>, Map>>>) (Object) Either.class, + (Class>>) (Object) Map.class, new ExecutionConfig())); harness.open(); @@ -271,8 +292,10 @@ public void testKeyedAdvancingTimeWithoutElements() throws Exception { harness.processWatermark(new Watermark(watermarkTimestamp2)); Queue result = harness.getOutput(); + Queue>, Long>>> sideOutput = harness.getSideOutput(timedOut); - assertEquals(3L, result.size()); + assertEquals(2L, result.size()); + assertEquals(1L, sideOutput.size()); Object watermark1 = result.poll(); @@ -280,19 +303,7 @@ public void testKeyedAdvancingTimeWithoutElements() throws Exception { assertEquals(watermarkTimestamp1, ((Watermark) watermark1).getTimestamp()); - Object resultObject = result.poll(); - - assertTrue(resultObject instanceof StreamRecord); - - StreamRecord>, Long>, Map>>> streamRecord = - (StreamRecord>, Long>, Map>>>) resultObject; - - assertTrue(streamRecord.getValue() instanceof Either.Left); - - Either.Left>, Long>, Map>> left = - (Either.Left>, Long>, Map>>) streamRecord.getValue(); - - Tuple2>, Long> leftResult = left.left(); + Tuple2>, Long> leftResult = sideOutput.poll().getValue(); assertEquals(watermarkTimestamp2, (long) leftResult.f1); assertEquals(expectedSequence, leftResult.f0); @@ -309,14 +320,12 @@ public void testKeyedAdvancingTimeWithoutElements() throws Exception { @Test public void testKeyedCEPOperatorNFAUpdate() throws Exception { - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), + + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( true, - null); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + new SimpleNFAFactory()); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness( + operator); try { harness.open(); @@ -331,14 +340,8 @@ public void testKeyedCEPOperatorNFAUpdate() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), - true, - null); - harness = getCepTestHarness(operator); + operator = CepOperatorTestUtilities.getKeyedCepOpearator(true, new SimpleNFAFactory()); + harness = CepOperatorTestUtilities.getCepTestHarness(operator); harness.setup(); harness.initializeState(snapshot); @@ -348,14 +351,8 @@ public void testKeyedCEPOperatorNFAUpdate() throws Exception { OperatorStateHandles snapshot2 = harness.snapshot(0L, 0L); harness.close(); - operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), - true, - null); - harness = getCepTestHarness(operator); + operator = CepOperatorTestUtilities.getKeyedCepOpearator(true, new SimpleNFAFactory()); + harness = CepOperatorTestUtilities.getCepTestHarness(operator); harness.setup(); harness.initializeState(snapshot2); @@ -383,14 +380,11 @@ public void testKeyedCEPOperatorNFAUpdateWithRocksDB() throws Exception { RocksDBStateBackend rocksDBStateBackend = new RocksDBStateBackend(new MemoryStateBackend()); rocksDBStateBackend.setDbStoragePath(rocksDbPath); - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( true, - null); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + new SimpleNFAFactory()); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness( + operator); try { harness.setStateBackend(rocksDBStateBackend); @@ -407,14 +401,8 @@ public void testKeyedCEPOperatorNFAUpdateWithRocksDB() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), - true, - null); - harness = getCepTestHarness(operator); + operator = CepOperatorTestUtilities.getKeyedCepOpearator(true, new SimpleNFAFactory()); + harness = CepOperatorTestUtilities.getCepTestHarness(operator); rocksDBStateBackend = new RocksDBStateBackend(new MemoryStateBackend()); rocksDBStateBackend.setDbStoragePath(rocksDbPath); @@ -427,14 +415,8 @@ public void testKeyedCEPOperatorNFAUpdateWithRocksDB() throws Exception { OperatorStateHandles snapshot2 = harness.snapshot(0L, 0L); harness.close(); - operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), - true, - null); - harness = getCepTestHarness(operator); + operator = CepOperatorTestUtilities.getKeyedCepOpearator(true, new SimpleNFAFactory()); + harness = CepOperatorTestUtilities.getCepTestHarness(operator); rocksDBStateBackend = new RocksDBStateBackend(new MemoryStateBackend()); rocksDBStateBackend.setDbStoragePath(rocksDbPath); @@ -460,14 +442,10 @@ public void testKeyedCEPOperatorNFAUpdateWithRocksDB() throws Exception { @Test public void testKeyedCEPOperatorNFAUpdateTimes() throws Exception { - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( true, - null); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + new SimpleNFAFactory()); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -506,14 +484,11 @@ public void testKeyedCEPOperatorNFAUpdateTimesWithRocksDB() throws Exception { RocksDBStateBackend rocksDBStateBackend = new RocksDBStateBackend(new MemoryStateBackend()); rocksDBStateBackend.setDbStoragePath(rocksDbPath); - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( true, - IntSerializer.INSTANCE, - new SimpleNFAFactory(), - true, - null); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + new SimpleNFAFactory()); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness( + operator); try { harness.setStateBackend(rocksDBStateBackend); @@ -560,8 +535,8 @@ public void testCEPOperatorCleanupEventTime() throws Exception { Event startEventK2 = new Event(43, "start", 1.0); - KeyedCEPPatternOperator operator = getKeyedCepOperator(false); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + SelectCepOperator>> operator = getKeyedCepOperator(false); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -605,8 +580,8 @@ public void testCEPOperatorCleanupEventTime() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - KeyedCEPPatternOperator operator2 = getKeyedCepOperator(false); - harness = getCepTestHarness(operator2); + SelectCepOperator>> operator2 = getKeyedCepOperator(false); + harness = CepOperatorTestUtilities.getCepTestHarness(operator2); harness.setup(); harness.initializeState(snapshot); harness.open(); @@ -656,14 +631,10 @@ public void testCEPOperatorCleanupEventTimeWithSameElements() throws Exception { Event middle1Event3 = new Event(41, "a", 4.0); Event middle2Event1 = new Event(41, "b", 5.0); - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new ComplexNFAFactory(), - true, - null); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( + false, + new ComplexNFAFactory()); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -755,8 +726,8 @@ public void testCEPOperatorCleanupProcessingTime() throws Exception { Event startEventK2 = new Event(43, "start", 1.0); - KeyedCEPPatternOperator operator = getKeyedCepOperator(true); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + SelectCepOperator>> operator = getKeyedCepOperator(true); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -783,8 +754,8 @@ public void testCEPOperatorCleanupProcessingTime() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - KeyedCEPPatternOperator operator2 = getKeyedCepOperator(true); - harness = getCepTestHarness(operator2); + SelectCepOperator>> operator2 = getKeyedCepOperator(true); + harness = CepOperatorTestUtilities.getCepTestHarness(operator2); harness.setup(); harness.initializeState(snapshot); harness.open(); @@ -875,23 +846,18 @@ public boolean filter(Event value) throws Exception { } }); - KeyedCEPPatternOperator operator = new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - false, - IntSerializer.INSTANCE, - new NFACompiler.NFAFactory() { - - private static final long serialVersionUID = 477082663248051994L; + SelectCepOperator>> operator = CepOperatorTestUtilities.getKeyedCepOpearator( + false, + new NFACompiler.NFAFactory() { + private static final long serialVersionUID = 477082663248051994L; - @Override - public NFA createNFA() { - return NFACompiler.compile(pattern, Event.createTypeSerializer(), false); - } - }, - true, - null); + @Override + public NFA createNFA() { + return NFACompiler.compile(pattern, Event.createTypeSerializer(), false); + } + }); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.setStateBackend(rocksDBStateBackend); @@ -951,8 +917,8 @@ public void testCEPOperatorComparatorProcessTime() throws Exception { Event startEventK2 = new Event(43, "start", 1.0); - KeyedCEPPatternOperator operator = getKeyedCepOperatorWithComparator(true); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + SelectCepOperator>> operator = getKeyedCepOperatorWithComparator(true); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -979,8 +945,8 @@ public void testCEPOperatorComparatorProcessTime() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - KeyedCEPPatternOperator operator2 = getKeyedCepOperatorWithComparator(true); - harness = getCepTestHarness(operator2); + SelectCepOperator>> operator2 = getKeyedCepOperatorWithComparator(true); + harness = CepOperatorTestUtilities.getCepTestHarness(operator2); harness.setup(); harness.initializeState(snapshot); harness.open(); @@ -1008,8 +974,8 @@ public void testCEPOperatorComparatorEventTime() throws Exception { Event startEventK2 = new Event(43, "start", 1.0); - KeyedCEPPatternOperator operator = getKeyedCepOperatorWithComparator(false); - OneInputStreamOperatorTestHarness>> harness = getCepTestHarness(operator); + SelectCepOperator>> operator = getKeyedCepOperatorWithComparator(false); + OneInputStreamOperatorTestHarness>> harness = CepOperatorTestUtilities.getCepTestHarness(operator); try { harness.open(); @@ -1040,8 +1006,8 @@ public void testCEPOperatorComparatorEventTime() throws Exception { OperatorStateHandles snapshot = harness.snapshot(0L, 0L); harness.close(); - KeyedCEPPatternOperator operator2 = getKeyedCepOperatorWithComparator(false); - harness = getCepTestHarness(operator2); + SelectCepOperator>> operator2 = getKeyedCepOperatorWithComparator(false); + harness = CepOperatorTestUtilities.getCepTestHarness(operator2); harness.setup(); harness.initializeState(snapshot); harness.open(); @@ -1077,52 +1043,20 @@ private void verifyPattern(Object outputObject, Event start, SubEvent middle, Ev assertEquals(end, patternMap.get("end").get(0)); } - private OneInputStreamOperatorTestHarness>> getCepTestHarness(boolean isProcessingTime) throws Exception { - KeySelector keySelector = new TestKeySelector(); - - return new KeyedOneInputStreamOperatorTestHarness<>( - getKeyedCepOperator(isProcessingTime), - keySelector, - BasicTypeInfo.INT_TYPE_INFO); - } - - private OneInputStreamOperatorTestHarness>> getCepTestHarness( - KeyedCEPPatternOperator cepOperator) throws Exception { - KeySelector keySelector = new TestKeySelector(); - - return new KeyedOneInputStreamOperatorTestHarness<>( - cepOperator, - keySelector, - BasicTypeInfo.INT_TYPE_INFO); - } - - private KeyedCEPPatternOperator getKeyedCepOperator( + private SelectCepOperator>> getKeyedCepOperator( boolean isProcessingTime) { - - return new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - isProcessingTime, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - null); + return CepOperatorTestUtilities.getKeyedCepOpearator(isProcessingTime, new NFAFactory()); } - private KeyedCEPPatternOperator getKeyedCepOperatorWithComparator( + private SelectCepOperator>> getKeyedCepOperatorWithComparator( boolean isProcessingTime) { - return new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), - isProcessingTime, - IntSerializer.INSTANCE, - new NFAFactory(), - true, - new org.apache.flink.cep.EventComparator() { - @Override - public int compare(Event o1, Event o2) { - return Double.compare(o1.getPrice(), o2.getPrice()); - } - }); + return CepOperatorTestUtilities.getKeyedCepOpearator(isProcessingTime, new NFAFactory(), new org.apache.flink.cep.EventComparator() { + @Override + public int compare(Event o1, Event o2) { + return Double.compare(o1.getPrice(), o2.getPrice()); + } + }); } private void compareMaps(List> actual, List> expected) { @@ -1180,14 +1114,12 @@ public int compare(Event o1, Event o2) { } } - private static class TestKeySelector implements KeySelector { - - private static final long serialVersionUID = -4873366487571254798L; + private OneInputStreamOperatorTestHarness>> getCepTestHarness(boolean isProcessingTime) throws Exception { + return CepOperatorTestUtilities.getCepTestHarness(getKeyedCepOpearator(isProcessingTime)); + } - @Override - public Integer getKey(Event value) throws Exception { - return value.getId(); - } + private SelectCepOperator>> getKeyedCepOpearator(boolean isProcessingTime) { + return CepOperatorTestUtilities.getKeyedCepOpearator(isProcessingTime, new CEPOperatorTest.NFAFactory()); } private static class NFAFactory implements NFACompiler.NFAFactory { diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java index 9fc43371d14ab..f5236c1715d97 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java @@ -18,7 +18,6 @@ package org.apache.flink.cep.operator; -import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.cep.Event; @@ -42,6 +41,7 @@ import java.util.Map; import java.util.Queue; +import static org.apache.flink.cep.operator.CepOperatorTestUtilities.getKeyedCepOpearator; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -372,13 +372,9 @@ private KeyedOneInputStreamOperatorTestHarness keySelector = new TestKeySelector(); return new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedCEPPatternOperator<>( - Event.createTypeSerializer(), + getKeyedCepOpearator( false, - BasicTypeInfo.INT_TYPE_INFO.createSerializer(new ExecutionConfig()), - new NFAFactory(), - true, - null), + new NFAFactory()), keySelector, BasicTypeInfo.INT_TYPE_INFO, maxParallelism, diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CepOperatorTestUtilities.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CepOperatorTestUtilities.java new file mode 100644 index 0000000000000..feb020ac7b45a --- /dev/null +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CepOperatorTestUtilities.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cep.operator; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.cep.Event; +import org.apache.flink.cep.EventComparator; +import org.apache.flink.cep.PatternSelectFunction; +import org.apache.flink.cep.nfa.compiler.NFACompiler; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; + +import java.util.List; +import java.util.Map; + +/** + * Utility methods for creating test {@link AbstractKeyedCEPPatternOperator}. + */ +public class CepOperatorTestUtilities { + + private static class TestKeySelector implements KeySelector { + + private static final long serialVersionUID = -4873366487571254798L; + + @Override + public Integer getKey(Event value) throws Exception { + return value.getId(); + } + } + + public static OneInputStreamOperatorTestHarness>> getCepTestHarness( + SelectCepOperator>> cepOperator) throws Exception { + KeySelector keySelector = new TestKeySelector(); + + return new KeyedOneInputStreamOperatorTestHarness<>( + cepOperator, + keySelector, + BasicTypeInfo.INT_TYPE_INFO); + } + + public static SelectCepOperator>> getKeyedCepOpearator( + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory) { + + return getKeyedCepOpearator(isProcessingTime, nfaFactory, null); + } + + public static SelectCepOperator>> getKeyedCepOpearator( + boolean isProcessingTime, + NFACompiler.NFAFactory nfaFactory, + EventComparator comparator) { + return new SelectCepOperator<>( + Event.createTypeSerializer(), + isProcessingTime, + nfaFactory, + comparator, + null, + new PatternSelectFunction>>() { + @Override + public Map> select(Map> pattern) throws Exception { + return pattern; + } + }); + } + + private CepOperatorTestUtilities() { + } +} diff --git a/flink-libraries/flink-table/pom.xml b/flink-libraries/flink-table/pom.xml index 0e943adf7679d..73629135148c6 100644 --- a/flink-libraries/flink-table/pom.xml +++ b/flink-libraries/flink-table/pom.xml @@ -218,8 +218,9 @@ under the License. org.apache.calcite:* org.apache.calcite.avatica:* net.hydromatic:* - org.reflections:* + org.reflections:* org.codehaus.janino:* + com.google.guava:guava diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala index 78667a2e9cac3..a9d60ddfbf820 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala @@ -153,28 +153,22 @@ abstract class BatchTableEnvironment( physicalTypeInfo: TypeInformation[IN], schema: RowSchema, requestedTypeInfo: TypeInformation[OUT], - functionName: String): - Option[MapFunction[IN, OUT]] = { - - if (requestedTypeInfo.getTypeClass == classOf[Row]) { - // Row to Row, no conversion needed - None - } else { - // some type that is neither Row or CRow - - val converterFunction = generateRowConverterFunction[OUT]( - physicalTypeInfo.asInstanceOf[TypeInformation[Row]], - schema, - requestedTypeInfo, - functionName - ) - - val mapFunction = new MapRunner[IN, OUT]( - converterFunction.name, - converterFunction.code, - converterFunction.returnType) + functionName: String) + : Option[MapFunction[IN, OUT]] = { + + val converterFunction = generateRowConverterFunction[OUT]( + physicalTypeInfo.asInstanceOf[TypeInformation[Row]], + schema, + requestedTypeInfo, + functionName + ) - Some(mapFunction) + // add a runner if we need conversion + converterFunction.map { func => + new MapRunner[IN, OUT]( + func.name, + func.code, + func.returnType) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala index 7328b2adc5f4f..8d8cebb0c471f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala @@ -23,35 +23,33 @@ import _root_.java.util.concurrent.atomic.AtomicInteger import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.plan.hep.HepMatchOrder -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.{RelNode, RelVisitor} -import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode} -import org.apache.calcite.sql.SqlKind +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField, RelDataTypeFieldImpl, RelRecordType} import org.apache.calcite.sql2rel.RelDecorrelator import org.apache.calcite.tools.{RuleSet, RuleSets} import org.apache.flink.api.common.functions.MapFunction -import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, TupleTypeInfo} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment -import org.apache.flink.table.calcite.RelTimeIndicatorConverter +import org.apache.flink.table.calcite.{FlinkTypeFactory, RelTimeIndicatorConverter} import org.apache.flink.table.explain.PlanJsonParser import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.datastream.{DataStreamRel, UpdateAsRetractionTrait, _} +import org.apache.flink.table.plan.nodes.datastream.{DataStreamRel, UpdateAsRetractionTrait} import org.apache.flink.table.plan.rules.FlinkRuleSets import org.apache.flink.table.plan.schema.{DataStreamTable, RowSchema, StreamTableSourceTable} import org.apache.flink.table.plan.util.UpdatingPlanChecker +import org.apache.flink.table.runtime.conversion._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.apache.flink.table.runtime.{CRowInputJavaTupleOutputMapRunner, CRowInputMapRunner, CRowInputScalaTupleOutputMapRunner} +import org.apache.flink.table.runtime.{CRowMapRunner, OutputRowtimeProcessFunction} import org.apache.flink.table.sinks.{AppendStreamTableSink, RetractStreamTableSink, TableSink, UpsertStreamTableSink} import org.apache.flink.table.sources.{DefinedRowtimeAttribute, StreamTableSource, TableSource} -import org.apache.flink.table.typeutils.TypeCheckUtils -import org.apache.flink.types.Row +import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} import _root_.scala.collection.JavaConverters._ @@ -225,38 +223,33 @@ abstract class StreamTableEnvironment( /** * Creates a final converter that maps the internal row type to external type. * - * @param physicalTypeInfo the input of the sink + * @param inputTypeInfo the input of the sink * @param schema the input schema with correct field names (esp. for POJO field mapping) * @param requestedTypeInfo the output type of the sink * @param functionName name of the map function. Must not be unique but has to be a * valid Java class identifier. */ - protected def getConversionMapper[IN, OUT]( - physicalTypeInfo: TypeInformation[IN], + protected def getConversionMapper[OUT]( + inputTypeInfo: TypeInformation[CRow], schema: RowSchema, requestedTypeInfo: TypeInformation[OUT], - functionName: String): - MapFunction[IN, OUT] = { - - if (requestedTypeInfo.getTypeClass == classOf[Row]) { - // CRow to Row, only needs to be unwrapped - new MapFunction[CRow, Row] { - override def map(value: CRow): Row = value.row - }.asInstanceOf[MapFunction[IN, OUT]] - } else { - // Some type that is neither CRow nor Row - val converterFunction = generateRowConverterFunction[OUT]( - physicalTypeInfo.asInstanceOf[CRowTypeInfo].rowType, - schema, - requestedTypeInfo, - functionName - ) + functionName: String) + : MapFunction[CRow, OUT] = { + + val converterFunction = generateRowConverterFunction[OUT]( + inputTypeInfo.asInstanceOf[CRowTypeInfo].rowType, + schema, + requestedTypeInfo, + functionName + ) + + converterFunction match { + + case Some(func) => + new CRowMapRunner[OUT](func.name, func.code, func.returnType) - new CRowInputMapRunner[OUT]( - converterFunction.name, - converterFunction.code, - converterFunction.returnType) - .asInstanceOf[MapFunction[IN, OUT]] + case _ => + new CRowToRowMapFunction().asInstanceOf[MapFunction[CRow, OUT]] } } @@ -270,74 +263,65 @@ abstract class StreamTableEnvironment( * valid Java class identifier. */ private def getConversionMapperWithChanges[OUT]( - physicalTypeInfo: TypeInformation[CRow], - schema: RowSchema, - requestedTypeInfo: TypeInformation[OUT], - functionName: String): - MapFunction[CRow, OUT] = { - - requestedTypeInfo match { - - // Scala tuple - case t: CaseClassTypeInfo[_] - if t.getTypeClass == classOf[(_, _)] && t.getTypeAt(0) == Types.BOOLEAN => - - val reqType = t.getTypeAt(1).asInstanceOf[TypeInformation[Any]] - if (reqType.getTypeClass == classOf[Row]) { - // Requested type is Row. Just rewrap CRow in Tuple2 - new MapFunction[CRow, (Boolean, Row)] { - override def map(cRow: CRow): (Boolean, Row) = { - (cRow.change, cRow.row) - } - }.asInstanceOf[MapFunction[CRow, OUT]] - } else { - // Use a map function to convert Row into requested type and wrap result in Tuple2 - val converterFunction = generateRowConverterFunction( - physicalTypeInfo.asInstanceOf[CRowTypeInfo].rowType, - schema, - reqType, - functionName - ) - - new CRowInputScalaTupleOutputMapRunner( - converterFunction.name, - converterFunction.code, - requestedTypeInfo.asInstanceOf[TypeInformation[(Boolean, Any)]]) - .asInstanceOf[MapFunction[CRow, OUT]] + physicalTypeInfo: TypeInformation[CRow], + schema: RowSchema, + requestedTypeInfo: TypeInformation[OUT], + functionName: String) + : MapFunction[CRow, OUT] = requestedTypeInfo match { - } + // Scala tuple + case t: CaseClassTypeInfo[_] + if t.getTypeClass == classOf[(_, _)] && t.getTypeAt(0) == Types.BOOLEAN => - // Java tuple - case t: TupleTypeInfo[_] - if t.getTypeClass == classOf[JTuple2[_, _]] && t.getTypeAt(0) == Types.BOOLEAN => - - val reqType = t.getTypeAt(1).asInstanceOf[TypeInformation[Any]] - if (reqType.getTypeClass == classOf[Row]) { - // Requested type is Row. Just rewrap CRow in Tuple2 - new MapFunction[CRow, JTuple2[JBool, Row]] { - val outT = new JTuple2(true.asInstanceOf[JBool], null.asInstanceOf[Row]) - override def map(cRow: CRow): JTuple2[JBool, Row] = { - outT.f0 = cRow.change - outT.f1 = cRow.row - outT - } - }.asInstanceOf[MapFunction[CRow, OUT]] - } else { - // Use a map function to convert Row into requested type and wrap result in Tuple2 - val converterFunction = generateRowConverterFunction( - physicalTypeInfo.asInstanceOf[CRowTypeInfo].rowType, - schema, - reqType, - functionName - ) - - new CRowInputJavaTupleOutputMapRunner( - converterFunction.name, - converterFunction.code, - requestedTypeInfo.asInstanceOf[TypeInformation[JTuple2[JBool, Any]]]) - .asInstanceOf[MapFunction[CRow, OUT]] - } - } + val reqType = t.getTypeAt[Any](1) + + // convert Row into requested type and wrap result in Tuple2 + val converterFunction = generateRowConverterFunction( + physicalTypeInfo.asInstanceOf[CRowTypeInfo].rowType, + schema, + reqType, + functionName + ) + + converterFunction match { + + case Some(func) => + new CRowToScalaTupleMapRunner( + func.name, + func.code, + requestedTypeInfo.asInstanceOf[TypeInformation[(Boolean, Any)]] + ).asInstanceOf[MapFunction[CRow, OUT]] + + case _ => + new CRowToScalaTupleMapFunction().asInstanceOf[MapFunction[CRow, OUT]] + } + + // Java tuple + case t: TupleTypeInfo[_] + if t.getTypeClass == classOf[JTuple2[_, _]] && t.getTypeAt(0) == Types.BOOLEAN => + + val reqType = t.getTypeAt[Any](1) + + // convert Row into requested type and wrap result in Tuple2 + val converterFunction = generateRowConverterFunction( + physicalTypeInfo.asInstanceOf[CRowTypeInfo].rowType, + schema, + reqType, + functionName + ) + + converterFunction match { + + case Some(func) => + new CRowToJavaTupleMapRunner( + func.name, + func.code, + requestedTypeInfo.asInstanceOf[TypeInformation[JTuple2[JBool, Any]]] + ).asInstanceOf[MapFunction[CRow, OUT]] + + case _ => + new CRowToJavaTupleMapFunction().asInstanceOf[MapFunction[CRow, OUT]] + } } /** @@ -356,9 +340,7 @@ abstract class StreamTableEnvironment( val dataStreamTable = new DataStreamTable[T]( dataStream, fieldIndexes, - fieldNames, - None, - None + fieldNames ) registerTableInternal(name, dataStreamTable) } @@ -393,12 +375,14 @@ abstract class StreamTableEnvironment( s"But is: ${execEnv.getStreamTimeCharacteristic}") } + // adjust field indexes and field names + val indexesWithIndicatorFields = adjustFieldIndexes(fieldIndexes, rowtime, proctime) + val namesWithIndicatorFields = adjustFieldNames(fieldNames, rowtime, proctime) + val dataStreamTable = new DataStreamTable[T]( dataStream, - fieldIndexes, - fieldNames, - rowtime, - proctime + indexesWithIndicatorFields, + namesWithIndicatorFields ) registerTableInternal(name, dataStreamTable) } @@ -501,6 +485,63 @@ abstract class StreamTableEnvironment( (rowtime, proctime) } + /** + * Injects markers for time indicator fields into the field indexes. + * + * @param fieldIndexes The field indexes into which the time indicators markers are injected. + * @param rowtime An optional rowtime indicator + * @param proctime An optional proctime indicator + * @return An adjusted array of field indexes. + */ + private def adjustFieldIndexes( + fieldIndexes: Array[Int], + rowtime: Option[(Int, String)], + proctime: Option[(Int, String)]): Array[Int] = { + + // inject rowtime field + val withRowtime = rowtime match { + case Some(rt) => fieldIndexes.patch(rt._1, Seq(TimeIndicatorTypeInfo.ROWTIME_MARKER), 0) + case _ => fieldIndexes + } + + // inject proctime field + val withProctime = proctime match { + case Some(pt) => withRowtime.patch(pt._1, Seq(TimeIndicatorTypeInfo.PROCTIME_MARKER), 0) + case _ => withRowtime + } + + withProctime + } + + /** + * Injects names of time indicator fields into the list of field names. + * + * @param fieldNames The array of field names into which the time indicator field names are + * injected. + * @param rowtime An optional rowtime indicator + * @param proctime An optional proctime indicator + * @return An adjusted array of field names. + */ + private def adjustFieldNames( + fieldNames: Array[String], + rowtime: Option[(Int, String)], + proctime: Option[(Int, String)]): Array[String] = { + + // inject rowtime field + val withRowtime = rowtime match { + case Some(rt) => fieldNames.patch(rt._1, Seq(rowtime.get._2), 0) + case _ => fieldNames + } + + // inject proctime field + val withProctime = proctime match { + case Some(pt) => withRowtime.patch(pt._1, Seq(proctime.get._2), 0) + case _ => withRowtime + } + + withProctime + } + /** * Returns the decoration rule set for this environment * including a custom RuleSet configuration. @@ -632,10 +673,21 @@ abstract class StreamTableEnvironment( val relNode = table.getRelNode val dataStreamPlan = optimize(relNode, updatesAsRetraction) - // we convert the logical row type to the output row type - val convertedOutputType = RelTimeIndicatorConverter.convertOutputType(relNode) - - translate(dataStreamPlan, convertedOutputType, queryConfig, withChangeFlag) + // zip original field names with optimized field types + val fieldTypes = relNode.getRowType.getFieldList.asScala + .zip(dataStreamPlan.getRowType.getFieldList.asScala) + // get name of original plan and type of optimized plan + .map(x => (x._1.getName, x._2.getType)) + // add field indexes + .zipWithIndex + // build new field types + .map(x => new RelDataTypeFieldImpl(x._1._1, x._2, x._1._2)) + + // build a record type from list of field types + val rowType = new RelRecordType( + fieldTypes.toList.asInstanceOf[List[RelDataTypeField]].asJava) + + translate(dataStreamPlan, rowType, queryConfig, withChangeFlag) } /** @@ -667,16 +719,42 @@ abstract class StreamTableEnvironment( // get CRow plan val plan: DataStream[CRow] = translateToCRow(logicalPlan, queryConfig) + val rowtimeFields = logicalType + .getFieldList.asScala + .filter(f => FlinkTypeFactory.isRowtimeIndicatorType(f.getType)) + + // convert the input type for the conversion mapper + // the input will be changed in the OutputRowtimeProcessFunction later + val convType = if (rowtimeFields.size > 1) { + throw new TableException( + s"Found more than one rowtime field: [${rowtimeFields.map(_.getName).mkString(", ")}] in " + + s"the table that should be converted to a DataStream.\n" + + s"Please select the rowtime field that should be used as event-time timestamp for the " + + s"DataStream by casting all other fields to TIMESTAMP.") + } else if (rowtimeFields.size == 1) { + val origRowType = plan.getType.asInstanceOf[CRowTypeInfo].rowType + val convFieldTypes = origRowType.getFieldTypes.map { t => + if (FlinkTypeFactory.isRowtimeIndicatorType(t)) { + SqlTimeTypeInfo.TIMESTAMP + } else { + t + } + } + CRowTypeInfo(new RowTypeInfo(convFieldTypes, origRowType.getFieldNames)) + } else { + plan.getType + } + // convert CRow to output type - val conversion = if (withChangeFlag) { + val conversion: MapFunction[CRow, A] = if (withChangeFlag) { getConversionMapperWithChanges( - plan.getType, + convType, new RowSchema(logicalType), tpe, "DataStreamSinkConversion") } else { getConversionMapper( - plan.getType, + convType, new RowSchema(logicalType), tpe, "DataStreamSinkConversion") @@ -684,13 +762,19 @@ abstract class StreamTableEnvironment( val rootParallelism = plan.getParallelism - conversion match { - case mapFunction: MapFunction[CRow, A] => - plan.map(mapFunction) - .returns(tpe) - .name(s"to: ${tpe.getTypeClass.getSimpleName}") - .setParallelism(rootParallelism) + val withRowtime = if (rowtimeFields.isEmpty) { + // no rowtime field to set + plan.map(conversion) + } else { + // set the only rowtime field as event-time timestamp for DataStream + // and convert it to SQL timestamp + plan.process(new OutputRowtimeProcessFunction[A](conversion, rowtimeFields.head.getIndex)) } + + withRowtime + .returns(tpe) + .name(s"to: ${tpe.getTypeClass.getSimpleName}") + .setParallelism(rootParallelism) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala index 3bca1568ae2bb..2e9e18f91a328 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala @@ -38,7 +38,7 @@ import org.apache.calcite.tools._ import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.typeutils._ +import org.apache.flink.api.java.typeutils.{RowTypeInfo, _} import org.apache.flink.api.java.{ExecutionEnvironment => JavaBatchExecEnv} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv} @@ -48,26 +48,23 @@ import org.apache.flink.table.api.java.{BatchTableEnvironment => JavaBatchTableE import org.apache.flink.table.api.scala.{BatchTableEnvironment => ScalaBatchTableEnv, StreamTableEnvironment => ScalaStreamTableEnv} import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkTypeFactory, FlinkTypeSystem} import org.apache.flink.table.catalog.{ExternalCatalog, ExternalCatalogSchema} -import org.apache.flink.table.codegen.{FunctionCodeGenerator, ExpressionReducer, GeneratedFunction} -import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference} -import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ -import org.apache.flink.table.functions.AggregateFunction -import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions} -import org.apache.flink.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.table.codegen.{ExpressionReducer, FunctionCodeGenerator, GeneratedFunction} +import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference, _} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions, _} +import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction} import org.apache.flink.table.plan.cost.DataSetCostFactory import org.apache.flink.table.plan.logical.{CatalogNode, LogicalRelNode} import org.apache.flink.table.plan.rules.FlinkRuleSets import org.apache.flink.table.plan.schema.{RelTable, RowSchema} import org.apache.flink.table.sinks.TableSink import org.apache.flink.table.sources.{DefinedFieldNames, TableSource} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.table.validate.FunctionCatalog import org.apache.flink.types.Row -import org.apache.flink.api.java.typeutils.RowTypeInfo -import _root_.scala.collection.JavaConverters._ -import _root_.scala.collection.mutable.HashMap import _root_.scala.annotation.varargs +import _root_.scala.collection.JavaConverters._ +import _root_.scala.collection.mutable /** * The abstract base class for batch and stream TableEnvironments. @@ -108,10 +105,10 @@ abstract class TableEnvironment(val config: TableConfig) { private[flink] val attrNameCntr: AtomicInteger = new AtomicInteger(0) // registered external catalog names -> catalog - private val externalCatalogs = new HashMap[String, ExternalCatalog] + private val externalCatalogs = new mutable.HashMap[String, ExternalCatalog] /** Returns the table config to define the runtime behavior of the Table API. */ - def getConfig = config + def getConfig: TableConfig = config /** * Returns the operator table for this environment including a custom Calcite configuration. @@ -692,7 +689,7 @@ abstract class TableEnvironment(val config: TableConfig) { case _ => throw new TableException( "Field reference expression or alias on field expression expected.") } - case r: RowTypeInfo => { + case r: RowTypeInfo => exprs.zipWithIndex flatMap { case (UnresolvedFieldReference(name), idx) => Some((idx, name)) @@ -707,8 +704,7 @@ abstract class TableEnvironment(val config: TableConfig) { case _ => throw new TableException( "Field reference expression or alias on field expression expected.") } - - } + case tpe => throw new TableException( s"Source of type $tpe cannot be converted into Table.") } @@ -719,33 +715,47 @@ abstract class TableEnvironment(val config: TableConfig) { throw new TableException("Field name can not be '*'.") } - (fieldNames.toArray, fieldIndexes.toArray) + (fieldNames.toArray, fieldIndexes.toArray) // build fails in Scala 2.10 if not converted } protected def generateRowConverterFunction[OUT]( inputTypeInfo: TypeInformation[Row], schema: RowSchema, requestedTypeInfo: TypeInformation[OUT], - functionName: String): - GeneratedFunction[MapFunction[Row, OUT], OUT] = { + functionName: String) + : Option[GeneratedFunction[MapFunction[Row, OUT], OUT]] = { // validate that at least the field types of physical and logical type match // we do that here to make sure that plan translation was correct - if (schema.physicalTypeInfo != inputTypeInfo) { + if (schema.typeInfo != inputTypeInfo) { throw TableException( s"The field types of physical and logical row types do not match. " + - s"Physical type is [${schema.physicalTypeInfo}], Logical type is [${inputTypeInfo}]. " + + s"Physical type is [${schema.typeInfo}], Logical type is [$inputTypeInfo]. " + s"This is a bug and should not happen. Please file an issue.") } - val fieldTypes = schema.physicalFieldTypeInfo - val fieldNames = schema.physicalFieldNames + // generic row needs no conversion + if (requestedTypeInfo.isInstanceOf[GenericTypeInfo[_]] && + requestedTypeInfo.getTypeClass == classOf[Row]) { + return None + } + + val fieldTypes = schema.fieldTypeInfos + val fieldNames = schema.fieldNames - // validate requested type + // check for valid type info if (requestedTypeInfo.getArity != fieldTypes.length) { throw new TableException( - s"Arity[${fieldTypes.length}] of result[${fieldTypes}] does not match " + - s"the number[${requestedTypeInfo.getArity}] of requested type[${requestedTypeInfo}].") + s"Arity [${fieldTypes.length}] of result [$fieldTypes] does not match " + + s"the number[${requestedTypeInfo.getArity}] of requested type [$requestedTypeInfo].") + } + + // check requested types + + def validateFieldType(fieldType: TypeInformation[_]): Unit = fieldType match { + case _: TimeIndicatorTypeInfo => + throw new TableException("The time indicator type is an internal type only.") + case _ => // ok } requestedTypeInfo match { @@ -758,9 +768,10 @@ abstract class TableEnvironment(val config: TableConfig) { throw new TableException(s"POJO does not define field name: $fName") } val requestedTypeInfo = pt.getTypeAt(pojoIdx) + validateFieldType(requestedTypeInfo) if (fType != requestedTypeInfo) { throw new TableException(s"Result field does not match requested type. " + - s"requested: $requestedTypeInfo; Actual: $fType") + s"Requested: $requestedTypeInfo; Actual: $fType") } } @@ -769,6 +780,7 @@ abstract class TableEnvironment(val config: TableConfig) { fieldTypes.zipWithIndex foreach { case (fieldTypeInfo, i) => val requestedTypeInfo = tt.getTypeAt(i) + validateFieldType(requestedTypeInfo) if (fieldTypeInfo != requestedTypeInfo) { throw new TableException(s"Result field does not match requested type. " + s"Requested: $requestedTypeInfo; Actual: $fieldTypeInfo") @@ -781,10 +793,11 @@ abstract class TableEnvironment(val config: TableConfig) { throw new TableException(s"Requested result type is an atomic type but " + s"result[$fieldTypes] has more or less than a single field.") } - val fieldTypeInfo = fieldTypes.head - if (fieldTypeInfo != at) { + val requestedTypeInfo = fieldTypes.head + validateFieldType(requestedTypeInfo) + if (requestedTypeInfo != at) { throw new TableException(s"Result field does not match requested type. " + - s"Requested: $at; Actual: $fieldTypeInfo") + s"Requested: $at; Actual: $requestedTypeInfo") } case _ => @@ -809,11 +822,13 @@ abstract class TableEnvironment(val config: TableConfig) { |return ${conversion.resultTerm}; |""".stripMargin - generator.generateFunction( + val generated = generator.generateFunction( functionName, classOf[MapFunction[Row, OUT]], body, requestedTypeInfo) + + Some(generated) } } @@ -972,7 +987,7 @@ object TableEnvironment { validateType(inputType) inputType match { - case t: CompositeType[_] => 0.until(t.getArity).map(t.getTypeAt(_)).toArray + case t: CompositeType[_] => 0.until(t.getArity).map(i => t.getTypeAt(i)).toArray case a: AtomicType[_] => Array(a.asInstanceOf[TypeInformation[_]]) case tpe => throw new TableException(s"Currently only CompositeType and AtomicType are supported.") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableSchema.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableSchema.scala index a67a07af106f0..6ee65f0a5c6da 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableSchema.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableSchema.scala @@ -28,13 +28,24 @@ class TableSchema( if (columnNames.length != columnTypes.length) { throw new TableException( - "Number of column indexes and column names must be equal.") + s"Number of field names and field types must be equal.\n" + + s"Number of names is ${columnNames.length}, number of types is ${columnTypes.length}.\n" + + s"List of field names: ${columnNames.mkString("[", ", ", "]")}.\n" + + s"List of field types: ${columnTypes.mkString("[", ", ", "]")}.") } // check uniqueness of field names if (columnNames.toSet.size != columnTypes.length) { + val duplicateFields = columnNames + // count occurences of field names + .groupBy(identity).mapValues(_.length) + // filter for occurences > 1 and map to field name + .filter(g => g._2 > 1).keys + throw new TableException( - "Table column names must be unique.") + s"Field names must be unique.\n" + + s"List of duplicate fields: ${duplicateFields.mkString("[", ", ", "]")}.\n" + + s"List of all fields: ${columnNames.mkString("[", ", ", "]")}.") } val columnNameToIndex: Map[String, Int] = columnNames.zipWithIndex.toMap diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala new file mode 100644 index 0000000000000..2214086e23fa5 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.dataview + +/** + * A [[DataView]] is a collection type that can be used in the accumulator of an + * [[org.apache.flink.table.functions.AggregateFunction]]. + * + * Depending on the context in which the [[org.apache.flink.table.functions.AggregateFunction]] is + * used, a [[DataView]] can be backed by a Java heap collection or a state backend. + */ +trait DataView extends Serializable { + + /** + * Clears the [[DataView]] and removes all data. + */ + def clear(): Unit + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala new file mode 100644 index 0000000000000..943fe033d7016 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.dataview + +import java.lang.reflect.Field + +import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor} +import org.apache.flink.table.dataview.{ListViewTypeInfo, MapViewTypeInfo} + +/** + * Data view specification. + * + * @tparam ACC type extends [[DataView]] + */ +trait DataViewSpec[ACC <: DataView] { + def stateId: String + def field: Field + def toStateDescriptor: StateDescriptor[_ <: State, _] +} + +case class ListViewSpec[T]( + stateId: String, + field: Field, + listViewTypeInfo: ListViewTypeInfo[T]) + extends DataViewSpec[ListView[T]] { + + override def toStateDescriptor: StateDescriptor[_ <: State, _] = + new ListStateDescriptor[T](stateId, listViewTypeInfo.elementType) +} + +case class MapViewSpec[K, V]( + stateId: String, + field: Field, + mapViewTypeInfo: MapViewTypeInfo[K, V]) + extends DataViewSpec[MapView[K, V]] { + + override def toStateDescriptor: StateDescriptor[_ <: State, _] = + new MapStateDescriptor[K, V](stateId, mapViewTypeInfo.keyType, mapViewTypeInfo.valueType) +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala new file mode 100644 index 0000000000000..59b2426db6b11 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.dataview + +import java.lang.{Iterable => JIterable} +import java.util + +import org.apache.flink.api.common.typeinfo.{TypeInfo, TypeInformation} +import org.apache.flink.table.dataview.ListViewTypeInfoFactory + +/** + * A [[ListView]] provides List functionality for accumulators used by user-defined aggregate + * functions [[org.apache.flink.api.common.functions.AggregateFunction]]. + * + * A [[ListView]] can be backed by a Java ArrayList or a state backend, depending on the context in + * which the aggregate function is used. + * + * At runtime [[ListView]] will be replaced by a [[org.apache.flink.table.dataview.StateListView]] + * if it is backed by a state backend. + * + * Example of an accumulator type with a [[ListView]] and an aggregate function that uses it: + * {{{ + * + * public class MyAccum { + * public ListView list; + * public long count; + * } + * + * public class MyAgg extends AggregateFunction { + * + * @Override + * public MyAccum createAccumulator() { + * MyAccum accum = new MyAccum(); + * accum.list = new ListView<>(Types.STRING); + * accum.count = 0L; + * return accum; + * } + * + * public void accumulate(MyAccum accumulator, String id) { + * accumulator.list.add(id); + * ... ... + * accumulator.get() + * ... ... + * } + * + * @Override + * public Long getValue(MyAccum accumulator) { + * accumulator.list.add(id); + * ... ... + * accumulator.get() + * ... ... + * return accumulator.count; + * } + * } + * + * }}} + * + * @param elementTypeInfo element type information + * @tparam T element type + */ +@TypeInfo(classOf[ListViewTypeInfoFactory[_]]) +class ListView[T]( + @transient private[flink] val elementTypeInfo: TypeInformation[T], + private[flink] val list: util.List[T]) + extends DataView { + + /** + * Creates a list view for elements of the specified type. + * + * @param elementTypeInfo The type of the list view elements. + */ + def this(elementTypeInfo: TypeInformation[T]) { + this(elementTypeInfo, new util.ArrayList[T]()) + } + + /** + * Creates a list view. + */ + def this() = this(null) + + /** + * Returns an iterable of the list view. + * + * @throws Exception Thrown if the system cannot get data. + * @return The iterable of the list or { @code null} if the list is empty. + */ + @throws[Exception] + def get: JIterable[T] = { + if (!list.isEmpty) { + list + } else { + null + } + } + + /** + * Adds the given value to the list. + * + * @throws Exception Thrown if the system cannot add data. + * @param value The element to be appended to this list view. + */ + @throws[Exception] + def add(value: T): Unit = list.add(value) + + /** + * Adds all of the elements of the specified list to this list view. + * + * @throws Exception Thrown if the system cannot add all data. + * @param list The list with the elements that will be stored in this list view. + */ + @throws[Exception] + def addAll(list: util.List[T]): Unit = this.list.addAll(list) + + /** + * Removes all of the elements from this list view. + */ + override def clear(): Unit = list.clear() + + override def equals(other: Any): Boolean = other match { + case that: ListView[T] => + list.equals(that.list) + case _ => false + } + + override def hashCode(): Int = list.hashCode() +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala new file mode 100644 index 0000000000000..9206d6af3e4a1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.dataview + +import java.lang.{Iterable => JIterable} +import java.util + +import org.apache.flink.api.common.typeinfo.{TypeInfo, TypeInformation} +import org.apache.flink.table.dataview.MapViewTypeInfoFactory + +/** + * A [[MapView]] provides Map functionality for accumulators used by user-defined aggregate + * functions [[org.apache.flink.table.functions.AggregateFunction]]. + * + * A [[MapView]] can be backed by a Java HashMap or a state backend, depending on the context in + * which the aggregation function is used. + * + * At runtime [[MapView]] will be replaced by a [[org.apache.flink.table.dataview.StateMapView]] + * if it is backed by a state backend. + * + * Example of an accumulator type with a [[MapView]] and an aggregate function that uses it: + * {{{ + * + * public class MyAccum { + * public MapView map; + * public long count; + * } + * + * public class MyAgg extends AggregateFunction { + * + * @Override + * public MyAccum createAccumulator() { + * MyAccum accum = new MyAccum(); + * accum.map = new MapView<>(Types.STRING, Types.INT); + * accum.count = 0L; + * return accum; + * } + * + * public void accumulate(MyAccum accumulator, String id) { + * try { + * if (!accumulator.map.contains(id)) { + * accumulator.map.put(id, 1); + * accumulator.count++; + * } + * } catch (Exception e) { + * e.printStackTrace(); + * } + * } + * + * @Override + * public Long getValue(MyAccum accumulator) { + * return accumulator.count; + * } + * } + * + * }}} + * + * @param keyTypeInfo key type information + * @param valueTypeInfo value type information + * @tparam K key type + * @tparam V value type + */ +@TypeInfo(classOf[MapViewTypeInfoFactory[_, _]]) +class MapView[K, V]( + @transient private[flink] val keyTypeInfo: TypeInformation[K], + @transient private[flink] val valueTypeInfo: TypeInformation[V], + private[flink] val map: util.Map[K, V]) + extends DataView { + + /** + * Creates a MapView with the specified key and value types. + * + * @param keyTypeInfo The type of keys of the MapView. + * @param valueTypeInfo The type of the values of the MapView. + */ + def this(keyTypeInfo: TypeInformation[K], valueTypeInfo: TypeInformation[V]) { + this(keyTypeInfo, valueTypeInfo, new util.HashMap[K, V]()) + } + + /** + * Creates a MapView. + */ + def this() = this(null, null) + + /** + * Return the value for the specified key or { @code null } if the key is not in the map view. + * + * @param key The look up key. + * @return The value for the specified key. + * @throws Exception Thrown if the system cannot get data. + */ + @throws[Exception] + def get(key: K): V = map.get(key) + + /** + * Inserts a value for the given key into the map view. + * If the map view already contains a value for the key, the existing value is overwritten. + * + * @param key The key for which the value is inserted. + * @param value The value that is inserted for the key. + * @throws Exception Thrown if the system cannot put data. + */ + @throws[Exception] + def put(key: K, value: V): Unit = map.put(key, value) + + /** + * Inserts all mappings from the specified map to this map view. + * + * @param map The map whose entries are inserted into this map view. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def putAll(map: util.Map[K, V]): Unit = this.map.putAll(map) + + /** + * Deletes the value for the given key. + * + * @param key The key for which the value is deleted. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def remove(key: K): Unit = map.remove(key) + + /** + * Checks if the map view contains a value for a given key. + * + * @param key The key to check. + * @return True if there exists a value for the given key, false otherwise. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def contains(key: K): Boolean = map.containsKey(key) + + /** + * Returns all entries of the map view. + * + * @return An iterable of all the key-value pairs in the map view. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def entries: JIterable[util.Map.Entry[K, V]] = map.entrySet() + + /** + * Returns all the keys in the map view. + * + * @return An iterable of all the keys in the map. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def keys: JIterable[K] = map.keySet() + + /** + * Returns all the values in the map view. + * + * @return An iterable of all the values in the map. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def values: JIterable[V] = map.values() + + /** + * Returns an iterator over all entries of the map view. + * + * @return An iterator over all the mappings in the map. + * @throws Exception Thrown if the system cannot access the map. + */ + @throws[Exception] + def iterator: util.Iterator[util.Map.Entry[K, V]] = map.entrySet().iterator() + + /** + * Removes all entries of this map. + */ + override def clear(): Unit = map.clear() + + override def equals(other: Any): Boolean = other match { + case that: MapView[K, V] => + map.equals(that.map) + case _ => false + } + + override def hashCode(): Int = map.hashCode() +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index dbefe203e9601..637e8cc0fba19 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -172,45 +172,20 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp * * @param fieldNames field names * @param fieldTypes field types, every element is Flink's [[TypeInformation]] - * @param rowtime optional system field to indicate event-time; the index determines the index - * in the final record. If the index is smaller than the number of specified - * fields, it shifts all following fields. - * @param proctime optional system field to indicate processing-time; the index determines the - * index in the final record. If the index is smaller than the number of - * specified fields, it shifts all following fields. * @return a struct type with the input fieldNames, input fieldTypes, and system fields */ def buildLogicalRowType( fieldNames: Seq[String], - fieldTypes: Seq[TypeInformation[_]], - rowtime: Option[(Int, String)], - proctime: Option[(Int, String)]) + fieldTypes: Seq[TypeInformation[_]]) : RelDataType = { val logicalRowTypeBuilder = builder val fields = fieldNames.zip(fieldTypes) - - var totalNumberOfFields = fields.length - if (rowtime.isDefined) { - totalNumberOfFields += 1 - } - if (proctime.isDefined) { - totalNumberOfFields += 1 - } - - var addedTimeAttributes = 0 - for (i <- 0 until totalNumberOfFields) { - if (rowtime.isDefined && rowtime.get._1 == i) { - logicalRowTypeBuilder.add(rowtime.get._2, createRowtimeIndicatorType()) - addedTimeAttributes += 1 - } else if (proctime.isDefined && proctime.get._1 == i) { - logicalRowTypeBuilder.add(proctime.get._2, createProctimeIndicatorType()) - addedTimeAttributes += 1 - } else { - val field = fields(i - addedTimeAttributes) - logicalRowTypeBuilder.add(field._1, createTypeFromTypeInfo(field._2, isNullable = true)) - } - } + fields.foreach(f => { + // time indicators are not nullable + val nullable = !FlinkTypeFactory.isTimeIndicatorType(f._2) + logicalRowTypeBuilder.add(f._1, createTypeFromTypeInfo(f._2, nullable)) + }) logicalRowTypeBuilder.build } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala index eb1429158a73d..1f88737247aeb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.calcite -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl, RelRecordType} +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core._ import org.apache.calcite.rel.logical._ import org.apache.calcite.rel.{RelNode, RelShuttle} @@ -26,8 +26,8 @@ import org.apache.calcite.rex._ import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo import org.apache.flink.table.api.{TableException, ValidationException} -import org.apache.flink.table.calcite.FlinkTypeFactory.isTimeIndicatorType -import org.apache.flink.table.functions.TimeMaterializationSqlFunction +import org.apache.flink.table.calcite.FlinkTypeFactory.{isRowtimeIndicatorType, _} +import org.apache.flink.table.functions.sql.ProctimeSqlFunction import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType @@ -242,9 +242,13 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { case lp: LogicalProject => val projects = lp.getProjects.zipWithIndex.map { case (expr, idx) => if (isTimeIndicatorType(expr.getType) && refIndices.contains(idx)) { - rexBuilder.makeCall( - TimeMaterializationSqlFunction, - expr) + if (isRowtimeIndicatorType(expr.getType)) { + // cast rowtime indicator to regular timestamp + rexBuilder.makeAbstractCast(timestamp, expr) + } else { + // generate proctime access + rexBuilder.makeCall(ProctimeSqlFunction, expr) + } } else { expr } @@ -259,9 +263,17 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle { case _ => val projects = input.getRowType.getFieldList.map { field => if (isTimeIndicatorType(field.getType) && refIndices.contains(field.getIndex)) { - rexBuilder.makeCall( - TimeMaterializationSqlFunction, - new RexInputRef(field.getIndex, field.getType)) + if (isRowtimeIndicatorType(field.getType)) { + // cast rowtime indicator to regular timestamp + rexBuilder.makeAbstractCast( + timestamp, + new RexInputRef(field.getIndex, field.getType)) + } else { + // generate proctime access + rexBuilder.makeCall( + ProctimeSqlFunction, + new RexInputRef(field.getIndex, field.getType)) + } } else { new RexInputRef(field.getIndex, field.getType) } @@ -311,19 +323,19 @@ object RelTimeIndicatorConverter { var needsConversion = false - // materialize all remaining time indicators + // materialize remaining proctime indicators val projects = convertedRoot.getRowType.getFieldList.map(field => - if (isTimeIndicatorType(field.getType)) { + if (isProctimeIndicatorType(field.getType)) { needsConversion = true rexBuilder.makeCall( - TimeMaterializationSqlFunction, + ProctimeSqlFunction, new RexInputRef(field.getIndex, field.getType)) } else { new RexInputRef(field.getIndex, field.getType) } ) - // add final conversion + // add final conversion if necessary if (needsConversion) { LogicalProject.create( convertedRoot, @@ -334,27 +346,6 @@ object RelTimeIndicatorConverter { } } - def convertOutputType(rootRel: RelNode): RelDataType = { - - val timestamp = rootRel - .getCluster - .getRexBuilder - .getTypeFactory - .asInstanceOf[FlinkTypeFactory] - .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP, isNullable = false) - - // convert all time indicators types to timestamps - val fields = rootRel.getRowType.getFieldList.map { field => - if (isTimeIndicatorType(field.getType)) { - new RelDataTypeFieldImpl(field.getName, field.getIndex, timestamp) - } else { - field - } - } - - new RelRecordType(fields) - } - /** * Materializes time indicator accesses in an expression. * @@ -415,7 +406,13 @@ class RexTimeIndicatorMaterializer( case _ => updatedCall.getOperands.map { o => if (isTimeIndicatorType(o.getType)) { - rexBuilder.makeCall(TimeMaterializationSqlFunction, o) + if (isRowtimeIndicatorType(o.getType)) { + // cast rowtime indicator to regular timestamp + rexBuilder.makeAbstractCast(timestamp, o) + } else { + // generate proctime access + rexBuilder.makeCall(ProctimeSqlFunction, o) + } } else { o } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala index 197449ca5e6ff..c74066f022e20 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala @@ -18,12 +18,12 @@ package org.apache.flink.table.catalog -import java.util.{Collections => JCollections, Collection => JCollection, LinkedHashSet => JLinkedHashSet, Set => JSet} +import java.util.{Collection => JCollection, Collections => JCollections, LinkedHashSet => JLinkedHashSet, Set => JSet} import org.apache.calcite.linq4j.tree.Expression import org.apache.calcite.schema._ import org.apache.flink.table.api.{CatalogNotExistException, TableNotExistException} -import org.slf4j.{Logger, LoggerFactory} +import org.apache.flink.table.util.Logging import scala.collection.JavaConverters._ @@ -38,9 +38,7 @@ import scala.collection.JavaConverters._ */ class ExternalCatalogSchema( catalogIdentifier: String, - catalog: ExternalCatalog) extends Schema { - - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + catalog: ExternalCatalog) extends Schema with Logging { /** * Looks up a sub-schema by the given sub-schema name in the external catalog. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalTableSourceUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalTableSourceUtil.scala index ccc2e9ebd4d8c..6bacac1fd50cb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalTableSourceUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/catalog/ExternalTableSourceUtil.scala @@ -27,9 +27,9 @@ import org.apache.flink.table.api.{AmbiguousTableSourceConverterException, NoMat import org.apache.flink.table.plan.schema.{StreamTableSourceTable, TableSourceTable} import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.sources.{StreamTableSource, TableSource} +import org.apache.flink.table.util.Logging import org.apache.flink.util.InstantiationUtil import org.reflections.Reflections -import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -37,13 +37,11 @@ import scala.collection.mutable /** * The utility class is used to convert ExternalCatalogTable to TableSourceTable. */ -object ExternalTableSourceUtil { +object ExternalTableSourceUtil extends Logging { // config file to specify scan package to search TableSourceConverter private val tableSourceConverterConfigFileName = "tableSourceConverter.properties" - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - // registered table type with the TableSourceConverter. // Key is table type name, Value is set of converter class. private val tableTypeToTableSourceConvertersClazz = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala index 680eb44efc056..22ce5ba4a317f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala @@ -17,16 +17,23 @@ */ package org.apache.flink.table.codegen -import java.lang.reflect.ParameterizedType +import java.lang.reflect.{Modifier, ParameterizedType} import java.lang.{Iterable => JIterable} +import org.apache.commons.codec.binary.Base64 +import org.apache.flink.api.common.state.{State, StateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.api.dataview._ import org.apache.flink.table.codegen.Indenter.toISC -import org.apache.flink.table.codegen.CodeGenUtils.newName +import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess} import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString} import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleElementIterable} +import org.apache.flink.util.InstantiationUtil + +import scala.collection.mutable /** * A code generator for generating [[GeneratedAggregations]]. @@ -41,13 +48,24 @@ class AggregationCodeGenerator( input: TypeInformation[_ <: Any]) extends CodeGenerator(config, nullableInput, input) { + // set of statements for cleanup dataview that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableCleanupStatements = mutable.LinkedHashSet[String]() + + /** + * @return code block of statements that need to be placed in the cleanup() method of + * [[GeneratedAggregations]] + */ + def reuseCleanupCode(): String = { + reusableCleanupStatements.mkString("", "\n", "\n") + } + /** * Generates a [[org.apache.flink.table.runtime.aggregate.GeneratedAggregations]] that can be * passed to a Java compiler. * * @param name Class name of the function. * Does not need to be unique but has to be a valid Java class identifier. - * @param generator The code generator instance * @param physicalInputTypes Physical input row types * @param aggregates All aggregate functions * @param aggFields Indexes of the input fields for all aggregate functions @@ -68,7 +86,6 @@ class AggregationCodeGenerator( */ def generateAggregations( name: String, - generator: CodeGenerator, physicalInputTypes: Seq[TypeInformation[_]], aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]], aggFields: Array[Array[Int]], @@ -80,34 +97,40 @@ class AggregationCodeGenerator( outputArity: Int, needRetract: Boolean, needMerge: Boolean, - needReset: Boolean) + needReset: Boolean, + accConfig: Option[Array[Seq[DataViewSpec[_]]]]) : GeneratedAggregationsFunction = { // get unique function name val funcName = newName(name) // register UDAGGs - val aggs = aggregates.map(a => generator.addReusableFunction(a)) + val aggs = aggregates.map(a => addReusableFunction(a, contextTerm)) + // get java types of accumulators val accTypeClasses = aggregates.map { a => a.getClass.getMethod("createAccumulator").getReturnType } val accTypes = accTypeClasses.map(_.getCanonicalName) - // get java classes of input fields - val javaClasses = physicalInputTypes.map(t => t.getTypeClass) // get parameter lists for aggregation functions - val parameters = aggFields.map { inFields => + val parametersCode = aggFields.map { inFields => val fields = for (f <- inFields) yield - s"(${javaClasses(f).getCanonicalName}) input.getField($f)" + s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)" fields.mkString(", ") } - val methodSignaturesList = aggFields.map { - inFields => for (f <- inFields) yield javaClasses(f) + + // get method signatures + val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes) + val methodSignaturesList = aggFields.map { inFields => + inFields.map(classes(_)) } + // initialize and create data views + addReusableDataViews() + // check and validate the needed methods aggregates.zipWithIndex.map { - case (a, i) => { + case (a, i) => getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) .getOrElse( throw new CodeGenException( @@ -159,6 +182,113 @@ class AggregationCodeGenerator( s"aggregate ${a.getClass.getCanonicalName}'.") ) } + } + + /** + * Create DataView Term, for example, acc1_map_dataview. + * + * @param aggIndex index of aggregate function + * @param fieldName field name of DataView + * @return term to access [[MapView]] or [[ListView]] + */ + def createDataViewTerm(aggIndex: Int, fieldName: String): String = { + s"acc${aggIndex}_${fieldName}_dataview" + } + + /** + * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup, + * close and member area of the generated function. + * + */ + def addReusableDataViews(): Unit = { + if (accConfig.isDefined) { + val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get + .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor))) + .toMap[String, StateDescriptor[_ <: State, _]] + + for (i <- aggs.indices) yield { + for (spec <- accConfig.get(i)) yield { + val dataViewField = spec.field + val dataViewTypeTerm = dataViewField.getType.getCanonicalName + val desc = descMapping.getOrElse(spec.stateId, + throw new CodeGenException( + s"Can not find DataView in accumulator by id: ${spec.stateId}")) + + // define the DataView variables + val serializedData = serializeStateDescriptor(desc) + val dataViewFieldTerm = createDataViewTerm(i, dataViewField.getName) + val field = + s""" + | transient $dataViewTypeTerm $dataViewFieldTerm = null; + |""".stripMargin + reusableMemberStatements.add(field) + + // create DataViews + val descFieldTerm = s"${dataViewFieldTerm}_desc" + val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName + val descDeserializeCode = + s""" + | $descClassQualifier $descFieldTerm = ($descClassQualifier) + | org.apache.flink.util.InstantiationUtil.deserializeObject( + | org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"), + | $contextTerm.getUserCodeClassLoader()); + |""".stripMargin + val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) { + s""" + | $descDeserializeCode + | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateMapView( + | $contextTerm.getMapState(( + | org.apache.flink.api.common.state.MapStateDescriptor)$descFieldTerm)); + |""".stripMargin + } else if (dataViewField.getType == classOf[ListView[_]]) { + s""" + | $descDeserializeCode + | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateListView( + | $contextTerm.getListState(( + | org.apache.flink.api.common.state.ListStateDescriptor)$descFieldTerm)); + |""".stripMargin + } else { + throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm") + } + reusableOpenStatements.add(createDataView) + + // cleanup DataViews + val cleanup = + s""" + | $dataViewFieldTerm.clear(); + |""".stripMargin + reusableCleanupStatements.add(cleanup) + } + } + } + } + + /** + * Generate statements to set data view field when use state backend. + * + * @param accTerm aggregation term + * @param aggIndex index of aggregation + * @return data view field set statements + */ + def genDataViewFieldSetter(accTerm: String, aggIndex: Int): String = { + if (accConfig.isDefined) { + val setters = for (spec <- accConfig.get(aggIndex)) yield { + val field = spec.field + val dataViewTerm = createDataViewTerm(aggIndex, field.getName) + val fieldSetter = if (Modifier.isPublic(field.getModifiers)) { + s"$accTerm.${field.getName} = $dataViewTerm;" + } else { + val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName) + s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};" + } + + s""" + | $fieldSetter + """.stripMargin + } + setters.mkString("\n") + } else { + "" } } @@ -168,7 +298,7 @@ class AggregationCodeGenerator( j""" | public final void setAggregationResults( | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row output)""".stripMargin + | org.apache.flink.types.Row output) throws Exception """.stripMargin val setAggs: String = { for (i <- aggs.indices) yield @@ -182,10 +312,11 @@ class AggregationCodeGenerator( j""" | org.apache.flink.table.functions.AggregateFunction baseClass$i = | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)}; - | + | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); + | ${genDataViewFieldSetter(s"acc$i", i)} | output.setField( | ${aggMapping(i)}, - | baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin + | baseClass$i.getValue(acc$i));""".stripMargin } }.mkString("\n") @@ -201,14 +332,17 @@ class AggregationCodeGenerator( j""" | public final void accumulate( | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row input)""".stripMargin + | org.apache.flink.types.Row input) throws Exception """.stripMargin val accumulate: String = { - for (i <- aggs.indices) yield + for (i <- aggs.indices) yield { j""" + | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); + | ${genDataViewFieldSetter(s"acc$i", i)} | ${aggs(i)}.accumulate( - | ((${accTypes(i)}) accs.getField($i)), - | ${parameters(i)});""".stripMargin + | acc$i, + | ${parametersCode(i)});""".stripMargin + } }.mkString("\n") j"""$sig { @@ -222,14 +356,17 @@ class AggregationCodeGenerator( j""" | public final void retract( | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row input)""".stripMargin + | org.apache.flink.types.Row input) throws Exception """.stripMargin val retract: String = { - for (i <- aggs.indices) yield + for (i <- aggs.indices) yield { j""" + | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); + | ${genDataViewFieldSetter(s"acc$i", i)} | ${aggs(i)}.retract( - | ((${accTypes(i)}) accs.getField($i)), - | ${parameters(i)});""".stripMargin + | acc$i, + | ${parametersCode(i)});""".stripMargin + } }.mkString("\n") if (needRetract) { @@ -248,7 +385,7 @@ class AggregationCodeGenerator( val sig: String = j""" - | public final org.apache.flink.types.Row createAccumulators() + | public final org.apache.flink.types.Row createAccumulators() throws Exception | """.stripMargin val init: String = j""" @@ -256,12 +393,15 @@ class AggregationCodeGenerator( | new org.apache.flink.types.Row(${aggs.length});""" .stripMargin val create: String = { - for (i <- aggs.indices) yield + for (i <- aggs.indices) yield { j""" + | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator(); + | ${genDataViewFieldSetter(s"acc$i", i)} | accs.setField( | $i, - | ${aggs(i)}.createAccumulator());""" - .stripMargin + | acc$i);""" + .stripMargin + } }.mkString("\n") val ret: String = j""" @@ -357,6 +497,10 @@ class AggregationCodeGenerator( """.stripMargin if (needMerge) { + if (accConfig.isDefined) { + throw new CodeGenException("DataView doesn't support merge when the backend uses " + + s"state when generate aggregation for $funcName.") + } j""" |$sig { |$merge @@ -386,13 +530,15 @@ class AggregationCodeGenerator( val sig: String = j""" | public final void resetAccumulator( - | org.apache.flink.types.Row accs)""".stripMargin + | org.apache.flink.types.Row accs) throws Exception """.stripMargin val reset: String = { - for (i <- aggs.indices) yield + for (i <- aggs.indices) yield { j""" - | ${aggs(i)}.resetAccumulator( - | ((${accTypes(i)}) accs.getField($i)));""".stripMargin + | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); + | ${genDataViewFieldSetter(s"acc$i", i)} + | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin + } }.mkString("\n") if (needReset) { @@ -405,6 +551,17 @@ class AggregationCodeGenerator( } } + val aggFuncCode = Seq( + genSetAggregationResults, + genAccumulate, + genRetract, + genCreateAccumulators, + genSetForwardedFields, + genSetConstantFlags, + genCreateOutputRow, + genMergeAccumulatorsPair, + genResetAccumulator).mkString("\n") + val generatedAggregationsClass = classOf[GeneratedAggregations].getCanonicalName var funcCode = j""" @@ -417,20 +574,29 @@ class AggregationCodeGenerator( | } | ${reuseConstructorCode(funcName)} | + | public final void open( + | org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception { + | ${reuseOpenCode()} + | } + | + | $aggFuncCode + | + | public final void cleanup() throws Exception { + | ${reuseCleanupCode()} + | } + | + | public final void close() throws Exception { + | ${reuseCloseCode()} + | } + |} """.stripMargin - funcCode += genSetAggregationResults + "\n" - funcCode += genAccumulate + "\n" - funcCode += genRetract + "\n" - funcCode += genCreateAccumulators + "\n" - funcCode += genSetForwardedFields + "\n" - funcCode += genSetConstantFlags + "\n" - funcCode += genCreateOutputRow + "\n" - funcCode += genMergeAccumulatorsPair + "\n" - funcCode += genResetAccumulator + "\n" - funcCode += "}" - GeneratedAggregationsFunction(funcName, funcCode) } + @throws[Exception] + def serializeStateDescriptor(stateDescriptor: StateDescriptor[_, _]): String = { + val byteArray = InstantiationUtil.serializeObject(stateDescriptor) + Base64.encodeBase64URLSafeString(byteArray) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index 1d8c926233b2f..161f9a3b055f4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -28,7 +28,7 @@ import org.apache.flink.api.common.typeinfo.{FractionalTypeInfo, SqlTimeTypeInfo import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, TupleTypeInfo, TypeExtractor} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.table.typeutils.{TimeIntervalTypeInfo, TypeCheckUtils} +import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TimeIntervalTypeInfo, TypeCheckUtils} object CodeGenUtils { @@ -90,6 +90,9 @@ object CodeGenUtils { case BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]" case CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]" + // time indicators are represented as Long even if they seem to be Timestamp + case _: TimeIndicatorTypeInfo => "java.lang.Long" + case _ => tpe.getTypeClass.getCanonicalName } @@ -123,8 +126,10 @@ object CodeGenUtils { def qualifyEnum(enum: Enum[_]): String = enum.getClass.getCanonicalName + "." + enum.name() - def internalToTimePointCode(resultType: TypeInformation[_], resultTerm: String) = + def internalToTimePointCode(resultType: TypeInformation[_], resultTerm: String): String = resultType match { + case _: TimeIndicatorTypeInfo => + resultTerm // time indicators are not modified case SqlTimeTypeInfo.DATE => s"${qualifyMethod(BuiltInMethod.INTERNAL_TO_DATE.method)}($resultTerm)" case SqlTimeTypeInfo.TIME => @@ -133,7 +138,7 @@ object CodeGenUtils { s"${qualifyMethod(BuiltInMethod.INTERNAL_TO_TIMESTAMP.method)}($resultTerm)" } - def timePointToInternalCode(resultType: TypeInformation[_], resultTerm: String) = + def timePointToInternalCode(resultType: TypeInformation[_], resultTerm: String): String = resultType match { case SqlTimeTypeInfo.DATE => s"${qualifyMethod(BuiltInMethod.DATE_TO_INT.method)}($resultTerm)" @@ -157,43 +162,43 @@ object CodeGenUtils { // ---------------------------------------------------------------------------------------------- - def requireNumeric(genExpr: GeneratedExpression) = + def requireNumeric(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isNumeric(genExpr.resultType)) { throw new CodeGenException("Numeric expression type expected, but was " + s"'${genExpr.resultType}'.") } - def requireComparable(genExpr: GeneratedExpression) = + def requireComparable(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isComparable(genExpr.resultType)) { throw new CodeGenException(s"Comparable type expected, but was '${genExpr.resultType}'.") } - def requireString(genExpr: GeneratedExpression) = + def requireString(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isString(genExpr.resultType)) { throw new CodeGenException("String expression type expected.") } - def requireBoolean(genExpr: GeneratedExpression) = + def requireBoolean(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isBoolean(genExpr.resultType)) { throw new CodeGenException("Boolean expression type expected.") } - def requireTemporal(genExpr: GeneratedExpression) = + def requireTemporal(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isTemporal(genExpr.resultType)) { throw new CodeGenException("Temporal expression type expected.") } - def requireTimeInterval(genExpr: GeneratedExpression) = + def requireTimeInterval(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isTimeInterval(genExpr.resultType)) { throw new CodeGenException("Interval expression type expected.") } - def requireArray(genExpr: GeneratedExpression) = + def requireArray(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isArray(genExpr.resultType)) { throw new CodeGenException("Array expression type expected.") } - def requireInteger(genExpr: GeneratedExpression) = + def requireInteger(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isInteger(genExpr.resultType)) { throw new CodeGenException("Integer expression type expected.") } @@ -243,7 +248,7 @@ object CodeGenUtils { val fieldName = pt.getFieldNames()(index) getFieldAccessor(pt.getTypeClass, fieldName) - case _ => throw new CodeGenException(s"Unsupported composite type: '${compType}'") + case _ => throw new CodeGenException(s"Unsupported composite type: '$compType'") } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 63fd058b980cb..bf6ee217b9683 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName} import org.apache.calcite.sql.fun.SqlStdOperatorTable._ +import org.apache.commons.lang3.StringEscapeUtils import org.apache.flink.api.common.functions._ import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.common.typeutils.CompositeType @@ -36,11 +37,13 @@ import org.apache.flink.table.api.TableConfig import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} +import org.apache.flink.table.codegen.calls.FunctionGenerator import org.apache.flink.table.codegen.calls.ScalarOperators._ -import org.apache.flink.table.codegen.calls.{BuiltInMethods, FunctionGenerator} -import org.apache.flink.table.functions.sql.ScalarSqlFunctions +import org.apache.flink.table.functions.sql.{ProctimeSqlFunction, ScalarSqlFunctions} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils -import org.apache.flink.table.functions.{FunctionContext, TimeMaterializationSqlFunction, UserDefinedFunction} + +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo +import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} import org.apache.flink.table.typeutils.TypeCheckUtils._ import scala.collection.JavaConversions._ @@ -55,10 +58,11 @@ import scala.collection.mutable * @param nullableInput input(s) can be null. * @param input1 type information about the first input of the Function * @param input2 type information about the second input if the Function is binary - * @param input1FieldMapping additional mapping information for input1 - * (e.g. POJO types have no deterministic field order and some input fields might not be read) - * @param input2FieldMapping additional mapping information for input2 - * (e.g. POJO types have no deterministic field order and some input fields might not be read) + * @param input1FieldMapping additional mapping information for input1. + * POJO types have no deterministic field order and some input fields might not be read. + * The input1FieldMapping is also used to inject time indicator attributes. + * @param input2FieldMapping additional mapping information for input2. + * POJO types have no deterministic field order and some input fields might not be read. */ abstract class CodeGenerator( config: TableConfig, @@ -105,31 +109,31 @@ abstract class CodeGenerator( // set of member statements that will be added only once // we use a LinkedHashSet to keep the insertion order - private val reusableMemberStatements = mutable.LinkedHashSet[String]() + protected val reusableMemberStatements = mutable.LinkedHashSet[String]() // set of constructor statements that will be added only once // we use a LinkedHashSet to keep the insertion order - private val reusableInitStatements = mutable.LinkedHashSet[String]() + protected val reusableInitStatements = mutable.LinkedHashSet[String]() // set of open statements for RichFunction that will be added only once // we use a LinkedHashSet to keep the insertion order - private val reusableOpenStatements = mutable.LinkedHashSet[String]() + protected val reusableOpenStatements = mutable.LinkedHashSet[String]() // set of close statements for RichFunction that will be added only once // we use a LinkedHashSet to keep the insertion order - private val reusableCloseStatements = mutable.LinkedHashSet[String]() + protected val reusableCloseStatements = mutable.LinkedHashSet[String]() // set of statements that will be added only once per record // we use a LinkedHashSet to keep the insertion order - private val reusablePerRecordStatements = mutable.LinkedHashSet[String]() + protected val reusablePerRecordStatements = mutable.LinkedHashSet[String]() // map of initial input unboxing expressions that will be added only once // (inputTerm, index) -> expr - private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]() + protected val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]() // set of constructor statements that will be added only once // we use a LinkedHashSet to keep the insertion order - private val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]() + protected val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]() /** * @return code block of statements that need to be placed in the member area of the Function @@ -244,16 +248,23 @@ abstract class CodeGenerator( returnType: TypeInformation[_ <: Any], resultFieldNames: Seq[String]) : GeneratedExpression = { - val input1AccessExprs = input1Mapping.map { idx => - generateInputAccess(input1, input1Term, idx) + + val input1AccessExprs = input1Mapping.map { + case TimeIndicatorTypeInfo.ROWTIME_MARKER => + // attribute is a rowtime indicator. Access event-time timestamp in StreamRecord. + generateRowtimeAccess() + case TimeIndicatorTypeInfo.PROCTIME_MARKER => + // attribute is proctime indicator. + // We use a null literal and generate a timestamp when we need it. + generateNullLiteral(TimeIndicatorTypeInfo.PROCTIME_INDICATOR) + case idx => + // regular attribute. Access attribute in input data type. + generateInputAccess(input1, input1Term, idx) } val input2AccessExprs = input2 match { case Some(ti) => - input2Mapping.map { idx => - generateInputAccess(ti, input2Term, idx) - }.toSeq - + input2Mapping.map(idx => generateInputAccess(ti, input2Term, idx)).toSeq case None => Seq() // add nothing } @@ -318,13 +329,13 @@ abstract class CodeGenerator( // initial type check if (returnType.getArity != fieldExprs.length) { throw new CodeGenException( - s"Arity[${returnType.getArity}] of result type[$returnType] does not match " + - s"number[${fieldExprs.length}] of expressions[$fieldExprs].") + s"Arity [${returnType.getArity}] of result type [$returnType] does not match " + + s"number [${fieldExprs.length}] of expressions [$fieldExprs].") } if (resultFieldNames.length != fieldExprs.length) { throw new CodeGenException( - s"Arity[${resultFieldNames.length}] of result field names[$resultFieldNames] does not " + - s"match number[${fieldExprs.length}] of expressions[$fieldExprs].") + s"Arity [${resultFieldNames.length}] of result field names [$resultFieldNames] does not " + + s"match number [${fieldExprs.length}] of expressions [$fieldExprs].") } // type check returnType match { @@ -332,8 +343,8 @@ abstract class CodeGenerator( fieldExprs.zipWithIndex foreach { case (fieldExpr, i) if fieldExpr.resultType != pt.getTypeAt(resultFieldNames(i)) => throw new CodeGenException( - s"Incompatible types of expression and result type. Expression[$fieldExpr] type is " + - s"[${fieldExpr.resultType}], result type is [${pt.getTypeAt(resultFieldNames(i))}]") + s"Incompatible types of expression and result type. Expression [$fieldExpr] type is" + + s" [${fieldExpr.resultType}], result type is [${pt.getTypeAt(resultFieldNames(i))}]") case _ => // ok } @@ -349,7 +360,7 @@ abstract class CodeGenerator( case at: AtomicType[_] if at != fieldExprs.head.resultType => throw new CodeGenException( - s"Incompatible types of expression and result type. Expression[${fieldExprs.head}] " + + s"Incompatible types of expression and result type. Expression [${fieldExprs.head}] " + s"type is [${fieldExprs.head.resultType}], result type is [$at]") case _ => // ok @@ -669,7 +680,8 @@ abstract class CodeGenerator( generateNonNullLiteral(resultType, decimalField) case VARCHAR | CHAR => - generateNonNullLiteral(resultType, "\"" + value.toString + "\"") + val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(value.toString) + generateNonNullLiteral(resultType, "\"" + escapedValue + "\"") case SYMBOL => generateSymbol(value.asInstanceOf[Enum[_]]) @@ -722,10 +734,8 @@ abstract class CodeGenerator( override def visitCall(call: RexCall): GeneratedExpression = { // special case: time materialization - if (call.getOperator == TimeMaterializationSqlFunction) { - return generateRecordTimestamp( - FlinkTypeFactory.isRowtimeIndicatorType(call.getOperands.get(0).getType) - ) + if (call.getOperator == ProctimeSqlFunction) { + return generateProctimeTimestamp() } val resultType = FlinkTypeFactory.toTypeInfo(call.getType) @@ -965,10 +975,10 @@ abstract class CodeGenerator( generateArrayElement(this, array) case ScalarSqlFunctions.CONCAT => - generateConcat(BuiltInMethods.CONCAT, operands) + generateConcat(this.nullCheck, operands) case ScalarSqlFunctions.CONCAT_WS => - generateConcat(BuiltInMethods.CONCAT_WS, operands) + generateConcatWs(operands) // advanced scalar functions case sqlOperator: SqlOperator => @@ -1201,21 +1211,27 @@ abstract class CodeGenerator( } val wrappedCode = if (nullCheck && !isReference(fieldType)) { + // assumes that fieldType is a boxed primitive. s""" - |$tmpTypeTerm $tmpTerm = $unboxedFieldCode; - |boolean $nullTerm = $tmpTerm == null; + |boolean $nullTerm = $fieldTerm == null; |$resultTypeTerm $resultTerm; |if ($nullTerm) { | $resultTerm = $defaultValue; |} |else { - | $resultTerm = $tmpTerm; + | $resultTerm = $fieldTerm; |} |""".stripMargin } else if (nullCheck) { s""" - |$resultTypeTerm $resultTerm = $unboxedFieldCode; |boolean $nullTerm = $fieldTerm == null; + |$resultTypeTerm $resultTerm; + |if ($nullTerm) { + | $resultTerm = $defaultValue; + |} + |else { + | $resultTerm = $unboxedFieldCode; + |} |""".stripMargin } else { s""" @@ -1268,27 +1284,31 @@ abstract class CodeGenerator( } } - private[flink] def generateRecordTimestamp(isEventTime: Boolean): GeneratedExpression = { + private[flink] def generateRowtimeAccess(): GeneratedExpression = { val resultTerm = newName("result") - val resultTypeTerm = primitiveTypeTermForTypeInfo(SqlTimeTypeInfo.TIMESTAMP) + val nullTerm = newName("isNull") - val resultCode = if (isEventTime) { + val accessCode = s""" - |$resultTypeTerm $resultTerm; - |if ($contextTerm.timestamp() == null) { + |Long $resultTerm = $contextTerm.timestamp(); + |if ($resultTerm == null) { | throw new RuntimeException("Rowtime timestamp is null. Please make sure that a proper " + | "TimestampAssigner is defined and the stream environment uses the EventTime time " + | "characteristic."); |} - |else { - | $resultTerm = $contextTerm.timestamp(); - |} - |""".stripMargin - } else { + |boolean $nullTerm = false; + """.stripMargin + + GeneratedExpression(resultTerm, nullTerm, accessCode, TimeIndicatorTypeInfo.ROWTIME_INDICATOR) + } + + private[flink] def generateProctimeTimestamp(): GeneratedExpression = { + val resultTerm = newName("result") + + val resultCode = s""" - |$resultTypeTerm $resultTerm = $contextTerm.timerService().currentProcessingTime(); + |long $resultTerm = $contextTerm.timerService().currentProcessingTime(); |""".stripMargin - } GeneratedExpression(resultTerm, NEVER_NULL, resultCode, SqlTimeTypeInfo.TIMESTAMP) } @@ -1439,9 +1459,10 @@ abstract class CodeGenerator( * Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]]. * * @param function [[UserDefinedFunction]] object to be instantiated during runtime + * @param contextTerm [[RuntimeContext]] term to access the [[RuntimeContext]] * @return member variable term */ - def addReusableFunction(function: UserDefinedFunction): String = { + def addReusableFunction(function: UserDefinedFunction, contextTerm: String = null): String = { val classQualifier = function.getClass.getCanonicalName val functionSerializedData = UserDefinedFunctionUtils.serialize(function) val fieldTerm = s"function_${function.functionIdentifier}" @@ -1461,10 +1482,15 @@ abstract class CodeGenerator( reusableInitStatements.add(functionDeserialization) - val openFunction = + val openFunction = if (contextTerm != null) { + s""" + |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}($contextTerm)); + """.stripMargin + } else { s""" |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}(getRuntimeContext())); """.stripMargin + } reusableOpenStatements.add(openFunction) val closeFunction = diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 1ab927d0020e2..7de7acaa9de45 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -17,8 +17,6 @@ */ package org.apache.flink.table.codegen.calls -import java.lang.reflect.Method - import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} import org.apache.calcite.util.BuiltInMethod @@ -28,7 +26,7 @@ import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo} import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression} -import org.apache.flink.table.typeutils.{TimeIntervalTypeInfo, TypeCoercion} +import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TimeIntervalTypeInfo, TypeCoercion} import org.apache.flink.table.typeutils.TypeCheckUtils._ object ScalarOperators { @@ -545,6 +543,11 @@ object ScalarOperators { operand: GeneratedExpression, targetType: TypeInformation[_]) : GeneratedExpression = (operand.resultType, targetType) match { + + // special case: cast from TimeIndicatorTypeInfo to SqlTimeTypeInfo + case (ti: TimeIndicatorTypeInfo, SqlTimeTypeInfo.TIMESTAMP) => + operand.copy(resultType = SqlTimeTypeInfo.TIMESTAMP) // just replace the TypeInformation + // identity casting case (fromTp, toTp) if fromTp == toTp => operand @@ -1026,14 +1029,48 @@ object ScalarOperators { } def generateConcat( - method: Method, - operands: Seq[GeneratedExpression]): GeneratedExpression = { + nullCheck: Boolean, + operands: Seq[GeneratedExpression]) + : GeneratedExpression = { - generateCallIfArgsNotNull(false, STRING_TYPE_INFO, operands) { - (terms) =>s"${qualifyMethod(method)}(${terms.mkString(", ")})" + generateCallIfArgsNotNull(nullCheck, STRING_TYPE_INFO, operands) { + (terms) =>s"${qualifyMethod(BuiltInMethods.CONCAT)}(${terms.mkString(", ")})" } } + def generateConcatWs(operands: Seq[GeneratedExpression]): GeneratedExpression = { + + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val defaultValue = primitiveDefaultValue(Types.STRING) + + val tempTerms = operands.tail.map(_ => newName("temp")) + + val operatorCode = + s""" + |${operands.map(_.code).mkString("\n")} + | + |String $resultTerm; + |boolean $nullTerm; + |if (${operands.head.nullTerm}) { + | $nullTerm = true; + | $resultTerm = $defaultValue; + |} else { + | ${operands.tail.zip(tempTerms).map { + case (o: GeneratedExpression, t: String) => + s"String $t;\n" + + s" if (${o.nullTerm}) $t = null; else $t = ${o.resultTerm};" + }.mkString("\n") + } + | $nullTerm = false; + | $resultTerm = ${qualifyMethod(BuiltInMethods.CONCAT_WS)} + | (${operands.head.resultTerm}, ${tempTerms.mkString(", ")}); + |} + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, operatorCode, Types.STRING) + } + def generateMapGet( codeGenerator: CodeGenerator, map: GeneratedExpression, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala new file mode 100644 index 0000000000000..a450c2ce1e552 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import org.apache.flink.api.common.typeutils._ +import org.apache.flink.api.common.typeutils.base.{CollectionSerializerConfigSnapshot, ListSerializer} +import org.apache.flink.core.memory.{DataInputView, DataOutputView} +import org.apache.flink.table.api.dataview.ListView + +/** + * A serializer for [[ListView]]. The serializer relies on an element + * serializer for the serialization of the list's elements. + * + * The serialization format for the list is as follows: four bytes for the length of the list, + * followed by the serialized representation of each element. + * + * @param listSerializer List serializer. + * @tparam T The type of element in the list. + */ +class ListViewSerializer[T](val listSerializer: ListSerializer[T]) + extends TypeSerializer[ListView[T]] { + + override def isImmutableType: Boolean = false + + override def duplicate(): TypeSerializer[ListView[T]] = { + new ListViewSerializer[T](listSerializer.duplicate().asInstanceOf[ListSerializer[T]]) + } + + override def createInstance(): ListView[T] = { + new ListView[T] + } + + override def copy(from: ListView[T]): ListView[T] = { + new ListView[T](null, listSerializer.copy(from.list)) + } + + override def copy(from: ListView[T], reuse: ListView[T]): ListView[T] = copy(from) + + override def getLength: Int = -1 + + override def serialize(record: ListView[T], target: DataOutputView): Unit = { + listSerializer.serialize(record.list, target) + } + + override def deserialize(source: DataInputView): ListView[T] = { + new ListView[T](null, listSerializer.deserialize(source)) + } + + override def deserialize(reuse: ListView[T], source: DataInputView): ListView[T] = + deserialize(source) + + override def copy(source: DataInputView, target: DataOutputView): Unit = + listSerializer.copy(source, target) + + override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass + + override def hashCode(): Int = listSerializer.hashCode() + + override def equals(obj: Any): Boolean = canEqual(this) && + listSerializer.equals(obj.asInstanceOf[ListViewSerializer[_]].listSerializer) + + override def snapshotConfiguration(): TypeSerializerConfigSnapshot = + listSerializer.snapshotConfiguration() + + // copy and modified from ListSerializer.ensureCompatibility + override def ensureCompatibility( + configSnapshot: TypeSerializerConfigSnapshot): CompatibilityResult[ListView[T]] = { + + configSnapshot match { + case snapshot: CollectionSerializerConfigSnapshot[_] => + val previousListSerializerAndConfig = snapshot.getSingleNestedSerializerAndConfig + + val compatResult = CompatibilityUtil.resolveCompatibilityResult( + previousListSerializerAndConfig.f0, + classOf[UnloadableDummyTypeSerializer[_]], + previousListSerializerAndConfig.f1, + listSerializer.getElementSerializer) + + if (!compatResult.isRequiresMigration) { + CompatibilityResult.compatible[ListView[T]] + } else if (compatResult.getConvertDeserializer != null) { + CompatibilityResult.requiresMigration( + new ListViewSerializer[T]( + new ListSerializer[T]( + new TypeDeserializerAdapter[T](compatResult.getConvertDeserializer)) + ) + ) + } else { + CompatibilityResult.requiresMigration[ListView[T]] + } + + case _ => CompatibilityResult.requiresMigration[ListView[T]] + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala new file mode 100644 index 0000000000000..a10b6754a7c43 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.common.typeutils.base.ListSerializer +import org.apache.flink.table.api.dataview.ListView + +/** + * [[ListView]] type information. + * + * @param elementType element type information + * @tparam T element type + */ +class ListViewTypeInfo[T](val elementType: TypeInformation[T]) + extends TypeInformation[ListView[T]] { + + override def isBasicType: Boolean = false + + override def isTupleType: Boolean = false + + override def getArity: Int = 1 + + override def getTotalFields: Int = 1 + + override def getTypeClass: Class[ListView[T]] = classOf[ListView[T]] + + override def isKeyType: Boolean = false + + override def createSerializer(config: ExecutionConfig): TypeSerializer[ListView[T]] = { + val typeSer = elementType.createSerializer(config) + new ListViewSerializer[T](new ListSerializer[T](typeSer)) + } + + override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass + + override def hashCode(): Int = 31 * elementType.hashCode + + override def equals(obj: Any): Boolean = canEqual(obj) && { + obj match { + case other: ListViewTypeInfo[T] => + elementType.equals(other.elementType) + case _ => false + } + } + + override def toString: String = s"ListView<$elementType>" +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala new file mode 100644 index 0000000000000..eda6cb9893f26 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import java.lang.reflect.Type +import java.util + +import org.apache.flink.api.common.typeinfo.{TypeInfoFactory, TypeInformation} +import org.apache.flink.api.java.typeutils.GenericTypeInfo +import org.apache.flink.table.api.dataview.ListView + +class ListViewTypeInfoFactory[T] extends TypeInfoFactory[ListView[T]] { + + override def createTypeInfo( + t: Type, + genericParameters: util.Map[String, TypeInformation[_]]): TypeInformation[ListView[T]] = { + + var elementType = genericParameters.get("T") + + if (elementType == null) { + // we might can get the elementType later from the ListView constructor + elementType = new GenericTypeInfo(classOf[Any]) + } + + new ListViewTypeInfo[T](elementType.asInstanceOf[TypeInformation[T]]) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala new file mode 100644 index 0000000000000..c53f10c37e597 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import org.apache.flink.api.common.typeutils._ +import org.apache.flink.api.common.typeutils.base.{MapSerializer, MapSerializerConfigSnapshot} +import org.apache.flink.core.memory.{DataInputView, DataOutputView} +import org.apache.flink.table.api.dataview.MapView + +/** + * A serializer for [[MapView]]. The serializer relies on a key serializer and a value + * serializer for the serialization of the map's key-value pairs. + * + * The serialization format for the map is as follows: four bytes for the length of the map, + * followed by the serialized representation of each key-value pair. To allow null values, + * each value is prefixed by a null marker. + * + * @param mapSerializer Map serializer. + * @tparam K The type of the keys in the map. + * @tparam V The type of the values in the map. + */ +class MapViewSerializer[K, V](val mapSerializer: MapSerializer[K, V]) + extends TypeSerializer[MapView[K, V]] { + + override def isImmutableType: Boolean = false + + override def duplicate(): TypeSerializer[MapView[K, V]] = + new MapViewSerializer[K, V]( + mapSerializer.duplicate().asInstanceOf[MapSerializer[K, V]]) + + override def createInstance(): MapView[K, V] = { + new MapView[K, V]() + } + + override def copy(from: MapView[K, V]): MapView[K, V] = { + new MapView[K, V](null, null, mapSerializer.copy(from.map)) + } + + override def copy(from: MapView[K, V], reuse: MapView[K, V]): MapView[K, V] = copy(from) + + override def getLength: Int = -1 // var length + + override def serialize(record: MapView[K, V], target: DataOutputView): Unit = { + mapSerializer.serialize(record.map, target) + } + + override def deserialize(source: DataInputView): MapView[K, V] = { + new MapView[K, V](null, null, mapSerializer.deserialize(source)) + } + + override def deserialize(reuse: MapView[K, V], source: DataInputView): MapView[K, V] = + deserialize(source) + + override def copy(source: DataInputView, target: DataOutputView): Unit = + mapSerializer.copy(source, target) + + override def canEqual(obj: Any): Boolean = obj != null && obj.getClass == getClass + + override def hashCode(): Int = mapSerializer.hashCode() + + override def equals(obj: Any): Boolean = canEqual(this) && + mapSerializer.equals(obj.asInstanceOf[MapViewSerializer[_, _]].mapSerializer) + + override def snapshotConfiguration(): TypeSerializerConfigSnapshot = + mapSerializer.snapshotConfiguration() + + // copy and modified from MapSerializer.ensureCompatibility + override def ensureCompatibility(configSnapshot: TypeSerializerConfigSnapshot) + : CompatibilityResult[MapView[K, V]] = { + + configSnapshot match { + case snapshot: MapSerializerConfigSnapshot[_, _] => + val previousKvSerializersAndConfigs = snapshot.getNestedSerializersAndConfigs + + val keyCompatResult = CompatibilityUtil.resolveCompatibilityResult( + previousKvSerializersAndConfigs.get(0).f0, + classOf[UnloadableDummyTypeSerializer[_]], + previousKvSerializersAndConfigs.get(0).f1, + mapSerializer.getKeySerializer) + + val valueCompatResult = CompatibilityUtil.resolveCompatibilityResult( + previousKvSerializersAndConfigs.get(1).f0, + classOf[UnloadableDummyTypeSerializer[_]], + previousKvSerializersAndConfigs.get(1).f1, + mapSerializer.getValueSerializer) + + if (!keyCompatResult.isRequiresMigration && !valueCompatResult.isRequiresMigration) { + CompatibilityResult.compatible[MapView[K, V]] + } else if (keyCompatResult.getConvertDeserializer != null + && valueCompatResult.getConvertDeserializer != null) { + CompatibilityResult.requiresMigration( + new MapViewSerializer[K, V]( + new MapSerializer[K, V]( + new TypeDeserializerAdapter[K](keyCompatResult.getConvertDeserializer), + new TypeDeserializerAdapter[V](valueCompatResult.getConvertDeserializer)) + ) + ) + } else { + CompatibilityResult.requiresMigration[MapView[K, V]] + } + + case _ => CompatibilityResult.requiresMigration[MapView[K, V]] + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala new file mode 100644 index 0000000000000..ec5c2226e4278 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.common.typeutils.base.MapSerializer +import org.apache.flink.table.api.dataview.MapView + +/** + * [[MapView]] type information. + * + * @param keyType key type information + * @param valueType value type information + * @tparam K key type + * @tparam V value type + */ +class MapViewTypeInfo[K, V]( + val keyType: TypeInformation[K], + val valueType: TypeInformation[V]) + extends TypeInformation[MapView[K, V]] { + + override def isBasicType = false + + override def isTupleType = false + + override def getArity = 1 + + override def getTotalFields = 1 + + override def getTypeClass: Class[MapView[K, V]] = classOf[MapView[K, V]] + + override def isKeyType: Boolean = false + + override def createSerializer(config: ExecutionConfig): TypeSerializer[MapView[K, V]] = { + val keySer = keyType.createSerializer(config) + val valueSer = valueType.createSerializer(config) + new MapViewSerializer[K, V](new MapSerializer[K, V](keySer, valueSer)) + } + + override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass + + override def hashCode(): Int = 31 * keyType.hashCode + valueType.hashCode + + override def equals(obj: Any): Boolean = canEqual(obj) && { + obj match { + case other: MapViewTypeInfo[_, _] => + keyType.equals(other.keyType) && + valueType.equals(other.valueType) + case _ => false + } + } + + override def toString: String = s"MapView<$keyType, $valueType>" +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala new file mode 100644 index 0000000000000..33c3ffe2a2e38 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import java.lang.reflect.Type +import java.util + +import org.apache.flink.api.common.typeinfo.{TypeInfoFactory, TypeInformation} +import org.apache.flink.api.java.typeutils.GenericTypeInfo +import org.apache.flink.table.api.dataview.MapView + +class MapViewTypeInfoFactory[K, V] extends TypeInfoFactory[MapView[K, V]] { + + override def createTypeInfo( + t: Type, + genericParameters: util.Map[String, TypeInformation[_]]): TypeInformation[MapView[K, V]] = { + + var keyType = genericParameters.get("K") + var valueType = genericParameters.get("V") + + if (keyType == null) { + // we might can get the keyType later from the MapView constructor + keyType = new GenericTypeInfo(classOf[Any]) + } + + if (valueType == null) { + // we might can get the valueType later from the MapView constructor + valueType = new GenericTypeInfo(classOf[Any]) + } + + new MapViewTypeInfo[K, V]( + keyType.asInstanceOf[TypeInformation[K]], + valueType.asInstanceOf[TypeInformation[V]]) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala new file mode 100644 index 0000000000000..70756ca761414 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.dataview + +import java.util +import java.lang.{Iterable => JIterable} + +import org.apache.flink.api.common.state._ +import org.apache.flink.table.api.dataview.ListView + +/** + * [[ListView]] use state backend. + * + * @param state list state + * @tparam T element type + */ +class StateListView[T](state: ListState[T]) extends ListView[T] { + + override def get: JIterable[T] = state.get() + + override def add(value: T): Unit = state.add(value) + + override def addAll(list: util.List[T]): Unit = { + val iterator = list.iterator() + while (iterator.hasNext) { + state.add(iterator.next()) + } + } + + override def clear(): Unit = state.clear() +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala new file mode 100644 index 0000000000000..22f5f0b23459f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.dataview + +import java.util +import java.lang.{Iterable => JIterable} + +import org.apache.flink.api.common.state.MapState +import org.apache.flink.table.api.dataview.MapView + +/** + * [[MapView]] use state backend. + * + * @param state map state + * @tparam K key type + * @tparam V value type + */ +class StateMapView[K, V](state: MapState[K, V]) extends MapView[K, V] { + + override def get(key: K): V = state.get(key) + + override def put(key: K, value: V): Unit = state.put(key, value) + + override def putAll(map: util.Map[K, V]): Unit = state.putAll(map) + + override def remove(key: K): Unit = state.remove(key) + + override def contains(key: K): Boolean = state.contains(key) + + override def entries: JIterable[util.Map.Entry[K, V]] = state.entries() + + override def keys: JIterable[K] = state.keys() + + override def values: JIterable[V] = state.values() + + override def iterator: util.Iterator[util.Map.Entry[K, V]] = state.iterator() + + override def clear(): Unit = state.clear() +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index 8f50971b74432..d3f9497e1bb5b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -33,7 +33,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation * - merge, and * - resetAccumulator * - * All these methods muse be declared publicly, not static and named exactly as the names + * All these methods must be declared publicly, not static and named exactly as the names * mentioned above. The methods createAccumulator and getValue are defined in the * [[AggregateFunction]] functions, while other methods are explained below. * diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ProctimeSqlFunction.scala similarity index 82% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ProctimeSqlFunction.scala index d87502650edb9..f30ad2fc0b35c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/sql/ProctimeSqlFunction.scala @@ -15,19 +15,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.flink.table.functions +package org.apache.flink.table.functions.sql import org.apache.calcite.sql._ import org.apache.calcite.sql.`type`._ import org.apache.calcite.sql.validate.SqlMonotonicity /** - * Function that materializes a time attribute to the metadata timestamp. After materialization - * the result can be used in regular arithmetical calculations. + * Function that materializes a processing time attribute. + * After materialization the result can be used in regular arithmetical calculations. */ -object TimeMaterializationSqlFunction +object ProctimeSqlFunction extends SqlFunction( - "TIME_MATERIALIZATION", + "PROCTIME", SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.TIMESTAMP), InferTypes.RETURN_TYPE, @@ -38,4 +38,6 @@ object TimeMaterializationSqlFunction override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity = SqlMonotonicity.INCREASING + + override def isDeterministic: Boolean = false } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala index 526ec47d03e4f..bb71d6362e9cf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala @@ -95,7 +95,9 @@ object AggSqlFunction { val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) .getOrElse( throw new ValidationException( - s"Operand types of ${signatureToString(operandTypeInfo)} could not be inferred.")) + s"Given parameters of function do not match any signature. \n" + + s"Actual: ${signatureToString(operandTypeInfo)} \n" + + s"Expected: ${signaturesToString(aggregateFunction, "accumulate")}")) val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) .map(typeFactory.createTypeFromTypeInfo(_, isNullable = true)) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala index 0776f7af26446..784bca74de77f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -45,7 +45,7 @@ class ScalarSqlFunction( extends SqlFunction( new SqlIdentifier(name, SqlParserPos.ZERO), createReturnTypeInference(name, scalarFunction, typeFactory), - createOperandTypeInference(scalarFunction, typeFactory), + createOperandTypeInference(name, scalarFunction, typeFactory), createOperandTypeChecker(name, scalarFunction), null, SqlFunctionCategory.USER_DEFINED_FUNCTION) { @@ -91,6 +91,7 @@ object ScalarSqlFunction { } private[flink] def createOperandTypeInference( + name: String, scalarFunction: ScalarFunction, typeFactory: FlinkTypeFactory) : SqlOperandTypeInference = { @@ -106,7 +107,11 @@ object ScalarSqlFunction { val operandTypeInfo = getOperandTypeInfo(callBinding) val foundSignature = getEvalMethodSignature(scalarFunction, operandTypeInfo) - .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + .getOrElse( + throw new ValidationException( + s"Given parameters of function '$name' do not match any signature. \n" + + s"Actual: ${signatureToString(operandTypeInfo)} \n" + + s"Expected: ${signaturesToString(scalarFunction, "eval")}")) val inferredTypes = scalarFunction .getParameterTypes(foundSignature) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 47469d1954710..f53bcdeca7ae3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.functions.utils +import java.util import java.lang.{Integer => JInt, Long => JLong} import java.lang.reflect.{Method, Modifier} import java.sql.{Date, Time, Timestamp} @@ -29,7 +30,10 @@ import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.{SqlCallBinding, SqlFunction} import org.apache.flink.api.common.functions.InvalidTypesException import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor} +import org.apache.flink.table.api.dataview._ +import org.apache.flink.table.dataview._ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} import org.apache.flink.table.expressions._ @@ -38,6 +42,8 @@ import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, Tabl import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.util.InstantiationUtil +import scala.collection.mutable + object UserDefinedFunctionUtils { /** @@ -175,11 +181,11 @@ object UserDefinedFunctionUtils { } }) { throw new ValidationException( - s"Scala-style variable arguments in '${methodName}' methods are not supported. Please " + + s"Scala-style variable arguments in '$methodName' methods are not supported. Please " + s"add a @scala.annotation.varargs annotation.") } else if (found.length > 1) { throw new ValidationException( - s"Found multiple '${methodName}' methods which match the signature.") + s"Found multiple '$methodName' methods which match the signature.") } found.headOption } @@ -218,7 +224,7 @@ object UserDefinedFunctionUtils { if (methods.isEmpty) { throw new ValidationException( s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + - s"one method named '${methodName}' which is public, not abstract and " + + s"one method named '$methodName' which is public, not abstract and " + s"(in case of table functions) not static.") } @@ -306,6 +312,111 @@ object UserDefinedFunctionUtils { // Utilities for user-defined functions // ---------------------------------------------------------------------------------------------- + /** + * Remove StateView fields from accumulator type information. + * + * @param index index of aggregate function + * @param aggFun aggregate function + * @param accType accumulator type information, only support pojo type + * @param isStateBackedDataViews is data views use state backend + * @return mapping of accumulator type information and data view config which contains id, + * field name and state descriptor + */ + def removeStateViewFieldsFromAccTypeInfo( + index: Int, + aggFun: AggregateFunction[_, _], + accType: TypeInformation[_], + isStateBackedDataViews: Boolean) + : (TypeInformation[_], Option[Seq[DataViewSpec[_]]]) = { + + /** Recursively checks if composite type includes a data view type. */ + def includesDataView(ct: CompositeType[_]): Boolean = { + (0 until ct.getArity).exists(i => + ct.getTypeAt(i) match { + case nestedCT: CompositeType[_] => includesDataView(nestedCT) + case t: TypeInformation[_] if t.getTypeClass == classOf[ListView[_]] => true + case t: TypeInformation[_] if t.getTypeClass == classOf[MapView[_, _]] => true + case _ => false + } + ) + } + + val acc = aggFun.createAccumulator() + accType match { + case pojoType: PojoTypeInfo[_] if pojoType.getArity > 0 => + val arity = pojoType.getArity + val newPojoFields = new util.ArrayList[PojoField]() + val accumulatorSpecs = new mutable.ArrayBuffer[DataViewSpec[_]] + for (i <- 0 until arity) { + val pojoField = pojoType.getPojoFieldAt(i) + val field = pojoField.getField + val fieldName = field.getName + field.setAccessible(true) + + pojoField.getTypeInformation match { + case ct: CompositeType[_] if includesDataView(ct) => + throw new TableException( + "MapView and ListView only supported at first level of accumulators of Pojo type.") + case map: MapViewTypeInfo[_, _] => + val mapView = field.get(acc).asInstanceOf[MapView[_, _]] + if (mapView != null) { + val keyTypeInfo = mapView.keyTypeInfo + val valueTypeInfo = mapView.valueTypeInfo + val newTypeInfo = if (keyTypeInfo != null && valueTypeInfo != null) { + new MapViewTypeInfo(keyTypeInfo, valueTypeInfo) + } else { + map + } + + // create map view specs with unique id (used as state name) + var spec = MapViewSpec( + "agg" + index + "$" + fieldName, + field, + newTypeInfo) + + accumulatorSpecs += spec + if (!isStateBackedDataViews) { + // add data view field if it is not backed by a state backend. + // data view fields which are backed by state backend are not serialized. + newPojoFields.add(new PojoField(field, newTypeInfo)) + } + } + + case list: ListViewTypeInfo[_] => + val listView = field.get(acc).asInstanceOf[ListView[_]] + if (listView != null) { + val elementTypeInfo = listView.elementTypeInfo + val newTypeInfo = if (elementTypeInfo != null) { + new ListViewTypeInfo(elementTypeInfo) + } else { + list + } + + // create list view specs with unique is (used as state name) + var spec = ListViewSpec( + "agg" + index + "$" + fieldName, + field, + newTypeInfo) + + accumulatorSpecs += spec + if (!isStateBackedDataViews) { + // add data view field if it is not backed by a state backend. + // data view fields which are backed by state backend are not serialized. + newPojoFields.add(new PojoField(field, newTypeInfo)) + } + } + + case _ => newPojoFields.add(pojoField) + } + } + (new PojoTypeInfo(accType.getTypeClass, newPojoFields), Some(accumulatorSpecs)) + case ct: CompositeType[_] if includesDataView(ct) => + throw new TableException( + "MapView and ListView only supported in accumulators of POJO type.") + case _ => (accType, None) + } + } + /** * Tries to infer the TypeInformation of an AggregateFunction's return type. * diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala index 3e355ff41d20b..2f1871b7e1f47 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala @@ -42,8 +42,8 @@ trait CommonCalc { GeneratedFunction[T, Row] = { val projection = generator.generateResultExpression( - returnSchema.physicalTypeInfo, - returnSchema.physicalFieldNames, + returnSchema.typeInfo, + returnSchema.fieldNames, calcProjection) // only projection @@ -80,7 +80,7 @@ trait CommonCalc { ruleDescription, functionClass, body, - returnSchema.physicalTypeInfo) + returnSchema.typeInfo) } private[flink] def conditionToString( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala index 96aaf3e5e38dd..7c01fdeb9b178 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala @@ -53,12 +53,10 @@ trait CommonCorrelate { functionClass: Class[T]): GeneratedFunction[T, Row] = { - val physicalRexCall = inputSchema.mapRexNode(rexCall) - val functionGenerator = new FunctionCodeGenerator( config, false, - inputSchema.physicalTypeInfo, + inputSchema.typeInfo, Some(udtfTypeInfo), None, pojoFieldMapping) @@ -69,7 +67,7 @@ trait CommonCorrelate { .addReusableConstructor(classOf[TableFunctionCollector[_]]) .head - val call = functionGenerator.generateExpression(physicalRexCall) + val call = functionGenerator.generateExpression(rexCall) var body = s""" |${call.resultTerm}.setCollector($collectorTerm); @@ -90,8 +88,8 @@ trait CommonCorrelate { } val outerResultExpr = functionGenerator.generateResultExpression( input1AccessExprs ++ input2NullExprs, - returnSchema.physicalTypeInfo, - returnSchema.physicalFieldNames) + returnSchema.typeInfo, + returnSchema.fieldNames) body += s""" |boolean hasOutput = $collectorTerm.isCollected(); @@ -108,7 +106,7 @@ trait CommonCorrelate { ruleDescription, functionClass, body, - returnSchema.physicalTypeInfo) + returnSchema.typeInfo) } /** @@ -126,7 +124,7 @@ trait CommonCorrelate { val generator = new CollectorCodeGenerator( config, false, - inputSchema.physicalTypeInfo, + inputSchema.typeInfo, Some(udtfTypeInfo), None, pojoFieldMapping) @@ -135,8 +133,8 @@ trait CommonCorrelate { val crossResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2AccessExprs, - returnSchema.physicalTypeInfo, - returnSchema.physicalFieldNames) + returnSchema.typeInfo, + returnSchema.fieldNames) val collectorCode = if (condition.isEmpty) { s""" @@ -148,7 +146,7 @@ trait CommonCorrelate { // adjust indicies of InputRefs to adhere to schema expected by generator val changeInputRefIndexShuttle = new RexShuttle { override def visitInputRef(inputRef: RexInputRef): RexNode = { - new RexInputRef(inputSchema.physicalArity + inputRef.getIndex, inputRef.getType) + new RexInputRef(inputSchema.arity + inputRef.getIndex, inputRef.getType) } } // Run generateExpression to add init statements (ScalarFunctions) of condition to generator. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala index dc7a0d6d3579e..5872d8cd360c4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala @@ -39,9 +39,7 @@ abstract class PhysicalTableSourceScan( val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] flinkTypeFactory.buildLogicalRowType( TableEnvironment.getFieldNames(tableSource), - TableEnvironment.getFieldTypes(tableSource.getReturnType), - None, - None) + TableEnvironment.getFieldTypes(tableSource.getReturnType)) } override def explainTerms(pw: RelWriter): RelWriter = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala index fb291e407a5a7..74aac431ce273 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -42,9 +42,7 @@ class BatchTableSourceScan( val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] flinkTypeFactory.buildLogicalRowType( TableEnvironment.getFieldNames(tableSource), - TableEnvironment.getFieldTypes(tableSource.getReturnType), - None, - None) + TableEnvironment.getFieldTypes(tableSource.getReturnType)) } override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala index 1583e31589d16..acbf94dd6516a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala @@ -186,7 +186,8 @@ class DataSetJoin( |""".stripMargin } else { - val condition = generator.generateExpression(joinCondition) + val nonEquiPredicates = joinInfo.getRemaining(this.cluster.getRexBuilder) + val condition = generator.generateExpression(nonEquiPredicates) body = s""" |${condition.code} |if (${condition.resultTerm}) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala index 2e0033063e7c9..45e69028b8305 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala @@ -52,7 +52,7 @@ class DataStreamCalc( with CommonCalc with DataStreamRel { - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = { new DataStreamCalc( @@ -100,7 +100,7 @@ class DataStreamCalc( val condition = if (calcProgram.getCondition != null) { val materializedCondition = RelTimeIndicatorConverter.convertExpression( calcProgram.expandLocalRef(calcProgram.getCondition), - inputSchema.logicalType, + inputSchema.relDataType, cluster.getRexBuilder) Some(materializedCondition) } else { @@ -110,12 +110,8 @@ class DataStreamCalc( // filter out time attributes val projection = calcProgram.getProjectList.asScala .map(calcProgram.expandLocalRef) - // time indicator fields must not be part of the code generation - .filter(expr => !FlinkTypeFactory.isTimeIndicatorType(expr.getType)) - // update indices - .map(expr => inputSchema.mapRexNode(expr)) - val generator = new FunctionCodeGenerator(config, false, inputSchema.physicalTypeInfo) + val generator = new FunctionCodeGenerator(config, false, inputSchema.typeInfo) val genFunction = generateFunction( generator, @@ -132,7 +128,7 @@ class DataStreamCalc( val processFunc = new CRowProcessRunner( genFunction.name, genFunction.code, - CRowTypeInfo(schema.physicalTypeInfo)) + CRowTypeInfo(schema.typeInfo)) inputDataStream .process(processFunc) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index b7165cd190183..18ab2a3354eb5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -50,7 +50,7 @@ class DataStreamCorrelate( with CommonCorrelate with DataStreamRel { - override def deriveRowType() = schema.logicalType + override def deriveRowType() = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamCorrelate( @@ -78,7 +78,7 @@ class DataStreamCorrelate( super.explainTerms(pw) .item("invocation", scan.getCall) .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName) - .item("rowType", schema.logicalType) + .item("rowType", schema.relDataType) .item("joinType", joinType) .itemIf("condition", condition.orNull, condition.isDefined) } @@ -130,7 +130,7 @@ class DataStreamCorrelate( .process(processFunc) // preserve input parallelism to ensure that acc and retract messages remain in order .setParallelism(inputParallelism) - .name(correlateOpName(rexCall, sqlFunction, schema.logicalType)) + .name(correlateOpName(rexCall, sqlFunction, schema.relDataType)) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala index 12694fc30fc27..58c9d820ee59a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -24,13 +24,13 @@ import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment} import org.apache.flink.table.codegen.AggregationCodeGenerator -import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.plan.nodes.CommonAggregate -import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair +import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging /** * @@ -55,11 +55,10 @@ class DataStreamGroupAggregate( groupings: Array[Int]) extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate - with DataStreamRel { - - private val LOG = LoggerFactory.getLogger(this.getClass) + with DataStreamRel + with Logging { - override def deriveRowType() = schema.logicalType + override def deriveRowType() = schema.relDataType override def needsUpdatesAsRetraction = true @@ -83,20 +82,20 @@ class DataStreamGroupAggregate( override def toString: String = { s"Aggregate(${ if (!groupings.isEmpty) { - s"groupBy: (${groupingToString(inputSchema.logicalType, groupings)}), " + s"groupBy: (${groupingToString(inputSchema.relDataType, groupings)}), " } else { "" } }select:(${aggregationToString( - inputSchema.logicalType, groupings, getRowType, namedAggregates, Nil)}))" + inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil)}))" } override def explainTerms(pw: RelWriter): RelWriter = { super.explainTerms(pw) .itemIf("groupBy", groupingToString( - inputSchema.logicalType, groupings), !groupings.isEmpty) + inputSchema.relDataType, groupings), !groupings.isEmpty) .item("select", aggregationToString( - inputSchema.logicalType, groupings, getRowType, namedAggregates, Nil)) + inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil)) } override def translateToPlan( @@ -112,37 +111,29 @@ class DataStreamGroupAggregate( val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) - val physicalNamedAggregates = namedAggregates.map { namedAggregate => - new CalcitePair[AggregateCall, String]( - inputSchema.mapAggregateCall(namedAggregate.left), - namedAggregate.right) - } - - val outRowType = CRowTypeInfo(schema.physicalTypeInfo) + val outRowType = CRowTypeInfo(schema.typeInfo) val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.physicalTypeInfo) + inputSchema.typeInfo) val aggString = aggregationToString( - inputSchema.logicalType, + inputSchema.relDataType, groupings, getRowType, namedAggregates, Nil) - val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.logicalType, groupings)}), " + + val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.relDataType, groupings)}), " + s"select: ($aggString)" val nonKeyedAggOpName = s"select: ($aggString)" - val physicalGrouping = groupings.map(inputSchema.mapIndex) - val processFunction = AggregateUtil.createGroupAggregateFunction( generator, - physicalNamedAggregates, - inputSchema.logicalType, - inputSchema.physicalFieldTypeInfo, + namedAggregates, + inputSchema.relDataType, + inputSchema.fieldTypeInfos, groupings, queryConfig, DataStreamRetractionRules.isAccRetract(this), @@ -150,7 +141,7 @@ class DataStreamGroupAggregate( val result: DataStream[CRow] = // grouped / keyed aggregation - if (physicalGrouping.nonEmpty) { + if (groupings.nonEmpty) { inputDS .keyBy(groupings: _*) .process(processFunction) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala index c4ffdb1e900f2..b15350f9cfb3a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -29,20 +29,21 @@ import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.AggregationCodeGenerator import org.apache.flink.table.expressions.ExpressionUtils._ +import org.apache.flink.table.expressions.ResolvedFieldReference import org.apache.flink.table.plan.logical._ import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate._ import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules +import org.apache.flink.table.runtime.RowtimeProcessFunction import org.apache.flink.table.runtime.aggregate.AggregateUtil._ import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval import org.apache.flink.table.runtime.triggers.StateCleaningCountTrigger -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging class DataStreamGroupWindowAggregate( window: LogicalWindow, @@ -54,11 +55,12 @@ class DataStreamGroupWindowAggregate( schema: RowSchema, inputSchema: RowSchema, grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { + extends SingleRel(cluster, traitSet, inputNode) + with CommonAggregate + with DataStreamRel + with Logging { - private val LOG = LoggerFactory.getLogger(this.getClass) - - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def needsUpdatesAsRetraction = true @@ -84,14 +86,14 @@ class DataStreamGroupWindowAggregate( override def toString: String = { s"Aggregate(${ if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), " + s"groupBy: (${groupingToString(inputSchema.relDataType, grouping)}), " } else { "" } }window: ($window), " + s"select: (${ aggregationToString( - inputSchema.logicalType, + inputSchema.relDataType, grouping, getRowType, namedAggregates, @@ -101,13 +103,13 @@ class DataStreamGroupWindowAggregate( override def explainTerms(pw: RelWriter): RelWriter = { super.explainTerms(pw) - .itemIf("groupBy", groupingToString(inputSchema.logicalType, grouping), !grouping.isEmpty) + .itemIf("groupBy", groupingToString(inputSchema.relDataType, grouping), !grouping.isEmpty) .item("window", window) .item( "select", aggregationToString( - inputSchema.logicalType, + inputSchema.relDataType, grouping, - schema.logicalType, + schema.relDataType, namedAggregates, namedProperties)) } @@ -118,14 +120,6 @@ class DataStreamGroupWindowAggregate( val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) - val physicalNamedAggregates = namedAggregates.map { namedAggregate => - new CalcitePair[AggregateCall, String]( - inputSchema.mapAggregateCall(namedAggregate.left), - namedAggregate.right) - } - val physicalNamedProperties = namedProperties - .filter(np => !FlinkTypeFactory.isTimeIndicatorType(np.property.resultType)) - val inputIsAccRetract = DataStreamRetractionRules.isAccRetract(input) if (inputIsAccRetract) { @@ -148,16 +142,30 @@ class DataStreamGroupWindowAggregate( "state size. You may specify a retention time of 0 to not clean up the state.") } - val outRowType = CRowTypeInfo(schema.physicalTypeInfo) + val timestampedInput = if (isRowtimeAttribute(window.timeAttribute)) { + // copy the window rowtime attribute into the StreamRecord timestamp field + val timeAttribute = window.timeAttribute.asInstanceOf[ResolvedFieldReference].name + val timeIdx = inputSchema.fieldNames.indexOf(timeAttribute) + + inputDS + .process( + new RowtimeProcessFunction(timeIdx, CRowTypeInfo(inputSchema.typeInfo))) + .setParallelism(inputDS.getParallelism) + .name(s"time attribute: ($timeAttribute)") + } else { + inputDS + } + + val outRowType = CRowTypeInfo(schema.typeInfo) val aggString = aggregationToString( - inputSchema.logicalType, + inputSchema.relDataType, grouping, - schema.logicalType, + schema.relDataType, namedAggregates, namedProperties) - val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), " + + val keyedAggOpName = s"groupBy: (${groupingToString(inputSchema.relDataType, grouping)}), " + s"window: ($window), " + s"select: ($aggString)" val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" @@ -165,24 +173,22 @@ class DataStreamGroupWindowAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.physicalTypeInfo) + inputSchema.typeInfo) val needMerge = window match { case SessionGroupWindow(_, _, _) => true case _ => false } - val physicalGrouping = grouping.map(inputSchema.mapIndex) - // grouped / keyed aggregation - if (physicalGrouping.length > 0) { + if (grouping.length > 0) { val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( window, - physicalGrouping.length, - physicalNamedAggregates.size, - schema.physicalArity, - physicalNamedProperties) + grouping.length, + namedAggregates.size, + schema.arity, + namedProperties) - val keyedStream = inputDS.keyBy(physicalGrouping: _*) + val keyedStream = timestampedInput.keyBy(grouping: _*) val windowedStream = createKeyedWindowedStream(queryConfig, window, keyedStream) .asInstanceOf[WindowedStream[CRow, Tuple, DataStreamWindow]] @@ -190,11 +196,11 @@ class DataStreamGroupWindowAggregate( val (aggFunction, accumulatorRowType, aggResultRowType) = AggregateUtil.createDataStreamAggregateFunction( generator, - physicalNamedAggregates, - inputSchema.physicalType, - inputSchema.physicalFieldTypeInfo, - schema.physicalType, - physicalGrouping, + namedAggregates, + inputSchema.relDataType, + inputSchema.fieldTypeInfos, + schema.relDataType, + grouping, needMerge) windowedStream @@ -205,20 +211,20 @@ class DataStreamGroupWindowAggregate( else { val windowFunction = AggregateUtil.createAggregationAllWindowFunction( window, - schema.physicalArity, - physicalNamedProperties) + schema.arity, + namedProperties) val windowedStream = - createNonKeyedWindowedStream(queryConfig, window, inputDS) + createNonKeyedWindowedStream(queryConfig, window, timestampedInput) .asInstanceOf[AllWindowedStream[CRow, DataStreamWindow]] val (aggFunction, accumulatorRowType, aggResultRowType) = AggregateUtil.createDataStreamAggregateFunction( generator, - physicalNamedAggregates, - inputSchema.physicalType, - inputSchema.physicalFieldTypeInfo, - schema.physicalType, + namedAggregates, + inputSchema.relDataType, + inputSchema.fieldTypeInfos, + schema.relDataType, Array[Int](), needMerge) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 34a7fd8ca1956..62345252c44ee 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -20,23 +20,23 @@ package org.apache.flink.table.plan.nodes.datastream import java.util.{List => JList} import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Window.Group import org.apache.calcite.rel.core.{AggregateCall, Window} import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING +import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.nodes.OverAggregate -import org.apache.flink.table.plan.schema.RowSchema -import org.apache.flink.table.runtime.aggregate._ -import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.table.codegen.AggregationCodeGenerator +import org.apache.flink.table.plan.nodes.OverAggregate import org.apache.flink.table.plan.rules.datastream.DataStreamRetractionRules +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair +import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging class DataStreamOverAggregate( logicWindow: Window, @@ -47,10 +47,10 @@ class DataStreamOverAggregate( inputSchema: RowSchema) extends SingleRel(cluster, traitSet, inputNode) with OverAggregate - with DataStreamRel { - private val LOG = LoggerFactory.getLogger(this.getClass) + with DataStreamRel + with Logging { - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def needsUpdatesAsRetraction = true @@ -78,15 +78,15 @@ class DataStreamOverAggregate( super.explainTerms(pw) .itemIf("partitionBy", - partitionToString(schema.logicalType, partitionKeys), partitionKeys.nonEmpty) + partitionToString(schema.relDataType, partitionKeys), partitionKeys.nonEmpty) .item("orderBy", - orderingToString(schema.logicalType, overWindow.orderKeys.getFieldCollations)) + orderingToString(schema.relDataType, overWindow.orderKeys.getFieldCollations)) .itemIf("rows", windowRange(logicWindow, overWindow, inputNode), overWindow.isRows) .itemIf("range", windowRange(logicWindow, overWindow, inputNode), !overWindow.isRows) .item( "select", aggregationToString( - inputSchema.logicalType, - schema.logicalType, + inputSchema.relDataType, + schema.relDataType, namedAggregates)) } @@ -134,67 +134,44 @@ class DataStreamOverAggregate( val generator = new AggregationCodeGenerator( tableEnv.getConfig, false, - inputSchema.physicalTypeInfo) + inputSchema.typeInfo) - val timeType = schema.logicalType + val timeType = schema.relDataType .getFieldList .get(orderKey.getFieldIndex) .getType - timeType match { - case _ if FlinkTypeFactory.isProctimeIndicatorType(timeType) => - // proc-time OVER window - if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // unbounded OVER window - createUnboundedAndCurrentRowOverWindow( - queryConfig, - generator, - inputDS, - isRowTimeType = false, - isRowsClause = overWindow.isRows) - } else if ( - overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && - overWindow.upperBound.isCurrentRow) { - - // bounded OVER window - createBoundedAndCurrentRowOverWindow( - queryConfig, - generator, - inputDS, - isRowTimeType = false, - isRowsClause = overWindow.isRows) - } else { - throw new TableException( - "OVER RANGE FOLLOWING windows are not supported yet.") - } - - case _ if FlinkTypeFactory.isRowtimeIndicatorType(timeType) => - // row-time OVER window - if (overWindow.lowerBound.isPreceding && - overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // unbounded OVER window - createUnboundedAndCurrentRowOverWindow( - queryConfig, - generator, - inputDS, - isRowTimeType = true, - isRowsClause = overWindow.isRows) - } else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) { - // bounded OVER window - createBoundedAndCurrentRowOverWindow( - queryConfig, - generator, - inputDS, - isRowTimeType = true, - isRowsClause = overWindow.isRows) - } else { - throw new TableException( - "OVER RANGE FOLLOWING windows are not supported yet.") - } - - case _ => - throw new TableException( - s"OVER windows can only be applied on time attributes.") + // identify window rowtime attribute + val rowTimeIdx: Option[Int] = if (FlinkTypeFactory.isRowtimeIndicatorType(timeType)) { + Some(orderKey.getFieldIndex) + } else if (FlinkTypeFactory.isProctimeIndicatorType(timeType)) { + None + } else { + throw new TableException(s"OVER windows can only be applied on time attributes.") + } + + if (overWindow.lowerBound.isPreceding && overWindow.lowerBound.isUnbounded && + overWindow.upperBound.isCurrentRow) { + // unbounded OVER window + createUnboundedAndCurrentRowOverWindow( + queryConfig, + generator, + inputDS, + rowTimeIdx, + isRowsClause = overWindow.isRows) + } else if ( + overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && + overWindow.upperBound.isCurrentRow) { + + // bounded OVER window + createBoundedAndCurrentRowOverWindow( + queryConfig, + generator, + inputDS, + rowTimeIdx, + isRowsClause = overWindow.isRows) + } else { + throw new TableException("OVER RANGE FOLLOWING windows are not supported yet.") } } @@ -202,31 +179,26 @@ class DataStreamOverAggregate( queryConfig: StreamQueryConfig, generator: AggregationCodeGenerator, inputDS: DataStream[CRow], - isRowTimeType: Boolean, + rowTimeIdx: Option[Int], isRowsClause: Boolean): DataStream[CRow] = { val overWindow: Group = logicWindow.groups.get(0) - val partitionKeys: Array[Int] = overWindow.keys.toArray.map(schema.mapIndex) + val partitionKeys: Array[Int] = overWindow.keys.toArray - val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates.map { - namedAggregate => - new CalcitePair[AggregateCall, String]( - schema.mapAggregateCall(namedAggregate.left), - namedAggregate.right) - } + val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates // get the output types - val returnTypeInfo = CRowTypeInfo(schema.physicalTypeInfo) + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) val processFunction = AggregateUtil.createUnboundedOverProcessFunction( generator, namedAggregates, - inputSchema.physicalType, - inputSchema.physicalTypeInfo, - inputSchema.physicalFieldTypeInfo, + inputSchema.relDataType, + inputSchema.typeInfo, + inputSchema.fieldTypeInfos, queryConfig, - isRowTimeType, + rowTimeIdx, partitionKeys.nonEmpty, isRowsClause) @@ -254,34 +226,29 @@ class DataStreamOverAggregate( queryConfig: StreamQueryConfig, generator: AggregationCodeGenerator, inputDS: DataStream[CRow], - isRowTimeType: Boolean, + rowTimeIdx: Option[Int], isRowsClause: Boolean): DataStream[CRow] = { val overWindow: Group = logicWindow.groups.get(0) - val partitionKeys: Array[Int] = overWindow.keys.toArray.map(schema.mapIndex) - val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates.map { - namedAggregate => - new CalcitePair[AggregateCall, String]( - schema.mapAggregateCall(namedAggregate.left), - namedAggregate.right) - } + val partitionKeys: Array[Int] = overWindow.keys.toArray + val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates val precedingOffset = getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0) // get the output types - val returnTypeInfo = CRowTypeInfo(schema.physicalTypeInfo) + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) val processFunction = AggregateUtil.createBoundedOverProcessFunction( generator, namedAggregates, - inputSchema.physicalType, - inputSchema.physicalTypeInfo, - inputSchema.physicalFieldTypeInfo, + inputSchema.relDataType, + inputSchema.typeInfo, + inputSchema.fieldTypeInfos, precedingOffset, queryConfig, isRowsClause, - isRowTimeType + rowTimeIdx ) val result: DataStream[CRow] = // partitioned aggregation @@ -318,18 +285,18 @@ class DataStreamOverAggregate( s"over: (${ if (!partitionKeys.isEmpty) { - s"PARTITION BY: ${partitionToString(inputSchema.logicalType, partitionKeys)}, " + s"PARTITION BY: ${partitionToString(inputSchema.relDataType, partitionKeys)}, " } else { "" } - }ORDER BY: ${orderingToString(inputSchema.logicalType, + }ORDER BY: ${orderingToString(inputSchema.relDataType, overWindow.orderKeys.getFieldCollations)}, " + s"${if (overWindow.isRows) "ROWS" else "RANGE"}" + s"${windowRange(logicWindow, overWindow, inputNode)}, " + s"select: (${ aggregationToString( - inputSchema.logicalType, - schema.logicalType, + inputSchema.relDataType, + schema.relDataType, namedAggregates) }))" } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala index 424c6a26633e5..9352efb5372dc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala @@ -43,7 +43,7 @@ class DataStreamScan( val dataStreamTable: DataStreamTable[Any] = getTable.unwrap(classOf[DataStreamTable[Any]]) - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamScan( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamSort.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamSort.scala index a11e6c179993b..8f9942fd608c3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamSort.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamSort.scala @@ -53,7 +53,7 @@ class DataStreamSort( with CommonSort with DataStreamRel { - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def copy( traitSet: RelTraitSet, @@ -75,13 +75,13 @@ class DataStreamSort( } override def toString: String = { - sortToString(schema.logicalType, sortCollation, sortOffset, sortFetch) + sortToString(schema.relDataType, sortCollation, sortOffset, sortFetch) } override def explainTerms(pw: RelWriter) : RelWriter = { sortExplainTerms( pw.input("input", getInput()), - schema.logicalType, + schema.relDataType, sortCollation, sortOffset, sortFetch) @@ -94,7 +94,7 @@ class DataStreamSort( val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) // need to identify time between others order fields. Time needs to be first sort element - val timeType = SortUtil.getFirstSortField(sortCollation, schema.logicalType).getType + val timeType = SortUtil.getFirstSortField(sortCollation, schema.relDataType).getType // time ordering needs to be ascending if (SortUtil.getFirstSortDirection(sortCollation) != Direction.ASCENDING) { @@ -141,15 +141,15 @@ class DataStreamSort( inputDS: DataStream[CRow], execCfg: ExecutionConfig): DataStream[CRow] = { - val returnTypeInfo = CRowTypeInfo(schema.physicalTypeInfo) + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) // if the order has secondary sorting fields in addition to the proctime if (sortCollation.getFieldCollations.size() > 1) { val processFunction = SortUtil.createProcTimeSortFunction( sortCollation, - inputSchema.logicalType, - inputSchema.physicalTypeInfo, + inputSchema.relDataType, + inputSchema.typeInfo, execCfg) inputDS.keyBy(new NullByteKeySelector[CRow]) @@ -173,12 +173,12 @@ class DataStreamSort( inputDS: DataStream[CRow], execCfg: ExecutionConfig): DataStream[CRow] = { - val returnTypeInfo = CRowTypeInfo(schema.physicalTypeInfo) + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) val processFunction = SortUtil.createRowTimeSortFunction( sortCollation, - inputSchema.logicalType, - inputSchema.physicalTypeInfo, + inputSchema.relDataType, + inputSchema.typeInfo, execCfg) inputDS.keyBy(new NullByteKeySelector[CRow]) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamUnion.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamUnion.scala index 6f4980aa70dab..7258ec88e4b93 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamUnion.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamUnion.scala @@ -38,7 +38,7 @@ class DataStreamUnion( extends BiRel(cluster, traitSet, leftNode, rightNode) with DataStreamRel { - override def deriveRowType() = schema.logicalType + override def deriveRowType() = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamUnion( @@ -55,7 +55,7 @@ class DataStreamUnion( } override def toString = { - s"Union All(union: (${schema.logicalFieldNames.mkString(", ")}))" + s"Union All(union: (${schema.fieldNames.mkString(", ")}))" } override def translateToPlan( @@ -68,6 +68,6 @@ class DataStreamUnion( } private def unionSelectionToString: String = { - schema.logicalFieldNames.mkString(", ") + schema.fieldNames.mkString(", ") } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamValues.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamValues.scala index 14766815c600d..1ef9107cbc488 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamValues.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamValues.scala @@ -41,10 +41,10 @@ class DataStreamValues( schema: RowSchema, tuples: ImmutableList[ImmutableList[RexLiteral]], ruleDescription: String) - extends Values(cluster, schema.logicalType, tuples, traitSet) + extends Values(cluster, schema.relDataType, tuples, traitSet) with DataStreamRel { - override def deriveRowType() = schema.logicalType + override def deriveRowType() = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamValues( @@ -62,14 +62,14 @@ class DataStreamValues( val config = tableEnv.getConfig - val returnType = CRowTypeInfo(schema.physicalTypeInfo) + val returnType = CRowTypeInfo(schema.typeInfo) val generator = new InputFormatCodeGenerator(config) // generate code for every record val generatedRecords = getTuples.asScala.map { r => generator.generateResultExpression( - schema.physicalTypeInfo, - schema.physicalFieldNames, + schema.typeInfo, + schema.fieldNames, r.asScala) } @@ -77,7 +77,7 @@ class DataStreamValues( val generatedFunction = generator.generateValuesInputFormat( ruleDescription, generatedRecords.map(_.code), - schema.physicalTypeInfo) + schema.typeInfo) val inputFormat = new CRowValuesInputFormat( generatedFunction.name, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index 987947c3d6ac3..f8015b354e6d3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -54,7 +54,7 @@ class DataStreamWindowJoin( with CommonJoin with DataStreamRel { - override def deriveRowType(): RelDataType = schema.logicalType + override def deriveRowType(): RelDataType = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamWindowJoin( @@ -76,7 +76,7 @@ class DataStreamWindowJoin( override def toString: String = { joinToString( - schema.logicalType, + schema.relDataType, joinCondition, joinType, getExpressionString) @@ -85,7 +85,7 @@ class DataStreamWindowJoin( override def explainTerms(pw: RelWriter): RelWriter = { joinExplainTerms( super.explainTerms(pw), - schema.logicalType, + schema.relDataType, joinCondition, joinType, getExpressionString) @@ -117,8 +117,8 @@ class DataStreamWindowJoin( WindowJoinUtil.generateJoinFunction( config, joinType, - leftSchema.physicalTypeInfo, - rightSchema.physicalTypeInfo, + leftSchema.typeInfo, + rightSchema.typeInfo, schema, remainCondition, ruleDescription) @@ -160,13 +160,13 @@ class DataStreamWindowJoin( leftKeys: Array[Int], rightKeys: Array[Int]): DataStream[CRow] = { - val returnTypeInfo = CRowTypeInfo(schema.physicalTypeInfo) + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) val procInnerJoinFunc = new ProcTimeWindowInnerJoin( leftLowerBound, leftUpperBound, - leftSchema.physicalTypeInfo, - rightSchema.physicalTypeInfo, + leftSchema.typeInfo, + rightSchema.typeInfo, joinFunctionName, joinFunctionCode) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamScan.scala index 25e72fa887ef0..4aca85633f1b8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamScan.scala @@ -18,14 +18,15 @@ package org.apache.flink.table.plan.nodes.datastream -import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.FunctionCodeGenerator import org.apache.flink.table.plan.nodes.CommonScan import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.types.Row import org.apache.flink.table.plan.schema.FlinkTable -import org.apache.flink.table.runtime.CRowOutputMapRunner +import org.apache.flink.table.runtime.CRowOutputProcessRunner import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import scala.collection.JavaConverters._ @@ -40,29 +41,42 @@ trait StreamScan extends CommonScan[CRow] with DataStreamRel { : DataStream[CRow] = { val inputType = input.getType - val internalType = CRowTypeInfo(schema.physicalTypeInfo) + val internalType = CRowTypeInfo(schema.typeInfo) // conversion if (needsConversion(input.getType, internalType)) { - val function = generatedConversionFunction( + val generator = new FunctionCodeGenerator( config, - classOf[MapFunction[Any, Row]], + false, inputType, - schema.physicalTypeInfo, - "DataStreamSourceConversion", - schema.physicalFieldNames, + None, Some(flinkTable.fieldIndexes)) - val mapFunc = new CRowOutputMapRunner( + val conversion = generator.generateConverterResultExpression( + schema.typeInfo, + schema.fieldNames) + + val body = + s""" + |${conversion.code} + |${generator.collectorTerm}.collect(${conversion.resultTerm}); + |""".stripMargin + + val function = generator.generateFunction( + "DataStreamSourceConversion", + classOf[ProcessFunction[Any, Row]], + body, + schema.typeInfo) + + val processFunc = new CRowOutputProcessRunner( function.name, function.code, internalType) val opName = s"from: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})" - // TODO we need a ProcessFunction here - input.map(mapFunc).name(opName).returns(internalType) + input.process(processFunc).name(opName).returns(internalType) } // no conversion necessary, forward else { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala index 72ecac58243c4..663b2762eb9c2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala @@ -30,6 +30,7 @@ import org.apache.flink.table.sources._ import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.sources.{StreamTableSource, TableSource} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo /** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */ class StreamTableSourceScan( @@ -46,29 +47,29 @@ class StreamTableSourceScan( val fieldNames = TableEnvironment.getFieldNames(tableSource).toList val fieldTypes = TableEnvironment.getFieldTypes(tableSource.getReturnType).toList - val fieldCnt = fieldNames.length + val fields = fieldNames.zip(fieldTypes) - val rowtime = tableSource match { + val withRowtime = tableSource match { case timeSource: DefinedRowtimeAttribute if timeSource.getRowtimeAttribute != null => val rowtimeAttribute = timeSource.getRowtimeAttribute - Some((fieldCnt, rowtimeAttribute)) + fields :+ (rowtimeAttribute, TimeIndicatorTypeInfo.ROWTIME_INDICATOR) case _ => - None + fields } - val proctime = tableSource match { + val withProctime = tableSource match { case timeSource: DefinedProctimeAttribute if timeSource.getProctimeAttribute != null => val proctimeAttribute = timeSource.getProctimeAttribute - Some((fieldCnt + (if (rowtime.isDefined) 1 else 0), proctimeAttribute)) + withRowtime :+ (proctimeAttribute, TimeIndicatorTypeInfo.PROCTIME_INDICATOR) case _ => - None + withRowtime } + val (fieldNamesWithIndicators, fieldTypesWithIndicators) = withProctime.unzip + flinkTypeFactory.buildLogicalRowType( - fieldNames, - fieldTypes, - rowtime, - proctime) + fieldNamesWithIndicators, + fieldTypesWithIndicators) } override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala index 3ae949ef57516..470d006257077 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala @@ -30,6 +30,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.sources.{DefinedProctimeAttribute, DefinedRowtimeAttribute, TableSource} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import scala.collection.JavaConverters._ @@ -51,29 +52,29 @@ class FlinkLogicalTableSourceScan( val fieldNames = TableEnvironment.getFieldNames(tableSource).toList val fieldTypes = TableEnvironment.getFieldTypes(tableSource.getReturnType).toList - val fieldCnt = fieldNames.length + val fields = fieldNames.zip(fieldTypes) - val rowtime = tableSource match { + val withRowtime = tableSource match { case timeSource: DefinedRowtimeAttribute if timeSource.getRowtimeAttribute != null => val rowtimeAttribute = timeSource.getRowtimeAttribute - Some((fieldCnt, rowtimeAttribute)) + fields :+ (rowtimeAttribute, TimeIndicatorTypeInfo.ROWTIME_INDICATOR) case _ => - None + fields } - val proctime = tableSource match { + val withProctime = tableSource match { case timeSource: DefinedProctimeAttribute if timeSource.getProctimeAttribute != null => val proctimeAttribute = timeSource.getProctimeAttribute - Some((fieldCnt + (if (rowtime.isDefined) 1 else 0), proctimeAttribute)) + withRowtime :+ (proctimeAttribute, TimeIndicatorTypeInfo.PROCTIME_INDICATOR) case _ => - None + withRowtime } + val (fieldNamesWithIndicators, fieldTypesWithIndicators) = withProctime.unzip + flinkTypeFactory.buildLogicalRowType( - fieldNames, - fieldTypes, - rowtime, - proctime) + fieldNamesWithIndicators, + fieldTypesWithIndicators) } override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala index 2075689b79dc4..7dfcbc523d309 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala @@ -87,7 +87,7 @@ class DataStreamWindowJoinRule val (windowBounds, remainCondition) = WindowJoinUtil.extractWindowBoundsFromPredicate( joinInfo.getRemaining(join.getCluster.getRexBuilder), - leftRowSchema.logicalArity, + leftRowSchema.arity, join.getRowType, join.getCluster.getRexBuilder, TableConfig.DEFAULT) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala index 70054b4a0d2c9..b7021e285a4a0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala @@ -18,27 +18,14 @@ package org.apache.flink.table.plan.schema -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.stats.FlinkStatistic class DataStreamTable[T]( val dataStream: DataStream[T], override val fieldIndexes: Array[Int], override val fieldNames: Array[String], - val rowtime: Option[(Int, String)], - val proctime: Option[(Int, String)], override val statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) extends FlinkTable[T](dataStream.getType, fieldIndexes, fieldNames, statistic) { - override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] - - flinkTypeFactory.buildLogicalRowType( - fieldNames, - fieldTypes, - rowtime, - proctime) - } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala index 752b00e9747dd..c76532fceaa10 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala @@ -26,6 +26,7 @@ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo abstract class FlinkTable[T]( val typeInfo: TypeInformation[T], @@ -36,25 +37,39 @@ abstract class FlinkTable[T]( if (fieldIndexes.length != fieldNames.length) { throw new TableException( - "Number of field indexes and field names must be equal.") + s"Number of field names and field indexes must be equal.\n" + + s"Number of names is ${fieldNames.length}, number of indexes is ${fieldIndexes.length}.\n" + + s"List of column names: ${fieldNames.mkString("[", ", ", "]")}.\n" + + s"List of column indexes: ${fieldIndexes.mkString("[", ", ", "]")}.") } // check uniqueness of field names if (fieldNames.length != fieldNames.toSet.size) { + val duplicateFields = fieldNames + // count occurences of field names + .groupBy(identity).mapValues(_.length) + // filter for occurences > 1 and map to field name + .filter(g => g._2 > 1).keys + throw new TableException( - "Table field names must be unique.") + s"Field names must be unique.\n" + + s"List of duplicate fields: ${duplicateFields.mkString("[", ", ", "]")}.\n" + + s"List of all fields: ${fieldNames.mkString("[", ", ", "]")}.") } val fieldTypes: Array[TypeInformation[_]] = typeInfo match { case cType: CompositeType[_] => // it is ok to leave out fields - if (fieldNames.length > cType.getArity) { + if (fieldIndexes.count(_ >= 0) > cType.getArity) { throw new TableException( s"Arity of type (" + cType.getFieldNames.deep + ") " + "must not be greater than number of field names " + fieldNames.deep + ".") } - fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]]) + fieldIndexes.map { + case TimeIndicatorTypeInfo.ROWTIME_MARKER => TimeIndicatorTypeInfo.ROWTIME_INDICATOR + case TimeIndicatorTypeInfo.PROCTIME_MARKER => TimeIndicatorTypeInfo.PROCTIME_INDICATOR + case i => cType.getTypeAt(i).asInstanceOf[TypeInformation[_]]} case aType: AtomicType[_] => if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) { throw new TableException( @@ -65,7 +80,7 @@ abstract class FlinkTable[T]( override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] - flinkTypeFactory.buildLogicalRowType(fieldNames, fieldTypes, None, None) + flinkTypeFactory.buildLogicalRowType(fieldNames, fieldTypes) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/RowSchema.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/RowSchema.scala index ccbe44d3af0ec..ad0f552b8f3a5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/RowSchema.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/RowSchema.scala @@ -18,14 +18,10 @@ package org.apache.flink.table.plan.schema -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField, RelRecordType} -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexShuttle} +import org.apache.calcite.rel.`type`.RelDataType import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.functions.TimeMaterializationSqlFunction import org.apache.flink.types.Row import scala.collection.JavaConversions._ @@ -35,127 +31,35 @@ import scala.collection.JavaConversions._ */ class RowSchema(private val logicalRowType: RelDataType) { - private lazy val physicalRowFields: Seq[RelDataTypeField] = logicalRowType.getFieldList filter { - field => !FlinkTypeFactory.isTimeIndicatorType(field.getType) - } - - private lazy val physicalRowType: RelDataType = new RelRecordType(physicalRowFields) - - private lazy val physicalRowFieldTypes: Seq[TypeInformation[_]] = physicalRowFields map { f => - FlinkTypeFactory.toTypeInfo(f.getType) - } - - private lazy val physicalRowFieldNames: Seq[String] = physicalRowFields.map(_.getName) + private lazy val physicalRowFieldTypes: Seq[TypeInformation[_]] = + logicalRowType.getFieldList map { f => FlinkTypeFactory.toTypeInfo(f.getType) } private lazy val physicalRowTypeInfo: TypeInformation[Row] = new RowTypeInfo( - physicalRowFieldTypes.toArray, physicalRowFieldNames.toArray) - - private lazy val indexMapping: Array[Int] = generateIndexMapping - - private lazy val inputRefUpdater = new RexInputRefUpdater() - - private def generateIndexMapping: Array[Int] = { - val mapping = new Array[Int](logicalRowType.getFieldCount) - var countTimeIndicators = 0 - var i = 0 - while (i < logicalRowType.getFieldCount) { - val t = logicalRowType.getFieldList.get(i).getType - if (FlinkTypeFactory.isTimeIndicatorType(t)) { - countTimeIndicators += 1 - // no mapping - mapping(i) = -1 - } else { - mapping(i) = i - countTimeIndicators - } - i += 1 - } - mapping - } - - private class RexInputRefUpdater extends RexShuttle { - - override def visitInputRef(inputRef: RexInputRef): RexNode = { - new RexInputRef(mapIndex(inputRef.getIndex), inputRef.getType) - } - - override def visitCall(call: RexCall): RexNode = call.getOperator match { - // we leave time indicators unchanged yet - // the index becomes invalid but right now we are only - // interested in the type of the input reference - case TimeMaterializationSqlFunction => call - case _ => super.visitCall(call) - } - } - - /** - * Returns the arity of the logical record. - */ - def logicalArity: Int = logicalRowType.getFieldCount - - /** - * Returns the arity of the physical record. - */ - def physicalArity: Int = physicalTypeInfo.getArity - - /** - * Returns a logical [[RelDataType]] including logical fields (i.e. time indicators). - */ - def logicalType: RelDataType = logicalRowType - - /** - * Returns a physical [[RelDataType]] with no logical fields (i.e. time indicators). - */ - def physicalType: RelDataType = physicalRowType - - /** - * Returns a physical [[TypeInformation]] of row with no logical fields (i.e. time indicators). - */ - def physicalTypeInfo: TypeInformation[Row] = physicalRowTypeInfo - - /** - * Returns [[TypeInformation]] of the row's fields with no logical fields (i.e. time indicators). - */ - def physicalFieldTypeInfo: Seq[TypeInformation[_]] = physicalRowFieldTypes + physicalRowFieldTypes.toArray, fieldNames.toArray) /** - * Returns the logical fields names including logical fields (i.e. time indicators). + * Returns the arity of the schema. */ - def logicalFieldNames: Seq[String] = logicalRowType.getFieldNames + def arity: Int = logicalRowType.getFieldCount /** - * Returns the physical fields names with no logical fields (i.e. time indicators). + * Returns the [[RelDataType]] of the schema */ - def physicalFieldNames: Seq[String] = physicalRowFieldNames + def relDataType: RelDataType = logicalRowType /** - * Converts logical indices to physical indices based on this schema. + * Returns the [[TypeInformation]] of of the schema */ - def mapIndex(logicalIndex: Int): Int = { - val mappedIndex = indexMapping(logicalIndex) - if (mappedIndex < 0) { - throw new TableException("Invalid access to a logical field.") - } else { - mappedIndex - } - } + def typeInfo: TypeInformation[Row] = physicalRowTypeInfo /** - * Converts logical indices of a aggregate call to physical ones. + * Returns the [[TypeInformation]] of fields of the schema */ - def mapAggregateCall(logicalAggCall: AggregateCall): AggregateCall = { - logicalAggCall.copy( - logicalAggCall.getArgList.map(mapIndex(_).asInstanceOf[Integer]), - if (logicalAggCall.filterArg < 0) { - logicalAggCall.filterArg - } else { - mapIndex(logicalAggCall.filterArg) - } - ) - } + def fieldTypeInfos: Seq[TypeInformation[_]] = physicalRowFieldTypes /** - * Converts logical field references of a [[RexNode]] to physical ones. + * Returns the fields names */ - def mapRexNode(logicalRexNode: RexNode): RexNode = logicalRexNode.accept(inputRefUpdater) + def fieldNames: Seq[String] = logicalRowType.getFieldNames } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala index 408381dea9e81..dc1f31ab27db5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala @@ -23,6 +23,7 @@ import org.apache.flink.table.api.{TableEnvironment, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.sources.{DefinedProctimeAttribute, DefinedRowtimeAttribute, TableSource} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo class StreamTableSourceTable[T]( override val tableSource: TableSource[T], @@ -36,41 +37,38 @@ class StreamTableSourceTable[T]( val fieldNames = TableEnvironment.getFieldNames(tableSource).toList val fieldTypes = TableEnvironment.getFieldTypes(tableSource.getReturnType).toList - val fieldCnt = fieldNames.length + val fields = fieldNames.zip(fieldTypes) - val rowtime = tableSource match { - case nullTimeSource : DefinedRowtimeAttribute - if nullTimeSource.getRowtimeAttribute == null => - None - case emptyStringTimeSource: DefinedRowtimeAttribute - if emptyStringTimeSource.getRowtimeAttribute.trim.equals("") => - throw TableException("The name of the rowtime attribute must not be empty.") - case timeSource: DefinedRowtimeAttribute => + val withRowtime = tableSource match { + case timeSource: DefinedRowtimeAttribute if timeSource.getRowtimeAttribute == null => + fields + case timeSource: DefinedRowtimeAttribute if timeSource.getRowtimeAttribute.trim.equals("") => + throw TableException("The name of the rowtime attribute must not be empty.") + case timeSource: DefinedRowtimeAttribute => val rowtimeAttribute = timeSource.getRowtimeAttribute - Some((fieldCnt, rowtimeAttribute)) + fields :+ (rowtimeAttribute, TimeIndicatorTypeInfo.ROWTIME_INDICATOR) case _ => - None + fields } - val proctime = tableSource match { - case nullTimeSource : DefinedProctimeAttribute - if nullTimeSource.getProctimeAttribute == null => - None - case emptyStringTimeSource: DefinedProctimeAttribute - if emptyStringTimeSource.getProctimeAttribute.trim.equals("") => - throw TableException("The name of the proctime attribute must not be empty.") - case timeSource: DefinedProctimeAttribute => + val withProctime = tableSource match { + case timeSource : DefinedProctimeAttribute if timeSource.getProctimeAttribute == null => + withRowtime + case timeSource: DefinedProctimeAttribute + if timeSource.getProctimeAttribute.trim.equals("") => + throw TableException("The name of the rowtime attribute must not be empty.") + case timeSource: DefinedProctimeAttribute => val proctimeAttribute = timeSource.getProctimeAttribute - Some((fieldCnt + (if (rowtime.isDefined) 1 else 0), proctimeAttribute)) + withRowtime :+ (proctimeAttribute, TimeIndicatorTypeInfo.PROCTIME_INDICATOR) case _ => - None + withRowtime } + val (fieldNamesWithIndicators, fieldTypesWithIndicators) = withProctime.unzip + flinkTypeFactory.buildLogicalRowType( - fieldNames, - fieldTypes, - rowtime, - proctime) + fieldNamesWithIndicators, + fieldTypesWithIndicators) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala index bf9a6881addef..53bf8e777af33 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExtractor.scala @@ -20,10 +20,11 @@ package org.apache.flink.table.plan.util import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rex._ +import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.{SqlFunction, SqlPostfixOperator} import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.expressions.{Expression, Literal, ResolvedFieldReference} +import org.apache.flink.table.expressions.{And, Expression, Literal, Or, ResolvedFieldReference} import org.apache.flink.table.validate.FunctionCatalog import org.apache.flink.util.Preconditions @@ -170,6 +171,10 @@ class RexNodeToExpressionConverter( None } else { call.getOperator match { + case SqlStdOperatorTable.OR => + Option(operands.reduceLeft(Or)) + case SqlStdOperatorTable.AND => + Option(operands.reduceLeft(And)) case function: SqlFunction => lookupFunction(replace(function.getName), operands) case postfix: SqlPostfixOperator => diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowCorrelateProcessRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowCorrelateProcessRunner.scala index 4f0a78550794a..2553d9cd67b1a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowCorrelateProcessRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowCorrelateProcessRunner.scala @@ -25,9 +25,9 @@ import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.{Logger, LoggerFactory} /** * A CorrelateProcessRunner with [[CRow]] input and [[CRow]] output. @@ -40,9 +40,8 @@ class CRowCorrelateProcessRunner( @transient var returnType: TypeInformation[CRow]) extends ProcessFunction[CRow, CRow] with ResultTypeQueryable[CRow] - with Compiler[Any] { - - val LOG: Logger = LoggerFactory.getLogger(this.getClass) + with Compiler[Any] + with Logging { private var function: ProcessFunction[Row, Row] = _ private var collector: TableFunctionCollector[_] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowMapRunner.scala similarity index 92% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputMapRunner.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowMapRunner.scala index 109c6e1ebe35c..54bac601bcc10 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowMapRunner.scala @@ -24,21 +24,20 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory /** * MapRunner with [[CRow]] input. */ -class CRowInputMapRunner[OUT]( +class CRowMapRunner[OUT]( name: String, code: String, @transient var returnType: TypeInformation[OUT]) extends RichMapFunction[CRow, OUT] with ResultTypeQueryable[OUT] - with Compiler[MapFunction[Row, OUT]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[MapFunction[Row, OUT]] + with Logging { private var function: MapFunction[Row, OUT] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala new file mode 100644 index 0000000000000..600b8987a28f4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector + +/** + * ProcessRunner with [[CRow]] output. + */ +class CRowOutputProcessRunner( + name: String, + code: String, + @transient var returnType: TypeInformation[CRow]) + extends ProcessFunction[Any, CRow] + with ResultTypeQueryable[CRow] + with Compiler[ProcessFunction[Any, Row]] + with Logging { + + private var function: ProcessFunction[Any, Row] = _ + private var cRowWrapper: CRowWrappingCollector = _ + + override def open(parameters: Configuration): Unit = { + LOG.debug(s"Compiling ProcessFunction: $name \n\n Code:\n$code") + val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) + LOG.debug("Instantiating ProcessFunction.") + function = clazz.newInstance() + + this.cRowWrapper = new CRowWrappingCollector() + this.cRowWrapper.setChange(true) + } + + override def processElement( + in: Any, + ctx: ProcessFunction[Any, CRow]#Context, + out: Collector[CRow]): Unit = { + + // remove timestamp from stream record + val tc = out.asInstanceOf[TimestampedCollector[_]] + tc.eraseTimestamp() + + cRowWrapper.out = out + function.processElement(in, ctx.asInstanceOf[ProcessFunction[Any, Row]#Context], cRowWrapper) + } + + override def getProducedType: TypeInformation[CRow] = returnType +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala index cef62a517be2f..a7f3d7287baff 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala @@ -25,9 +25,9 @@ import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * ProcessRunner with [[CRow]] input and [[CRow]] output. @@ -38,9 +38,8 @@ class CRowProcessRunner( @transient var returnType: TypeInformation[CRow]) extends ProcessFunction[CRow, CRow] with ResultTypeQueryable[CRow] - with Compiler[ProcessFunction[Row, Row]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[ProcessFunction[Row, Row]] + with Logging { private var function: ProcessFunction[Row, Row] = _ private var cRowWrapper: CRowWrappingCollector = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala index 478b6b64cfcad..e2f5e6113361f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging import org.apache.flink.util.Collector -import org.slf4j.{Logger, LoggerFactory} class CorrelateFlatMapRunner[IN, OUT]( flatMapName: String, @@ -35,9 +35,8 @@ class CorrelateFlatMapRunner[IN, OUT]( @transient var returnType: TypeInformation[OUT]) extends RichFlatMapFunction[IN, OUT] with ResultTypeQueryable[OUT] - with Compiler[Any] { - - val LOG: Logger = LoggerFactory.getLogger(this.getClass) + with Compiler[Any] + with Logging { private var function: FlatMapFunction[IN, OUT] = _ private var collector: TableFunctionCollector[_] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatJoinRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatJoinRunner.scala index 67acc0b10ee1d..0bf65694367f9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatJoinRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatJoinRunner.scala @@ -21,10 +21,10 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.functions.{FlatJoinFunction, RichFlatJoinFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.table.codegen.Compiler import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory class FlatJoinRunner[IN1, IN2, OUT]( name: String, @@ -32,9 +32,8 @@ class FlatJoinRunner[IN1, IN2, OUT]( @transient var returnType: TypeInformation[OUT]) extends RichFlatJoinFunction[IN1, IN2, OUT] with ResultTypeQueryable[OUT] - with Compiler[FlatJoinFunction[IN1, IN2, OUT]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[FlatJoinFunction[IN1, IN2, OUT]] + with Logging { private var function: FlatJoinFunction[IN1, IN2, OUT] = null diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala index 938da59ea4aa4..6c1f80489851e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala @@ -22,11 +22,11 @@ import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.table.codegen.Compiler import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory class FlatMapRunner( name: String, @@ -34,9 +34,8 @@ class FlatMapRunner( @transient var returnType: TypeInformation[Row]) extends RichFlatMapFunction[Row, Row] with ResultTypeQueryable[Row] - with Compiler[FlatMapFunction[Row, Row]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[FlatMapFunction[Row, Row]] + with Logging { private var function: FlatMapFunction[Row, Row] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapRunner.scala index 14eeecfb451b5..00d18ecc00794 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapRunner.scala @@ -21,9 +21,9 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.functions.{MapFunction, RichMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.table.codegen.Compiler import org.apache.flink.configuration.Configuration -import org.slf4j.LoggerFactory +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging class MapRunner[IN, OUT]( name: String, @@ -31,9 +31,8 @@ class MapRunner[IN, OUT]( @transient var returnType: TypeInformation[OUT]) extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] - with Compiler[MapFunction[IN, OUT]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[MapFunction[IN, OUT]] + with Logging { private var function: MapFunction[IN, OUT] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala index 00b7b8eee0ce1..5f5a2cc4a19a3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala @@ -21,9 +21,9 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.functions.{FlatJoinFunction, RichFlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.table.codegen.Compiler import org.apache.flink.configuration.Configuration -import org.slf4j.LoggerFactory +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging abstract class MapSideJoinRunner[IN1, IN2, SINGLE_IN, MULTI_IN, OUT]( name: String, @@ -32,9 +32,8 @@ abstract class MapSideJoinRunner[IN1, IN2, SINGLE_IN, MULTI_IN, OUT]( broadcastSetName: String) extends RichFlatMapFunction[MULTI_IN, OUT] with ResultTypeQueryable[OUT] - with Compiler[FlatJoinFunction[IN1, IN2, OUT]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[FlatJoinFunction[IN1, IN2, OUT]] + with Logging { protected var function: FlatJoinFunction[IN1, IN2, OUT] = _ protected var broadcastSet: Option[SINGLE_IN] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/OutputRowtimeProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/OutputRowtimeProcessFunction.scala new file mode 100644 index 0000000000000..3eaeea308f915 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/OutputRowtimeProcessFunction.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime + +import org.apache.calcite.runtime.SqlFunctions +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.functions.util.FunctionUtils +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.util.Collector + +/** + * Wraps a ProcessFunction and sets a Timestamp field of a CRow as + * [[org.apache.flink.streaming.runtime.streamrecord.StreamRecord]] timestamp. + */ +class OutputRowtimeProcessFunction[OUT]( + function: MapFunction[CRow, OUT], + rowtimeIdx: Int) + extends ProcessFunction[CRow, OUT] { + + override def open(parameters: Configuration): Unit = { + FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext) + FunctionUtils.openFunction(function, parameters) + } + + override def processElement( + in: CRow, + ctx: ProcessFunction[CRow, OUT]#Context, + out: Collector[OUT]): Unit = { + + val timestamp = in.row.getField(rowtimeIdx).asInstanceOf[Long] + out.asInstanceOf[TimestampedCollector[_]].setAbsoluteTimestamp(timestamp) + + val convertedTimestamp = SqlFunctions.internalToTimestamp(timestamp) + in.row.setField(rowtimeIdx, convertedTimestamp) + + out.collect(function.map(in)) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/RowtimeProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/RowtimeProcessFunction.scala new file mode 100644 index 0000000000000..e192b075afdcd --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/RowtimeProcessFunction.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.util.Collector + +/** + * ProcessFunction to copy a timestamp from a [[org.apache.flink.types.Row]] field into the + * [[org.apache.flink.streaming.runtime.streamrecord.StreamRecord]]. + */ +class RowtimeProcessFunction( + val rowtimeIdx: Int, + @transient var returnType: TypeInformation[CRow]) + extends ProcessFunction[CRow, CRow] + with ResultTypeQueryable[CRow] { + + override def processElement( + in: CRow, + ctx: ProcessFunction[CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + + val timestamp = in.row.getField(rowtimeIdx).asInstanceOf[Long] + out.asInstanceOf[TimestampedCollector[CRow]].setAbsoluteTimestamp(timestamp) + out.collect(in) + } + + override def getProducedType: TypeInformation[CRow] = returnType +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala index dd9c015c2d98e..d3bffda284df9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateAggFunction.scala @@ -21,8 +21,8 @@ package org.apache.flink.table.runtime.aggregate import org.apache.flink.api.common.functions.AggregateFunction import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory /** * Aggregate Function used for the aggregate operator in @@ -31,9 +31,8 @@ import org.slf4j.LoggerFactory * @param genAggregations Generated aggregate helper function */ class AggregateAggFunction(genAggregations: GeneratedAggregationsFunction) - extends AggregateFunction[CRow, Row, Row] with Compiler[GeneratedAggregations] { + extends AggregateFunction[CRow, Row, Row] with Compiler[GeneratedAggregations] with Logging { - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def createAccumulator(): Row = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index c9f98e31bd3be..58940d06abb05 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -32,10 +32,11 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} +import org.apache.flink.table.api.dataview.DataViewSpec import org.apache.flink.table.api.{StreamQueryConfig, TableException} import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.{AggregationCodeGenerator, CodeGenerator} +import org.apache.flink.table.codegen.AggregationCodeGenerator import org.apache.flink.table.expressions.ExpressionUtils.isTimeIntervalLiteral import org.apache.flink.table.expressions._ import org.apache.flink.table.functions.aggfunctions._ @@ -66,7 +67,7 @@ object AggregateUtil { * @param inputType Physical type of the row. * @param inputTypeInfo Physical type information of the row. * @param inputFieldTypeInfo Physical type information of the row's fields. - * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType + * @param rowTimeIdx The index of the rowtime field or None in case of processing time. * @param isPartitioned It is a tag that indicate whether the input is partitioned * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause */ @@ -77,16 +78,17 @@ object AggregateUtil { inputTypeInfo: TypeInformation[Row], inputFieldTypeInfo: Seq[TypeInformation[_]], queryConfig: StreamQueryConfig, - isRowTimeType: Boolean, + rowTimeIdx: Option[Int], isPartitioned: Boolean, isRowsClause: Boolean) : ProcessFunction[CRow, CRow] = { - val (aggFields, aggregates, accTypes) = + val (aggFields, aggregates, accTypes, accSpecs) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetraction = false, + isStateBackedDataViews = true) val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*) @@ -96,7 +98,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "UnboundedProcessingOverAggregateHelper", - generator, inputFieldTypeInfo, aggregates, aggFields, @@ -108,16 +109,18 @@ object AggregateUtil { outputArity, needRetract = false, needMerge = false, - needReset = false + needReset = false, + accConfig = Some(accSpecs) ) - if (isRowTimeType) { + if (rowTimeIdx.isDefined) { if (isRowsClause) { // ROWS unbounded over process function new RowTimeUnboundedRowsOver( genFunction, aggregationStateType, CRowTypeInfo(inputTypeInfo), + rowTimeIdx.get, queryConfig) } else { // RANGE unbounded over process function @@ -125,6 +128,7 @@ object AggregateUtil { genFunction, aggregationStateType, CRowTypeInfo(inputTypeInfo), + rowTimeIdx.get, queryConfig) } } else { @@ -159,11 +163,12 @@ object AggregateUtil { generateRetraction: Boolean, consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = { - val (aggFields, aggregates, accTypes) = + val (aggFields, aggregates, accTypes, accSpecs) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputRowType, - consumeRetraction) + consumeRetraction, + isStateBackedDataViews = true) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -173,7 +178,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "NonWindowedAggregationHelper", - generator, inputFieldTypes, aggregates, aggFields, @@ -185,7 +189,8 @@ object AggregateUtil { outputArity, consumeRetraction, needMerge = false, - needReset = false + needReset = false, + accConfig = Some(accSpecs) ) new GroupAggProcessFunction( @@ -207,7 +212,7 @@ object AggregateUtil { * @param inputFieldTypeInfo Physical type information of the row's fields. * @param precedingOffset the preceding offset * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause - * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType + * @param rowTimeIdx The index of the rowtime field or None in case of processing time. * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] */ private[flink] def createBoundedOverProcessFunction( @@ -219,15 +224,16 @@ object AggregateUtil { precedingOffset: Long, queryConfig: StreamQueryConfig, isRowsClause: Boolean, - isRowTimeType: Boolean) + rowTimeIdx: Option[Int]) : ProcessFunction[CRow, CRow] = { val needRetract = true - val (aggFields, aggregates, accTypes) = + val (aggFields, aggregates, accTypes, accSpecs) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetract) + needRetract, + isStateBackedDataViews = true) val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*) val inputRowType = CRowTypeInfo(inputTypeInfo) @@ -238,7 +244,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "BoundedOverAggregateHelper", - generator, inputFieldTypeInfo, aggregates, aggFields, @@ -250,16 +255,18 @@ object AggregateUtil { outputArity, needRetract, needMerge = false, - needReset = true + needReset = false, + accConfig = Some(accSpecs) ) - if (isRowTimeType) { + if (rowTimeIdx.isDefined) { if (isRowsClause) { new RowTimeBoundedRowsOver( genFunction, aggregationStateType, inputRowType, precedingOffset, + rowTimeIdx.get, queryConfig) } else { new RowTimeBoundedRangeOver( @@ -267,6 +274,7 @@ object AggregateUtil { aggregationStateType, inputRowType, precedingOffset, + rowTimeIdx.get, queryConfig) } } else { @@ -322,7 +330,7 @@ object AggregateUtil { : MapFunction[Row, Row] = { val needRetract = false - val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, needRetract) @@ -368,7 +376,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", - generator, inputFieldTypeInfo, aggregates, aggFieldIndexes, @@ -380,7 +387,8 @@ object AggregateUtil { outputArity, needRetract, needMerge = false, - needReset = true + needReset = true, + None ) new DataSetWindowAggMapFunction( @@ -428,7 +436,7 @@ object AggregateUtil { : RichGroupReduceFunction[Row, Row] = { val needRetract = false - val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, needRetract) @@ -447,7 +455,6 @@ object AggregateUtil { // sliding time-window for partial aggregations val genFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", - generator, physicalInputTypes, aggregates, aggFieldIndexes, @@ -459,7 +466,8 @@ object AggregateUtil { keysAndAggregatesArity + 1, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) new DataSetSlideTimeWindowAggReduceGroupFunction( genFunction, @@ -542,7 +550,7 @@ object AggregateUtil { : RichGroupReduceFunction[Row, Row] = { val needRetract = false - val (aggFieldIndexes, aggregates, _) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates, _, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, needRetract) @@ -551,7 +559,6 @@ object AggregateUtil { val genPreAggFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", - generator, physicalInputTypes, aggregates, aggFieldIndexes, @@ -563,12 +570,12 @@ object AggregateUtil { outputType.getFieldCount, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) val genFinalAggFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", - generator, physicalInputTypes, aggregates, aggFieldIndexes, @@ -580,7 +587,8 @@ object AggregateUtil { outputType.getFieldCount, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -588,7 +596,7 @@ object AggregateUtil { window match { case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => // tumbling time window - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + val (startPos, endPos, _) = computeWindowPropertyPos(properties) if (doAllSupportPartialMerge(aggregates)) { // for incremental aggregations new DataSetTumbleTimeWindowAggReduceCombineFunction( @@ -615,7 +623,7 @@ object AggregateUtil { asLong(size)) case SessionGroupWindow(_, _, gap) => - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + val (startPos, endPos, _) = computeWindowPropertyPos(properties) new DataSetSessionWindowAggReduceGroupFunction( genFinalAggFunction, keysAndAggregatesArity, @@ -625,7 +633,7 @@ object AggregateUtil { isInputCombined) case SlidingGroupWindow(_, _, size, _) if isTimeInterval(size.resultType) => - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + val (startPos, endPos, _) = computeWindowPropertyPos(properties) if (doAllSupportPartialMerge(aggregates)) { // for partial aggregations new DataSetSlideWindowAggReduceCombineFunction( @@ -689,7 +697,7 @@ object AggregateUtil { groupings: Array[Int]): MapPartitionFunction[Row, Row] = { val needRetract = false - val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, needRetract) @@ -710,7 +718,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", - generator, physicalInputTypes, aggregates, aggFieldIndexes, @@ -722,7 +729,8 @@ object AggregateUtil { groupings.length + aggregates.length + 2, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) new DataSetSessionWindowAggregatePreProcessor( @@ -763,7 +771,7 @@ object AggregateUtil { : GroupCombineFunction[Row, Row] = { val needRetract = false - val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, needRetract) @@ -785,7 +793,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", - generator, physicalInputTypes, aggregates, aggFieldIndexes, @@ -797,7 +804,8 @@ object AggregateUtil { groupings.length + aggregates.length + 2, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) new DataSetSessionWindowAggregatePreProcessor( @@ -830,7 +838,7 @@ object AggregateUtil { RichGroupReduceFunction[Row, Row]) = { val needRetract = false - val (aggInFields, aggregates, accTypes) = transformToAggregateFunctions( + val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, needRetract) @@ -866,7 +874,6 @@ object AggregateUtil { val genPreAggFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", - generator, inputFieldTypeInfo, aggregates, aggInFields, @@ -878,7 +885,8 @@ object AggregateUtil { groupings.length + aggregates.length, needRetract, needMerge = false, - needReset = true + needReset = true, + None ) // compute mapping of forwarded grouping keys @@ -893,7 +901,6 @@ object AggregateUtil { val genFinalAggFunction = generator.generateAggregations( "DataSetAggregateFinalHelper", - generator, inputFieldTypeInfo, aggregates, aggInFields, @@ -905,7 +912,8 @@ object AggregateUtil { outputType.getFieldCount, needRetract, needMerge = true, - needReset = true + needReset = true, + None ) ( @@ -917,7 +925,6 @@ object AggregateUtil { else { val genFunction = generator.generateAggregations( "DataSetAggregateHelper", - generator, inputFieldTypeInfo, aggregates, aggInFields, @@ -929,7 +936,8 @@ object AggregateUtil { outputType.getFieldCount, needRetract, needMerge = false, - needReset = true + needReset = true, + None ) ( @@ -951,10 +959,11 @@ object AggregateUtil { : AllWindowFunction[Row, CRow, DataStreamWindow] = { if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + val (startPos, endPos, timePos) = computeWindowPropertyPos(properties) new IncrementalAggregateAllTimeWindowFunction( startPos, endPos, + timePos, finalRowArity) .asInstanceOf[AllWindowFunction[Row, CRow, DataStreamWindow]] } else { @@ -975,12 +984,13 @@ object AggregateUtil { WindowFunction[Row, CRow, Tuple, DataStreamWindow] = { if (isTimeWindow(window)) { - val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) + val (startPos, endPos, timePos) = computeWindowPropertyPos(properties) new IncrementalAggregateTimeWindowFunction( numGroupingKeys, numAggregates, startPos, endPos, + timePos, finalRowArity) .asInstanceOf[WindowFunction[Row, CRow, Tuple, DataStreamWindow]] } else { @@ -1002,7 +1012,7 @@ object AggregateUtil { : (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo, RowTypeInfo) = { val needRetract = false - val (aggFields, aggregates, accTypes) = + val (aggFields, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, @@ -1013,7 +1023,6 @@ object AggregateUtil { val genFunction = generator.generateAggregations( "GroupingWindowAggregateHelper", - generator, inputFieldTypeInfo, aggregates, aggFields, @@ -1025,7 +1034,8 @@ object AggregateUtil { outputArity, needRetract, needMerge, - needReset = false + needReset = false, + None ) val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) @@ -1136,32 +1146,42 @@ object AggregateUtil { } } - private[flink] def computeWindowStartEndPropertyPos( - properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int]) = { + private[flink] def computeWindowPropertyPos( + properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int], Option[Int]) = { - val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) { - (p, x) => p match { + val propPos = properties.foldRight( + (None: Option[Int], None: Option[Int], None: Option[Int], 0)) { + case (p, (s, e, t, i)) => p match { case NamedWindowProperty(_, prop) => prop match { - case WindowStart(_) if x._1.isDefined => + case WindowStart(_) if s.isDefined => throw new TableException("Duplicate WindowStart property encountered. This is a bug.") case WindowStart(_) => - (Some(x._3), x._2, x._3 - 1) - case WindowEnd(_) if x._2.isDefined => + (Some(i), e, t, i - 1) + case WindowEnd(_) if e.isDefined => throw new TableException("Duplicate WindowEnd property encountered. This is a bug.") case WindowEnd(_) => - (x._1, Some(x._3), x._3 - 1) + (s, Some(i), t, i - 1) + case RowtimeAttribute(_) if t.isDefined => + throw new TableException( + "Duplicate Window rowtime property encountered. This is a bug.") + case RowtimeAttribute(_) => + (s, e, Some(i), i - 1) } } } - (propPos._1, propPos._2) + (propPos._1, propPos._2, propPos._3) } private def transformToAggregateFunctions( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, - needRetraction: Boolean) - : (Array[Array[Int]], Array[TableAggregateFunction[_, _]], Array[TypeInformation[_]]) = { + needRetraction: Boolean, + isStateBackedDataViews: Boolean = false) + : (Array[Array[Int]], + Array[TableAggregateFunction[_, _]], + Array[TypeInformation[_]], + Array[Seq[DataViewSpec[_]]]) = { // store the aggregate fields of each aggregate function, by the same order of aggregates. val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size) @@ -1202,7 +1222,7 @@ object AggregateUtil { case DECIMAL => new DecimalSumWithRetractAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Sum aggregate does no support type: '${sqlType}'") + throw new TableException(s"Sum aggregate does no support type: '$sqlType'") } } else { aggregates(index) = sqlTypeName match { @@ -1221,7 +1241,7 @@ object AggregateUtil { case DECIMAL => new DecimalSumAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Sum aggregate does no support type: '${sqlType}'") + throw new TableException(s"Sum aggregate does no support type: '$sqlType'") } } @@ -1243,7 +1263,7 @@ object AggregateUtil { case DECIMAL => new DecimalSum0WithRetractAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Sum0 aggregate does no support type: '${sqlType}'") + throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'") } } else { aggregates(index) = sqlTypeName match { @@ -1262,7 +1282,7 @@ object AggregateUtil { case DECIMAL => new DecimalSum0AggFunction case sqlType: SqlTypeName => - throw new TableException(s"Sum0 aggregate does no support type: '${sqlType}'") + throw new TableException(s"Sum0 aggregate does no support type: '$sqlType'") } } @@ -1283,7 +1303,7 @@ object AggregateUtil { case DECIMAL => new DecimalAvgAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Avg aggregate does no support type: '${sqlType}'") + throw new TableException(s"Avg aggregate does no support type: '$sqlType'") } case sqlMinMaxFunction: SqlMinMaxAggFunction => @@ -1310,7 +1330,7 @@ object AggregateUtil { new StringMinWithRetractAggFunction case sqlType: SqlTypeName => throw new TableException( - s"Min with retract aggregate does no support type: '${sqlType}'") + s"Min with retract aggregate does no support type: '$sqlType'") } } else { sqlTypeName match { @@ -1333,7 +1353,7 @@ object AggregateUtil { case VARCHAR | CHAR => new StringMinAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Min aggregate does no support type: '${sqlType}'") + throw new TableException(s"Min aggregate does no support type: '$sqlType'") } } } else { @@ -1359,7 +1379,7 @@ object AggregateUtil { new StringMaxWithRetractAggFunction case sqlType: SqlTypeName => throw new TableException( - s"Max with retract aggregate does no support type: '${sqlType}'") + s"Max with retract aggregate does no support type: '$sqlType'") } } else { sqlTypeName match { @@ -1382,7 +1402,7 @@ object AggregateUtil { case VARCHAR | CHAR => new StringMaxAggFunction case sqlType: SqlTypeName => - throw new TableException(s"Max aggregate does no support type: '${sqlType}'") + throw new TableException(s"Max aggregate does no support type: '$sqlType'") } } } @@ -1399,14 +1419,28 @@ object AggregateUtil { } } + val accSpecs = new Array[Seq[DataViewSpec[_]]](aggregateCalls.size) + // create accumulator type information for every aggregate function aggregates.zipWithIndex.foreach { case (agg, index) => - if (null == accTypes(index)) { + if (accTypes(index) != null) { + val (accType, specs) = removeStateViewFieldsFromAccTypeInfo(index, + agg, + accTypes(index), + isStateBackedDataViews) + if (specs.isDefined) { + accSpecs(index) = specs.get + accTypes(index) = accType + } else { + accSpecs(index) = Seq() + } + } else { + accSpecs(index) = Seq() accTypes(index) = getAccumulatorTypeOfAggregateFunction(agg) } } - (aggFieldIndexes, aggregates, accTypes) + (aggFieldIndexes, aggregates, accTypes, accSpecs) } private def createRowTypeForKeysAndAggregates( @@ -1451,7 +1485,7 @@ object AggregateUtil { relDataType.head.getIndex } else { throw TableException( - s"Encountered more than one time attribute with the same name: '${relDataType}'") + s"Encountered more than one time attribute with the same name: '$relDataType'") } case e => throw TableException( "The time attribute of window in batch environment should be " + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala index 5f459f98a4aa2..83e1b1315ad05 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala @@ -22,9 +22,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * [[RichGroupReduceFunction]] to compute aggregates that do not support pre-aggregation for batch @@ -35,12 +35,11 @@ import org.slf4j.LoggerFactory class DataSetAggFunction( private val genAggregations: GeneratedAggregationsFunction) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] with Logging { private var output: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala index 9b81992cba24a..52762712b948c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala @@ -23,9 +23,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * [[RichGroupReduceFunction]] to compute the final result of a pre-aggregated aggregation @@ -36,12 +36,11 @@ import org.slf4j.LoggerFactory class DataSetFinalAggFunction( private val genAggregations: GeneratedAggregationsFunction) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] with Logging { private var output: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala index 8febe3e820ff4..fc3366bd31691 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala @@ -22,9 +22,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions._ import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * [[GroupCombineFunction]] and [[MapPartitionFunction]] to compute pre-aggregates for batch @@ -36,12 +36,12 @@ class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction) extends AbstractRichFunction with GroupCombineFunction[Row, Row] with MapPartitionFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var output: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala index fabf200add416..d99ca31df9626 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala @@ -20,11 +20,11 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction -import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -54,7 +54,8 @@ class DataSetSessionWindowAggReduceGroupFunction( gap: Long, isInputCombined: Boolean) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var collector: RowTimeWindowPropertyCollector = _ private val intermediateRowWindowStartPos = keysAndAggregatesArity @@ -63,7 +64,6 @@ class DataSetSessionWindowAggReduceGroupFunction( private var output: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -78,7 +78,10 @@ class DataSetSessionWindowAggReduceGroupFunction( output = function.createOutputRow() accumulators = function.createAccumulators() - collector = new RowTimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) + collector = new RowTimeWindowPropertyCollector( + finalRowWindowStartPos, + finalRowWindowEndPos, + None) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala index 9bcac3031d52c..666bfee19348e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala @@ -22,11 +22,11 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.{AbstractRichFunction, GroupCombineFunction, MapPartitionFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * This wraps the aggregate logic inside of @@ -46,13 +46,13 @@ class DataSetSessionWindowAggregatePreProcessor( with MapPartitionFunction[Row,Row] with GroupCombineFunction[Row,Row] with ResultTypeQueryable[Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var output: Row = _ private val rowTimeFieldPos = keysAndAggregatesArity private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala index b3a19a443328a..3af7969a37c06 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala @@ -25,9 +25,9 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * It is used for sliding windows on batch for time-windows. It takes a prepared input row (with @@ -53,7 +53,8 @@ class DataSetSlideTimeWindowAggReduceGroupFunction( extends RichGroupReduceFunction[Row, Row] with CombineFunction[Row, Row] with ResultTypeQueryable[Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private val timeFieldPos = returnType.getArity - 1 private val intermediateWindowStartPos = keysAndAggregatesArity @@ -61,7 +62,6 @@ class DataSetSlideTimeWindowAggReduceGroupFunction( protected var intermediateRow: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala index 56ed08ade4a80..c64a52217f130 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala @@ -22,9 +22,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -45,7 +45,8 @@ class DataSetSlideWindowAggReduceGroupFunction( finalRowWindowEndPos: Option[Int], windowSize: Long) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var collector: RowTimeWindowPropertyCollector = _ protected val windowStartPos: Int = keysAndAggregatesArity @@ -53,7 +54,6 @@ class DataSetSlideWindowAggReduceGroupFunction( private var output: Row = _ protected var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) protected var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -68,7 +68,10 @@ class DataSetSlideWindowAggReduceGroupFunction( output = function.createOutputRow() accumulators = function.createAccumulators() - collector = new RowTimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) + collector = new RowTimeWindowPropertyCollector( + finalRowWindowStartPos, + finalRowWindowEndPos, + None) } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala index 0e73f7b34baf2..22fe389a5b250 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -22,9 +22,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -38,12 +38,12 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( private val genAggregations: GeneratedAggregationsFunction, private val windowSize: Long) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var output: Row = _ private var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala index 8af2c2e8a8a69..7ae4c173e78cd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -22,9 +22,9 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -44,7 +44,8 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( windowEndPos: Option[Int], keysAndAggregatesArity: Int) extends RichGroupReduceFunction[Row, Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var collector: RowTimeWindowPropertyCollector = _ protected var aggregateBuffer: Row = new Row(keysAndAggregatesArity + 1) @@ -52,7 +53,6 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( private var output: Row = _ protected var accumulators: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) protected var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -67,7 +67,7 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( output = function.createOutputRow() accumulators = function.createAccumulators() - collector = new RowTimeWindowPropertyCollector(windowStartPos, windowEndPos) + collector = new RowTimeWindowPropertyCollector(windowStartPos, windowEndPos, None) } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala index d49ed0e13ee39..324784fc45386 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala @@ -26,8 +26,8 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory /** * This map function only works for windows on batch tables. @@ -44,12 +44,12 @@ class DataSetWindowAggMapFunction( @transient private val returnType: TypeInformation[Row]) extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var accs: Row = _ private var output: Row = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala index 5f48e091996e5..7b201142ce4fa 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.runtime.aggregate -import org.apache.flink.api.common.functions.Function +import org.apache.flink.api.common.functions.{Function, RuntimeContext} import org.apache.flink.types.Row /** @@ -26,6 +26,14 @@ import org.apache.flink.types.Row */ abstract class GeneratedAggregations extends Function { + /** + * Setup method for [[org.apache.flink.table.functions.AggregateFunction]]. + * It can be used for initialization work. By default, this method does nothing. + * + * @param ctx The runtime context. + */ + def open(ctx: RuntimeContext) + /** * Sets the results of the aggregations (partial or final) to the output row. * Final results are computed with the aggregation function. @@ -100,6 +108,17 @@ abstract class GeneratedAggregations extends Function { * aggregated results */ def resetAccumulator(accumulators: Row) + + /** + * Cleanup for the accumulators. + */ + def cleanup() + + /** + * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]]. + * It can be used for clean up work. By default, this method does nothing. + */ + def close() } class SingleElementIterable[T] extends java.lang.Iterable[T] { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala index 690a7c3d6cda9..df594608ea086 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -19,17 +19,16 @@ package org.apache.flink.table.runtime.aggregate import java.lang.{Long => JLong} +import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor} +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.types.Row -import org.apache.flink.util.Collector -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.api.common.state.ValueState import org.apache.flink.table.api.{StreamQueryConfig, Types} import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} -import org.slf4j.{Logger, LoggerFactory} import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector /** * Aggregate Function used for the groupby (without window) aggregate @@ -43,9 +42,9 @@ class GroupAggProcessFunction( private val generateRetraction: Boolean, private val queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { - val LOG: Logger = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ private var newRow: CRow = _ @@ -65,6 +64,7 @@ class GroupAggProcessFunction( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) newRow = new CRow(function.createOutputRow(), true) prevRow = new CRow(function.createOutputRow(), false) @@ -162,7 +162,11 @@ class GroupAggProcessFunction( if (needToCleanupState(timestamp)) { cleanupState(state, cntState) + function.cleanup() } } + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala index 711cc0505bf5b..3c2e8581343cf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateAllTimeWindowFunction.scala @@ -29,13 +29,15 @@ import org.apache.flink.util.Collector * * Computes the final aggregate value from incrementally computed aggregates. * - * @param windowStartPos the start position of window - * @param windowEndPos the end position of window + * @param windowStartOffset the offset of the window start property + * @param windowEndOffset the offset of the window end property + * @param windowRowtimeOffset the offset of the window rowtime property * @param finalRowArity The arity of the final output row. */ class IncrementalAggregateAllTimeWindowFunction( - private val windowStartPos: Option[Int], - private val windowEndPos: Option[Int], + private val windowStartOffset: Option[Int], + private val windowEndOffset: Option[Int], + private val windowRowtimeOffset: Option[Int], private val finalRowArity: Int) extends IncrementalAggregateAllWindowFunction[TimeWindow]( finalRowArity) { @@ -43,7 +45,10 @@ class IncrementalAggregateAllTimeWindowFunction( private var collector: CRowTimeWindowPropertyCollector = _ override def open(parameters: Configuration): Unit = { - collector = new CRowTimeWindowPropertyCollector(windowStartPos, windowEndPos) + collector = new CRowTimeWindowPropertyCollector( + windowStartOffset, + windowEndOffset, + windowRowtimeOffset) super.open(parameters) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala index 809bbfdd5676a..19502302385f4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/IncrementalAggregateTimeWindowFunction.scala @@ -29,15 +29,19 @@ import org.apache.flink.util.Collector /** * Computes the final aggregate value from incrementally computed aggreagtes. * - * @param windowStartPos the start position of window - * @param windowEndPos the end position of window - * @param finalRowArity The arity of the final output row + * @param numGroupingKey the number of grouping keys + * @param numAggregates the number of aggregates + * @param windowStartOffset the offset of the window start property + * @param windowEndOffset the offset of the window end property + * @param windowRowtimeOffset the offset of the window rowtime property + * @param finalRowArity The arity of the final output row. */ class IncrementalAggregateTimeWindowFunction( private val numGroupingKey: Int, private val numAggregates: Int, - private val windowStartPos: Option[Int], - private val windowEndPos: Option[Int], + private val windowStartOffset: Option[Int], + private val windowEndOffset: Option[Int], + private val windowRowtimeOffset: Option[Int], private val finalRowArity: Int) extends IncrementalAggregateWindowFunction[TimeWindow]( numGroupingKey, @@ -47,7 +51,10 @@ class IncrementalAggregateTimeWindowFunction( private var collector: CRowTimeWindowPropertyCollector = _ override def open(parameters: Configuration): Unit = { - collector = new CRowTimeWindowPropertyCollector(windowStartPos, windowEndPos) + collector = new CRowTimeWindowPropertyCollector( + windowStartOffset, + windowEndOffset, + windowRowtimeOffset) super.open(parameters) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala index 8f2ec98a51c03..1d947a029537a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala @@ -22,19 +22,17 @@ import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.apache.flink.api.common.state.ValueState -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.common.state.MapState -import org.apache.flink.api.common.state.MapStateDescriptor +import org.apache.flink.api.common.state._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ListTypeInfo import java.util.{ArrayList, List => JList} import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging /** * Process Function used for the aggregate in bounded proc-time OVER window @@ -52,13 +50,13 @@ class ProcTimeBoundedRangeOver( inputType: TypeInformation[CRow], queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var output: CRow = _ private var accumulatorState: ValueState[Row] = _ private var rowMapState: MapState[Long, JList[Row]] = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -70,6 +68,8 @@ class ProcTimeBoundedRangeOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) + output = new CRow(function.createOutputRow(), true) // We keep the elements received in a MapState indexed based on their ingestion time @@ -120,9 +120,13 @@ class ProcTimeBoundedRangeOver( if (needToCleanupState(timestamp)) { // clean up and return cleanupState(rowMapState, accumulatorState) + function.cleanup() return } + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() + // we consider the original timestamp of events // that have registered this time trigger 1 ms ago @@ -197,4 +201,7 @@ class ProcTimeBoundedRangeOver( accumulatorState.update(accumulators) } + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala index ccc4b461e291e..ccddaa5b10fc7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala @@ -18,25 +18,19 @@ package org.apache.flink.table.runtime.aggregate import java.util +import java.util.{List => JList} +import org.apache.flink.api.common.state.{MapState, MapStateDescriptor, ValueState, ValueStateDescriptor} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.api.common.state.ValueState -import org.apache.flink.api.common.state.MapState -import org.apache.flink.api.common.state.MapStateDescriptor -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.ListTypeInfo -import java.util.{List => JList} - -import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} /** * Process Function for ROW clause processing-time bounded OVER window @@ -53,7 +47,8 @@ class ProcTimeBoundedRowsOver( inputType: TypeInformation[CRow], queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { Preconditions.checkArgument(precedingOffset > 0) @@ -63,7 +58,6 @@ class ProcTimeBoundedRowsOver( private var counterState: ValueState[Long] = _ private var smallestTsState: ValueState[Long] = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -75,6 +69,7 @@ class ProcTimeBoundedRowsOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) output = new CRow(function.createOutputRow(), true) // We keep the elements received in a Map state keyed @@ -194,6 +189,11 @@ class ProcTimeBoundedRowsOver( if (needToCleanupState(timestamp)) { cleanupState(rowMapState, accumulatorState, counterState, smallestTsState) + function.cleanup() } } + + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeSortProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeSortProcessFunction.scala index 2d0b14b3e3f83..1e12060a6bafd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeSortProcessFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeSortProcessFunction.scala @@ -23,10 +23,11 @@ import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.types.Row import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.util.{Collector, Preconditions} - import java.util.ArrayList import java.util.Collections +import org.apache.flink.streaming.api.operators.TimestampedCollector + /** * ProcessFunction to sort on processing time and additional attributes. @@ -75,7 +76,10 @@ class ProcTimeSortProcessFunction( timestamp: Long, ctx: ProcessFunction[CRow, CRow]#OnTimerContext, out: Collector[CRow]): Unit = { - + + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() + val iter = bufferedEvents.get.iterator() // insert all rows into the sort buffer diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala index 7a7b44d378b28..6e4c510578682 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala @@ -17,17 +17,16 @@ */ package org.apache.flink.table.runtime.aggregate +import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor} +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.types.Row -import org.apache.flink.util.Collector -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.api.common.state.ValueState import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.CRow -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector /** * Process Function for processing-time unbounded OVER window @@ -40,11 +39,11 @@ class ProcTimeUnboundedOver( aggregationStateType: RowTypeInfo, queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { private var output: CRow = _ private var state: ValueState[Row] = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -56,6 +55,7 @@ class ProcTimeUnboundedOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) output = new CRow(function.createOutputRow(), true) val stateDescriptor: ValueStateDescriptor[Row] = @@ -97,6 +97,11 @@ class ProcTimeUnboundedOver( if (needToCleanupState(timestamp)) { cleanupState(state) + function.cleanup() } } + + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala index 1a207bbb3a2b5..85c523ea4b4fc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala @@ -24,12 +24,13 @@ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} -import org.slf4j.LoggerFactory /** * Process Function for RANGE clause event-time bounded OVER window @@ -44,9 +45,11 @@ class RowTimeBoundedRangeOver( aggregationStateType: RowTypeInfo, inputRowType: CRowTypeInfo, precedingOffset: Long, + rowTimeIdx: Int, queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { Preconditions.checkNotNull(aggregationStateType) Preconditions.checkNotNull(precedingOffset) @@ -64,7 +67,6 @@ class RowTimeBoundedRangeOver( // to this time stamp. private var dataState: MapState[Long, JList[Row]] = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -76,6 +78,7 @@ class RowTimeBoundedRangeOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) output = new CRow(function.createOutputRow(), true) @@ -114,7 +117,7 @@ class RowTimeBoundedRangeOver( registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime()) // triggering timestamp for trigger calculation - val triggeringTs = ctx.timestamp + val triggeringTs = input.getField(rowTimeIdx).asInstanceOf[Long] val lastTriggeringTs = lastTriggeringTsState.value @@ -156,6 +159,7 @@ class RowTimeBoundedRangeOver( if (noRecordsToProcess) { // we clean the state cleanupState(dataState, accumulatorState, lastTriggeringTsState) + function.cleanup() } else { // There are records left to process because a watermark has not been received yet. // This would only happen if the input stream has stopped. So we don't need to clean up. @@ -166,6 +170,9 @@ class RowTimeBoundedRangeOver( return } + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() + // gets all window data from state for the calculation val inputs: JList[Row] = dataState.get(timestamp) @@ -237,6 +244,10 @@ class RowTimeBoundedRangeOver( // update cleanup timer registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime()) } + + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala index a4b1076c514d6..e120d6b0afda4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala @@ -25,12 +25,13 @@ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.api.StreamQueryConfig -import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} /** * Process Function for ROWS clause event-time bounded OVER window @@ -45,9 +46,11 @@ class RowTimeBoundedRowsOver( aggregationStateType: RowTypeInfo, inputRowType: CRowTypeInfo, precedingOffset: Long, + rowTimeIdx: Int, queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { Preconditions.checkNotNull(aggregationStateType) Preconditions.checkNotNull(precedingOffset) @@ -69,7 +72,6 @@ class RowTimeBoundedRowsOver( // to this time stamp. private var dataState: MapState[Long, JList[Row]] = _ - val LOG = LoggerFactory.getLogger(this.getClass) private var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -81,6 +83,7 @@ class RowTimeBoundedRowsOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) output = new CRow(function.createOutputRow(), true) @@ -123,7 +126,7 @@ class RowTimeBoundedRowsOver( registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime()) // triggering timestamp for trigger calculation - val triggeringTs = ctx.timestamp + val triggeringTs = input.getField(rowTimeIdx).asInstanceOf[Long] val lastTriggeringTs = lastTriggeringTsState.value // check if the data is expired, if not, save the data and register event time timer @@ -165,6 +168,7 @@ class RowTimeBoundedRowsOver( if (noRecordsToProcess) { // We clean the state cleanupState(dataState, accumulatorState, dataCountState, lastTriggeringTsState) + function.cleanup() } else { // There are records left to process because a watermark has not been received yet. // This would only happen if the input stream has stopped. So we don't need to clean up. @@ -175,6 +179,9 @@ class RowTimeBoundedRowsOver( return } + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() + // gets all window data from state for the calculation val inputs: JList[Row] = dataState.get(timestamp) @@ -258,6 +265,10 @@ class RowTimeBoundedRowsOver( // update cleanup timer registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime()) } + + override def close(): Unit = { + function.close() + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeSortProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeSortProcessFunction.scala index 737f32c255aed..0d69355108986 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeSortProcessFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeSortProcessFunction.scala @@ -17,29 +17,28 @@ */ package org.apache.flink.table.runtime.aggregate -import org.apache.flink.api.common.state.ValueState -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.common.state.MapState -import org.apache.flink.api.common.state.MapStateDescriptor +import java.util.{Collections, ArrayList => JArrayList, List => JList} + +import org.apache.flink.api.common.state.{MapState, MapStateDescriptor, ValueState, ValueStateDescriptor} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.ListTypeInfo import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.types.Row import org.apache.flink.util.{Collector, Preconditions} -import java.util.Collections -import java.util.{List => JList, ArrayList => JArrayList} - /** * ProcessFunction to sort on event-time and possibly addtional secondary sort attributes. * * @param inputRowType The data type of the input data. + * @param rowtimeIdx The index of the rowtime field. * @param rowComparator A comparator to sort rows. */ class RowTimeSortProcessFunction( private val inputRowType: CRowTypeInfo, + private val rowtimeIdx: Int, private val rowComparator: Option[CollectionRowComparator]) extends ProcessFunction[CRow, CRow] { @@ -84,7 +83,7 @@ class RowTimeSortProcessFunction( val input = inputC.row // timestamp of the processed row - val rowtime = ctx.timestamp + val rowtime = input.getField(rowtimeIdx).asInstanceOf[Long] val lastTriggeringTs = lastTriggeringTsState.value @@ -105,13 +104,15 @@ class RowTimeSortProcessFunction( } } } - - + override def onTimer( timestamp: Long, ctx: ProcessFunction[CRow, CRow]#OnTimerContext, out: Collector[CRow]): Unit = { - + + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() + // gets all rows for the triggering timestamps val inputs: JList[Row] = dataState.get(timestamp) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala index f38ba93794e94..27d307b540b32 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala @@ -20,18 +20,18 @@ package org.apache.flink.table.runtime.aggregate import java.util import java.util.{List => JList} +import org.apache.flink.api.common.state._ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.ListTypeInfo import org.apache.flink.configuration.Configuration -import org.apache.flink.types.Row import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.util.{Collector, Preconditions} -import org.apache.flink.api.common.state._ -import org.apache.flink.api.java.typeutils.ListTypeInfo import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} -import org.slf4j.LoggerFactory +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector /** @@ -45,9 +45,11 @@ abstract class RowTimeUnboundedOver( genAggregations: GeneratedAggregationsFunction, intermediateType: TypeInformation[Row], inputType: TypeInformation[CRow], + rowTimeIdx: Int, queryConfig: StreamQueryConfig) extends ProcessFunctionWithCleanupState[CRow, CRow](queryConfig) - with Compiler[GeneratedAggregations] { + with Compiler[GeneratedAggregations] + with Logging { protected var output: CRow = _ // state to hold the accumulators of the aggregations @@ -57,7 +59,6 @@ abstract class RowTimeUnboundedOver( // list to sort timestamps to access rows in timestamp order private var sortedTimestamps: util.LinkedList[Long] = _ - val LOG = LoggerFactory.getLogger(this.getClass) protected var function: GeneratedAggregations = _ override def open(config: Configuration) { @@ -69,6 +70,7 @@ abstract class RowTimeUnboundedOver( genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() + function.open(getRuntimeContext) output = new CRow(function.createOutputRow(), true) sortedTimestamps = new util.LinkedList[Long]() @@ -108,11 +110,11 @@ abstract class RowTimeUnboundedOver( // register state-cleanup timer registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime()) - val timestamp = ctx.timestamp() + val timestamp = input.getField(rowTimeIdx).asInstanceOf[Long] val curWatermark = ctx.timerService().currentWatermark() // discard late record - if (timestamp >= curWatermark) { + if (timestamp > curWatermark) { // ensure every key just registers one timer ctx.timerService.registerEventTimeTimer(curWatermark + 1) @@ -148,6 +150,7 @@ abstract class RowTimeUnboundedOver( if (noRecordsToProcess) { // we clean the state cleanupState(rowMapState, accumulatorState) + function.cleanup() } else { // There are records left to process because a watermark has not been received yet. // This would only happen if the input stream has stopped. So we don't need to clean up. @@ -158,8 +161,8 @@ abstract class RowTimeUnboundedOver( return } - Preconditions.checkArgument(out.isInstanceOf[TimestampedCollector[CRow]]) - val collector = out.asInstanceOf[TimestampedCollector[CRow]] + // remove timestamp set outside of ProcessFunction. + out.asInstanceOf[TimestampedCollector[_]].eraseTimestamp() val keyIterator = rowMapState.keys.iterator if (keyIterator.hasNext) { @@ -188,10 +191,9 @@ abstract class RowTimeUnboundedOver( while (!sortedTimestamps.isEmpty) { val curTimestamp = sortedTimestamps.removeFirst() val curRowList = rowMapState.get(curTimestamp) - collector.setAbsoluteTimestamp(curTimestamp) // process the same timestamp datas, the mechanism is different according ROWS or RANGE - processElementsWithSameTimestamp(curRowList, lastAccumulator, collector) + processElementsWithSameTimestamp(curRowList, lastAccumulator, out) rowMapState.remove(curTimestamp) } @@ -240,6 +242,9 @@ abstract class RowTimeUnboundedOver( lastAccumulator: Row, out: Collector[CRow]): Unit + override def close(): Unit = { + function.close() + } } /** @@ -250,11 +255,13 @@ class RowTimeUnboundedRowsOver( genAggregations: GeneratedAggregationsFunction, intermediateType: TypeInformation[Row], inputType: TypeInformation[CRow], + rowTimeIdx: Int, queryConfig: StreamQueryConfig) extends RowTimeUnboundedOver( genAggregations: GeneratedAggregationsFunction, intermediateType, inputType, + rowTimeIdx, queryConfig) { override def processElementsWithSameTimestamp( @@ -266,7 +273,6 @@ class RowTimeUnboundedRowsOver( while (i < curRowList.size) { val curRow = curRowList.get(i) - var j = 0 // copy forwarded fields to output row function.setForwardedFields(curRow, output.row) @@ -290,11 +296,13 @@ class RowTimeUnboundedRangeOver( genAggregations: GeneratedAggregationsFunction, intermediateType: TypeInformation[Row], inputType: TypeInformation[CRow], + rowTimeIdx: Int, queryConfig: StreamQueryConfig) extends RowTimeUnboundedOver( genAggregations: GeneratedAggregationsFunction, intermediateType, inputType, + rowTimeIdx, queryConfig) { override def processElementsWithSameTimestamp( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SortUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SortUtil.scala index 5f83f1d348f97..d62c7b975957e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SortUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/SortUtil.scala @@ -60,6 +60,8 @@ object SortUtil { inputTypeInfo: TypeInformation[Row], execCfg: ExecutionConfig): ProcessFunction[CRow, CRow] = { + val rowtimeIdx = collationSort.getFieldCollations.get(0).getFieldIndex + val collectionRowComparator = if (collationSort.getFieldCollations.size() > 1) { val rowComp = createRowComparator( @@ -76,6 +78,7 @@ object SortUtil { new RowTimeSortProcessFunction( inputCRowType, + rowtimeIdx, collectionRowComparator) } @@ -139,7 +142,7 @@ object SortUtil { } new RowComparator( - new RowSchema(inputType).physicalArity, + new RowSchema(inputType).arity, sortFields.toArray, fieldComps.toArray, new Array[TypeSerializer[AnyRef]](0), // not required because we only compare objects. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/TimeWindowPropertyCollector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/TimeWindowPropertyCollector.scala index 0c8ae007a3f9c..16e4a0b480efd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/TimeWindowPropertyCollector.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/TimeWindowPropertyCollector.scala @@ -29,7 +29,8 @@ import org.apache.flink.util.Collector */ abstract class TimeWindowPropertyCollector[T]( windowStartOffset: Option[Int], - windowEndOffset: Option[Int]) + windowEndOffset: Option[Int], + windowRowtimeOffset: Option[Int]) extends Collector[T] { var wrappedCollector: Collector[T] = _ @@ -55,20 +56,32 @@ abstract class TimeWindowPropertyCollector[T]( SqlFunctions.internalToTimestamp(windowEnd)) } + if (windowRowtimeOffset.isDefined) { + output.setField( + lastFieldPos + windowRowtimeOffset.get, + windowEnd - 1) + } + wrappedCollector.collect(record) } override def close(): Unit = wrappedCollector.close() } -class RowTimeWindowPropertyCollector(startOffset: Option[Int], endOffset: Option[Int]) - extends TimeWindowPropertyCollector[Row](startOffset, endOffset) { +class RowTimeWindowPropertyCollector( + startOffset: Option[Int], + endOffset: Option[Int], + rowtimeOffset: Option[Int]) + extends TimeWindowPropertyCollector[Row](startOffset, endOffset, rowtimeOffset) { override def getRow(record: Row): Row = record } -class CRowTimeWindowPropertyCollector(startOffset: Option[Int], endOffset: Option[Int]) - extends TimeWindowPropertyCollector[CRow](startOffset, endOffset) { +class CRowTimeWindowPropertyCollector( + startOffset: Option[Int], + endOffset: Option[Int], + rowtimeOffset: Option[Int]) + extends TimeWindowPropertyCollector[CRow](startOffset, endOffset, rowtimeOffset) { override def getRow(record: CRow): Row = record.row } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapFunction.scala new file mode 100644 index 0000000000000..6b4f87e54d2a4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapFunction.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.conversion + +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +import _root_.java.lang.{Boolean => JBool} + +/** + * Convert [[CRow]] to a [[JTuple2]] containing a [[Row]]. + */ +class CRowToJavaTupleMapFunction extends MapFunction[CRow, JTuple2[JBool, Row]] { + + val out: JTuple2[JBool, Row] = new JTuple2(true.asInstanceOf[JBool], null.asInstanceOf[Row]) + + override def map(cRow: CRow): JTuple2[JBool, Row] = { + out.f0 = cRow.change + out.f1 = cRow.row + out + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputTupleOutputMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapRunner.scala similarity index 64% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputTupleOutputMapRunner.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapRunner.scala index 7c964371827f9..a9966f8a562f4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowInputTupleOutputMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToJavaTupleMapRunner.scala @@ -16,32 +16,31 @@ * limitations under the License. */ -package org.apache.flink.table.runtime +package org.apache.flink.table.runtime.conversion import java.lang.{Boolean => JBool} import org.apache.flink.api.common.functions.{MapFunction, RichMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory -import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} /** - * Convert [[CRow]] to a [[JTuple2]] + * Convert [[CRow]] to a [[JTuple2]]. */ -class CRowInputJavaTupleOutputMapRunner( +class CRowToJavaTupleMapRunner( name: String, code: String, @transient var returnType: TypeInformation[JTuple2[JBool, Any]]) extends RichMapFunction[CRow, Any] - with ResultTypeQueryable[JTuple2[JBool, Any]] - with Compiler[MapFunction[Row, Any]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with ResultTypeQueryable[JTuple2[JBool, Any]] + with Compiler[MapFunction[Row, Any]] + with Logging { private var function: MapFunction[Row, Any] = _ private var tupleWrapper: JTuple2[JBool, Any] = _ @@ -62,31 +61,3 @@ class CRowInputJavaTupleOutputMapRunner( override def getProducedType: TypeInformation[JTuple2[JBool, Any]] = returnType } - -/** - * Convert [[CRow]] to a [[Tuple2]] - */ -class CRowInputScalaTupleOutputMapRunner( - name: String, - code: String, - @transient var returnType: TypeInformation[(Boolean, Any)]) - extends RichMapFunction[CRow, (Boolean, Any)] - with ResultTypeQueryable[(Boolean, Any)] - with Compiler[MapFunction[Row, Any]] { - - val LOG = LoggerFactory.getLogger(this.getClass) - - private var function: MapFunction[Row, Any] = _ - - override def open(parameters: Configuration): Unit = { - LOG.debug(s"Compiling MapFunction: $name \n\n Code:\n$code") - val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) - LOG.debug("Instantiating MapFunction.") - function = clazz.newInstance() - } - - override def map(in: CRow): (Boolean, Any) = - (in.change, function.map(in.row)) - - override def getProducedType: TypeInformation[(Boolean, Any)] = returnType -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToRowMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToRowMapFunction.scala new file mode 100644 index 0000000000000..050f15f9a1170 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToRowMapFunction.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.conversion + +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * Maps a CRow to a Row. + */ +class CRowToRowMapFunction extends MapFunction[CRow, Row] { + + override def map(value: CRow): Row = value.row + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapFunction.scala new file mode 100644 index 0000000000000..6461cc4a5307f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapFunction.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.conversion + +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * Convert [[CRow]] to a [[Tuple2]]. + */ +class CRowToScalaTupleMapFunction extends MapFunction[CRow, (Boolean, Row)] { + + override def map(cRow: CRow): (Boolean, Row) = { + (cRow.change, cRow.row) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapRunner.scala similarity index 69% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputMapRunner.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapRunner.scala index cb8f69556e351..5b122f3176e1d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/conversion/CRowToScalaTupleMapRunner.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.runtime +package org.apache.flink.table.runtime.conversion import org.apache.flink.api.common.functions.{MapFunction, RichMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation @@ -24,37 +24,32 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory /** - * MapRunner with [[CRow]] output. + * Convert [[CRow]] to a [[Tuple2]]. */ -class CRowOutputMapRunner( - name: String, - code: String, - @transient var returnType: TypeInformation[CRow]) - extends RichMapFunction[Any, CRow] - with ResultTypeQueryable[CRow] - with Compiler[MapFunction[Any, Row]] { +class CRowToScalaTupleMapRunner( + name: String, + code: String, + @transient var returnType: TypeInformation[(Boolean, Any)]) + extends RichMapFunction[CRow, (Boolean, Any)] + with ResultTypeQueryable[(Boolean, Any)] + with Compiler[MapFunction[Row, Any]] + with Logging { - val LOG = LoggerFactory.getLogger(this.getClass) - - private var function: MapFunction[Any, Row] = _ - private var outCRow: CRow = _ + private var function: MapFunction[Row, Any] = _ override def open(parameters: Configuration): Unit = { LOG.debug(s"Compiling MapFunction: $name \n\n Code:\n$code") val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) LOG.debug("Instantiating MapFunction.") function = clazz.newInstance() - outCRow = new CRow(null, true) } - override def map(in: Any): CRow = { - outCRow.row = function.map(in) - outCRow - } + override def map(in: CRow): (Boolean, Any) = + (in.change, function.map(in.row)) - override def getProducedType: TypeInformation[CRow] = returnType + override def getProducedType: TypeInformation[(Boolean, Any)] = returnType } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/CRowValuesInputFormat.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/CRowValuesInputFormat.scala index 1cb3a6e0283ce..fff5924931871 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/CRowValuesInputFormat.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/CRowValuesInputFormat.scala @@ -24,8 +24,8 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.core.io.GenericInputSplit import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory class CRowValuesInputFormat( name: String, @@ -34,9 +34,8 @@ class CRowValuesInputFormat( extends GenericInputFormat[CRow] with NonParallelInput with ResultTypeQueryable[CRow] - with Compiler[GenericInputFormat[Row]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[GenericInputFormat[Row]] + with Logging { private var format: GenericInputFormat[Row] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/ValuesInputFormat.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/ValuesInputFormat.scala index 43ce6056610cd..858146a450a83 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/ValuesInputFormat.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/io/ValuesInputFormat.scala @@ -21,10 +21,10 @@ package org.apache.flink.table.runtime.io import org.apache.flink.api.common.io.{GenericInputFormat, NonParallelInput} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.table.codegen.Compiler import org.apache.flink.core.io.GenericInputSplit +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row -import org.slf4j.LoggerFactory class ValuesInputFormat( name: String, @@ -33,9 +33,8 @@ class ValuesInputFormat( extends GenericInputFormat[Row] with NonParallelInput with ResultTypeQueryable[Row] - with Compiler[GenericInputFormat[Row]] { - - val LOG = LoggerFactory.getLogger(this.getClass) + with Compiler[GenericInputFormat[Row]] + with Logging { private var format: GenericInputFormat[Row] = _ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala index e62a18f05816d..824037630f285 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala @@ -30,9 +30,9 @@ import org.apache.flink.streaming.api.functions.co.CoProcessFunction import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.CRowWrappingCollector import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging import org.apache.flink.types.Row import org.apache.flink.util.Collector -import org.slf4j.LoggerFactory /** * A CoProcessFunction to support stream join stream, currently just support inner-join @@ -55,7 +55,8 @@ class ProcTimeWindowInnerJoin( private val genJoinFuncName: String, private val genJoinFuncCode: String) extends CoProcessFunction[CRow, CRow, CRow] - with Compiler[FlatJoinFunction[Row, Row, Row]]{ + with Compiler[FlatJoinFunction[Row, Row, Row]] + with Logging { private var cRowWrapper: CRowWrappingCollector = _ @@ -80,8 +81,6 @@ class ProcTimeWindowInnerJoin( private val leftStreamWinSize: Long = if (leftLowerBound <= 0) -leftLowerBound else -1 private val rightStreamWinSize: Long = if (leftUpperBound >= 0) leftUpperBound else -1 - val LOG = LoggerFactory.getLogger(this.getClass) - override def open(config: Configuration) { LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + s"Code:\n$genJoinFuncCode") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala index 379b8d2557226..b5661139d2efe 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala @@ -418,8 +418,8 @@ object WindowJoinUtil { Some(rightType)) val conversion = generator.generateConverterResultExpression( - returnType.physicalTypeInfo, - returnType.physicalType.getFieldNames.asScala) + returnType.typeInfo, + returnType.fieldNames) // if other condition is none, then output the result directly val body = otherCondition match { @@ -429,9 +429,8 @@ object WindowJoinUtil { |${generator.collectorTerm}.collect(${conversion.resultTerm}); |""".stripMargin case Some(remainCondition) => - // map logical field accesses to physical accesses - val physicalCondition = returnType.mapRexNode(remainCondition) - val genCond = generator.generateExpression(physicalCondition) + // generate code for remaining condition + val genCond = generator.generateExpression(remainCondition) s""" |${genCond.code} |if (${genCond.resultTerm}) { @@ -445,7 +444,7 @@ object WindowJoinUtil { ruleDescription, classOf[FlatJoinFunction[Row, Row, Row]], body, - returnType.physicalTypeInfo) + returnType.typeInfo) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala new file mode 100644 index 0000000000000..f25de256e96ac --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.operators + +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.watermark.Watermark + +/** + * A [[KeyedCoProcessOperator]] that supports holding back watermarks with a static delay. + */ +class KeyedCoProcessOperatorWithWatermarkDelay[KEY, IN1, IN2, OUT]( + private val flatMapper: CoProcessFunction[IN1, IN2, OUT], + private val watermarkDelay: Long = 0L) + extends KeyedCoProcessOperator[KEY, IN1, IN2, OUT](flatMapper) { + + /** emits watermark without delay */ + def emitWithoutDelay(mark: Watermark): Unit = output.emitWatermark(mark) + + /** emits watermark with delay */ + def emitWithDelay(mark: Watermark): Unit = { + output.emitWatermark(new Watermark(mark.getTimestamp - watermarkDelay)) + } + + if (watermarkDelay < 0) { + throw new IllegalArgumentException("The watermark delay should be non-negative.") + } + + // choose watermark emitter + val emitter: Watermark => Unit = if (watermarkDelay == 0) { + emitWithoutDelay + } else { + emitWithDelay + } + + @throws[Exception] + override def processWatermark(mark: Watermark) { + if (timeServiceManager != null) timeServiceManager.advanceWatermark(mark) + + emitter(mark) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelay.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelay.scala new file mode 100644 index 0000000000000..74b4773005a65 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelay.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators + +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.KeyedProcessOperator +import org.apache.flink.streaming.api.watermark.Watermark + +/** + * A [[KeyedProcessOperator]] that supports holding back watermarks with a static delay. + */ +class KeyedProcessOperatorWithWatermarkDelay[KEY, IN, OUT]( + private val function: ProcessFunction[IN, OUT], + private var watermarkDelay: Long = 0L) + extends KeyedProcessOperator[KEY, IN, OUT](function) { + + /** emits watermark without delay */ + def emitWithoutDelay(mark: Watermark): Unit = output.emitWatermark(mark) + + /** emits watermark with delay */ + def emitWithDelay(mark: Watermark): Unit = { + output.emitWatermark(new Watermark(mark.getTimestamp - watermarkDelay)) + } + + if (watermarkDelay < 0) { + throw new IllegalArgumentException("The watermark delay should be non-negative.") + } + + // choose watermark emitter + val emitter: Watermark => Unit = if (watermarkDelay == 0) { + emitWithoutDelay + } else { + emitWithDelay + } + + @throws[Exception] + override def processWatermark(mark: Watermark) { + if (timeServiceManager != null) timeServiceManager.advanceWatermark(mark) + + emitter(mark) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TimeIndicatorTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TimeIndicatorTypeInfo.scala index 083f1ebfd0ec7..824f3fb3ca4d5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TimeIndicatorTypeInfo.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/typeutils/TimeIndicatorTypeInfo.scala @@ -20,13 +20,14 @@ package org.apache.flink.table.typeutils import java.sql.Timestamp +import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo -import org.apache.flink.api.common.typeutils.TypeComparator -import org.apache.flink.api.common.typeutils.base.{SqlTimestampComparator, SqlTimestampSerializer} +import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.common.typeutils.base.{LongSerializer, SqlTimestampComparator, SqlTimestampSerializer} /** * Type information for indicating event or processing time. However, it behaves like a - * regular SQL timestamp. + * regular SQL timestamp but is serialized as Long. */ class TimeIndicatorTypeInfo(val isEventTime: Boolean) extends SqlTimeTypeInfo[Timestamp]( @@ -34,12 +35,21 @@ class TimeIndicatorTypeInfo(val isEventTime: Boolean) SqlTimestampSerializer.INSTANCE, classOf[SqlTimestampComparator].asInstanceOf[Class[TypeComparator[Timestamp]]]) { + // this replaces the effective serializer by a LongSerializer + // it is a hacky but efficient solution to keep the object creation overhead low but still + // be compatible with the corresponding SqlTimestampTypeInfo + override def createSerializer(executionConfig: ExecutionConfig): TypeSerializer[Timestamp] = + LongSerializer.INSTANCE.asInstanceOf[TypeSerializer[Timestamp]] + override def toString: String = s"TimeIndicatorTypeInfo(${if (isEventTime) "rowtime" else "proctime" })" } object TimeIndicatorTypeInfo { + val ROWTIME_MARKER: Int = -1 + val PROCTIME_MARKER: Int = -2 + val ROWTIME_INDICATOR = new TimeIndicatorTypeInfo(true) val PROCTIME_INDICATOR = new TimeIndicatorTypeInfo(false) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/Logging.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/Logging.scala new file mode 100644 index 0000000000000..b6be99e6710d9 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/util/Logging.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.util + +import org.slf4j.{Logger, LoggerFactory} + +/** + * Helper class to ensure the logger is never serialized. + */ +trait Logging { + @transient lazy val LOG: Logger = LoggerFactory.getLogger(getClass) +} diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java index 4d06bc2c88eec..14f812aee4f32 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/runtime/utils/JavaUserDefinedAggFunctions.java @@ -18,7 +18,10 @@ package org.apache.flink.table.runtime.utils; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.functions.AggregateFunction; import java.util.Iterator; @@ -135,4 +138,200 @@ public void retract(WeightedAvgAccum accumulator, int iValue, int iWeight) { accumulator.count -= iWeight; } } + + /** + * CountDistinct accumulator. + */ + public static class CountDistinctAccum { + public MapView map; + public long count; + } + + /** + * CountDistinct aggregate. + */ + public static class CountDistinct extends AggregateFunction { + + @Override + public CountDistinctAccum createAccumulator() { + CountDistinctAccum accum = new CountDistinctAccum(); + accum.map = new MapView<>(Types.STRING, Types.INT); + accum.count = 0L; + return accum; + } + + //Overloaded accumulate method + public void accumulate(CountDistinctAccum accumulator, String id) { + try { + Integer cnt = accumulator.map.get(id); + if (cnt != null) { + cnt += 1; + accumulator.map.put(id, cnt); + } else { + accumulator.map.put(id, 1); + accumulator.count += 1; + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + //Overloaded accumulate method + public void accumulate(CountDistinctAccum accumulator, long id) { + try { + Integer cnt = accumulator.map.get(String.valueOf(id)); + if (cnt != null) { + cnt += 1; + accumulator.map.put(String.valueOf(id), cnt); + } else { + accumulator.map.put(String.valueOf(id), 1); + accumulator.count += 1; + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Override + public Long getValue(CountDistinctAccum accumulator) { + return accumulator.count; + } + } + + /** + * CountDistinct aggregate with merge. + */ + public static class CountDistinctWithMerge extends CountDistinct { + + //Overloaded merge method + public void merge(CountDistinctAccum acc, Iterable it) { + Iterator iter = it.iterator(); + while (iter.hasNext()) { + CountDistinctAccum mergeAcc = iter.next(); + acc.count += mergeAcc.count; + + try { + Iterator itrMap = mergeAcc.map.keys().iterator(); + while (itrMap.hasNext()) { + String key = itrMap.next(); + Integer cnt = mergeAcc.map.get(key); + if (acc.map.contains(key)) { + acc.map.put(key, acc.map.get(key) + cnt); + } else { + acc.map.put(key, cnt); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + } + } + } + + /** + * CountDistinct aggregate with merge and reset. + */ + public static class CountDistinctWithMergeAndReset extends CountDistinctWithMerge { + + //Overloaded retract method + public void resetAccumulator(CountDistinctAccum acc) { + acc.map.clear(); + acc.count = 0; + } + } + + /** + * CountDistinct aggregate with retract. + */ + public static class CountDistinctWithRetractAndReset extends CountDistinct { + + //Overloaded retract method + public void retract(CountDistinctAccum accumulator, long id) { + try { + Integer cnt = accumulator.map.get(String.valueOf(id)); + if (cnt != null) { + cnt -= 1; + if (cnt <= 0) { + accumulator.map.remove(String.valueOf(id)); + accumulator.count -= 1; + } else { + accumulator.map.put(String.valueOf(id), cnt); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + //Overloaded retract method + public void resetAccumulator(CountDistinctAccum acc) { + acc.map.clear(); + acc.count = 0; + } + } + + /** + * Accumulator for test DataView. + */ + public static class DataViewTestAccum { + public MapView map; + public MapView map2; // for test not initialized + public long count; + private ListView list = new ListView<>(Types.LONG); + + public ListView getList() { + return list; + } + + public void setList(ListView list) { + this.list = list; + } + } + + public static boolean isCloseCalled = false; + + /** + * Aggregate for test DataView. + */ + public static class DataViewTestAgg extends AggregateFunction { + @Override + public DataViewTestAccum createAccumulator() { + DataViewTestAccum accum = new DataViewTestAccum(); + accum.map = new MapView<>(Types.STRING, Types.INT); + accum.count = 0L; + return accum; + } + + // Overloaded accumulate method + public void accumulate(DataViewTestAccum accumulator, String a, long b) { + try { + if (!accumulator.map.contains(a)) { + accumulator.map.put(a, 1); + accumulator.count++; + } + + accumulator.list.add(b); + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Override + public Long getValue(DataViewTestAccum accumulator) { + long sum = accumulator.count; + try { + for (Long value : accumulator.list.get()) { + sum += value; + } + } catch (Exception e) { + e.printStackTrace(); + } + return sum; + } + + @Override + public void close() { + isCloseCalled = true; + } + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala index a5a1319d0c96e..d20002a9077b4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala @@ -41,7 +41,7 @@ class SortTest extends TableTestBase { unaryNode("DataStreamSort", streamTableNode(0), term("orderBy", "proctime ASC", "c ASC")), - term("select", "a", "TIME_MATERIALIZATION(proctime) AS proctime", "c")) + term("select", "a", "PROCTIME(proctime) AS proctime", "c")) streamUtil.verifySql(sqlQuery, expected) } @@ -57,7 +57,7 @@ class SortTest extends TableTestBase { unaryNode("DataStreamSort", streamTableNode(0), term("orderBy", "rowtime ASC, c ASC")), - term("select", "a", "TIME_MATERIALIZATION(rowtime) AS rowtime", "c")) + term("select", "a", "rowtime", "c")) streamUtil.verifySql(sqlQuery, expected) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala index 5d4386ca66202..696706159b4c9 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala @@ -43,7 +43,7 @@ class TableSourceTest extends TableTestBase { unaryNode( "DataStreamCalc", "StreamTableSourceScan(table=[[rowTimeT]], fields=[id, val, name, addTime])", - term("select", "TIME_MATERIALIZATION(addTime) AS addTime", "id", "name", "val") + term("select", "addTime", "id", "name", "val") ) util.verifyTable(t, expected) } @@ -90,7 +90,7 @@ class TableSourceTest extends TableTestBase { unaryNode( "DataStreamCalc", "StreamTableSourceScan(table=[[procTimeT]], fields=[id, val, name, pTime])", - term("select", "TIME_MATERIALIZATION(pTime) AS pTime", "id", "name", "val") + term("select", "PROCTIME(pTime) AS pTime", "id", "name", "val") ) util.verifyTable(t, expected) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/FlinkTableValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/FlinkTableValidationTest.scala new file mode 100644 index 0000000000000..a845f5c1f1bbf --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/FlinkTableValidationTest.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.api.validation + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test + +class FlinkTableValidationTest extends TableTestBase { + + @Test + def testFieldNamesDuplicate() { + + thrown.expect(classOf[TableException]) + thrown.expectMessage("Field names must be unique.\n" + + "List of duplicate fields: [a].\n" + + "List of all fields: [a, a, b].") + + val util = batchTestUtil() + util.addTable[(Int, Int, String)]("MyTable", 'a, 'a, 'b) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/TableSchemaValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/TableSchemaValidationTest.scala index 1a7815aa1b257..c430e59efa0c1 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/TableSchemaValidationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/TableSchemaValidationTest.scala @@ -24,12 +24,35 @@ import org.junit.Test class TableSchemaValidationTest extends TableTestBase { - @Test(expected = classOf[TableException]) - def testInvalidSchema() { + @Test + def testColumnNameAndColumnTypeNotEqual() { + thrown.expect(classOf[TableException]) + thrown.expectMessage( + "Number of field names and field types must be equal.\n" + + "Number of names is 3, number of types is 2.\n" + + "List of field names: [a, b, c].\n" + + "List of field types: [Integer, String].") + val fieldNames = Array("a", "b", "c") val typeInfos: Array[TypeInformation[_]] = Array( BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO) new TableSchema(fieldNames, typeInfos) } + + @Test + def testColumnNamesDuplicate() { + thrown.expect(classOf[TableException]) + thrown.expectMessage( + "Field names must be unique.\n" + + "List of duplicate fields: [a].\n" + + "List of all fields: [a, a, c].") + + val fieldNames = Array("a", "a", "c") + val typeInfos: Array[TypeInformation[_]] = Array( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO) + new TableSchema(fieldNames, typeInfos) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/UserDefinedFunctionValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/UserDefinedFunctionValidationTest.scala new file mode 100644 index 0000000000000..aeb226eb35a80 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/validation/UserDefinedFunctionValidationTest.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.api.validation + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.expressions.utils.Func0 +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.OverAgg0 +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test + +class UserDefinedFunctionValidationTest extends TableTestBase { + + @Test + def testScalarFunctionOperandTypeCheck(): Unit = { + thrown.expect(classOf[ValidationException]) + thrown.expectMessage( + "Given parameters of function 'func' do not match any signature. \n" + + "Actual: (java.lang.String) \n" + + "Expected: (int)") + val util = streamTestUtil() + util.addTable[(Int, String)]("t", 'a, 'b) + util.tableEnv.registerFunction("func", Func0) + util.verifySql("select func(b) from t", "n/a") + } + + @Test + def testAggregateFunctionOperandTypeCheck(): Unit = { + thrown.expect(classOf[ValidationException]) + thrown.expectMessage( + "Given parameters of function do not match any signature. \n" + + "Actual: (java.lang.String, java.lang.Integer) \n" + + "Expected: (org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions" + + ".Accumulator0, long, int)") + + val util = streamTestUtil() + val agg = new OverAgg0 + util.addTable[(Int, String)]("t", 'a, 'b) + util.tableEnv.registerFunction("agg", agg) + util.verifySql("select agg(b, a) from t", "n/a") + } + +} + diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala new file mode 100644 index 0000000000000..3f70bcef21e0e --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/ListViewSerializerTest.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import java.lang.Long +import java.util.Random + +import org.apache.flink.api.common.typeutils.base.{ListSerializer, LongSerializer} +import org.apache.flink.api.common.typeutils.{SerializerTestBase, TypeSerializer} +import org.apache.flink.table.api.dataview.ListView + +/** + * A test for the [[ListViewSerializer]]. + */ +class ListViewSerializerTest extends SerializerTestBase[ListView[Long]] { + + override protected def createSerializer(): TypeSerializer[ListView[Long]] = { + val listSerializer = new ListSerializer[Long](LongSerializer.INSTANCE) + new ListViewSerializer[Long](listSerializer) + } + + override protected def getLength: Int = -1 + + override protected def getTypeClass: Class[ListView[Long]] = classOf[ListView[Long]] + + override protected def getTestData: Array[ListView[Long]] = { + val rnd = new Random(321332) + + // empty + val listview1 = new ListView[Long]() + + // single element + val listview2 = new ListView[Long]() + listview2.add(12345L) + + // multiple elements + val listview3 = new ListView[Long]() + var i = 0 + while (i < rnd.nextInt(200)) { + listview3.add(rnd.nextLong) + i += 1 + } + + Array[ListView[Long]](listview1, listview2, listview3) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala new file mode 100644 index 0000000000000..15f9b0240b282 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/dataview/MapViewSerializerTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.dataview + +import java.lang.Long +import java.util.Random + +import org.apache.flink.api.common.typeutils.base.{LongSerializer, MapSerializer, StringSerializer} +import org.apache.flink.api.common.typeutils.{SerializerTestBase, TypeSerializer} +import org.apache.flink.table.api.dataview.MapView + +/** + * A test for the [[MapViewSerializer]]. + */ +class MapViewSerializerTest extends SerializerTestBase[MapView[Long, String]] { + + override protected def createSerializer(): TypeSerializer[MapView[Long, String]] = { + val mapSerializer = new MapSerializer[Long, String](LongSerializer.INSTANCE, + StringSerializer.INSTANCE) + new MapViewSerializer[Long, String](mapSerializer) + } + + override protected def getLength: Int = -1 + + override protected def getTypeClass: Class[MapView[Long, String]] = + classOf[MapView[Long, String]] + + override protected def getTestData: Array[MapView[Long, String]] = { + val rnd = new Random(321654) + + // empty + val mapview1 = new MapView[Long, String]() + + // single element + val mapview2 = new MapView[Long, String]() + mapview2.put(12345L, "12345L") + + // multiple elements + val mapview3 = new MapView[Long, String]() + var i = 0 + while (i < rnd.nextInt(200)) { + mapview3.put(rnd.nextLong, Long.toString(rnd.nextLong)) + i += 1 + } + + // null-value maps + val mapview4 = new MapView[Long, String]() + mapview4.put(999L, null) + + Array[MapView[Long, String]](mapview1, mapview2, mapview3, mapview4) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala index 1d761c36b68d8..8fae11a8f4de5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/TemporalTypesTest.scala @@ -538,10 +538,31 @@ class TemporalTypesTest extends ExpressionTestBase { "1990-09-12 10:20:45.123") } + @Test + def testSelectNullValues(): Unit ={ + testAllApis( + 'f11, + "f11", + "f11", + "null" + ) + testAllApis( + 'f12, + "f12", + "f12", + "null" + ) + testAllApis( + 'f13, + "f13", + "f13", + "null" + ) + } // ---------------------------------------------------------------------------------------------- def testData: Row = { - val testData = new Row(11) + val testData = new Row(14) testData.setField(0, Date.valueOf("1990-10-14")) testData.setField(1, Time.valueOf("10:20:45")) testData.setField(2, Timestamp.valueOf("1990-10-14 10:20:45.123")) @@ -553,6 +574,10 @@ class TemporalTypesTest extends ExpressionTestBase { testData.setField(8, 1467012213000L) testData.setField(9, 24) testData.setField(10, 12000L) + // null selection test. + testData.setField(11, null) + testData.setField(12, null) + testData.setField(13, null) testData } @@ -568,6 +593,9 @@ class TemporalTypesTest extends ExpressionTestBase { Types.INT, Types.LONG, Types.INTERVAL_MONTHS, - Types.INTERVAL_MILLIS).asInstanceOf[TypeInformation[Any]] + Types.INTERVAL_MILLIS, + Types.SQL_DATE, + Types.SQL_TIME, + Types.SQL_TIMESTAMP).asInstanceOf[TypeInformation[Any]] } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index 9b3407e5c12ef..71ff70d1d04ce 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -118,6 +118,16 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "-1") } + @Test + def testDoubleQuoteParameters(): Unit = { + val hello = "\"\"" + testAllApis( + Func3(42, hello), + s"Func3(42, '$hello')", + s"Func3(42, '$hello')", + s"42 and $hello") + } + @Test def testResults(): Unit = { testAllApis( diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala index 840be171614b1..c2a01c68ad441 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/RexProgramExtractorTest.scala @@ -20,12 +20,13 @@ package org.apache.flink.table.plan import java.math.BigDecimal -import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} +import org.apache.calcite.plan.RelOptUtil +import org.apache.calcite.rex._ import org.apache.calcite.sql.SqlPostfixOperator import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, INTEGER, VARCHAR} import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.flink.table.expressions._ -import org.apache.flink.table.plan.util.RexProgramExtractor +import org.apache.flink.table.plan.util.{RexNodeToExpressionConverter, RexProgramExtractor} import org.apache.flink.table.utils.InputTypeBuilder.inputOf import org.apache.flink.table.validate.FunctionCatalog import org.hamcrest.CoreMatchers.is @@ -33,6 +34,7 @@ import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat} import org.junit.Test import scala.collection.JavaConverters._ +import scala.collection.mutable class RexProgramExtractorTest extends RexProgramTestBase { @@ -104,6 +106,8 @@ class RexProgramExtractorTest extends RexProgramTestBase { val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3) // 100 val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + // 200 + val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(200L)) // a = amount < 100 val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3)) @@ -113,15 +117,17 @@ class RexProgramExtractorTest extends RexProgramTestBase { val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3)) // d = amount <= id val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)) + // e = price == 200 + val e = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t4)) // a AND b val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b).asJava)) - // (a AND b) or c - val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c).asJava)) - // not d + // (a AND b) OR c OR e + val or = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.OR, List(and, c, e).asJava)) + // NOT d val not = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.NOT, List(d).asJava)) - // (a AND b) OR c) AND (NOT d) + // (a AND b) OR c OR e) AND (NOT d) builder.addCondition(builder.addExpr( rexBuilder.makeCall(SqlStdOperatorTable.AND, List(or, not).asJava))) @@ -134,13 +140,64 @@ class RexProgramExtractorTest extends RexProgramTestBase { functionCatalog) val expected: Array[Expression] = Array( - ExpressionParser.parseExpression("amount < 100 || price == 100"), - ExpressionParser.parseExpression("id > 100 || price == 100"), + ExpressionParser.parseExpression("amount < 100 || price == 100 || price === 200"), + ExpressionParser.parseExpression("id > 100 || price == 100 || price === 200"), ExpressionParser.parseExpression("!(amount <= id)")) assertExpressionArrayEquals(expected, convertedExpressions) assertEquals(0, unconvertedRexNodes.length) } + @Test + def testExtractANDExpressions(): Unit = { + val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + + // amount + val t0 = rexBuilder.makeInputRef(allFieldTypes.get(2), 2) + // id + val t1 = rexBuilder.makeInputRef(allFieldTypes.get(1), 1) + // price + val t2 = rexBuilder.makeInputRef(allFieldTypes.get(3), 3) + // 100 + val t3 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + + // a = amount < 100 + val a = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t0, t3)) + // b = id > 100 + val b = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t3)) + // c = price == 100 + val c = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, t2, t3)) + // d = amount <= id + val d = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, t0, t1)) + + // a AND b AND c AND d + val and = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(a, b, c, d).asJava)) + + builder.addCondition(builder.addExpr(and)) + + val program = builder.getProgram + val relBuilder: RexBuilder = new RexBuilder(typeFactory) + + val expanded = program.expandLocalRef(program.getCondition) + + var convertedExpressions = new mutable.ArrayBuffer[Expression] + val unconvertedRexNodes = new mutable.ArrayBuffer[RexNode] + val inputNames = program.getInputRowType.getFieldNames.asScala.toArray + val converter = new RexNodeToExpressionConverter(inputNames, functionCatalog) + + expanded.accept(converter) match { + case Some(expression) => + convertedExpressions += expression + case None => unconvertedRexNodes += expanded + } + + val expected: Array[Expression] = Array( + ExpressionParser.parseExpression("amount < 100 && id > 100 && price === 100 && amount <= id")) + + assertExpressionArrayEquals(expected, convertedExpressions.toArray) + assertEquals(0, unconvertedRexNodes.length) + } + @Test def testExtractArithmeticConditions(): Unit = { val inputRowType = typeFactory.createStructType(allFieldTypes, allFieldNames) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala index b17debea45367..ab80c65c1327a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala @@ -48,7 +48,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "FLOOR(TIME_MATERIALIZATION(rowtime)", "FLAG(DAY)) AS rowtime"), + term("select", "FLOOR(CAST(rowtime)", "FLAG(DAY)) AS rowtime"), term("where", ">(long, 0)") ) @@ -65,8 +65,8 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long", "int", - "TIME_MATERIALIZATION(proctime) AS proctime") + term("select", "rowtime", "long", "int", + "PROCTIME(proctime) AS proctime") ) util.verifyTable(result, expected) @@ -84,7 +84,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime"), + term("select", "rowtime"), term("where", ">(rowtime, 1990-12-02 12:11:11)") ) @@ -107,7 +107,7 @@ class TimeIndicatorConversionTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "long", "TIME_MATERIALIZATION(rowtime) AS rowtime") + term("select", "long", "CAST(rowtime) AS rowtime") ), term("groupBy", "rowtime"), term("select", "rowtime", "COUNT(long) AS TMP_0") @@ -134,7 +134,7 @@ class TimeIndicatorConversionTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime", "long") + term("select", "CAST(rowtime) AS rowtime", "long") ), term("groupBy", "long"), term("select", "long", "MIN(rowtime) AS TMP_0") @@ -159,16 +159,13 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCorrelate", streamTableNode(0), term("invocation", - s"${func.functionIdentifier}(TIME_MATERIALIZATION($$0), TIME_MATERIALIZATION($$3), '')"), + s"${func.functionIdentifier}(CAST($$0):TIMESTAMP(3) NOT NULL, PROCTIME($$3), '')"), term("function", func), term("rowType", "RecordType(TIME ATTRIBUTE(ROWTIME) rowtime, BIGINT long, INTEGER int, " + "TIME ATTRIBUTE(PROCTIME) proctime, VARCHAR(65536) s)"), term("joinType", "INNER") ), - term("select", - "TIME_MATERIALIZATION(rowtime) AS rowtime", - "TIME_MATERIALIZATION(proctime) AS proctime", - "s") + term("select", "rowtime", "PROCTIME(proctime) AS proctime", "s") ) util.verifyTable(result, expected) @@ -219,7 +216,7 @@ class TimeIndicatorConversionTest extends TableTestBase { streamTableNode(0), term("union all", "rowtime", "long", "int") ), - term("select", "TIME_MATERIALIZATION(rowtime) AS rowtime") + term("select", "rowtime") ) util.verifyTable(result, expected) @@ -287,7 +284,7 @@ class TimeIndicatorConversionTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "TIME_MATERIALIZATION(proctime) AS proctime", "long") + term("select", "PROCTIME(proctime) AS proctime", "long") ), term("groupBy", "proctime"), term("select", "proctime", "COUNT(long) AS EXPR$0") @@ -312,7 +309,7 @@ class TimeIndicatorConversionTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "long", "TIME_MATERIALIZATION(proctime) AS proctime") + term("select", "long", "PROCTIME(proctime) AS proctime") ), term("groupBy", "long"), term("select", "long", "MIN(proctime) AS EXPR$0") @@ -368,7 +365,7 @@ class TimeIndicatorConversionTest extends TableTestBase { unaryNode( "DataStreamCalc", streamTableNode(0), - term("select", "long", "rowtime", "TIME_MATERIALIZATION(rowtime) AS $f2") + term("select", "long", "rowtime", "CAST(rowtime) AS rowtime0") ), term("groupBy", "long"), term( @@ -377,7 +374,7 @@ class TimeIndicatorConversionTest extends TableTestBase { 'w$, 'rowtime, 100.millis)), - term("select", "long", "MIN($f2) AS EXPR$0") + term("select", "long", "MIN(rowtime0) AS EXPR$0") ), term("select", "EXPR$0", "long") ) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala index d563f9636ce3f..cf96d19966b6a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala @@ -23,7 +23,7 @@ import java.math.BigDecimal import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMergeAndReset +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinctWithMergeAndReset, WeightedAvgWithMergeAndReset} import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase @@ -226,13 +226,14 @@ class AggregationsITCase( val tEnv = TableEnvironment.getTableEnvironment(env, config) val countFun = new CountAggFunction val wAvgFun = new WeightedAvgWithMergeAndReset + val countDistinct = new CountDistinctWithMergeAndReset val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) .groupBy('b) - .select('b, 'a.sum, countFun('c), wAvgFun('b, 'a), wAvgFun('a, 'a)) + .select('b, 'a.sum, countFun('c), wAvgFun('b, 'a), wAvgFun('a, 'a), countDistinct('c)) - val expected = "1,1,1,1,1\n" + "2,5,2,2,2\n" + "3,15,3,3,5\n" + "4,34,4,4,8\n" + - "5,65,5,5,13\n" + "6,111,6,6,18\n" + val expected = "1,1,1,1,1,1\n" + "2,5,2,2,2,2\n" + "3,15,3,3,5,3\n" + "4,34,4,4,8,4\n" + + "5,65,5,5,13,5\n" + "6,111,6,6,18,6\n" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala index 80ff55e673cfb..67164b71b54c4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala @@ -46,12 +46,10 @@ class HarnessTestBase { UserDefinedFunctionUtils.serialize(new IntSumWithRetractAggFunction) protected val MinMaxRowType = new RowTypeInfo(Array[TypeInformation[_]]( - INT_TYPE_INFO, LONG_TYPE_INFO, - INT_TYPE_INFO, STRING_TYPE_INFO, LONG_TYPE_INFO), - Array("a", "b", "c", "d", "e")) + Array("rowtime", "a", "b")) protected val SumRowType = new RowTypeInfo(Array[TypeInformation[_]]( LONG_TYPE_INFO, @@ -103,13 +101,13 @@ class HarnessTestBase { | | org.apache.flink.table.functions.AggregateFunction baseClass0 = | (org.apache.flink.table.functions.AggregateFunction) fmin; - | output.setField(5, baseClass0.getValue( + | output.setField(3, baseClass0.getValue( | (org.apache.flink.table.functions.aggfunctions.MinWithRetractAccumulator) | accs.getField(0))); | | org.apache.flink.table.functions.AggregateFunction baseClass1 = | (org.apache.flink.table.functions.AggregateFunction) fmax; - | output.setField(6, baseClass1.getValue( + | output.setField(4, baseClass1.getValue( | (org.apache.flink.table.functions.aggfunctions.MaxWithRetractAccumulator) | accs.getField(1))); | } @@ -121,12 +119,12 @@ class HarnessTestBase { | fmin.accumulate( | ((org.apache.flink.table.functions.aggfunctions.MinWithRetractAccumulator) | accs.getField(0)), - | (java.lang.Long) input.getField(4)); + | (java.lang.Long) input.getField(2)); | | fmax.accumulate( | ((org.apache.flink.table.functions.aggfunctions.MaxWithRetractAccumulator) | accs.getField(1)), - | (java.lang.Long) input.getField(4)); + | (java.lang.Long) input.getField(2)); | } | | public void retract( @@ -136,12 +134,12 @@ class HarnessTestBase { | fmin.retract( | ((org.apache.flink.table.functions.aggfunctions.MinWithRetractAccumulator) | accs.getField(0)), - | (java.lang.Long) input.getField(4)); + | (java.lang.Long) input.getField(2)); | | fmax.retract( | ((org.apache.flink.table.functions.aggfunctions.MaxWithRetractAccumulator) | accs.getField(1)), - | (java.lang.Long) input.getField(4)); + | (java.lang.Long) input.getField(2)); | } | | public org.apache.flink.types.Row createAccumulators() { @@ -166,14 +164,20 @@ class HarnessTestBase { | output.setField(0, input.getField(0)); | output.setField(1, input.getField(1)); | output.setField(2, input.getField(2)); - | output.setField(3, input.getField(3)); - | output.setField(4, input.getField(4)); | } | | public org.apache.flink.types.Row createOutputRow() { - | return new org.apache.flink.types.Row(7); + | return new org.apache.flink.types.Row(5); + | } + | + | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) { | } | + | public void cleanup() { + | } + | + | public void close() { + | } |/******* This test does not use the following methods *******/ | public org.apache.flink.types.Row mergeAccumulatorsPair( | org.apache.flink.types.Row a, @@ -286,6 +290,15 @@ class HarnessTestBase { | public final void resetAccumulator( | org.apache.flink.types.Row accs) { | } + | + | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) { + | } + | + | public void cleanup() { + | } + | + | public void close() { + | } |} |""".stripMargin @@ -326,7 +339,7 @@ object HarnessTestBase { /** * Return 0 for equal Rows and non zero for different rows */ - class RowResultSortComparator(indexCounter: Int) extends Comparator[Object] with Serializable { + class RowResultSortComparator() extends Comparator[Object] with Serializable { override def compare(o1: Object, o2: Object): Int = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 6c24c5d9532fa..065b7bcc15e38 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -154,7 +154,7 @@ class JoinHarnessTest extends HarnessTestBase{ expectedOutput.add(new StreamRecord( CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25)) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -227,7 +227,7 @@ class JoinHarnessTest extends HarnessTestBase{ expectedOutput.add(new StreamRecord( CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12)) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala index 04214f9a178ef..dd14d7edaa561 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala @@ -91,7 +91,7 @@ class NonWindowHarnessTest extends HarnessTestBase { expectedOutput.add(new StreamRecord(CRow(Row.of(9L: JLong, 18: JInt), true), 1)) expectedOutput.add(new StreamRecord(CRow(Row.of(10L: JLong, 3: JInt), true), 1)) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -150,7 +150,7 @@ class NonWindowHarnessTest extends HarnessTestBase { expectedOutput.add(new StreamRecord(CRow(Row.of(10L: JLong, 2: JInt), false), 10)) expectedOutput.add(new StreamRecord(CRow(Row.of(10L: JLong, 5: JInt), true), 10)) - verify(expectedOutput, result, new RowResultSortComparator(0)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/OverWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/OverWindowHarnessTest.scala index 8cad64f7280ae..def1972866a07 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/OverWindowHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/OverWindowHarnessTest.scala @@ -15,16 +15,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.table.runtime.harness -import java.lang.{Integer => JInt, Long => JLong} +import java.lang.{Long => JLong} import java.util.concurrent.ConcurrentLinkedQueue import org.apache.flink.api.common.time.Time import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.streaming.api.operators.KeyedProcessOperator import org.apache.flink.streaming.runtime.streamrecord.StreamRecord -import org.apache.flink.table.api.StreamQueryConfig +import org.apache.flink.table.api.{StreamQueryConfig, Types} import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.harness.HarnessTestBase._ import org.apache.flink.table.runtime.types.CRow @@ -33,7 +34,7 @@ import org.junit.Test class OverWindowHarnessTest extends HarnessTestBase{ - protected var queryConfig = + protected var queryConfig: StreamQueryConfig = new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(2), Time.seconds(3)) @Test @@ -48,8 +49,10 @@ class OverWindowHarnessTest extends HarnessTestBase{ queryConfig)) val testHarness = - createHarnessTester(processFunction,new TupleRowKeySelector[Integer](0),BasicTypeInfo - .INT_TYPE_INFO) + createHarnessTester( + processFunction, + new TupleRowKeySelector[String](1), + Types.STRING) testHarness.open() @@ -57,91 +60,77 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 3L: JLong), change = true))) // register cleanup timer with 4100 testHarness.setProcessingTime(1100) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 20L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 30L: JLong), change = true))) // register cleanup timer with 6001 testHarness.setProcessingTime(3001) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 9L: JLong), change = true))) // trigger cleanup timer and register cleanup timer with 9002 testHarness.setProcessingTime(6002) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "bbb", 40L: JLong), change = true))) val result = testHarness.getOutput val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 2L: JLong, 3L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 3L: JLong, 2L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 3L: JLong, 4L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 4L: JLong, 3L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 4L: JLong, 5L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 5L: JLong, 4L: JLong, 5L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 5L: JLong, 6L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "aaa", 6L: JLong, 5L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 20L: JLong, 30L: JLong), true), 1)) + CRow(Row.of(1L: JLong, "bbb", 30L: JLong, 20L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 6L: JLong, 7L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 7L: JLong, 6L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 7L: JLong, 8L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 8L: JLong, 7L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 8L: JLong, 9L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 9L: JLong, 8L: JLong, 9L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 10L: JLong, 10L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -163,59 +152,59 @@ class OverWindowHarnessTest extends HarnessTestBase{ val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO) + new TupleRowKeySelector[String](1), + Types.STRING) testHarness.open() // register cleanup timer with 3003 testHarness.setProcessingTime(3) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 10L: JLong), change = true))) testHarness.setProcessingTime(4) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 2L: JLong), change = true))) // trigger cleanup timer and register cleanup timer with 6003 testHarness.setProcessingTime(3003) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 20L: JLong), change = true))) testHarness.setProcessingTime(5) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 4L: JLong), change = true))) // register cleanup timer with 9002 testHarness.setProcessingTime(6002) testHarness.setProcessingTime(7002) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 30L: JLong), change = true))) // register cleanup timer with 14002 testHarness.setProcessingTime(11002) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 7L: JLong), change = true))) testHarness.setProcessingTime(11004) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 40L: JLong), change = true))) testHarness.setProcessingTime(11006) @@ -225,49 +214,35 @@ class OverWindowHarnessTest extends HarnessTestBase{ // all elements at the same proc timestamp have the same value per key expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 4)) + CRow(Row.of(0L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 4)) + CRow(Row.of(0L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 5)) + CRow(Row.of(0L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 3L: JLong, 4L: JLong), true), 3004)) + CRow(Row.of(0L: JLong, "aaa", 3L: JLong, 3L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow(Row.of( - 2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 20L: JLong, 20L: JLong), true), 3004)) + CRow(Row.of(0L: JLong, "bbb", 20L: JLong, 20L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 4L: JLong, 4L: JLong), true), 6)) + CRow(Row.of(0L: JLong, "aaa", 4L: JLong, 4L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 5L: JLong, 6L: JLong), true), 7003)) + CRow(Row.of(0L: JLong, "aaa", 5L: JLong, 5L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 5L: JLong, 6L: JLong), true), 7003)) + CRow(Row.of(0L: JLong, "aaa", 6L: JLong, 5L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 30L: JLong, 30L: JLong), true), 7003)) + CRow(Row.of(0L: JLong, "bbb", 30L: JLong, 30L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 7L: JLong, 7L: JLong), true), 11003)) + CRow(Row.of(0L: JLong, "aaa", 7L: JLong, 7L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow(Row.of( - 1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 7L: JLong, 10L: JLong), true), 11005)) + CRow(Row.of(0L: JLong, "aaa", 8L: JLong, 7L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow(Row.of( - 1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 7L: JLong, 10L: JLong), true), 11005)) + CRow(Row.of(0L: JLong, "aaa", 9L: JLong, 7L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 7L: JLong, 10L: JLong), true), 11005)) + CRow(Row.of(0L: JLong, "aaa", 10L: JLong, 7L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), true), 11005)) + CRow(Row.of(0L: JLong, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -284,8 +259,8 @@ class OverWindowHarnessTest extends HarnessTestBase{ val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO) + new TupleRowKeySelector[String](1), + Types.STRING) testHarness.open() @@ -293,85 +268,71 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1003) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 20L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 30L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 8L: JLong), change = true))) // trigger cleanup timer and register cleanup timer with 8003 testHarness.setProcessingTime(5003) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "bbb", 40L: JLong), change = true))) val result = testHarness.getOutput val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 1L: JLong, 5L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 5L: JLong, 1L: JLong, 5L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), true), 0)) + CRow(Row.of(0L: JLong, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 9L: JLong, 9L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "aaa", 9L: JLong, 9L: JLong, 9L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 9L: JLong, 10L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "aaa", 10L: JLong, 9L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), true), 5003)) + CRow(Row.of(0L: JLong, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -387,63 +348,64 @@ class OverWindowHarnessTest extends HarnessTestBase{ minMaxAggregationStateType, minMaxCRowType, 4000, + 0, new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(1), Time.seconds(2)))) val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[String](3), + new TupleRowKeySelector[String](1), BasicTypeInfo.STRING_TYPE_INFO) testHarness.open() testHarness.processWatermark(1) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processWatermark(2) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 3)) + CRow(Row.of(3L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processWatermark(4000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processWatermark(4001) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 4002)) + CRow(Row.of(4002L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processWatermark(4002) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 0L: JLong, 0: JInt, "aaa", 4L: JLong), true), 4003)) + CRow(Row.of(4003L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processWatermark(4800) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 11L: JLong, 1: JInt, "bbb", 25L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "bbb", 25L: JLong), change = true))) testHarness.processWatermark(6500) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong), change = true))) testHarness.processWatermark(7000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processWatermark(8000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processWatermark(12000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong), change = true))) testHarness.processWatermark(19000) @@ -453,21 +415,22 @@ class OverWindowHarnessTest extends HarnessTestBase{ // check that state is removed after max retention time testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong), true), 20001)) // clean-up 3000 + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong), change = true))) // clean-up 3000 testHarness.setProcessingTime(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong), true), 20002)) // clean-up 4500 + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong), change = true))) // clean-up 4500 testHarness.processWatermark(20010) // compute output assert(testHarness.numKeyedStateEntries() > 0) // check that we have state testHarness.setProcessingTime(4499) assert(testHarness.numKeyedStateEntries() > 0) // check that we have state testHarness.setProcessingTime(4500) + val x = testHarness.numKeyedStateEntries() assert(testHarness.numKeyedStateEntries() == 0) // check that all state is gone // check that state is only removed if all data was processed testHarness.processElement(new StreamRecord( - CRow(Row.of(3: JInt, 0L: JLong, 0: JInt, "ccc", 3L: JLong), true), 20011)) // clean-up 6500 + CRow(Row.of(20011L: JLong, "ccc", 3L: JLong), change = true))) // clean-up 6500 assert(testHarness.numKeyedStateEntries() > 0) // check that we have state testHarness.setProcessingTime(6500) // clean-up attempt but rescheduled to 8500 @@ -487,59 +450,42 @@ class OverWindowHarnessTest extends HarnessTestBase{ // all elements at the same row-time have the same value per key expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 2)) + CRow(Row.of(2L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 3)) + CRow(Row.of(3L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), true), 4002)) + CRow(Row.of(4002L: JLong, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 0L: JLong, 0: JInt, "aaa", 4L: JLong, 2L: JLong, 4L: JLong), true), 4003)) + CRow(Row.of(4003L: JLong, "aaa", 4L: JLong, 2L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 11L: JLong, 1: JInt, "bbb", 25L: JLong, 25L: JLong, 25L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "bbb", 25L: JLong, 25L: JLong, 25L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 2L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong, 2L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 2L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong, 2L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 2L: JLong, 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong, 2L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 2L: JLong, 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong, 2L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 25L: JLong, 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong, 25L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 8L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong, 8L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 8L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong, 8L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong, 40L: JLong, 40L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), true), 20001)) + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), true), 20002)) + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(3: JInt, 0L: JLong, 0: JInt, "ccc", 3L: JLong, 3L: JLong, 3L: JLong), true), 20011)) + CRow(Row.of(20011L: JLong, "ccc", 3L: JLong, 3L: JLong, 3L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -552,59 +498,60 @@ class OverWindowHarnessTest extends HarnessTestBase{ minMaxAggregationStateType, minMaxCRowType, 3, + 0, new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(1), Time.seconds(2)))) val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[String](3), + new TupleRowKeySelector[String](1), BasicTypeInfo.STRING_TYPE_INFO) testHarness.open() testHarness.processWatermark(800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processWatermark(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processWatermark(4000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong), change = true))) testHarness.processWatermark(4800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processWatermark(6500) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong), change = true))) testHarness.processWatermark(7000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processWatermark(8000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processWatermark(12000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong), change = true))) testHarness.processWatermark(19000) @@ -614,10 +561,10 @@ class OverWindowHarnessTest extends HarnessTestBase{ // check that state is removed after max retention time testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong), true), 20001)) // clean-up 3000 + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong), change = true))) // clean-up 3000 testHarness.setProcessingTime(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong), true), 20002)) // clean-up 4500 + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong), change = true))) // clean-up 4500 testHarness.processWatermark(20010) // compute output assert(testHarness.numKeyedStateEntries() > 0) // check that we have state @@ -628,7 +575,7 @@ class OverWindowHarnessTest extends HarnessTestBase{ // check that state is only removed if all data was processed testHarness.processElement(new StreamRecord( - CRow(Row.of(3: JInt, 0L: JLong, 0: JInt, "ccc", 3L: JLong), true), 20011)) // clean-up 6500 + CRow(Row.of(20011L: JLong, "ccc", 3L: JLong), change = true))) // clean-up 6500 assert(testHarness.numKeyedStateEntries() > 0) // check that we have state testHarness.setProcessingTime(6500) // clean-up attempt but rescheduled to 8500 @@ -648,59 +595,42 @@ class OverWindowHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 2L: JLong, 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong, 2L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 3L: JLong, 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong, 3L: JLong, 5L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 4L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong, 4L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 5L: JLong, 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong, 5L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 6L: JLong, 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong, 6L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 7L: JLong, 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong, 7L: JLong, 9L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 8L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong, 8L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 20L: JLong, 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong, 20L: JLong, 40L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), true), 20001)) + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), true), 20002)) + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(3: JInt, 0L: JLong, 0: JInt, "ccc", 3L: JLong, 3L: JLong, 3L: JLong), true), 20011)) + CRow(Row.of(20011L: JLong, "ccc", 3L: JLong, 3L: JLong, 3L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -715,12 +645,13 @@ class OverWindowHarnessTest extends HarnessTestBase{ genMinMaxAggFunction, minMaxAggregationStateType, minMaxCRowType, + 0, new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(1), Time.seconds(2)))) val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[String](3), + new TupleRowKeySelector[String](1), BasicTypeInfo.STRING_TYPE_INFO) testHarness.open() @@ -728,47 +659,47 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1000) testHarness.processWatermark(800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processWatermark(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processWatermark(4000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong), change = true))) testHarness.processWatermark(4800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processWatermark(6500) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong), change = true))) testHarness.processWatermark(7000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processWatermark(8000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processWatermark(12000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong), change = true))) testHarness.processWatermark(19000) @@ -781,10 +712,13 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.processWatermark(20000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong), true), 20001)) // clean-up 5000 + CRow(Row.of(20000L: JLong, "ccc", 1L: JLong), change = true))) // test for late data + + testHarness.processElement(new StreamRecord( + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong), change = true))) // clean-up 5000 testHarness.setProcessingTime(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong), true), 20002)) // clean-up 5000 + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong), change = true))) // clean-up 5000 assert(testHarness.numKeyedStateEntries() > 0) testHarness.setProcessingTime(5000) // does not clean up, because data left. New timer 7000 @@ -802,56 +736,40 @@ class OverWindowHarnessTest extends HarnessTestBase{ // all elements at the same row-time have the same value per key expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 1L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong, 1L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 1L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong, 1L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 1L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong, 1L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 10L: JLong, 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong, 10L: JLong, 40L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), true), 20001)) + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), true), 20002)) + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } @@ -863,12 +781,13 @@ class OverWindowHarnessTest extends HarnessTestBase{ genMinMaxAggFunction, minMaxAggregationStateType, minMaxCRowType, + 0, new StreamQueryConfig().withIdleStateRetentionTime(Time.seconds(1), Time.seconds(2)))) val testHarness = createHarnessTester( processFunction, - new TupleRowKeySelector[String](3), + new TupleRowKeySelector[String](1), BasicTypeInfo.STRING_TYPE_INFO) testHarness.open() @@ -876,47 +795,47 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1000) testHarness.processWatermark(800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong), change = true))) testHarness.processWatermark(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong), change = true))) testHarness.processWatermark(4000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong), change = true))) testHarness.processWatermark(4800) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong), change = true))) testHarness.processWatermark(6500) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong), change = true))) testHarness.processWatermark(7000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong), change = true))) testHarness.processWatermark(8000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong), change = true))) testHarness.processWatermark(12000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong), change = true))) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong), change = true))) testHarness.processWatermark(19000) @@ -929,10 +848,13 @@ class OverWindowHarnessTest extends HarnessTestBase{ testHarness.processWatermark(20000) testHarness.processElement(new StreamRecord( - CRow(Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong), true), 20001)) // clean-up 5000 + CRow(Row.of(20000L: JLong, "ccc", 2L: JLong), change = true))) // test for late data + + testHarness.processElement(new StreamRecord( + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong), change = true))) // clean-up 5000 testHarness.setProcessingTime(2500) testHarness.processElement(new StreamRecord( - CRow(Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong), true), 20002)) // clean-up 5000 + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong), change = true))) // clean-up 5000 assert(testHarness.numKeyedStateEntries() > 0) testHarness.setProcessingTime(5000) // does not clean up, because data left. New timer 7000 @@ -949,56 +871,40 @@ class OverWindowHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), true), 801)) + CRow(Row.of(801L: JLong, "aaa", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), true), 2501)) + CRow(Row.of(2501L: JLong, "bbb", 10L: JLong, 10L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "aaa", 3L: JLong, 1L: JLong, 3L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), true), 4001)) + CRow(Row.of(4001L: JLong, "bbb", 20L: JLong, 10L: JLong, 20L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), true), 4801)) + CRow(Row.of(4801L: JLong, "aaa", 4L: JLong, 1L: JLong, 4L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 5L: JLong, 1L: JLong, 5L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 5L: JLong, 1L: JLong, 5L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "aaa", 6L: JLong, 1L: JLong, 6L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), true), 6501)) + CRow(Row.of(6501L: JLong, "bbb", 30L: JLong, 10L: JLong, 30L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), true), 7001)) + CRow(Row.of(7001L: JLong, "aaa", 7L: JLong, 1L: JLong, 7L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), true), 8001)) + CRow(Row.of(8001L: JLong, "aaa", 8L: JLong, 1L: JLong, 8L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 9L: JLong, 1L: JLong, 9L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 9L: JLong, 1L: JLong, 9L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 10L: JLong, 1L: JLong, 10L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "aaa", 10L: JLong, 1L: JLong, 10L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "bbb", 40L: JLong, 10L: JLong, 40L: JLong), true), 12001)) + CRow(Row.of(12001L: JLong, "bbb", 40L: JLong, 10L: JLong, 40L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(1: JInt, 0L: JLong, 0: JInt, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), true), 20001)) + CRow(Row.of(20001L: JLong, "ccc", 1L: JLong, 1L: JLong, 1L: JLong), change = true))) expectedOutput.add(new StreamRecord( - CRow( - Row.of(2: JInt, 0L: JLong, 0: JInt, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), true), 20002)) + CRow(Row.of(20002L: JLong, "ccc", 2L: JLong, 1L: JLong, 2L: JLong), change = true))) - verify(expectedOutput, result, new RowResultSortComparator(6)) + verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/SortProcessFunctionHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/SortProcessFunctionHarnessTest.scala index 0451534d2b36a..9490039137822 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/SortProcessFunctionHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/SortProcessFunctionHarnessTest.scala @@ -35,6 +35,7 @@ import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, import org.apache.flink.table.runtime.aggregate.{CollectionRowComparator, ProcTimeSortProcessFunction, RowTimeSortProcessFunction} import org.apache.flink.table.runtime.harness.SortProcessFunctionHarnessTest.TupleRowSelector import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.types.Row import org.junit.Test @@ -75,7 +76,7 @@ class SortProcessFunctionHarnessTest { inputCRowType, collectionRowComparator)) - val testHarness = new KeyedOneInputStreamOperatorTestHarness[Integer,CRow,CRow]( + val testHarness = new KeyedOneInputStreamOperatorTestHarness[Integer, CRow, CRow]( processFunction, new TupleRowSelector(0), BasicTypeInfo.INT_TYPE_INFO) @@ -86,77 +87,77 @@ class SortProcessFunctionHarnessTest { // timestamp is ignored in processing time testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong), true), 1001)) + Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong), true), 2002)) + Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 2: JInt, "aaa", 11L: JLong), true), 2003)) + Row.of(1: JInt, 12L: JLong, 2: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2004)) + Row.of(1: JInt, 12L: JLong, 0: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2006)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong), true))) //move the timestamp to ensure the execution testHarness.setProcessingTime(1005) - + testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 1L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2007)) + Row.of(1: JInt, 1L: JLong, 0: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 3L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2007)) + Row.of(1: JInt, 3L: JLong, 0: JInt, "aaa", 11L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 2L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2007)) - + Row.of(1: JInt, 2L: JLong, 0: JInt, "aaa", 11L: JLong), true))) + testHarness.setProcessingTime(1008) - + val result = testHarness.getOutput - + val expectedOutput = new ConcurrentLinkedQueue[Object]() - + // all elements at the same proc timestamp have the same value // elements should be sorted ascending on field 1 and descending on field 2 // (10,0) (11,1) (12,2) (12,1) (12,0) // (1,0) (2,0) - + expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong),true), 4)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong),true), 4)) + Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 2: JInt, "aaa", 11L: JLong),true), 4)) + Row.of(1: JInt, 12L: JLong, 2: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong),true), 4)) + Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 0: JInt, "aaa", 11L: JLong),true), 4)) - + Row.of(1: JInt, 12L: JLong, 0: JInt, "aaa", 11L: JLong),true))) + expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 1L: JLong, 0: JInt, "aaa", 11L: JLong),true), 1006)) + Row.of(1: JInt, 1L: JLong, 0: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 2L: JLong, 0: JInt, "aaa", 11L: JLong),true), 1006)) + Row.of(1: JInt, 2L: JLong, 0: JInt, "aaa", 11L: JLong),true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 3L: JLong, 0: JInt, "aaa", 11L: JLong),true), 1006)) + Row.of(1: JInt, 3L: JLong, 0: JInt, "aaa", 11L: JLong),true))) TestHarnessUtil.assertOutputEquals("Output was not correctly sorted.", expectedOutput, result) - + testHarness.close() } - + @Test def testSortRowTimeHarnessPartitioned(): Unit = { - + val rT = new RowTypeInfo(Array[TypeInformation[_]]( INT_TYPE_INFO, LONG_TYPE_INFO, INT_TYPE_INFO, STRING_TYPE_INFO, - LONG_TYPE_INFO), + TimeIndicatorTypeInfo.ROWTIME_INDICATOR), Array("a", "b", "c", "d", "e")) val indexes = Array(1, 2) - + val fieldComps = Array[TypeComparator[AnyRef]]( LONG_TYPE_INFO.createComparator(true, null).asInstanceOf[TypeComparator[AnyRef]], INT_TYPE_INFO.createComparator(false, null).asInstanceOf[TypeComparator[AnyRef]] ) - val booleanOrders = Array(true, false) + val booleanOrders = Array(true, false) val rowComp = new RowComparator( rT.getTotalFields, @@ -164,21 +165,22 @@ class SortProcessFunctionHarnessTest { fieldComps, new Array[TypeSerializer[AnyRef]](0), //used only for serialized comparisons booleanOrders) - + val collectionRowComparator = new CollectionRowComparator(rowComp) - + val inputCRowType = CRowTypeInfo(rT) - + val processFunction = new KeyedProcessOperator[Integer,CRow,CRow]( new RowTimeSortProcessFunction( inputCRowType, + 4, Some(collectionRowComparator))) - + val testHarness = new KeyedOneInputStreamOperatorTestHarness[Integer, CRow, CRow]( - processFunction, - new TupleRowSelector(0), + processFunction, + new TupleRowSelector(0), BasicTypeInfo.INT_TYPE_INFO) - + testHarness.open() testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime) @@ -186,71 +188,71 @@ class SortProcessFunctionHarnessTest { // timestamp is ignored in processing time testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong), true), 1001)) + Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1001L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong), true), 2002)) + Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 2002L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 13L: JLong, 2: JInt, "aaa", 11L: JLong), true), 2002)) + Row.of(1: JInt, 13L: JLong, 2: JInt, "aaa", 2002L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 11L: JLong), true), 2002)) + Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 2002L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2002)) + Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 2002L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 11L: JLong), true), 2004)) + Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 2004L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2006)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 2006L: JLong), true))) // move watermark forward testHarness.processWatermark(2007) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 20L: JLong, 1: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 20L: JLong, 1: JInt, "aaa", 2008L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2002)) // too late + Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 2002L: JLong), true))) // too late testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 11L: JLong), true), 2019)) // too early + Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 2019L: JLong), true))) // too early testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 20L: JLong, 2: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 20L: JLong, 2: JInt, "aaa", 2008L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2010)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 2010L: JLong), true))) testHarness.processElement(new StreamRecord(new CRow( - Row.of(1: JInt, 19L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 19L: JLong, 0: JInt, "aaa", 2008L: JLong), true))) // move watermark forward testHarness.processWatermark(2012) val result = testHarness.getOutput - + val expectedOutput = new ConcurrentLinkedQueue[Object]() - + // all elements at the same proc timestamp have the same value // elements should be sorted ascending on field 1 and descending on field 2 // (10,0) (11,1) (12,2) (12,1) (12,0) expectedOutput.add(new Watermark(3)) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 11L: JLong),true), 1001)) + Row.of(1: JInt, 11L: JLong, 1: JInt, "aaa", 1001L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 11L: JLong),true), 2002)) + Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 2002L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 11L: JLong),true), 2002)) + Row.of(1: JInt, 12L: JLong, 1: JInt, "aaa", 2002L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 13L: JLong, 2: JInt, "aaa", 11L: JLong),true), 2002)) + Row.of(1: JInt, 13L: JLong, 2: JInt, "aaa", 2002L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 11L: JLong),true), 2002)) + Row.of(1: JInt, 14L: JLong, 0: JInt, "aaa", 2002L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 11L: JLong),true), 2004)) + Row.of(1: JInt, 12L: JLong, 3: JInt, "aaa", 2004L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong),true), 2006)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 2006L: JLong), true))) expectedOutput.add(new Watermark(2007)) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 19L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 19L: JLong, 0: JInt, "aaa", 2008L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 20L: JLong, 2: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 20L: JLong, 2: JInt, "aaa", 2008L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 20L: JLong, 1: JInt, "aaa", 11L: JLong), true), 2008)) + Row.of(1: JInt, 20L: JLong, 1: JInt, "aaa", 2008L: JLong), true))) expectedOutput.add(new StreamRecord(new CRow( - Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 11L: JLong), true), 2010)) + Row.of(1: JInt, 10L: JLong, 0: JInt, "aaa", 2010L: JLong), true))) expectedOutput.add(new Watermark(2012)) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelayTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelayTest.scala new file mode 100644 index 0000000000000..243a034268af0 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelayTest.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.operators + +import java.util.concurrent.ConcurrentLinkedQueue + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.java.functions.KeySelector +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.{KeyedTwoInputStreamOperatorTestHarness, TestHarnessUtil} +import org.apache.flink.util.{Collector, TestLogger} +import org.junit.Test + +/** + * Tests [[KeyedCoProcessOperatorWithWatermarkDelay]]. + */ +class KeyedCoProcessOperatorWithWatermarkDelayTest extends TestLogger { + + @Test + def testHoldingBackWatermarks(): Unit = { + val operator = new KeyedCoProcessOperatorWithWatermarkDelay[String, Integer, String, String]( + new EmptyCoProcessFunction, 100) + val testHarness = new KeyedTwoInputStreamOperatorTestHarness[String, Integer, String, String]( + operator, + new IntToStringKeySelector, new CoIdentityKeySelector[String], + BasicTypeInfo.STRING_TYPE_INFO) + + testHarness.setup() + testHarness.open() + testHarness.processWatermark1(new Watermark(101)) + testHarness.processWatermark2(new Watermark(202)) + testHarness.processWatermark1(new Watermark(103)) + testHarness.processWatermark2(new Watermark(204)) + + val expectedOutput = new ConcurrentLinkedQueue[AnyRef] + expectedOutput.add(new Watermark(1)) + expectedOutput.add(new Watermark(3)) + + TestHarnessUtil.assertOutputEquals( + "Output was not correct.", + expectedOutput, + testHarness.getOutput) + + testHarness.close() + } + + @Test(expected = classOf[IllegalArgumentException]) + def testDelayParameter(): Unit = { + new KeyedCoProcessOperatorWithWatermarkDelay[AnyRef, Integer, String, String]( + new EmptyCoProcessFunction, -1) + } +} + +private class EmptyCoProcessFunction extends CoProcessFunction[Integer, String, String] { + override def processElement1( + value: Integer, + ctx: CoProcessFunction[Integer, String, String]#Context, + out: Collector[String]): Unit = { + // do nothing + } + + override def processElement2( + value: String, + ctx: CoProcessFunction[Integer, String, String]#Context, + out: Collector[String]): Unit = { + //do nothing + } +} + + +private class IntToStringKeySelector extends KeySelector[Integer, String] { + override def getKey(value: Integer): String = String.valueOf(value) +} + +private class CoIdentityKeySelector[T] extends KeySelector[T, T] { + override def getKey(value: T): T = value +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelayTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelayTest.scala new file mode 100644 index 0000000000000..d419453726570 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessOperatorWithWatermarkDelayTest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators + +import java.util.concurrent.ConcurrentLinkedQueue + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.java.functions.KeySelector +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil} +import org.apache.flink.util.{Collector, TestLogger} +import org.junit.Test + +/** + * Tests [[KeyedProcessOperatorWithWatermarkDelay]]. + */ +class KeyedProcessOperatorWithWatermarkDelayTest extends TestLogger { + + @Test + def testHoldingBackWatermarks(): Unit = { + val operator = new KeyedProcessOperatorWithWatermarkDelay[Integer, Integer, String]( + new EmptyProcessFunction, 100) + val testHarness = new KeyedOneInputStreamOperatorTestHarness[Integer, Integer, String]( + operator, new IdentityKeySelector, BasicTypeInfo.INT_TYPE_INFO) + + testHarness.setup() + testHarness.open() + testHarness.processWatermark(new Watermark(101)) + testHarness.processWatermark(new Watermark(103)) + + val expectedOutput = new ConcurrentLinkedQueue[AnyRef] + expectedOutput.add(new Watermark(1)) + expectedOutput.add(new Watermark(3)) + + TestHarnessUtil.assertOutputEquals( + "Output was not correct.", + expectedOutput, + testHarness.getOutput) + + testHarness.close() + } + + @Test(expected = classOf[IllegalArgumentException]) + def testDelayParameter(): Unit = { + new KeyedProcessOperatorWithWatermarkDelay[Integer, Integer, String]( + new EmptyProcessFunction, -1) + } +} + +private class EmptyProcessFunction extends ProcessFunction[Integer, String] { + override def processElement( + value: Integer, + ctx: ProcessFunction[Integer, String]#Context, + out: Collector[String]): Unit = { + // do nothing + } +} + +private class IdentityKeySelector[T] extends KeySelector[T, T] { + override def getKey(value: T): T = value +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/TimeAttributesITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/TimeAttributesITCase.scala index 4c478de1596e8..24d8695819c22 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/TimeAttributesITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/TimeAttributesITCase.scala @@ -63,7 +63,8 @@ class TimeAttributesITCase extends StreamingMultipleProgramsTestBase { val stream = env .fromCollection(data) .assignTimestampsAndWatermarks(new TimestampWithEqualWatermark()) - val table = stream.toTable(tEnv, 'rowtime.rowtime, 'int, 'double, 'float, 'bigdec, 'string) + val table = stream.toTable( + tEnv, 'rowtime.rowtime, 'int, 'double, 'float, 'bigdec, 'string, 'proctime.proctime) val t = table.select('rowtime.cast(Types.STRING)) @@ -123,6 +124,13 @@ class TimeAttributesITCase extends StreamingMultipleProgramsTestBase { tEnv, 'rowtime.rowtime, 'int, 'double, 'float, 'bigdec, 'string, 'proctime.proctime) val func = new TableFunc + // we test if this can be executed with any exceptions + table.join(func('proctime, 'proctime, 'string) as 's).toAppendStream[Row] + + // we test if this can be executed with any exceptions + table.join(func('rowtime, 'rowtime, 'string) as 's).toAppendStream[Row] + + // we can only test rowtime, not proctime val t = table.join(func('rowtime, 'proctime, 'string) as 's).select('rowtime, 's) val results = t.toAppendStream[Row] diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala index 744ac4602daba..eb3d37fb0375f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala @@ -24,7 +24,8 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.StreamITCase.RetractingSink import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment} -import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg} +import org.apache.flink.table.runtime.utils.{JavaUserDefinedAggFunctions, StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.junit.Assert.assertEquals import org.junit.Test @@ -154,4 +155,42 @@ class AggregateITCase extends StreamingWithStateTestBase { "12,3,5,1", "5,3,4,2") assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + + @Test + def testGroupAggregateWithStateBackend(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "A")) + data.+=((2, 2L, "B")) + data.+=((3, 2L, "B")) + data.+=((4, 3L, "C")) + data.+=((5, 3L, "C")) + data.+=((6, 3L, "C")) + data.+=((7, 4L, "B")) + data.+=((8, 4L, "A")) + data.+=((9, 4L, "D")) + data.+=((10, 4L, "E")) + data.+=((11, 5L, "A")) + data.+=((12, 5L, "B")) + + val distinct = new CountDistinct + val testAgg = new DataViewTestAgg + val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, distinct('c), testAgg('c, 'b)) + + val results = t.toRetractStream[Row](queryConfig) + results.addSink(new StreamITCase.RetractingSink) + env.execute() + + val expected = List("1,1,2", "2,1,5", "3,1,10", "4,4,20", "5,2,12") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + + // verify agg close is called + assert(JavaUserDefinedAggFunctions.isCloseCalled) + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala index 1561da0417edb..f6e739efc8958 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/GroupWindowITCase.scala @@ -29,7 +29,7 @@ import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{WeightedAvg, WeightedAvgWithMerge} +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithMerge, WeightedAvg, WeightedAvgWithMerge} import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.stream.table.GroupWindowITCase._ import org.apache.flink.table.runtime.utils.StreamITCase @@ -75,19 +75,21 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val countDistinct = new CountDistinct val windowedTable = table .window(Slide over 2.rows every 1.rows on 'proctime as 'w) .groupBy('w, 'string) .select('string, countFun('int), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) + weightAvgFun('long, 'int), weightAvgFun('int, 'int), + countDistinct('long)) val results = windowedTable.toAppendStream[Row](queryConfig) results.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = Seq("Hello world,1,3,8,3", "Hello world,2,3,12,3", "Hello,1,2,2,2", - "Hello,2,2,3,2", "Hi,1,1,1,1") + val expected = Seq("Hello world,1,3,8,3,1", "Hello world,2,3,12,3,2", "Hello,1,2,2,2,1", + "Hello,2,2,3,2,2", "Hi,1,1,1,1,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -112,6 +114,7 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { val countFun = new CountAggFunction val weightAvgFun = new WeightedAvgWithMerge + val countDistinct = new CountDistinctWithMerge val stream = env .fromCollection(sessionWindowTestdata) @@ -122,13 +125,14 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { .window(Session withGap 5.milli on 'rowtime as 'w) .groupBy('w, 'string) .select('string, countFun('int), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) + weightAvgFun('long, 'int), weightAvgFun('int, 'int), + countDistinct('long)) val results = windowedTable.toAppendStream[Row] results.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = Seq("Hello World,1,9,9,9", "Hello,1,16,16,16", "Hello,4,3,5,5") + val expected = Seq("Hello World,1,9,9,9,1", "Hello,1,16,16,16,1", "Hello,4,3,5,5,4") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -143,18 +147,21 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { val table = stream.toTable(tEnv, 'long, 'int, 'string, 'proctime.proctime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val countDistinct = new CountDistinct val windowedTable = table .window(Tumble over 2.rows on 'proctime as 'w) .groupBy('w) .select(countFun('string), 'int.avg, - weightAvgFun('long, 'int), weightAvgFun('int, 'int)) + weightAvgFun('long, 'int), weightAvgFun('int, 'int), + countDistinct('long) + ) val results = windowedTable.toAppendStream[Row](queryConfig) results.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = Seq("2,1,1,1", "2,2,6,2") + val expected = Seq("2,1,1,1,2", "2,2,6,2,2") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -171,22 +178,24 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { val table = stream.toTable(tEnv, 'long, 'int, 'string, 'rowtime.rowtime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val countDistinct = new CountDistinct val windowedTable = table .window(Tumble over 5.milli on 'rowtime as 'w) .groupBy('w, 'string) .select('string, countFun('string), 'int.avg, weightAvgFun('long, 'int), - weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) + weightAvgFun('int, 'int), 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end, + countDistinct('long)) val results = windowedTable.toAppendStream[Row] results.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = Seq( - "Hello world,1,3,8,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,3,16,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,2,3,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + "Hello world,1,3,8,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01,1", + "Hello world,1,3,16,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02,1", + "Hello,2,2,3,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,2", + "Hi,1,1,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -208,17 +217,18 @@ class GroupWindowITCase extends StreamingMultipleProgramsTestBase { val table = stream.toTable(tEnv, 'long, 'int, 'string, 'int2, 'int3, 'proctime.proctime) val weightAvgFun = new WeightedAvg + val countDistinct = new CountDistinct val windowedTable = table .window(Slide over 2.rows every 1.rows on 'proctime as 'w) .groupBy('w, 'int2, 'int3, 'string) - .select(weightAvgFun('long, 'int)) + .select(weightAvgFun('long, 'int), countDistinct('long)) val results = windowedTable.toAppendStream[Row] results.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = Seq("12", "8", "2", "3", "1") + val expected = Seq("12,2", "8,1", "2,1", "3,2", "1,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala index 73484d2551db8..54971b2bc81fd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/OverWindowITCase.scala @@ -25,7 +25,7 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceCont import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvg +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, CountDistinctWithRetractAndReset, WeightedAvg} import org.apache.flink.table.runtime.utils.JavaUserDefinedScalarFunctions.JavaFunc0 import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction @@ -51,6 +51,7 @@ class OverWindowITCase extends StreamingWithStateTestBase { (6L, 6, "Hello"), (7L, 7, "Hello World"), (8L, 8, "Hello World"), + (8L, 8, "Hello World"), (20L, 20, "Hello World")) val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -62,20 +63,24 @@ class OverWindowITCase extends StreamingWithStateTestBase { val table = stream.toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val countDist = new CountDistinct val windowedTable = table .window( Over partitionBy 'c orderBy 'proctime preceding UNBOUNDED_ROW as 'w) - .select('c, countFun('b) over 'w as 'mycount, weightAvgFun('a, 'b) over 'w as 'wAvg) - .select('c, 'mycount, 'wAvg) + .select('c, + countFun('b) over 'w as 'mycount, + weightAvgFun('a, 'b) over 'w as 'wAvg, + countDist('a) over 'w as 'countDist) + .select('c, 'mycount, 'wAvg, 'countDist) val results = windowedTable.toAppendStream[Row] results.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = Seq( - "Hello World,1,7", "Hello World,2,7", "Hello World,3,14", - "Hello,1,1", "Hello,2,1", "Hello,3,2", "Hello,4,3", "Hello,5,3", "Hello,6,4") + "Hello World,1,7,1", "Hello World,2,7,2", "Hello World,3,7,2", "Hello World,4,13,3", + "Hello,1,1,1", "Hello,2,1,2", "Hello,3,2,3", "Hello,4,3,4", "Hello,5,3,5", "Hello,6,4,6") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -112,6 +117,7 @@ class OverWindowITCase extends StreamingWithStateTestBase { val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg val plusOne = new JavaFunc0 + val countDist = new CountDistinct val windowedTable = table .window(Over partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_RANGE following @@ -128,26 +134,27 @@ class OverWindowITCase extends StreamingWithStateTestBase { 'b.max over 'w, 'b.min over 'w, ('b.min over 'w).abs(), - weightAvgFun('b, 'a) over 'w) + weightAvgFun('b, 'a) over 'w, + countDist('c) over 'w as 'countDist) val result = windowedTable.toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = mutable.MutableList( - "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", - "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", - "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", - "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1", - "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1", - "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", - "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", - "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2", - "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", - "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", - "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", - "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3", - "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3" + "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2", + "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2", + "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2", + "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1,3", + "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1,1", + "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2", + "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2,2", + "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2,3", + "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3", + "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3", + "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3,3", + "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3,2", + "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3,2" ) assertEquals(expected.sorted, StreamITCase.testResults.sorted) @@ -179,32 +186,33 @@ class OverWindowITCase extends StreamingWithStateTestBase { env.setParallelism(1) StreamITCase.testResults = mutable.MutableList() + val countDist = new CountDistinctWithRetractAndReset val stream = env.fromCollection(data) val table = stream.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime) val windowedTable = table .window(Over partitionBy 'a orderBy 'proctime preceding 4.rows following CURRENT_ROW as 'w) - .select('a, 'c.sum over 'w, 'c.min over 'w) + .select('a, 'c.sum over 'w, 'c.min over 'w, countDist('e) over 'w) val result = windowedTable.toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = mutable.MutableList( - "1,0,0", - "2,1,1", - "2,3,1", - "3,3,3", - "3,7,3", - "3,12,3", - "4,6,6", - "4,13,6", - "4,21,6", - "4,30,6", - "5,10,10", - "5,21,10", - "5,33,10", - "5,46,10", - "5,60,10") + "1,0,0,1", + "2,1,1,1", + "2,3,1,2", + "3,3,3,1", + "3,7,3,1", + "3,12,3,2", + "4,6,6,1", + "4,13,6,2", + "4,21,6,2", + "4,30,6,2", + "5,10,10,1", + "5,21,10,2", + "5,33,10,2", + "5,46,10,3", + "5,60,10,3") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -240,25 +248,27 @@ class OverWindowITCase extends StreamingWithStateTestBase { val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear + val countDist = new CountDistinctWithRetractAndReset val table = env.addSource[(Long, Int, String)]( new RowTimeSourceFunction[(Long, Int, String)](data)) .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) val windowedTable = table .window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows following CURRENT_ROW as 'w) - .select('c, 'a, 'a.count over 'w, 'a.sum over 'w) + .select('c, 'a, 'a.count over 'w, 'a.sum over 'w, countDist('a) over 'w) val result = windowedTable.toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = mutable.MutableList( - "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3", - "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6", - "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12", - "Hello,6,3,15", - "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21", - "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35") + "Hello,1,1,1,1", "Hello,1,2,2,1", "Hello,1,3,3,1", + "Hello,2,3,4,2", "Hello,2,3,5,2", "Hello,2,3,6,1", + "Hello,3,3,7,2", "Hello,4,3,9,3", "Hello,5,3,12,3", + "Hello,6,3,15,3", + "Hello World,7,1,7,1", "Hello World,7,2,14,1", "Hello World,7,3,21,1", + "Hello World,7,3,21,1", "Hello World,8,3,22,2", "Hello World,20,3,35,3") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -302,6 +312,7 @@ class OverWindowITCase extends StreamingWithStateTestBase { val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear + val countDist = new CountDistinctWithRetractAndReset val table = env.addSource[(Long, Int, String)]( new RowTimeSourceFunction[(Long, Int, String)](data)) .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) @@ -309,23 +320,24 @@ class OverWindowITCase extends StreamingWithStateTestBase { val windowedTable = table .window( Over partitionBy 'c orderBy 'rowtime preceding 1.seconds following CURRENT_RANGE as 'w) - .select('c, 'b, 'a.count over 'w, 'a.sum over 'w) + .select('c, 'b, 'a.count over 'w, 'a.sum over 'w, countDist('a) over 'w) val result = windowedTable.toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = mutable.MutableList( - "Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3", - "Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9", - "Hello,3,4,9", - "Hello,4,2,7", - "Hello,5,2,9", - "Hello,6,2,11", "Hello,65,2,12", - "Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18", - "Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7", - "Hello World,8,2,15", - "Hello World,20,1,20") + "Hello,1,1,1,1", "Hello,15,2,2,1", "Hello,16,3,3,1", + "Hello,2,6,9,2", "Hello,3,6,9,2", "Hello,2,6,9,2", + "Hello,3,4,9,2", + "Hello,4,2,7,2", + "Hello,5,2,9,2", + "Hello,6,2,11,2", "Hello,65,2,12,1", + "Hello,9,2,12,1", "Hello,9,2,12,1", "Hello,18,3,18,1", + "Hello World,7,1,7,1", "Hello World,17,3,21,1", + "Hello World,77,3,21,1", "Hello World,18,1,7,1", + "Hello World,8,2,15,2", + "Hello World,20,1,20,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala index 4121754336169..830359fbea17a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSinkITCase.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.runtime.stream.table import java.io.File import java.lang.{Boolean => JBool} +import java.sql.Timestamp import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.typeinfo.TypeInformation @@ -28,19 +29,22 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.sink.SinkFunction import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.runtime.utils.StreamTestData +import org.apache.flink.table.api.{TableEnvironment, TableException, Types} +import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData} import org.apache.flink.table.sinks._ import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row +import org.apache.flink.util.Collector import org.junit.Assert._ import org.junit.Test import scala.collection.mutable +import scala.collection.JavaConverters._ class TableSinkITCase extends StreamingMultipleProgramsTestBase { @@ -199,8 +203,6 @@ class TableSinkITCase extends StreamingMultipleProgramsTestBase { } - - @Test def testUpsertSinkOnAppendingTableWithFullKey1(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -349,6 +351,136 @@ class TableSinkITCase extends StreamingMultipleProgramsTestBase { assertEquals(expected, retracted) } + @Test + def testToAppendStreamRowtime(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env) + .assignAscendingTimestamps(_._1.toLong) + .toTable(tEnv, 'id, 'num, 'text, 'rowtime.rowtime) + + val r = t + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('num, 'w) + .select('num, 'w.rowtime, 'w.rowtime.cast(Types.LONG)) + + r.toAppendStream[Row] + .process(new ProcessFunction[Row, Row] { + override def processElement( + row: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + val rowTS: Long = row.getField(2).asInstanceOf[Long] + if (ctx.timestamp() == rowTS) { + out.collect(row) + } + } + }).addSink(new StreamITCase.StringSink[Row]) + + env.execute() + + val expected = List( + "1,1970-01-01 00:00:00.004,4", + "2,1970-01-01 00:00:00.004,4", + "3,1970-01-01 00:00:00.004,4", + "3,1970-01-01 00:00:00.009,9", + "4,1970-01-01 00:00:00.009,9", + "4,1970-01-01 00:00:00.014,14", + "5,1970-01-01 00:00:00.014,14", + "5,1970-01-01 00:00:00.019,19", + "6,1970-01-01 00:00:00.019,19", + "6,1970-01-01 00:00:00.024,24") + + assertEquals(expected, StreamITCase.testResults.sorted) + } + + @Test + def testToRetractStreamRowtime(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env) + .assignAscendingTimestamps(_._1.toLong) + .toTable(tEnv, 'id, 'num, 'text, 'rowtime.rowtime) + + val r = t + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('num, 'w) + .select('num, 'w.rowtime, 'w.rowtime.cast(Types.LONG)) + + r.toRetractStream[Row] + .process(new ProcessFunction[(Boolean, Row), Row] { + override def processElement( + row: (Boolean, Row), + ctx: ProcessFunction[(Boolean, Row), Row]#Context, + out: Collector[Row]): Unit = { + + val rowTs = row._2.getField(2).asInstanceOf[Long] + if (ctx.timestamp() == rowTs) { + out.collect(row._2) + } + } + }).addSink(new StreamITCase.StringSink[Row]) + + env.execute() + + val expected = List( + "1,1970-01-01 00:00:00.004,4", + "2,1970-01-01 00:00:00.004,4", + "3,1970-01-01 00:00:00.004,4", + "3,1970-01-01 00:00:00.009,9", + "4,1970-01-01 00:00:00.009,9", + "4,1970-01-01 00:00:00.014,14", + "5,1970-01-01 00:00:00.014,14", + "5,1970-01-01 00:00:00.019,19", + "6,1970-01-01 00:00:00.019,19", + "6,1970-01-01 00:00:00.024,24") + + assertEquals(expected, StreamITCase.testResults.sorted) + } + + @Test(expected = classOf[TableException]) + def testToAppendStreamMultiRowtime(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + + val t = StreamTestData.get3TupleDataStream(env) + .assignAscendingTimestamps(_._1.toLong) + .toTable(tEnv, 'id, 'num, 'text, 'rowtime.rowtime) + + val r = t + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('num, 'w) + .select('num, 'w.rowtime, 'w.rowtime as 'rowtime2) + + r.toAppendStream[Row] + } + + @Test(expected = classOf[TableException]) + def testToRetractStreamMultiRowtime(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + + val t = StreamTestData.get3TupleDataStream(env) + .assignAscendingTimestamps(_._1.toLong) + .toTable(tEnv, 'id, 'num, 'text, 'rowtime.rowtime) + + val r = t + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('num, 'w) + .select('num, 'w.rowtime, 'w.rowtime as 'rowtime2) + + r.toRetractStream[Row] + } + /** Converts a list of retraction messages into a list of final results. */ private def restractResults(results: List[JTuple2[JBool, Row]]): List[String] = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala index dcf2acd3766b2..fb99864db3ee0 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TestFilterableTableSource.scala @@ -89,6 +89,7 @@ class TestFilterableTableSource( iterator.remove() case (_, _) => } + case _ => } } diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosEntrypointUtils.java b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosEntrypointUtils.java new file mode 100755 index 0000000000000..368d62d8621f0 --- /dev/null +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosEntrypointUtils.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.mesos.entrypoint; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.IllegalConfigurationException; +import org.apache.flink.mesos.configuration.MesosOptions; +import org.apache.flink.mesos.runtime.clusterframework.MesosTaskManagerParameters; +import org.apache.flink.mesos.util.MesosConfiguration; +import org.apache.flink.runtime.clusterframework.ContainerSpecification; +import org.apache.flink.runtime.clusterframework.overlays.CompositeContainerOverlay; +import org.apache.flink.runtime.clusterframework.overlays.FlinkDistributionOverlay; +import org.apache.flink.runtime.clusterframework.overlays.HadoopConfOverlay; +import org.apache.flink.runtime.clusterframework.overlays.HadoopUserOverlay; +import org.apache.flink.runtime.clusterframework.overlays.KeytabOverlay; +import org.apache.flink.runtime.clusterframework.overlays.Krb5ConfOverlay; +import org.apache.flink.runtime.clusterframework.overlays.SSLStoreOverlay; + +import org.apache.mesos.Protos; +import org.slf4j.Logger; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import scala.concurrent.duration.Duration; +import scala.concurrent.duration.FiniteDuration; + +/** + * Utils for Mesos entrpoints. + */ +public class MesosEntrypointUtils { + + /** + * Loads and validates the Mesos scheduler configuration. + * @param flinkConfig the global configuration. + * @param hostname the hostname to advertise to the Mesos master. + */ + public static MesosConfiguration createMesosSchedulerConfiguration(Configuration flinkConfig, String hostname) { + + Protos.FrameworkInfo.Builder frameworkInfo = Protos.FrameworkInfo.newBuilder() + .setHostname(hostname); + Protos.Credential.Builder credential = null; + + if (!flinkConfig.contains(MesosOptions.MASTER_URL)) { + throw new IllegalConfigurationException(MesosOptions.MASTER_URL.key() + " must be configured."); + } + String masterUrl = flinkConfig.getString(MesosOptions.MASTER_URL); + + Duration failoverTimeout = FiniteDuration.apply( + flinkConfig.getInteger( + MesosOptions.FAILOVER_TIMEOUT_SECONDS), + TimeUnit.SECONDS); + frameworkInfo.setFailoverTimeout(failoverTimeout.toSeconds()); + + frameworkInfo.setName(flinkConfig.getString( + MesosOptions.RESOURCEMANAGER_FRAMEWORK_NAME)); + + frameworkInfo.setRole(flinkConfig.getString( + MesosOptions.RESOURCEMANAGER_FRAMEWORK_ROLE)); + + frameworkInfo.setUser(flinkConfig.getString( + MesosOptions.RESOURCEMANAGER_FRAMEWORK_USER)); + + if (flinkConfig.contains(MesosOptions.RESOURCEMANAGER_FRAMEWORK_PRINCIPAL)) { + frameworkInfo.setPrincipal(flinkConfig.getString( + MesosOptions.RESOURCEMANAGER_FRAMEWORK_PRINCIPAL)); + + credential = Protos.Credential.newBuilder(); + credential.setPrincipal(frameworkInfo.getPrincipal()); + + // some environments use a side-channel to communicate the secret to Mesos, + // and thus don't set the 'secret' configuration setting + if (flinkConfig.contains(MesosOptions.RESOURCEMANAGER_FRAMEWORK_SECRET)) { + credential.setSecret(flinkConfig.getString( + MesosOptions.RESOURCEMANAGER_FRAMEWORK_SECRET)); + } + } + + MesosConfiguration mesos = + new MesosConfiguration(masterUrl, frameworkInfo, scala.Option.apply(credential)); + + return mesos; + } + + public static MesosTaskManagerParameters createTmParameters(Configuration configuration, Logger log) { + // TM configuration + final MesosTaskManagerParameters taskManagerParameters = MesosTaskManagerParameters.create(configuration); + + log.info("TaskManagers will be created with {} task slots", + taskManagerParameters.containeredParameters().numSlots()); + log.info("TaskManagers will be started with container size {} MB, JVM heap size {} MB, " + + "JVM direct memory limit {} MB, {} cpus", + taskManagerParameters.containeredParameters().taskManagerTotalMemoryMB(), + taskManagerParameters.containeredParameters().taskManagerHeapSizeMB(), + taskManagerParameters.containeredParameters().taskManagerDirectMemoryLimitMB(), + taskManagerParameters.cpus()); + + return taskManagerParameters; + } + + public static ContainerSpecification createContainerSpec(Configuration configuration, Configuration dynamicProperties) + throws Exception { + // generate a container spec which conveys the artifacts/vars needed to launch a TM + ContainerSpecification spec = new ContainerSpecification(); + + // propagate the AM dynamic configuration to the TM + spec.getDynamicConfiguration().addAll(dynamicProperties); + + applyOverlays(configuration, spec); + + return spec; + } + + /** + * Generate a container specification as a TaskManager template. + * + *

This code is extremely Mesos-specific and registers all the artifacts that the TaskManager + * needs (such as JAR file, config file, ...) and all environment variables into a container specification. + * The Mesos fetcher then ensures that those artifacts will be copied into the task's sandbox directory. + * A lightweight HTTP server serves the artifacts to the fetcher. + */ + public static void applyOverlays( + Configuration configuration, ContainerSpecification containerSpec) throws IOException { + + // create the overlays that will produce the specification + CompositeContainerOverlay overlay = new CompositeContainerOverlay( + FlinkDistributionOverlay.newBuilder().fromEnvironment(configuration).build(), + HadoopConfOverlay.newBuilder().fromEnvironment(configuration).build(), + HadoopUserOverlay.newBuilder().fromEnvironment(configuration).build(), + KeytabOverlay.newBuilder().fromEnvironment(configuration).build(), + Krb5ConfOverlay.newBuilder().fromEnvironment(configuration).build(), + SSLStoreOverlay.newBuilder().fromEnvironment(configuration).build() + ); + + // apply the overlays + overlay.configure(containerSpec); + } + +} diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosJobClusterEntrypoint.java b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosJobClusterEntrypoint.java new file mode 100755 index 0000000000000..ba3b51db68d28 --- /dev/null +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosJobClusterEntrypoint.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.mesos.entrypoint; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.mesos.runtime.clusterframework.MesosResourceManager; +import org.apache.flink.mesos.runtime.clusterframework.MesosTaskManagerParameters; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServices; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServicesUtils; +import org.apache.flink.mesos.util.MesosConfiguration; +import org.apache.flink.runtime.blob.BlobServer; +import org.apache.flink.runtime.clusterframework.BootstrapTools; +import org.apache.flink.runtime.clusterframework.ContainerSpecification; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.entrypoint.JobClusterEntrypoint; +import org.apache.flink.runtime.heartbeat.HeartbeatServices; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.metrics.MetricRegistry; +import org.apache.flink.runtime.resourcemanager.ResourceManager; +import org.apache.flink.runtime.resourcemanager.ResourceManagerConfiguration; +import org.apache.flink.runtime.resourcemanager.ResourceManagerRuntimeServices; +import org.apache.flink.runtime.resourcemanager.ResourceManagerRuntimeServicesConfiguration; +import org.apache.flink.runtime.rpc.FatalErrorHandler; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkException; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.PosixParser; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Entry point for Mesos per-job clusters. + */ +public class MesosJobClusterEntrypoint extends JobClusterEntrypoint { + + public static final String JOB_GRAPH_FILE_PATH = "flink.jobgraph.path"; + + // ------------------------------------------------------------------------ + // Command-line options + // ------------------------------------------------------------------------ + + private static final Options ALL_OPTIONS; + + static { + ALL_OPTIONS = + new Options() + .addOption(BootstrapTools.newDynamicPropertiesOption()); + } + + private final Configuration dynamicProperties; + + private MesosConfiguration schedulerConfiguration; + + private MesosServices mesosServices; + + private MesosTaskManagerParameters taskManagerParameters; + + private ContainerSpecification taskManagerContainerSpec; + + public MesosJobClusterEntrypoint(Configuration config, Configuration dynamicProperties) { + super(config); + + this.dynamicProperties = Preconditions.checkNotNull(dynamicProperties); + } + + @Override + protected void initializeServices(Configuration config) throws Exception { + super.initializeServices(config); + + final String hostname = config.getString(JobManagerOptions.ADDRESS); + + // Mesos configuration + schedulerConfiguration = MesosEntrypointUtils.createMesosSchedulerConfiguration(config, hostname); + + // services + mesosServices = MesosServicesUtils.createMesosServices(config, hostname); + + // TM configuration + taskManagerParameters = MesosEntrypointUtils.createTmParameters(config, LOG); + taskManagerContainerSpec = MesosEntrypointUtils.createContainerSpec(config, dynamicProperties); + } + + @Override + protected void startClusterComponents(Configuration configuration, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, BlobServer blobServer, HeartbeatServices heartbeatServices, MetricRegistry metricRegistry) throws Exception { + super.startClusterComponents(configuration, rpcService, highAvailabilityServices, blobServer, heartbeatServices, metricRegistry); + } + + @Override + protected ResourceManager createResourceManager( + Configuration configuration, + ResourceID resourceId, + RpcService rpcService, + HighAvailabilityServices highAvailabilityServices, + HeartbeatServices heartbeatServices, + MetricRegistry metricRegistry, + FatalErrorHandler fatalErrorHandler) throws Exception { + final ResourceManagerConfiguration rmConfiguration = ResourceManagerConfiguration.fromConfiguration(configuration); + final ResourceManagerRuntimeServicesConfiguration rmServicesConfiguration = ResourceManagerRuntimeServicesConfiguration.fromConfiguration(configuration); + final ResourceManagerRuntimeServices rmRuntimeServices = ResourceManagerRuntimeServices.fromConfiguration( + rmServicesConfiguration, + highAvailabilityServices, + rpcService.getScheduledExecutor()); + + return new MesosResourceManager( + rpcService, + ResourceManager.RESOURCE_MANAGER_NAME, + resourceId, + rmConfiguration, + highAvailabilityServices, + heartbeatServices, + rmRuntimeServices.getSlotManager(), + metricRegistry, + rmRuntimeServices.getJobLeaderIdService(), + fatalErrorHandler, + configuration, + mesosServices, + schedulerConfiguration, + taskManagerParameters, + taskManagerContainerSpec + ); + } + + @Override + protected JobGraph retrieveJobGraph(Configuration configuration) throws FlinkException { + String jobGraphFile = configuration.getString(JOB_GRAPH_FILE_PATH, "job.graph"); + File fp = new File(jobGraphFile); + + try (FileInputStream input = new FileInputStream(fp); + ObjectInputStream obInput = new ObjectInputStream(input)) { + + return (JobGraph) obInput.readObject(); + } catch (FileNotFoundException e) { + throw new FlinkException("Could not find the JobGraph file.", e); + } catch (ClassNotFoundException | IOException e) { + throw new FlinkException("Could not load the JobGraph from file.", e); + } + } + + @Override + protected void stopClusterComponents(boolean cleanupHaData) throws Exception { + Throwable exception = null; + + try { + super.stopClusterComponents(cleanupHaData); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } + + if (mesosServices != null) { + try { + mesosServices.close(cleanupHaData); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } + } + + if (exception != null) { + throw new FlinkException("Could not properly shut down the Mesos job cluster entry point.", exception); + } + } + + public static void main(String[] args) { + // load configuration incl. dynamic properties + CommandLineParser parser = new PosixParser(); + CommandLine cmd; + try { + cmd = parser.parse(ALL_OPTIONS, args); + } + catch (Exception e){ + LOG.error("Could not parse the command-line options.", e); + System.exit(STARTUP_FAILURE_RETURN_CODE); + return; + } + + Configuration dynamicProperties = BootstrapTools.parseDynamicProperties(cmd); + Configuration configuration = GlobalConfiguration.loadConfigurationWithDynamicProperties(dynamicProperties); + + MesosJobClusterEntrypoint clusterEntrypoint = new MesosJobClusterEntrypoint(configuration, dynamicProperties); + + clusterEntrypoint.startCluster(); + } +} diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosSessionClusterEntrypoint.java b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosSessionClusterEntrypoint.java new file mode 100755 index 0000000000000..0ee2680a774f2 --- /dev/null +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosSessionClusterEntrypoint.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.mesos.entrypoint; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.mesos.runtime.clusterframework.MesosResourceManager; +import org.apache.flink.mesos.runtime.clusterframework.MesosTaskManagerParameters; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServices; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServicesUtils; +import org.apache.flink.mesos.util.MesosConfiguration; +import org.apache.flink.runtime.blob.BlobServer; +import org.apache.flink.runtime.clusterframework.BootstrapTools; +import org.apache.flink.runtime.clusterframework.ContainerSpecification; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.entrypoint.SessionClusterEntrypoint; +import org.apache.flink.runtime.heartbeat.HeartbeatServices; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.metrics.MetricRegistry; +import org.apache.flink.runtime.resourcemanager.ResourceManager; +import org.apache.flink.runtime.resourcemanager.ResourceManagerConfiguration; +import org.apache.flink.runtime.resourcemanager.ResourceManagerRuntimeServices; +import org.apache.flink.runtime.resourcemanager.ResourceManagerRuntimeServicesConfiguration; +import org.apache.flink.runtime.rpc.FatalErrorHandler; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.FlinkException; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.PosixParser; + +/** + * Entry point for Mesos session clusters. + */ +public class MesosSessionClusterEntrypoint extends SessionClusterEntrypoint { + + // ------------------------------------------------------------------------ + // Command-line options + // ------------------------------------------------------------------------ + + private static final Options ALL_OPTIONS; + + static { + ALL_OPTIONS = + new Options() + .addOption(BootstrapTools.newDynamicPropertiesOption()); + } + + private final Configuration dynamicProperties; + + private MesosConfiguration mesosConfig; + + private MesosServices mesosServices; + + private MesosTaskManagerParameters taskManagerParameters; + + private ContainerSpecification taskManagerContainerSpec; + + public MesosSessionClusterEntrypoint(Configuration config, Configuration dynamicProperties) { + super(config); + + this.dynamicProperties = Preconditions.checkNotNull(dynamicProperties); + } + + @Override + protected void initializeServices(Configuration config) throws Exception { + super.initializeServices(config); + + final String hostname = config.getString(JobManagerOptions.ADDRESS); + + // Mesos configuration + mesosConfig = MesosEntrypointUtils.createMesosSchedulerConfiguration(config, hostname); + + // services + mesosServices = MesosServicesUtils.createMesosServices(config, hostname); + + // TM configuration + taskManagerParameters = MesosEntrypointUtils.createTmParameters(config, LOG); + taskManagerContainerSpec = MesosEntrypointUtils.createContainerSpec(config, dynamicProperties); + } + + @Override + protected void startClusterComponents(Configuration configuration, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, BlobServer blobServer, HeartbeatServices heartbeatServices, MetricRegistry metricRegistry) throws Exception { + super.startClusterComponents(configuration, rpcService, highAvailabilityServices, blobServer, heartbeatServices, metricRegistry); + } + + @Override + protected ResourceManager createResourceManager( + Configuration configuration, + ResourceID resourceId, + RpcService rpcService, + HighAvailabilityServices highAvailabilityServices, + HeartbeatServices heartbeatServices, + MetricRegistry metricRegistry, + FatalErrorHandler fatalErrorHandler) throws Exception { + final ResourceManagerConfiguration rmConfiguration = ResourceManagerConfiguration.fromConfiguration(configuration); + final ResourceManagerRuntimeServicesConfiguration rmServicesConfiguration = ResourceManagerRuntimeServicesConfiguration.fromConfiguration(configuration); + final ResourceManagerRuntimeServices rmRuntimeServices = ResourceManagerRuntimeServices.fromConfiguration( + rmServicesConfiguration, + highAvailabilityServices, + rpcService.getScheduledExecutor()); + + return new MesosResourceManager( + rpcService, + ResourceManager.RESOURCE_MANAGER_NAME, + resourceId, + rmConfiguration, + highAvailabilityServices, + heartbeatServices, + rmRuntimeServices.getSlotManager(), + metricRegistry, + rmRuntimeServices.getJobLeaderIdService(), + fatalErrorHandler, + configuration, + mesosServices, + mesosConfig, + taskManagerParameters, + taskManagerContainerSpec + ); + } + + @Override + protected void stopClusterComponents(boolean cleanupHaData) throws Exception { + Throwable exception = null; + + try { + super.stopClusterComponents(cleanupHaData); + } catch (Throwable t) { + exception = t; + } + + if (mesosServices != null) { + try { + mesosServices.close(cleanupHaData); + } catch (Throwable t) { + exception = t; + } + } + + if (exception != null) { + throw new FlinkException("Could not properly shut down the Mesos session cluster entry point.", exception); + } + } + + public static void main(String[] args) { + // load configuration incl. dynamic properties + CommandLineParser parser = new PosixParser(); + CommandLine cmd; + try { + cmd = parser.parse(ALL_OPTIONS, args); + } + catch (Exception e){ + LOG.error("Could not parse the command-line options.", e); + System.exit(STARTUP_FAILURE_RETURN_CODE); + return; + } + + Configuration dynamicProperties = BootstrapTools.parseDynamicProperties(cmd); + Configuration configuration = GlobalConfiguration.loadConfigurationWithDynamicProperties(dynamicProperties); + + MesosSessionClusterEntrypoint clusterEntrypoint = new MesosSessionClusterEntrypoint(configuration, dynamicProperties); + + clusterEntrypoint.startCluster(); + } +} diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosTaskExecutorRunner.java b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosTaskExecutorRunner.java new file mode 100644 index 0000000000000..897e26d1b79f0 --- /dev/null +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/entrypoint/MesosTaskExecutorRunner.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.mesos.entrypoint; + +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.mesos.runtime.clusterframework.MesosConfigKeys; +import org.apache.flink.runtime.clusterframework.BootstrapTools; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.security.SecurityUtils; +import org.apache.flink.runtime.taskexecutor.TaskManagerRunner; +import org.apache.flink.runtime.util.EnvironmentInformation; +import org.apache.flink.runtime.util.JvmShutdownSafeguard; +import org.apache.flink.runtime.util.SignalHandler; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.PosixParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Callable; + +/** + * The entry point for running a TaskManager in a Mesos container. + */ +public class MesosTaskExecutorRunner { + + private static final Logger LOG = LoggerFactory.getLogger(MesosTaskExecutorRunner.class); + + private static final int INIT_ERROR_EXIT_CODE = 31; + + private static final Options ALL_OPTIONS; + + static { + ALL_OPTIONS = + new Options() + .addOption(BootstrapTools.newDynamicPropertiesOption()); + } + + public static void main(String[] args) throws Exception { + EnvironmentInformation.logEnvironmentInfo(LOG, MesosTaskExecutorRunner.class.getSimpleName(), args); + SignalHandler.register(LOG); + JvmShutdownSafeguard.installAsShutdownHook(LOG); + + // try to parse the command line arguments + CommandLineParser parser = new PosixParser(); + CommandLine cmd = parser.parse(ALL_OPTIONS, args); + + final Configuration configuration; + try { + Configuration dynamicProperties = BootstrapTools.parseDynamicProperties(cmd); + LOG.debug("Mesos dynamic properties: {}", dynamicProperties); + + configuration = GlobalConfiguration.loadConfigurationWithDynamicProperties(dynamicProperties); + } + catch (Throwable t) { + LOG.error("Failed to load the TaskManager configuration and dynamic properties.", t); + System.exit(INIT_ERROR_EXIT_CODE); + return; + } + + // read the environment variables + final Map envs = System.getenv(); + final String tmpDirs = envs.get(MesosConfigKeys.ENV_FLINK_TMP_DIR); + + // configure local directory + String flinkTempDirs = configuration.getString(ConfigConstants.TASK_MANAGER_TMP_DIR_KEY, null); + if (flinkTempDirs != null) { + LOG.info("Overriding Mesos temporary file directories with those " + + "specified in the Flink config: {}", flinkTempDirs); + } + else if (tmpDirs != null) { + LOG.info("Setting directories for temporary files to: {}", tmpDirs); + configuration.setString(ConfigConstants.TASK_MANAGER_TMP_DIR_KEY, tmpDirs); + } + + // configure the default filesystem + try { + FileSystem.setDefaultScheme(configuration); + } catch (IOException e) { + throw new IOException("Error while setting the default " + + "filesystem scheme from configuration.", e); + } + + // tell akka to die in case of an error + configuration.setBoolean(AkkaOptions.JVM_EXIT_ON_FATAL_ERROR, true); + + // Infer the resource identifier from the environment variable + String containerID = Preconditions.checkNotNull(envs.get(MesosConfigKeys.ENV_FLINK_CONTAINER_ID)); + final ResourceID resourceId = new ResourceID(containerID); + LOG.info("ResourceID assigned for this container: {}", resourceId); + + // Run the TM in the security context + SecurityUtils.SecurityConfiguration sc = new SecurityUtils.SecurityConfiguration(configuration); + SecurityUtils.install(sc); + + try { + SecurityUtils.getInstalledContext().runSecured(new Callable() { + @Override + public Integer call() throws Exception { + TaskManagerRunner.runTaskManager(configuration, resourceId); + + return 0; + } + }); + } + catch (Throwable t) { + LOG.error("Error while starting the TaskManager", t); + System.exit(INIT_ERROR_EXIT_CODE); + } + } +} diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/LaunchableMesosWorker.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/LaunchableMesosWorker.java index ce7bb9d6271fc..2c3250738027c 100644 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/LaunchableMesosWorker.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/LaunchableMesosWorker.java @@ -23,6 +23,7 @@ import org.apache.flink.mesos.Utils; import org.apache.flink.mesos.scheduler.LaunchableTask; import org.apache.flink.mesos.util.MesosArtifactResolver; +import org.apache.flink.mesos.util.MesosArtifactServer; import org.apache.flink.mesos.util.MesosConfiguration; import org.apache.flink.runtime.clusterframework.ContainerSpecification; import org.apache.flink.runtime.clusterframework.ContaineredTaskManagerParameters; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; @@ -261,12 +263,15 @@ public Protos.TaskInfo launch(Protos.SlaveID slaveId, TaskAssignmentResult assig env.addVariables(variable(MesosConfigKeys.ENV_FRAMEWORK_NAME, mesosConfiguration.frameworkInfo().getName())); // build the launch command w/ dynamic application properties - Option bootstrapCmdOption = params.bootstrapCommand(); - - final String bootstrapCommand = bootstrapCmdOption.isDefined() ? bootstrapCmdOption.get() + " && " : ""; - final String launchCommand = bootstrapCommand + "$FLINK_HOME/bin/mesos-taskmanager.sh " + ContainerSpecification.formatSystemProperties(dynamicProperties); - - cmd.setValue(launchCommand); + StringBuilder launchCommand = new StringBuilder(); + if (params.bootstrapCommand().isDefined()) { + launchCommand.append(params.bootstrapCommand().get()).append(" && "); + } + launchCommand + .append(params.command()) + .append(" ") + .append(ContainerSpecification.formatSystemProperties(dynamicProperties)); + cmd.setValue(launchCommand.toString()); // build the container info Protos.ContainerInfo.Builder containerInfo = Protos.ContainerInfo.newBuilder(); @@ -312,4 +317,17 @@ public String toString() { "taskRequest=" + taskRequest + '}'; } + + /** + * Configures an artifact server to serve the artifacts associated with a container specification. + * @param server the server to configure. + * @param container the container with artifacts to serve. + * @throws IOException if the artifacts cannot be accessed. + */ + static void configureArtifactServer(MesosArtifactServer server, ContainerSpecification container) throws IOException { + // serve the artifacts associated with the container environment + for (ContainerSpecification.Artifact artifact : container.getArtifacts()) { + server.addPath(artifact.source, artifact.dest); + } + } } diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosApplicationMasterRunner.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosApplicationMasterRunner.java old mode 100644 new mode 100755 index 7891386675e56..c0a68559501a9 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosApplicationMasterRunner.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosApplicationMasterRunner.java @@ -21,11 +21,10 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.GlobalConfiguration; -import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.configuration.WebOptions; import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.mesos.configuration.MesosOptions; +import org.apache.flink.mesos.entrypoint.MesosEntrypointUtils; import org.apache.flink.mesos.runtime.clusterframework.services.MesosServices; import org.apache.flink.mesos.runtime.clusterframework.services.MesosServicesUtils; import org.apache.flink.mesos.runtime.clusterframework.store.MesosWorkerStore; @@ -34,13 +33,6 @@ import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.clusterframework.BootstrapTools; import org.apache.flink.runtime.clusterframework.ContainerSpecification; -import org.apache.flink.runtime.clusterframework.overlays.CompositeContainerOverlay; -import org.apache.flink.runtime.clusterframework.overlays.FlinkDistributionOverlay; -import org.apache.flink.runtime.clusterframework.overlays.HadoopConfOverlay; -import org.apache.flink.runtime.clusterframework.overlays.HadoopUserOverlay; -import org.apache.flink.runtime.clusterframework.overlays.KeytabOverlay; -import org.apache.flink.runtime.clusterframework.overlays.Krb5ConfOverlay; -import org.apache.flink.runtime.clusterframework.overlays.SSLStoreOverlay; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils; import org.apache.flink.runtime.jobmanager.JobManager; @@ -65,7 +57,6 @@ import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.Options; import org.apache.commons.cli.PosixParser; -import org.apache.mesos.Protos; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,7 +64,6 @@ import java.net.InetAddress; import java.net.URL; import java.util.Map; -import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -81,7 +71,6 @@ import java.util.concurrent.TimeUnit; import scala.Option; -import scala.concurrent.duration.Duration; import scala.concurrent.duration.FiniteDuration; import static org.apache.flink.util.Preconditions.checkState; @@ -163,8 +152,7 @@ protected int run(final String[] args) { CommandLine cmd = parser.parse(ALL_OPTIONS, args); final Configuration dynamicProperties = BootstrapTools.parseDynamicProperties(cmd); - GlobalConfiguration.setDynamicProperties(dynamicProperties); - final Configuration config = GlobalConfiguration.loadConfiguration(); + final Configuration config = GlobalConfiguration.loadConfigurationWithDynamicProperties(dynamicProperties); // configure the default filesystem try { @@ -222,7 +210,7 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie LOG.info("App Master Hostname to use: {}", appMasterHostname); // Mesos configuration - final MesosConfiguration mesosConfig = createMesosConfig(config, appMasterHostname); + final MesosConfiguration mesosConfig = MesosEntrypointUtils.createMesosSchedulerConfiguration(config, appMasterHostname); // JM configuration int numberProcessors = Hardware.getNumberCPUCores(); @@ -235,19 +223,10 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie numberProcessors, new ExecutorThreadFactory("mesos-jobmanager-io")); - mesosServices = MesosServicesUtils.createMesosServices(config); + mesosServices = MesosServicesUtils.createMesosServices(config, appMasterHostname); // TM configuration - final MesosTaskManagerParameters taskManagerParameters = MesosTaskManagerParameters.create(config); - - LOG.info("TaskManagers will be created with {} task slots", - taskManagerParameters.containeredParameters().numSlots()); - LOG.info("TaskManagers will be started with container size {} MB, JVM heap size {} MB, " + - "JVM direct memory limit {} MB, {} cpus", - taskManagerParameters.containeredParameters().taskManagerTotalMemoryMB(), - taskManagerParameters.containeredParameters().taskManagerHeapSizeMB(), - taskManagerParameters.containeredParameters().taskManagerDirectMemoryLimitMB(), - taskManagerParameters.cpus()); + final MesosTaskManagerParameters taskManagerParameters = MesosEntrypointUtils.createTmParameters(config, LOG); // JM endpoint, which should be explicitly configured based on acquired net resources final int listeningPort = config.getInteger(JobManagerOptions.PORT); @@ -268,9 +247,7 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie // try to start the artifact server LOG.debug("Starting Artifact Server"); - final int artifactServerPort = config.getInteger(MesosOptions.ARTIFACT_SERVER_PORT); - final String artifactServerPrefix = UUID.randomUUID().toString(); - artifactServer = new MesosArtifactServer(artifactServerPrefix, akkaHostname, artifactServerPort, config); + artifactServer = mesosServices.getArtifactServer(); // ----------------- (3) Generate the configuration for the TaskManagers ------------------- @@ -287,10 +264,10 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie taskManagerContainerSpec.getDynamicConfiguration().addAll(taskManagerConfig); // apply the overlays - applyOverlays(config, taskManagerContainerSpec); + MesosEntrypointUtils.applyOverlays(config, taskManagerContainerSpec); // configure the artifact server to serve the specified artifacts - configureArtifactServer(artifactServer, taskManagerContainerSpec); + LaunchableMesosWorker.configureArtifactServer(artifactServer, taskManagerContainerSpec); // ----------------- (4) start the actors ------------------- @@ -386,14 +363,6 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie } } - if (artifactServer != null) { - try { - artifactServer.stop(); - } catch (Throwable ignored) { - LOG.error("Failed to stop the artifact server", ignored); - } - } - if (actorSystem != null) { try { actorSystem.shutdown(); @@ -444,12 +413,6 @@ protected int runPrivileged(Configuration config, Configuration dynamicPropertie } } - try { - artifactServer.stop(); - } catch (Throwable t) { - LOG.error("Failed to stop the artifact server", t); - } - if (highAvailabilityServices != null) { try { highAvailabilityServices.close(); @@ -490,85 +453,4 @@ protected Class getArchivistClass() { return MemoryArchivist.class; } - /** - * Loads and validates the ResourceManager Mesos configuration from the given Flink configuration. - */ - public static MesosConfiguration createMesosConfig(Configuration flinkConfig, String hostname) { - - Protos.FrameworkInfo.Builder frameworkInfo = Protos.FrameworkInfo.newBuilder() - .setHostname(hostname); - Protos.Credential.Builder credential = null; - - if (!flinkConfig.contains(MesosOptions.MASTER_URL)) { - throw new IllegalConfigurationException(MesosOptions.MASTER_URL.key() + " must be configured."); - } - String masterUrl = flinkConfig.getString(MesosOptions.MASTER_URL); - - Duration failoverTimeout = FiniteDuration.apply( - flinkConfig.getInteger( - MesosOptions.FAILOVER_TIMEOUT_SECONDS), - TimeUnit.SECONDS); - frameworkInfo.setFailoverTimeout(failoverTimeout.toSeconds()); - - frameworkInfo.setName(flinkConfig.getString( - MesosOptions.RESOURCEMANAGER_FRAMEWORK_NAME)); - - frameworkInfo.setRole(flinkConfig.getString( - MesosOptions.RESOURCEMANAGER_FRAMEWORK_ROLE)); - - frameworkInfo.setUser(flinkConfig.getString( - MesosOptions.RESOURCEMANAGER_FRAMEWORK_USER)); - - if (flinkConfig.contains(MesosOptions.RESOURCEMANAGER_FRAMEWORK_PRINCIPAL)) { - frameworkInfo.setPrincipal(flinkConfig.getString( - MesosOptions.RESOURCEMANAGER_FRAMEWORK_PRINCIPAL)); - - credential = Protos.Credential.newBuilder(); - credential.setPrincipal(frameworkInfo.getPrincipal()); - - // some environments use a side-channel to communicate the secret to Mesos, - // and thus don't set the 'secret' configuration setting - if (flinkConfig.contains(MesosOptions.RESOURCEMANAGER_FRAMEWORK_SECRET)) { - credential.setSecret(flinkConfig.getString( - MesosOptions.RESOURCEMANAGER_FRAMEWORK_SECRET)); - } - } - - MesosConfiguration mesos = - new MesosConfiguration(masterUrl, frameworkInfo, scala.Option.apply(credential)); - - return mesos; - } - - /** - * Generate a container specification as a TaskManager template. - * - *

This code is extremely Mesos-specific and registers all the artifacts that the TaskManager - * needs (such as JAR file, config file, ...) and all environment variables into a container specification. - * The Mesos fetcher then ensures that those artifacts will be copied into the task's sandbox directory. - * A lightweight HTTP server serves the artifacts to the fetcher. - */ - private static void applyOverlays( - Configuration globalConfiguration, ContainerSpecification containerSpec) throws IOException { - - // create the overlays that will produce the specification - CompositeContainerOverlay overlay = new CompositeContainerOverlay( - FlinkDistributionOverlay.newBuilder().fromEnvironment(globalConfiguration).build(), - HadoopConfOverlay.newBuilder().fromEnvironment(globalConfiguration).build(), - HadoopUserOverlay.newBuilder().fromEnvironment(globalConfiguration).build(), - KeytabOverlay.newBuilder().fromEnvironment(globalConfiguration).build(), - Krb5ConfOverlay.newBuilder().fromEnvironment(globalConfiguration).build(), - SSLStoreOverlay.newBuilder().fromEnvironment(globalConfiguration).build() - ); - - // apply the overlays - overlay.configure(containerSpec); - } - - private static void configureArtifactServer(MesosArtifactServer server, ContainerSpecification container) throws IOException { - // serve the artifacts associated with the container environment - for (ContainerSpecification.Artifact artifact : container.getArtifacts()) { - server.addPath(artifact.source, artifact.dest); - } - } } diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManager.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManager.java index 05d7e1f886f63..6335745004a18 100644 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManager.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManager.java @@ -192,7 +192,7 @@ protected ActorRef createConnectionMonitor() { protected ActorRef createTaskRouter() { return context().actorOf( - Tasks.createActorProps(Tasks.class, config, schedulerDriver, TaskMonitor.class), + Tasks.createActorProps(Tasks.class, self(), config, schedulerDriver, TaskMonitor.class), "tasks"); } diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManager.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManager.java index 736af59514d2e..8a8f20842f959 100644 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManager.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManager.java @@ -20,6 +20,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServices; import org.apache.flink.mesos.runtime.clusterframework.store.MesosWorkerStore; import org.apache.flink.mesos.scheduler.ConnectionMonitor; import org.apache.flink.mesos.scheduler.LaunchCoordinator; @@ -38,7 +39,7 @@ import org.apache.flink.mesos.scheduler.messages.ResourceOffers; import org.apache.flink.mesos.scheduler.messages.SlaveLost; import org.apache.flink.mesos.scheduler.messages.StatusUpdate; -import org.apache.flink.mesos.util.MesosArtifactResolver; +import org.apache.flink.mesos.util.MesosArtifactServer; import org.apache.flink.mesos.util.MesosConfiguration; import org.apache.flink.runtime.clusterframework.ApplicationStatus; import org.apache.flink.runtime.clusterframework.ContainerSpecification; @@ -48,7 +49,6 @@ import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; -import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.JobLeaderIdService; import org.apache.flink.runtime.resourcemanager.ResourceManager; @@ -75,6 +75,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -98,17 +99,20 @@ public class MesosResourceManager extends ResourceManager(taskManagerParameters.containeredParameters().taskManagerEnv())), taskManagerParameters.containerVolumes(), taskManagerParameters.constraints(), + taskManagerParameters.command(), taskManagerParameters.bootstrapCommand(), taskManagerParameters.getTaskManagerHostname() ); + LOG.debug("LaunchableMesosWorker parameters: {}", params); + LaunchableMesosWorker launchable = new LaunchableMesosWorker( - artifactResolver, + artifactServer, params, taskManagerContainerSpec, taskID, diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerParameters.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerParameters.java index f5a415e2ab243..3859913ecda3d 100644 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerParameters.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerParameters.java @@ -74,6 +74,10 @@ public class MesosTaskManagerParameters { key("mesos.resourcemanager.tasks.hostname") .noDefaultValue(); + public static final ConfigOption MESOS_TM_CMD = + key("mesos.resourcemanager.tasks.taskmanager-cmd") + .defaultValue("$FLINK_HOME/bin/mesos-taskmanager.sh"); // internal + public static final ConfigOption MESOS_TM_BOOTSTRAP_CMD = key("mesos.resourcemanager.tasks.bootstrap-cmd") .noDefaultValue(); @@ -107,6 +111,8 @@ public class MesosTaskManagerParameters { private final List constraints; + private final String command; + private final Option bootstrapCommand; private final Option taskManagerHostname; @@ -118,6 +124,7 @@ public MesosTaskManagerParameters( ContaineredTaskManagerParameters containeredParameters, List containerVolumes, List constraints, + String command, Option bootstrapCommand, Option taskManagerHostname) { @@ -127,6 +134,7 @@ public MesosTaskManagerParameters( this.containeredParameters = Preconditions.checkNotNull(containeredParameters); this.containerVolumes = Preconditions.checkNotNull(containerVolumes); this.constraints = Preconditions.checkNotNull(constraints); + this.command = Preconditions.checkNotNull(command); this.bootstrapCommand = Preconditions.checkNotNull(bootstrapCommand); this.taskManagerHostname = Preconditions.checkNotNull(taskManagerHostname); } @@ -182,6 +190,13 @@ public Option getTaskManagerHostname() { return taskManagerHostname; } + /** + * Get the command. + */ + public String command() { + return command; + } + /** * Get the bootstrap command. */ @@ -199,6 +214,7 @@ public String toString() { ", containerVolumes=" + containerVolumes + ", constraints=" + constraints + ", taskManagerHostName=" + taskManagerHostname + + ", command=" + command + ", bootstrapCommand=" + bootstrapCommand + '}'; } @@ -249,7 +265,8 @@ public static MesosTaskManagerParameters create(Configuration flinkConfig) { //obtain Task Manager Host Name from the configuration Option taskManagerHostname = Option.apply(flinkConfig.getString(MESOS_TM_HOSTNAME)); - //obtain bootstrap command from the configuration + //obtain command-line from the configuration + String tmCommand = flinkConfig.getString(MESOS_TM_CMD); Option tmBootstrapCommand = Option.apply(flinkConfig.getString(MESOS_TM_BOOTSTRAP_CMD)); return new MesosTaskManagerParameters( @@ -259,6 +276,7 @@ public static MesosTaskManagerParameters create(Configuration flinkConfig) { containeredParameters, containerVolumes, constraints, + tmCommand, tmBootstrapCommand, taskManagerHostname); } diff --git a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerRunner.java b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerRunner.java index e1b0efa6a3818..4236a4341d915 100644 --- a/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerRunner.java +++ b/flink-mesos/src/main/java/org/apache/flink/mesos/runtime/clusterframework/MesosTaskManagerRunner.java @@ -73,10 +73,9 @@ public static void runTaskManager(String[] args, final Class FlinkConfiguration} +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.clusterframework.ContaineredJobManager import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager @@ -34,7 +35,7 @@ import org.apache.flink.runtime.metrics.{MetricRegistry => FlinkMetricRegistry} import scala.concurrent.duration._ -/** JobManager actor for execution on Mesos. . +/** JobManager actor for execution on Mesos. * * @param flinkConfiguration Configuration object for the actor * @param futureExecutor Execution context which is used to execute concurrent tasks in the @@ -43,7 +44,8 @@ import scala.concurrent.duration._ * @param instanceManager Instance manager to manage the registered * [[org.apache.flink.runtime.taskmanager.TaskManager]] * @param scheduler Scheduler to schedule Flink jobs - * @param libraryCacheManager Manager to manage uploaded jar files + * @param blobServer BLOB store for file uploads + * @param libraryCacheManager manages uploaded jar files and class paths * @param archive Archive for finished Flink jobs * @param restartStrategyFactory Restart strategy to be used in case of a job recovery * @param timeout Timeout for futures @@ -55,6 +57,7 @@ class MesosJobManager( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: FlinkScheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -70,6 +73,7 @@ class MesosJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-mesos/src/main/scala/org/apache/flink/mesos/scheduler/Tasks.scala b/flink-mesos/src/main/scala/org/apache/flink/mesos/scheduler/Tasks.scala index 4f49c16134dce..54d1bd2179504 100644 --- a/flink-mesos/src/main/scala/org/apache/flink/mesos/scheduler/Tasks.scala +++ b/flink-mesos/src/main/scala/org/apache/flink/mesos/scheduler/Tasks.scala @@ -34,6 +34,7 @@ import scala.collection.mutable.{Map => MutableMap} * Routes messages between the scheduler and individual task monitor actors. */ class Tasks( + manager: ActorRef, flinkConfig: Configuration, schedulerDriver: SchedulerDriver, taskMonitorCreator: (ActorRefFactory,TaskGoalState) => ActorRef) extends Actor { @@ -92,11 +93,11 @@ class Tasks( } case msg: Reconcile => - context.parent.forward(msg) + manager.forward(msg) case msg: TaskTerminated => taskMap.remove(msg.taskID) - context.parent.forward(msg) + manager.forward(msg) } private def createTask(task: TaskGoalState): ActorRef = { @@ -113,6 +114,7 @@ object Tasks { */ def createActorProps[T <: Tasks, M <: TaskMonitor]( actorClass: Class[T], + manager: ActorRef, flinkConfig: Configuration, schedulerDriver: SchedulerDriver, taskMonitorClass: Class[M]): Props = { @@ -122,6 +124,6 @@ object Tasks { factory.actorOf(props) } - Props.create(actorClass, flinkConfig, schedulerDriver, taskMonitorCreator) + Props.create(actorClass, manager, flinkConfig, schedulerDriver, taskMonitorCreator) } } diff --git a/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManagerTest.java b/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManagerTest.java index 8bfb4d120c15d..ff324865274e7 100644 --- a/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManagerTest.java +++ b/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosFlinkResourceManagerTest.java @@ -251,6 +251,7 @@ public void initialize() { containeredParams, Collections.emptyList(), Collections.emptyList(), + "", Option.empty(), Option.empty()); diff --git a/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManagerTest.java b/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManagerTest.java index f9e35a9e6b2f2..dbd0746f3da77 100644 --- a/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManagerTest.java +++ b/flink-mesos/src/test/java/org/apache/flink/mesos/runtime/clusterframework/MesosResourceManagerTest.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; +import org.apache.flink.mesos.runtime.clusterframework.services.MesosServices; import org.apache.flink.mesos.runtime.clusterframework.store.MesosWorkerStore; import org.apache.flink.mesos.scheduler.ConnectionMonitor; import org.apache.flink.mesos.scheduler.LaunchCoordinator; @@ -32,7 +33,7 @@ import org.apache.flink.mesos.scheduler.messages.Registered; import org.apache.flink.mesos.scheduler.messages.ResourceOffers; import org.apache.flink.mesos.scheduler.messages.StatusUpdate; -import org.apache.flink.mesos.util.MesosArtifactResolver; +import org.apache.flink.mesos.util.MesosArtifactServer; import org.apache.flink.mesos.util.MesosConfiguration; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.clusterframework.ApplicationStatus; @@ -46,6 +47,7 @@ import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.JobMasterRegistrationSuccess; import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; @@ -53,12 +55,13 @@ import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.resourcemanager.JobLeaderIdService; import org.apache.flink.runtime.resourcemanager.ResourceManagerConfiguration; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.SlotRequest; import org.apache.flink.runtime.resourcemanager.slotmanager.ResourceManagerActions; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.taskexecutor.SlotReport; import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskexecutor.TaskExecutorRegistrationSuccess; @@ -77,6 +80,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.slf4j.Logger; @@ -159,17 +163,15 @@ public TestingMesosResourceManager( FatalErrorHandler fatalErrorHandler, // Mesos specifics - ActorSystem actorSystem, Configuration flinkConfig, + MesosServices mesosServices, MesosConfiguration mesosConfig, - MesosWorkerStore workerStore, MesosTaskManagerParameters taskManagerParameters, - ContainerSpecification taskManagerContainerSpec, - MesosArtifactResolver artifactResolver) { + ContainerSpecification taskManagerContainerSpec) { super(rpcService, resourceManagerEndpointId, resourceId, resourceManagerConfiguration, highAvailabilityServices, heartbeatServices, slotManager, metricRegistry, - jobLeaderIdService, fatalErrorHandler, actorSystem, flinkConfig, mesosConfig, workerStore, - taskManagerParameters, taskManagerContainerSpec, artifactResolver); + jobLeaderIdService, fatalErrorHandler, flinkConfig, mesosServices, mesosConfig, + taskManagerParameters, taskManagerContainerSpec); } @Override @@ -205,14 +207,15 @@ protected void closeTaskManagerConnection(ResourceID resourceID, Exception cause static class Context implements AutoCloseable { // services - TestingSerialRpcService rpcService; + TestingRpcService rpcService; TestingFatalErrorHandler fatalErrorHandler; MockMesosResourceManagerRuntimeServices rmServices; + MockMesosServices mesosServices; // RM ResourceManagerConfiguration rmConfiguration; ResourceID rmResourceID; - static final String RM_ADDRESS = "/resourceManager"; + static final String RM_ADDRESS = "resourceManager"; TestingMesosResourceManager resourceManager; // domain objects for test purposes @@ -239,9 +242,10 @@ static class Context implements AutoCloseable { * Create mock RM dependencies. */ Context() throws Exception { - rpcService = new TestingSerialRpcService(); + rpcService = new TestingRpcService(); fatalErrorHandler = new TestingFatalErrorHandler(); rmServices = new MockMesosResourceManagerRuntimeServices(); + mesosServices = new MockMesosServices(); // TaskExecutor templating ContainerSpecification containerSpecification = new ContainerSpecification(); @@ -249,7 +253,7 @@ static class Context implements AutoCloseable { new ContaineredTaskManagerParameters(1024, 768, 256, 4, new HashMap()); MesosTaskManagerParameters tmParams = new MesosTaskManagerParameters( 1.0, MesosTaskManagerParameters.ContainerType.MESOS, Option.empty(), containeredParams, - Collections.emptyList(), Collections.emptyList(), Option.empty(), + Collections.emptyList(), Collections.emptyList(), "", Option.empty(), Option.empty()); // resource manager @@ -270,13 +274,11 @@ static class Context implements AutoCloseable { rmServices.jobLeaderIdService, fatalErrorHandler, // Mesos specifics - system, flinkConfig, + mesosServices, rmServices.mesosConfig, - rmServices.workerStore, tmParams, - containerSpecification, - rmServices.artifactResolver + containerSpecification ); // TaskExecutors @@ -300,6 +302,7 @@ class MockResourceManagerRuntimeServices { public final TestingLeaderElectionService rmLeaderElectionService; public final JobLeaderIdService jobLeaderIdService; public final SlotManager slotManager; + public final CompletableFuture slotManagerStarted; public ResourceManagerActions rmActions; public UUID rmLeaderSessionId; @@ -312,6 +315,7 @@ class MockResourceManagerRuntimeServices { heartbeatServices = new TestingHeartbeatServices(5L, 5L, scheduledExecutor); metricRegistry = mock(MetricRegistry.class); slotManager = mock(SlotManager.class); + slotManagerStarted = new CompletableFuture<>(); jobLeaderIdService = new JobLeaderIdService( highAvailabilityServices, rpcService.getScheduledExecutor(), @@ -321,16 +325,17 @@ class MockResourceManagerRuntimeServices { @Override public Object answer(InvocationOnMock invocation) throws Throwable { rmActions = invocation.getArgumentAt(2, ResourceManagerActions.class); + slotManagerStarted.complete(true); return null; } - }).when(slotManager).start(any(UUID.class), any(Executor.class), any(ResourceManagerActions.class)); + }).when(slotManager).start(any(ResourceManagerId.class), any(Executor.class), any(ResourceManagerActions.class)); when(slotManager.registerSlotRequest(any(SlotRequest.class))).thenReturn(true); } - public void grantLeadership() { + public void grantLeadership() throws Exception { rmLeaderSessionId = UUID.randomUUID(); - rmLeaderElectionService.isLeader(rmLeaderSessionId); + rmLeaderElectionService.isLeader(rmLeaderSessionId).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -338,7 +343,7 @@ class MockMesosResourceManagerRuntimeServices extends MockResourceManagerRuntime public SchedulerDriver schedulerDriver; public MesosConfiguration mesosConfig; public MesosWorkerStore workerStore; - public MesosArtifactResolver artifactResolver; + public MesosArtifactServer artifactServer; MockMesosResourceManagerRuntimeServices() throws Exception { schedulerDriver = mock(SchedulerDriver.class); @@ -351,7 +356,28 @@ class MockMesosResourceManagerRuntimeServices extends MockResourceManagerRuntime workerStore = mock(MesosWorkerStore.class); when(workerStore.getFrameworkID()).thenReturn(Option.empty()); - artifactResolver = mock(MesosArtifactResolver.class); + artifactServer = mock(MesosArtifactServer.class); + } + } + + class MockMesosServices implements MesosServices { + @Override + public MesosWorkerStore createMesosWorkerStore(Configuration configuration, Executor executor) throws Exception { + return rmServices.workerStore; + } + + @Override + public ActorSystem getLocalActorSystem() { + return system; + } + + @Override + public MesosArtifactServer getArtifactServer() { + return rmServices.artifactServer; + } + + @Override + public void close(boolean cleanup) throws Exception { } } @@ -360,7 +386,7 @@ class MockJobMaster { public final ResourceID resourceID; public final String address; public final JobMasterGateway gateway; - public final UUID leaderSessionID; + public final JobMasterId jobMasterId; public final TestingLeaderRetrievalService leaderRetrievalService; MockJobMaster(JobID jobID) { @@ -368,8 +394,8 @@ class MockJobMaster { this.resourceID = new ResourceID(jobID.toString()); this.address = "/" + jobID; this.gateway = mock(JobMasterGateway.class); - this.leaderSessionID = UUID.randomUUID(); - this.leaderRetrievalService = new TestingLeaderRetrievalService(this.address, this.leaderSessionID); + this.jobMasterId = JobMasterId.generate(); + this.leaderRetrievalService = new TestingLeaderRetrievalService(this.address, this.jobMasterId.toUUID()); } } @@ -417,7 +443,7 @@ public void startResourceManager() throws Exception { */ public void registerJobMaster(MockJobMaster jobMaster) throws Exception { CompletableFuture registration = resourceManager.registerJobManager( - rmServices.rmLeaderSessionId, jobMaster.leaderSessionID, jobMaster.resourceID, jobMaster.address, jobMaster.jobID, timeout); + jobMaster.jobMasterId, jobMaster.resourceID, jobMaster.address, jobMaster.jobID, timeout); assertTrue(registration.get() instanceof JobMasterRegistrationSuccess); } @@ -426,11 +452,12 @@ public void registerJobMaster(MockJobMaster jobMaster) throws Exception { */ public MesosWorkerStore.Worker allocateWorker(Protos.TaskID taskID, ResourceProfile resourceProfile) throws Exception { when(rmServices.workerStore.newTaskID()).thenReturn(taskID); + rmServices.slotManagerStarted.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); rmServices.rmActions.allocateResource(resourceProfile); MesosWorkerStore.Worker expected = MesosWorkerStore.Worker.newWorker(taskID, resourceProfile); // drain the probe messages - verify(rmServices.workerStore).putWorker(expected); + verify(rmServices.workerStore, Mockito.timeout(timeout.toMilliseconds())).putWorker(expected); assertThat(resourceManager.workersInNew, hasEntry(extractResourceID(taskID), expected)); resourceManager.taskRouter.expectMsgClass(TaskMonitor.TaskGoalStateUpdated.class); resourceManager.launchCoordinator.expectMsgClass(LaunchCoordinator.Launch.class); @@ -501,12 +528,13 @@ public void testRequestNewWorkers() throws Exception { // allocate a worker when(rmServices.workerStore.newTaskID()).thenReturn(task1).thenThrow(new AssertionFailedError()); + rmServices.slotManagerStarted.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); rmServices.rmActions.allocateResource(resourceProfile1); // verify that a new worker was persisted, the internal state was updated, the task router was notified, // and the launch coordinator was asked to launch a task MesosWorkerStore.Worker expected = MesosWorkerStore.Worker.newWorker(task1, resourceProfile1); - verify(rmServices.workerStore).putWorker(expected); + verify(rmServices.workerStore, Mockito.timeout(timeout.toMilliseconds())).putWorker(expected); assertThat(resourceManager.workersInNew, hasEntry(extractResourceID(task1), expected)); resourceManager.taskRouter.expectMsgClass(TaskMonitor.TaskGoalStateUpdated.class); resourceManager.launchCoordinator.expectMsgClass(LaunchCoordinator.Launch.class); @@ -591,8 +619,8 @@ public void testWorkerStarted() throws Exception { // send registration message CompletableFuture successfulFuture = - resourceManager.registerTaskExecutor(rmServices.rmLeaderSessionId, task1Executor.address, task1Executor.resourceID, slotReport, timeout); - RegistrationResponse response = successfulFuture.get(5, TimeUnit.SECONDS); + resourceManager.registerTaskExecutor(task1Executor.address, task1Executor.resourceID, slotReport, timeout); + RegistrationResponse response = successfulFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); assertTrue(response instanceof TaskExecutorRegistrationSuccess); // verify the internal state @@ -627,6 +655,36 @@ public void testWorkerFailed() throws Exception { }}; } + /** + * Test planned stop of a launched worker. + */ + @Test + public void testStopWorker() throws Exception { + new Context() {{ + // set the initial persistent state with a launched worker + MesosWorkerStore.Worker worker1launched = MesosWorkerStore.Worker.newWorker(task1).launchWorker(slave1, slave1host); + when(rmServices.workerStore.getFrameworkID()).thenReturn(Option.apply(framework1)); + when(rmServices.workerStore.recoverWorkers()).thenReturn(singletonList(worker1launched)); + startResourceManager(); + + // drain the assign message + resourceManager.launchCoordinator.expectMsgClass(LaunchCoordinator.Assign.class); + + // tell the RM to stop the worker + resourceManager.stopWorker(extractResourceID(task1)); + + // verify that the instance state was updated + MesosWorkerStore.Worker worker1Released = worker1launched.releaseWorker(); + verify(rmServices.workerStore).putWorker(worker1Released); + assertThat(resourceManager.workersInLaunch.entrySet(), empty()); + assertThat(resourceManager.workersBeingReturned, hasEntry(extractResourceID(task1), worker1Released)); + + // verify that the monitor was notified + resourceManager.taskRouter.expectMsgClass(TaskMonitor.TaskGoalStateUpdated.class); + resourceManager.launchCoordinator.expectMsgClass(LaunchCoordinator.Unassign.class); + }}; + } + /** * Test application shutdown handling. */ diff --git a/flink-mesos/src/test/scala/org/apache/flink/mesos/scheduler/TasksTest.scala b/flink-mesos/src/test/scala/org/apache/flink/mesos/scheduler/TasksTest.scala index fcf2977af48ff..b3d9a5fddf169 100644 --- a/flink-mesos/src/test/scala/org/apache/flink/mesos/scheduler/TasksTest.scala +++ b/flink-mesos/src/test/scala/org/apache/flink/mesos/scheduler/TasksTest.scala @@ -93,7 +93,7 @@ class TasksTest taskActorRef } TestActorRef[Tasks]( - Props(classOf[Tasks], config, schedulerDriver, taskActorCreator), + Props(classOf[Tasks], testActor, config, schedulerDriver, taskActorCreator), testActor, TestFSMUtils.randomName) } diff --git a/flink-metrics/flink-metrics-datadog/src/main/java/org/apache/flink/metrics/datadog/DatadogHttpReporter.java b/flink-metrics/flink-metrics-datadog/src/main/java/org/apache/flink/metrics/datadog/DatadogHttpReporter.java index a47b2bf13e94d..b7e1c24fa6034 100644 --- a/flink-metrics/flink-metrics-datadog/src/main/java/org/apache/flink/metrics/datadog/DatadogHttpReporter.java +++ b/flink-metrics/flink-metrics-datadog/src/main/java/org/apache/flink/metrics/datadog/DatadogHttpReporter.java @@ -31,6 +31,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -140,6 +141,8 @@ public void report() { try { client.send(request); + } catch (SocketTimeoutException e) { + LOGGER.warn("Failed reporting metrics to Datadog because of socket timeout.", e.getMessage()); } catch (Exception e) { LOGGER.warn("Failed reporting metrics to Datadog.", e); } diff --git a/flink-optimizer/pom.xml b/flink-optimizer/pom.xml index b94e11e7f8acb..a507edecd0c59 100644 --- a/flink-optimizer/pom.xml +++ b/flink-optimizer/pom.xml @@ -57,9 +57,8 @@ under the License. - com.google.guava - guava - ${guava.version} + org.apache.flink + flink-shaded-guava diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/SingleInputNode.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/SingleInputNode.java index 5691d194ec152..964e2d609b9fe 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/SingleInputNode.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/SingleInputNode.java @@ -55,7 +55,7 @@ import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.apache.flink.util.Visitor; -import com.google.common.collect.Sets; +import org.apache.flink.shaded.guava18.com.google.common.collect.Sets; /** * A node in the optimizer's program representation for an operation with a single input. diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/TwoInputNode.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/TwoInputNode.java index a4199a80fdb72..48815dc9a1c3b 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/TwoInputNode.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/dag/TwoInputNode.java @@ -58,7 +58,7 @@ import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.apache.flink.util.Visitor; -import com.google.common.collect.Sets; +import org.apache.flink.shaded.guava18.com.google.common.collect.Sets; /** * A node in the optimizer plan that represents an operator with a two different inputs, such as Join, diff --git a/flink-quickstart/flink-quickstart-java/src/main/resources/archetype-resources/pom.xml b/flink-quickstart/flink-quickstart-java/src/main/resources/archetype-resources/pom.xml index 5da38cdef9503..c973a6a3406ce 100644 --- a/flink-quickstart/flink-quickstart-java/src/main/resources/archetype-resources/pom.xml +++ b/flink-quickstart/flink-quickstart-java/src/main/resources/archetype-resources/pom.xml @@ -92,7 +92,7 @@ under the License. ${flink.version} - org.slf4j diff --git a/flink-quickstart/flink-quickstart-scala/src/main/resources/archetype-resources/pom.xml b/flink-quickstart/flink-quickstart-scala/src/main/resources/archetype-resources/pom.xml index 67fe4c1d91da3..42d7cdb3a496a 100644 --- a/flink-quickstart/flink-quickstart-scala/src/main/resources/archetype-resources/pom.xml +++ b/flink-quickstart/flink-quickstart-scala/src/main/resources/archetype-resources/pom.xml @@ -93,7 +93,7 @@ under the License. ${flink.version} - org.slf4j diff --git a/flink-runtime-web/pom.xml b/flink-runtime-web/pom.xml index cf8f30784d09c..14c91c5fd8776 100644 --- a/flink-runtime-web/pom.xml +++ b/flink-runtime-web/pom.xml @@ -61,14 +61,13 @@ under the License. - org.javassist - javassist + org.apache.flink + flink-shaded-guava - com.google.guava - guava - ${guava.version} + org.javassist + javassist - - com.google.guava - guava - ${guava.version} - - org.scala-lang scala-library diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationNamespaceSerializerProxy.java b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationNamespaceSerializerProxy.java deleted file mode 100644 index c4e23ca27ffe4..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationNamespaceSerializerProxy.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration; - -import org.apache.flink.api.common.typeutils.CompatibilityResult; -import org.apache.flink.api.common.typeutils.ParameterlessTypeSerializerConfig; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; - -import java.io.IOException; -import java.io.Serializable; - -/** - * The purpose of this class is the be filled in as a placeholder for the namespace serializer when migrating from - * Flink 1.1 savepoint (which did not include the namespace serializer) to Flink 1.2 (which always must include a - * (non-null) namespace serializer. This is then replaced as soon as the user is re-registering her state again for - * the first run under Flink 1.2 and provides again the real namespace serializer. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class MigrationNamespaceSerializerProxy extends TypeSerializer { - - public static final MigrationNamespaceSerializerProxy INSTANCE = new MigrationNamespaceSerializerProxy(); - - private static final long serialVersionUID = -707800010807094491L; - - private MigrationNamespaceSerializerProxy() { - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public TypeSerializer duplicate() { - return this; - } - - @Override - public Serializable createInstance() { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public Serializable copy(Serializable from) { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public Serializable copy(Serializable from, Serializable reuse) { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(Serializable record, DataOutputView target) throws IOException { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public Serializable deserialize(DataInputView source) throws IOException { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public Serializable deserialize(Serializable reuse, DataInputView source) throws IOException { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - throw new UnsupportedOperationException( - "This is just a proxy used during migration until the real type serializer is provided by the user."); - } - - @Override - public TypeSerializerConfigSnapshot snapshotConfiguration() { - return new ParameterlessTypeSerializerConfig(getClass().getCanonicalName()); - } - - @Override - public CompatibilityResult ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { - // always assume compatibility since we're just a proxy for migration - return CompatibilityResult.compatible(); - } - - @Override - public boolean equals(Object obj) { - return obj instanceof MigrationNamespaceSerializerProxy; - } - - @Override - public boolean canEqual(Object obj) { - return true; - } - - @Override - public int hashCode() { - return 42; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/api/common/state/ListStateDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/migration/api/common/state/ListStateDescriptor.java deleted file mode 100644 index 5196d2dad6b96..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/api/common/state/ListStateDescriptor.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.api.common.state; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.StateBinder; -import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeutils.TypeSerializer; - -/** - * The old version of the {@link org.apache.flink.api.common.state.ListStateDescriptor}, retained for - * serialization backwards compatibility. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Internal -@Deprecated -@SuppressWarnings("deprecation") -public class ListStateDescriptor extends StateDescriptor, T> { - private static final long serialVersionUID = 1L; - - /** - * Creates a new {@code ListStateDescriptor} with the given name and list element type. - * - *

If this constructor fails (because it is not possible to describe the type via a class), - * consider using the {@link #ListStateDescriptor(String, TypeInformation)} constructor. - * - * @param name The (unique) name for the state. - * @param typeClass The type of the values in the state. - */ - public ListStateDescriptor(String name, Class typeClass) { - super(name, typeClass, null); - } - - /** - * Creates a new {@code ListStateDescriptor} with the given name and list element type. - * - * @param name The (unique) name for the state. - * @param typeInfo The type of the values in the state. - */ - public ListStateDescriptor(String name, TypeInformation typeInfo) { - super(name, typeInfo, null); - } - - /** - * Creates a new {@code ListStateDescriptor} with the given name and list element type. - * - * @param name The (unique) name for the state. - * @param typeSerializer The type serializer for the list values. - */ - public ListStateDescriptor(String name, TypeSerializer typeSerializer) { - super(name, typeSerializer, null); - } - - // ------------------------------------------------------------------------ - - @Override - public ListState bind(StateBinder stateBinder) throws Exception { - throw new IllegalStateException("Cannot bind states with a legacy state descriptor."); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - ListStateDescriptor that = (ListStateDescriptor) o; - - return serializer.equals(that.serializer) && name.equals(that.name); - - } - - @Override - public int hashCode() { - int result = serializer.hashCode(); - result = 31 * result + name.hashCode(); - return result; - } - - @Override - public String toString() { - return "ListStateDescriptor{" + - "serializer=" + serializer + - '}'; - } - - @Override - public org.apache.flink.api.common.state.StateDescriptor.Type getType() { - return org.apache.flink.api.common.state.StateDescriptor.Type.LIST; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/KeyGroupState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/KeyGroupState.java deleted file mode 100644 index 0b25e0855a563..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/KeyGroupState.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.checkpoint; - -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.util.SerializedValue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.Serializable; - -/** - * Simple container class which contains the serialized state handle for a key group. - * - * The key group state handle is kept in serialized form because it can contain user code classes - * which might not be available on the JobManager. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class KeyGroupState implements Serializable { - private static final long serialVersionUID = -5926696455438467634L; - - private static final Logger LOG = LoggerFactory.getLogger(KeyGroupState.class); - - private final SerializedValue> keyGroupState; - - private final long stateSize; - - private final long duration; - - public KeyGroupState(SerializedValue> keyGroupState, long stateSize, long duration) { - this.keyGroupState = keyGroupState; - - this.stateSize = stateSize; - - this.duration = duration; - } - - public SerializedValue> getKeyGroupState() { - return keyGroupState; - } - - public long getDuration() { - return duration; - } - - public long getStateSize() { - return stateSize; - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof KeyGroupState) { - KeyGroupState other = (KeyGroupState) obj; - - return keyGroupState.equals(other.keyGroupState) && stateSize == other.stateSize && - duration == other.duration; - } else { - return false; - } - } - - @Override - public int hashCode() { - return (int) (this.stateSize ^ this.stateSize >>> 32) + - 31 * ((int) (this.duration ^ this.duration >>> 32) + - 31 * keyGroupState.hashCode()); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/SubtaskState.java deleted file mode 100644 index d42d1467c705c..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/SubtaskState.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.checkpoint; - -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.util.SerializedValue; - -import java.io.Serializable; - -import static org.apache.flink.util.Preconditions.checkNotNull; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class SubtaskState implements Serializable { - - private static final long serialVersionUID = -2394696997971923995L; - - /** The state of the parallel operator */ - private final SerializedValue> state; - - /** - * The state size. This is also part of the deserialized state handle. - * We store it here in order to not deserialize the state handle when - * gathering stats. - */ - private final long stateSize; - - /** The duration of the acknowledged (ack timestamp - trigger timestamp). */ - private final long duration; - - public SubtaskState( - SerializedValue> state, - long stateSize, - long duration) { - - this.state = checkNotNull(state, "State"); - // Sanity check and don't fail checkpoint because of this. - this.stateSize = stateSize >= 0 ? stateSize : 0; - - this.duration = duration; - } - - // -------------------------------------------------------------------------------------------- - - public SerializedValue> getState() { - return state; - } - - public long getStateSize() { - return stateSize; - } - - public long getDuration() { - return duration; - } - - public void discard(ClassLoader userClassLoader) throws Exception { - - } - - // -------------------------------------------------------------------------------------------- - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - else if (o instanceof SubtaskState) { - SubtaskState that = (SubtaskState) o; - return this.state.equals(that.state) && stateSize == that.stateSize && - duration == that.duration; - } - else { - return false; - } - } - - @Override - public int hashCode() { - return (int) (this.stateSize ^ this.stateSize >>> 32) + - 31 * ((int) (this.duration ^ this.duration >>> 32) + - 31 * state.hashCode()); - } - - @Override - public String toString() { - return String.format("SubtaskState(Size: %d, Duration: %d, State: %s)", stateSize, duration, state); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/TaskState.java deleted file mode 100644 index c0a7b2d39733b..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/TaskState.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.checkpoint; - -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.util.SerializedValue; -import org.apache.flink.runtime.jobgraph.JobVertexID; - -import java.io.Serializable; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class TaskState implements Serializable { - - private static final long serialVersionUID = -4845578005863201810L; - - private final JobVertexID jobVertexID; - - /** Map of task states which can be accessed by their sub task index */ - private final Map subtaskStates; - - /** Map of key-value states which can be accessed by their key group index */ - private final Map kvStates; - - /** Parallelism of the operator when it was checkpointed */ - private final int parallelism; - - public TaskState(JobVertexID jobVertexID, int parallelism) { - this.jobVertexID = jobVertexID; - - this.subtaskStates = new HashMap<>(parallelism); - - this.kvStates = new HashMap<>(); - - this.parallelism = parallelism; - } - - public JobVertexID getJobVertexID() { - return jobVertexID; - } - - public void putState(int subtaskIndex, SubtaskState subtaskState) { - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + subtaskStates.size()); - } else { - subtaskStates.put(subtaskIndex, subtaskState); - } - } - - public SubtaskState getState(int subtaskIndex) { - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + subtaskStates.size()); - } else { - return subtaskStates.get(subtaskIndex); - } - } - - public Collection getStates() { - return subtaskStates.values(); - } - - public Map getSubtaskStatesById() { - return subtaskStates; - } - - public long getStateSize() { - long result = 0L; - - for (SubtaskState subtaskState : subtaskStates.values()) { - result += subtaskState.getStateSize(); - } - - for (KeyGroupState keyGroupState : kvStates.values()) { - result += keyGroupState.getStateSize(); - } - - return result; - } - - public int getNumberCollectedStates() { - return subtaskStates.size(); - } - - public int getParallelism() { - return parallelism; - } - - public void putKvState(int keyGroupId, KeyGroupState keyGroupState) { - kvStates.put(keyGroupId, keyGroupState); - } - - public KeyGroupState getKvState(int keyGroupId) { - return kvStates.get(keyGroupId); - } - - /** - * Retrieve the set of key-value state key groups specified by the given key group partition set. - * The key groups are returned as a map where the key group index maps to the serialized state - * handle of the key group. - * - * @param keyGroupPartition Set of key group indices - * @return Map of serialized key group state handles indexed by their key group index. - */ - public Map>> getUnwrappedKvStates(Set keyGroupPartition) { - HashMap>> result = new HashMap<>(keyGroupPartition.size()); - - for (Integer keyGroupId : keyGroupPartition) { - KeyGroupState keyGroupState = kvStates.get(keyGroupId); - - if (keyGroupState != null) { - result.put(keyGroupId, kvStates.get(keyGroupId).getKeyGroupState()); - } - } - - return result; - } - - public int getNumberCollectedKvStates() { - return kvStates.size(); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof TaskState) { - TaskState other = (TaskState) obj; - - return jobVertexID.equals(other.jobVertexID) && parallelism == other.parallelism && - subtaskStates.equals(other.subtaskStates) && kvStates.equals(other.kvStates); - } else { - return false; - } - } - - @Override - public int hashCode() { - return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, kvStates); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java deleted file mode 100644 index 7888d2fe6774d..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.checkpoint.savepoint; - -import org.apache.flink.migration.runtime.checkpoint.TaskState; -import org.apache.flink.runtime.checkpoint.MasterState; -import org.apache.flink.runtime.checkpoint.OperatorState; -import org.apache.flink.runtime.checkpoint.savepoint.Savepoint; -import org.apache.flink.util.Preconditions; - -import java.util.Collection; - -/** - * Savepoint version 0. - * - *

This format was introduced with Flink 1.1.0. - */ -@SuppressWarnings("deprecation") -public class SavepointV0 implements Savepoint { - - /** The savepoint version. */ - public static final int VERSION = 0; - - /** The checkpoint ID */ - private final long checkpointId; - - /** The task states */ - private final Collection taskStates; - - public SavepointV0(long checkpointId, Collection taskStates) { - this.checkpointId = checkpointId; - this.taskStates = Preconditions.checkNotNull(taskStates, "Task States"); - } - - @Override - public int getVersion() { - return VERSION; - } - - @Override - public long getCheckpointId() { - return checkpointId; - } - - @Override - public Collection getTaskStates() { - // since checkpoints are never deserialized into this format, - // this method should never be called - throw new UnsupportedOperationException(); - } - - @Override - public Collection getMasterStates() { - // since checkpoints are never deserialized into this format, - // this method should never be called - throw new UnsupportedOperationException(); - } - - @Override - public Collection getOperatorStates() { - return null; - } - - @Override - public void dispose() throws Exception { - //NOP - } - - - public Collection getOldTaskStates() { - return taskStates; - } - - @Override - public String toString() { - return "Savepoint(version=" + VERSION + ")"; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - - if (o == null || getClass() != o.getClass()) { - return false; - } - - SavepointV0 that = (SavepointV0) o; - return checkpointId == that.checkpointId && getTaskStates().equals(that.getTaskStates()); - } - - @Override - public int hashCode() { - int result = (int) (checkpointId ^ (checkpointId >>> 32)); - result = 31 * result + taskStates.hashCode(); - return result; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java deleted file mode 100644 index d285906262565..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java +++ /dev/null @@ -1,425 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.checkpoint.savepoint; - -import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.checkpoint.KeyGroupState; -import org.apache.flink.migration.runtime.checkpoint.SubtaskState; -import org.apache.flink.migration.runtime.checkpoint.TaskState; -import org.apache.flink.migration.runtime.state.AbstractStateBackend; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.runtime.state.filesystem.AbstractFileStateHandle; -import org.apache.flink.migration.runtime.state.memory.SerializedStateHandle; -import org.apache.flink.migration.state.MigrationKeyGroupStateHandle; -import org.apache.flink.migration.state.MigrationStreamStateHandle; -import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState; -import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskStateList; -import org.apache.flink.migration.util.SerializedValue; -import org.apache.flink.runtime.checkpoint.savepoint.SavepointSerializer; -import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.KeyGroupRangeOffsets; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.MultiStreamStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.filesystem.FileStateHandle; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; -import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; -import org.apache.flink.util.IOUtils; -import org.apache.flink.util.InstantiationUtil; -import org.apache.flink.util.Preconditions; - -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - *

- *

In contrast to previous savepoint versions, this serializer makes sure - * that no default Java serialization is used for serialization. Therefore, we - * don't rely on any involved Java classes to stay the same. - */ -@SuppressWarnings("deprecation") -public class SavepointV0Serializer implements SavepointSerializer { - - public static final SavepointV0Serializer INSTANCE = new SavepointV0Serializer(); - private static final StreamStateHandle SIGNAL_0 = new ByteStreamStateHandle("SIGNAL_0", new byte[]{0}); - private static final StreamStateHandle SIGNAL_1 = new ByteStreamStateHandle("SIGNAL_1", new byte[]{1}); - - private static final int MAX_SIZE = 4 * 1024 * 1024; - - private SavepointV0Serializer() { - } - - - @Override - public void serialize(SavepointV2 savepoint, DataOutputStream dos) throws IOException { - throw new UnsupportedOperationException("This serializer is read-only and only exists for backwards compatibility"); - } - - @Override - public SavepointV2 deserialize(DataInputStream dis, ClassLoader userClassLoader) throws IOException { - - long checkpointId = dis.readLong(); - - // Task states - int numTaskStates = dis.readInt(); - List taskStates = new ArrayList<>(numTaskStates); - - for (int i = 0; i < numTaskStates; i++) { - JobVertexID jobVertexId = new JobVertexID(dis.readLong(), dis.readLong()); - int parallelism = dis.readInt(); - - // Add task state - TaskState taskState = new TaskState(jobVertexId, parallelism); - taskStates.add(taskState); - - // Sub task states - int numSubTaskStates = dis.readInt(); - for (int j = 0; j < numSubTaskStates; j++) { - int subtaskIndex = dis.readInt(); - - SerializedValue> serializedValue = readSerializedValueStateHandle(dis); - - long stateSize = dis.readLong(); - long duration = dis.readLong(); - - SubtaskState subtaskState = new SubtaskState( - serializedValue, - stateSize, - duration); - - taskState.putState(subtaskIndex, subtaskState); - } - - // Key group states - int numKvStates = dis.readInt(); - for (int j = 0; j < numKvStates; j++) { - int keyGroupIndex = dis.readInt(); - - SerializedValue> serializedValue = readSerializedValueStateHandle(dis); - - long stateSize = dis.readLong(); - long duration = dis.readLong(); - - KeyGroupState keyGroupState = new KeyGroupState( - serializedValue, - stateSize, - duration); - - taskState.putKvState(keyGroupIndex, keyGroupState); - } - } - - try { - - return convertSavepoint(taskStates, userClassLoader, checkpointId); - } catch (Exception e) { - - throw new IOException(e); - } - } - - private static SerializedValue> readSerializedValueStateHandle(DataInputStream dis) - throws IOException { - - int length = dis.readInt(); - - SerializedValue> serializedValue; - if (length == -1) { - serializedValue = new SerializedValue<>(null); - } else { - byte[] serializedData = new byte[length]; - dis.readFully(serializedData, 0, length); - serializedValue = SerializedValue.fromBytes(serializedData); - } - - return serializedValue; - } - - private SavepointV2 convertSavepoint( - List taskStates, - ClassLoader userClassLoader, - long checkpointID) throws Exception { - - List newTaskStates = new ArrayList<>(taskStates.size()); - - for (TaskState taskState : taskStates) { - newTaskStates.add(convertTaskState(taskState, userClassLoader, checkpointID)); - } - - return new SavepointV2(checkpointID, newTaskStates); - } - - private org.apache.flink.runtime.checkpoint.TaskState convertTaskState( - TaskState taskState, - ClassLoader userClassLoader, - long checkpointID) throws Exception { - - JobVertexID jobVertexID = taskState.getJobVertexID(); - int parallelism = taskState.getParallelism(); - int chainLength = determineOperatorChainLength(taskState, userClassLoader); - - org.apache.flink.runtime.checkpoint.TaskState newTaskState = - new org.apache.flink.runtime.checkpoint.TaskState( - jobVertexID, - parallelism, - parallelism, - chainLength); - - if (chainLength > 0) { - - Map subtaskStates = taskState.getSubtaskStatesById(); - - for (Map.Entry subtaskState : subtaskStates.entrySet()) { - int parallelInstanceIdx = subtaskState.getKey(); - newTaskState.putState(parallelInstanceIdx, convertSubtaskState( - subtaskState.getValue(), - parallelInstanceIdx, - userClassLoader, - checkpointID)); - } - } - - return newTaskState; - } - - private org.apache.flink.runtime.checkpoint.SubtaskState convertSubtaskState( - SubtaskState subtaskState, - int parallelInstanceIdx, - ClassLoader userClassLoader, - long checkpointID) throws Exception { - - SerializedValue> serializedValue = subtaskState.getState(); - - StreamTaskStateList stateList = (StreamTaskStateList) serializedValue.deserializeValue(userClassLoader); - StreamTaskState[] streamTaskStates = stateList.getState(userClassLoader); - - List newChainStateList = Arrays.asList(new StreamStateHandle[streamTaskStates.length]); - KeyGroupsStateHandle newKeyedState = null; - - for (int chainIdx = 0; chainIdx < streamTaskStates.length; ++chainIdx) { - - StreamTaskState streamTaskState = streamTaskStates[chainIdx]; - if (streamTaskState == null) { - continue; - } - - newChainStateList.set(chainIdx, convertOperatorAndFunctionState(streamTaskState)); - HashMap> oldKeyedState = streamTaskState.getKvStates(); - - if (null != oldKeyedState) { - Preconditions.checkState(null == newKeyedState, "Found more than one keyed state in chain"); - newKeyedState = convertKeyedBackendState(oldKeyedState, parallelInstanceIdx, checkpointID); - } - } - - ChainedStateHandle newChainedState = new ChainedStateHandle<>(newChainStateList); - ChainedStateHandle nopChain = - new ChainedStateHandle<>(Arrays.asList(new OperatorStateHandle[newChainedState.getLength()])); - - return new org.apache.flink.runtime.checkpoint.SubtaskState( - newChainedState, - nopChain, - nopChain, - newKeyedState, - null); - } - - /** - * This is public so that we can use it when restoring a legacy snapshot - * in {@code AbstractStreamOperatorTestHarness}. - */ - public static StreamStateHandle convertOperatorAndFunctionState(StreamTaskState streamTaskState) throws Exception { - - List mergeStateHandles = new ArrayList<>(4); - - StateHandle functionState = streamTaskState.getFunctionState(); - StateHandle operatorState = streamTaskState.getOperatorState(); - - if (null != functionState) { - mergeStateHandles.add(SIGNAL_1); - mergeStateHandles.add(convertStateHandle(functionState)); - } else { - mergeStateHandles.add(SIGNAL_0); - } - - if (null != operatorState) { - mergeStateHandles.add(convertStateHandle(operatorState)); - } - - return new MigrationStreamStateHandle(new MultiStreamStateHandle(mergeStateHandles)); - } - - /** - * This is public so that we can use it when restoring a legacy snapshot - * in {@code AbstractStreamOperatorTestHarness}. - */ - public static KeyGroupsStateHandle convertKeyedBackendState( - HashMap> oldKeyedState, - int parallelInstanceIdx, - long checkpointID) throws Exception { - - if (null != oldKeyedState) { - - CheckpointStreamFactory checkpointStreamFactory = new MemCheckpointStreamFactory(MAX_SIZE); - - CheckpointStreamFactory.CheckpointStateOutputStream keyedStateOut = - checkpointStreamFactory.createCheckpointStateOutputStream(checkpointID, 0L); - - try { - final long offset = keyedStateOut.getPos(); - - InstantiationUtil.serializeObject(keyedStateOut, oldKeyedState); - StreamStateHandle streamStateHandle = keyedStateOut.closeAndGetHandle(); - keyedStateOut = null; // makes IOUtils.closeQuietly(...) ignore this - - if (null != streamStateHandle) { - KeyGroupRangeOffsets keyGroupRangeOffsets = - new KeyGroupRangeOffsets(parallelInstanceIdx, parallelInstanceIdx, new long[]{offset}); - - return new MigrationKeyGroupStateHandle(keyGroupRangeOffsets, streamStateHandle); - } - } finally { - IOUtils.closeQuietly(keyedStateOut); - } - } - return null; - } - - private int determineOperatorChainLength( - TaskState taskState, - ClassLoader userClassLoader) throws IOException, ClassNotFoundException { - - Collection subtaskStates = taskState.getStates(); - - if (subtaskStates == null || subtaskStates.isEmpty()) { - return 0; - } - - SubtaskState firstSubtaskState = subtaskStates.iterator().next(); - Object toCastTaskStateList = firstSubtaskState.getState().deserializeValue(userClassLoader); - - if (toCastTaskStateList instanceof StreamTaskStateList) { - StreamTaskStateList taskStateList = (StreamTaskStateList) toCastTaskStateList; - StreamTaskState[] streamTaskStates = taskStateList.getState(userClassLoader); - - return streamTaskStates.length; - } - return 0; - } - - /** - * This is public so that we can use it when restoring a legacy snapshot - * in {@code AbstractStreamOperatorTestHarness}. - */ - public static StreamStateHandle convertStateHandle(StateHandle oldStateHandle) throws Exception { - if (oldStateHandle instanceof AbstractFileStateHandle) { - Path path = ((AbstractFileStateHandle) oldStateHandle).getFilePath(); - return new FileStateHandle(path, oldStateHandle.getStateSize()); - } else if (oldStateHandle instanceof SerializedStateHandle) { - byte[] data = ((SerializedStateHandle) oldStateHandle).getSerializedData(); - return new ByteStreamStateHandle(String.valueOf(System.identityHashCode(data)), data); - } else if (oldStateHandle instanceof org.apache.flink.migration.runtime.state.memory.ByteStreamStateHandle) { - byte[] data = - ((org.apache.flink.migration.runtime.state.memory.ByteStreamStateHandle) oldStateHandle).getData(); - return new ByteStreamStateHandle(String.valueOf(System.identityHashCode(data)), data); - } else if (oldStateHandle instanceof AbstractStateBackend.DataInputViewHandle) { - return convertStateHandle( - ((AbstractStateBackend.DataInputViewHandle) oldStateHandle).getStreamStateHandle()); - } - throw new IllegalArgumentException("Unknown state handle type: " + oldStateHandle); - } - - @VisibleForTesting - public void serializeOld(SavepointV0 savepoint, DataOutputStream dos) throws IOException { - dos.writeLong(savepoint.getCheckpointId()); - - Collection taskStates = savepoint.getOldTaskStates(); - dos.writeInt(taskStates.size()); - - for (org.apache.flink.migration.runtime.checkpoint.TaskState taskState : savepoint.getOldTaskStates()) { - // Vertex ID - dos.writeLong(taskState.getJobVertexID().getLowerPart()); - dos.writeLong(taskState.getJobVertexID().getUpperPart()); - - // Parallelism - int parallelism = taskState.getParallelism(); - dos.writeInt(parallelism); - - // Sub task states - dos.writeInt(taskState.getNumberCollectedStates()); - - for (int i = 0; i < parallelism; i++) { - SubtaskState subtaskState = taskState.getState(i); - - if (subtaskState != null) { - dos.writeInt(i); - - SerializedValue serializedValue = subtaskState.getState(); - if (serializedValue == null) { - dos.writeInt(-1); // null - } else { - byte[] serialized = serializedValue.getByteArray(); - dos.writeInt(serialized.length); - dos.write(serialized, 0, serialized.length); - } - - dos.writeLong(subtaskState.getStateSize()); - dos.writeLong(subtaskState.getDuration()); - } - } - - // Key group states - dos.writeInt(taskState.getNumberCollectedKvStates()); - - for (int i = 0; i < parallelism; i++) { - KeyGroupState keyGroupState = taskState.getKvState(i); - - if (keyGroupState != null) { - dos.write(i); - - SerializedValue serializedValue = keyGroupState.getKeyGroupState(); - if (serializedValue == null) { - dos.writeInt(-1); // null - } else { - byte[] serialized = serializedValue.getByteArray(); - dos.writeInt(serialized.length); - dos.write(serialized, 0, serialized.length); - } - - dos.writeLong(keyGroupState.getStateSize()); - dos.writeLong(keyGroupState.getDuration()); - } - } - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractCloseableHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractCloseableHandle.java deleted file mode 100644 index 775b304748007..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractCloseableHandle.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state; - -import java.io.Closeable; -import java.io.IOException; -import java.io.Serializable; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; - -/** - * A simple base for closable handles. - * - * Offers to register a stream (or other closable object) that close calls are delegated to if - * the handle is closed or was already closed. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public abstract class AbstractCloseableHandle implements Closeable, Serializable { - - /** Serial Version UID must be constant to maintain format compatibility */ - private static final long serialVersionUID = 1L; - - /** To atomically update the "closable" field without needing to add a member class like "AtomicBoolean */ - private static final AtomicIntegerFieldUpdater CLOSER = - AtomicIntegerFieldUpdater.newUpdater(AbstractCloseableHandle.class, "isClosed"); - - // ------------------------------------------------------------------------ - - /** The closeable to close if this handle is closed late */ - private transient volatile Closeable toClose; - - /** Flag to remember if this handle was already closed */ - @SuppressWarnings("unused") // this field is actually updated, but via the "CLOSER" updater - private transient volatile int isClosed; - - // ------------------------------------------------------------------------ - - protected final void registerCloseable(Closeable toClose) throws IOException { - if (toClose == null) { - return; - } - - // NOTE: The order of operations matters here: - // (1) first setting the closeable - // (2) checking the flag. - // Because the order in the {@link #close()} method is the opposite, and - // both variables are volatile (reordering barriers), we can be sure that - // one of the methods always notices the effect of a concurrent call to the - // other method. - - this.toClose = toClose; - - // check if we were closed early - if (this.isClosed != 0) { - toClose.close(); - throw new IOException("handle is closed"); - } - } - - /** - * Closes the handle. - * - *

If a "Closeable" has been registered via {@link #registerCloseable(Closeable)}, - * then this will be closes. - * - *

If any "Closeable" will be registered via {@link #registerCloseable(Closeable)} in the future, - * it will immediately be closed and that method will throw an exception. - * - * @throws IOException Exceptions occurring while closing an already registered {@code Closeable} - * are forwarded. - * - * @see #registerCloseable(Closeable) - */ - @Override - public final void close() throws IOException { - // NOTE: The order of operations matters here: - // (1) first setting the closed flag - // (2) checking whether there is already a closeable - // Because the order in the {@link #registerCloseable(Closeable)} method is the opposite, and - // both variables are volatile (reordering barriers), we can be sure that - // one of the methods always notices the effect of a concurrent call to the - // other method. - - if (CLOSER.compareAndSet(this, 0, 1)) { - final Closeable toClose = this.toClose; - if (toClose != null) { - this.toClose = null; - toClose.close(); - } - } - } - - /** - * Checks whether this handle has been closed. - * - * @return True is the handle is closed, false otherwise. - */ - public boolean isClosed() { - return isClosed != 0; - } - - /** - * This method checks whether the handle is closed and throws an exception if it is closed. - * If the handle is not closed, this method does nothing. - * - * @throws IOException Thrown, if the handle has been closed. - */ - public void ensureNotClosed() throws IOException { - if (isClosed != 0) { - throw new IOException("handle is closed"); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractStateBackend.java deleted file mode 100644 index 7c53c406270c2..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/AbstractStateBackend.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state; - -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; - -import java.io.IOException; -import java.io.Serializable; - -/** - * A state backend defines how state is stored and snapshotted during checkpoints. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public abstract class AbstractStateBackend implements Serializable { - - private static final long serialVersionUID = 4620413814639220247L; - - /** - * Simple state handle that resolved a {@link DataInputView} from a StreamStateHandle. - */ - public static final class DataInputViewHandle implements StateHandle { - - private static final long serialVersionUID = 2891559813513532079L; - - private final StreamStateHandle stream; - - private DataInputViewHandle(StreamStateHandle stream) { - this.stream = stream; - } - - public StreamStateHandle getStreamStateHandle() { - return stream; - } - - @Override - public DataInputView getState(ClassLoader userCodeClassLoader) throws Exception { - return new DataInputViewStreamWrapper(stream.getState(userCodeClassLoader)); - } - - @Override - public void discardState() throws Exception { - throw new UnsupportedOperationException(); - } - - @Override - public long getStateSize() throws Exception { - return stream.getStateSize(); - } - - @Override - public void close() throws IOException { - throw new UnsupportedOperationException(); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateHandle.java deleted file mode 100644 index fd3917f2b735f..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateHandle.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state; - -/** - * StateHandle is a general handle interface meant to abstract operator state fetching. - * A StateHandle implementation can for example include the state itself in cases where the state - * is lightweight or fetching it lazily from some external storage when the state is too large. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public interface StateHandle extends StateObject { - - /** - * This retrieves and return the state represented by the handle. - * - * @param userCodeClassLoader Class loader for deserializing user code specific classes - * - * @return The state represented by the handle. - * @throws Exception Thrown, if the state cannot be fetched. - */ - T getState(ClassLoader userCodeClassLoader) throws Exception; -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateObject.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateObject.java deleted file mode 100644 index 59bc0ca948e16..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StateObject.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state; - -/** - * Base of all types that represent checkpointed state. Specializations are for - * example {@link StateHandle StateHandles} (directly resolve to state) and - * {@link KvStateSnapshot key/value state snapshots}. - * - *

State objects define how to: - *

    - *
  • Discard State: The {@link #discardState()} method defines how state is permanently - * disposed/deleted. After that method call, state may not be recoverable any more.
  • - - *
  • Close the current state access: The {@link #close()} method defines how to - * stop the current access or recovery to the state. Called for example when an operation is - * canceled during recovery.
  • - *
- * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public interface StateObject extends java.io.Closeable, java.io.Serializable { - - /** - * Discards the state referred to by this handle, to free up resources in - * the persistent storage. This method is called when the handle will not be - * used any more. - */ - void discardState() throws Exception; - - /** - * Returns the size of the state in bytes. - * - *

If the the size is not known, return {@code 0}. - * - * @return Size of the state in bytes. - * @throws Exception If the operation fails during size retrieval. - */ - long getStateSize() throws Exception; -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFileStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFileStateHandle.java deleted file mode 100644 index a522a95b2c552..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFileStateHandle.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.state.AbstractCloseableHandle; -import org.apache.flink.migration.runtime.state.StateObject; -import org.apache.flink.util.FileUtils; - -import java.io.IOException; - -import static org.apache.flink.util.Preconditions.checkNotNull; - -/** - * Base class for state that is stored in a file. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public abstract class AbstractFileStateHandle extends AbstractCloseableHandle implements StateObject { - - private static final long serialVersionUID = 350284443258002355L; - - /** The path to the file in the filesystem, fully describing the file system */ - private final Path filePath; - - /** Cached file system handle */ - private transient FileSystem fs; - - /** - * Creates a new file state for the given file path. - * - * @param filePath The path to the file that stores the state. - */ - protected AbstractFileStateHandle(Path filePath) { - this.filePath = checkNotNull(filePath); - } - - /** - * Gets the path where this handle's state is stored. - * @return The path where this handle's state is stored. - */ - public Path getFilePath() { - return filePath; - } - - /** - * Discard the state by deleting the file that stores the state. If the parent directory - * of the state is empty after deleting the state file, it is also deleted. - * - * @throws Exception Thrown, if the file deletion (not the directory deletion) fails. - */ - @Override - public void discardState() throws Exception { - getFileSystem().delete(filePath, false); - - try { - FileUtils.deletePathIfEmpty(getFileSystem(), filePath.getParent()); - } catch (Exception ignored) {} - } - - /** - * Gets the file system that stores the file state. - * @return The file system that stores the file state. - * @throws IOException Thrown if the file system cannot be accessed. - */ - protected FileSystem getFileSystem() throws IOException { - if (fs == null) { - fs = FileSystem.get(filePath.toUri()); - } - return fs; - } - - /** - * Returns the file size in bytes. - * - * @return The file size in bytes. - * @throws IOException Thrown if the file system cannot be accessed. - */ - protected long getFileSize() throws IOException { - return getFileSystem().getFileStatus(filePath).getLen(); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFsStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFsStateSnapshot.java deleted file mode 100644 index 7099c617df905..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/AbstractFsStateSnapshot.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.api.common.state.State; -import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.core.fs.Path; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.memory.AbstractMigrationRestoreStrategy; -import org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot; -import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; -import org.apache.flink.runtime.state.heap.StateTable; - -import java.io.IOException; - -/** - * A snapshot of a heap key/value state stored in a file. - * - * @param The type of the key in the snapshot state. - * @param The type of the namespace in the snapshot state. - * @param The type of the state value. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public abstract class AbstractFsStateSnapshot> - extends AbstractFileStateHandle implements KvStateSnapshot, MigrationRestoreSnapshot { - - private static final long serialVersionUID = 1L; - - /** Key Serializer */ - protected final TypeSerializer keySerializer; - - /** Namespace Serializer */ - protected final TypeSerializer namespaceSerializer; - - /** Serializer for the state value */ - protected final TypeSerializer stateSerializer; - - /** StateDescriptor, for sanity checks */ - protected final SD stateDesc; - - public AbstractFsStateSnapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - SD stateDesc, - Path filePath) { - super(filePath); - this.stateDesc = stateDesc; - this.keySerializer = keySerializer; - this.stateSerializer = stateSerializer; - this.namespaceSerializer = namespaceSerializer; - - } - - @Override - public long getStateSize() throws IOException { - return getFileSize(); - } - - public TypeSerializer getKeySerializer() { - return keySerializer; - } - - public TypeSerializer getNamespaceSerializer() { - return namespaceSerializer; - } - - public TypeSerializer getStateSerializer() { - return stateSerializer; - } - - public SD getStateDesc() { - return stateDesc; - } - - @Override - @SuppressWarnings("unchecked") - public StateTable deserialize( - String stateName, - HeapKeyedStateBackend stateBackend) throws IOException { - - final FileSystem fs = getFilePath().getFileSystem(); - try (FSDataInputStream inStream = fs.open(getFilePath())) { - final DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(inStream); - AbstractMigrationRestoreStrategy restoreStrategy = - new AbstractMigrationRestoreStrategy(keySerializer, namespaceSerializer, stateSerializer) { - @Override - protected DataInputView openDataInputView() throws IOException { - return inView; - } - }; - return restoreStrategy.deserialize(stateName, stateBackend); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileSerializableStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileSerializableStateHandle.java deleted file mode 100644 index b4a3a730097d1..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileSerializableStateHandle.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.util.MigrationInstantiationUtil; - -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.Serializable; - -/** - * A state handle that points to state stored in a file via Java Serialization. - * - * @param The type of state pointed to by the state handle. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FileSerializableStateHandle extends AbstractFileStateHandle implements StateHandle { - - private static final long serialVersionUID = -657631394290213622L; - - /** - * Creates a new FileSerializableStateHandle pointing to state at the given file path. - * - * @param filePath The path to the file containing the checkpointed state. - */ - public FileSerializableStateHandle(Path filePath) { - super(filePath); - } - - @Override - @SuppressWarnings("unchecked") - public T getState(ClassLoader classLoader) throws Exception { - ensureNotClosed(); - - try (FSDataInputStream inStream = getFileSystem().open(getFilePath())) { - // make sure any deserialization can be aborted - registerCloseable(inStream); - - ObjectInputStream ois = new MigrationInstantiationUtil.ClassLoaderObjectInputStream(inStream, classLoader); - return (T) ois.readObject(); - } - } - - /** - * Returns the file size in bytes. - * - * @return The file size in bytes. - * @throws IOException Thrown if the file system cannot be accessed. - */ - @Override - public long getStateSize() throws IOException { - return getFileSize(); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileStreamStateHandle.java deleted file mode 100644 index 7444be1c9e754..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FileStreamStateHandle.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.runtime.state.StreamStateHandle; - -import java.io.IOException; -import java.io.InputStream; -import java.io.Serializable; - -/** - * A state handle that points to state in a file system, accessible as an input stream. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FileStreamStateHandle extends AbstractFileStateHandle implements StreamStateHandle { - - private static final long serialVersionUID = -6826990484549987311L; - - /** - * Creates a new FileStreamStateHandle pointing to state at the given file path. - * - * @param filePath The path to the file containing the checkpointed state. - */ - public FileStreamStateHandle(Path filePath) { - super(filePath); - } - - @Override - public InputStream getState(ClassLoader userCodeClassLoader) throws Exception { - ensureNotClosed(); - - InputStream inStream = getFileSystem().open(getFilePath()); - // make sure the state handle is cancelable - registerCloseable(inStream); - - return inStream; - } - - /** - * Returns the file size in bytes. - * - * @return The file size in bytes. - * @throws IOException Thrown if the file system cannot be accessed. - */ - @Override - public long getStateSize() throws IOException { - return getFileSize(); - } - - @Override - public StateHandle toSerializableHandle() { - FileSerializableStateHandle handle = new FileSerializableStateHandle<>(getFilePath()); - - // forward closed status - if (isClosed()) { - try { - handle.close(); - } catch (IOException e) { - // should not happen on a fresh handle, but forward anyways - throw new RuntimeException(e); - } - } - - return handle; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsFoldingState.java deleted file mode 100644 index ec89ab8931d13..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsFoldingState.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.api.common.state.FoldingState; -import org.apache.flink.api.common.state.FoldingStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.fs.Path; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FsFoldingState { - public static class Snapshot extends AbstractFsStateSnapshot, FoldingStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - FoldingStateDescriptor stateDescs, - Path filePath) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, filePath); - } - } - -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsListState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsListState.java deleted file mode 100644 index 71404abee9014..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsListState.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.fs.Path; - -import java.util.ArrayList; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FsListState { - - public static class Snapshot extends AbstractFsStateSnapshot, ListState, ListStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer> stateSerializer, - ListStateDescriptor stateDescs, - Path filePath) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, filePath); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsReducingState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsReducingState.java deleted file mode 100644 index 153f88c53921c..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsReducingState.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.api.common.state.ReducingState; -import org.apache.flink.api.common.state.ReducingStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.fs.Path; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FsReducingState { - - public static class Snapshot extends AbstractFsStateSnapshot, ReducingStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - ReducingStateDescriptor stateDescs, - Path filePath) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, filePath); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsStateBackend.java deleted file mode 100644 index d17751028db9d..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsStateBackend.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.state.AbstractStateBackend; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FsStateBackend extends AbstractStateBackend { - - private static final long serialVersionUID = -8191916350224044011L; - - private static final Logger LOG = LoggerFactory.getLogger(FsStateBackend.class); - - /** By default, state smaller than 1024 bytes will not be written to files, but - * will be stored directly with the metadata */ - public static final int DEFAULT_FILE_STATE_THRESHOLD = 1024; - - /** Maximum size of state that is stored with the metadata, rather than in files */ - public static final int MAX_FILE_STATE_THRESHOLD = 1024 * 1024; - - /** Default size for the write buffer */ - private static final int DEFAULT_WRITE_BUFFER_SIZE = 4096; - - - /** The path to the directory for the checkpoint data, including the file system - * description via scheme and optional authority */ - private final Path basePath = null; - - /** State below this size will be stored as part of the metadata, rather than in files */ - private final int fileStateThreshold = 0; -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsValueState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsValueState.java deleted file mode 100644 index d2ae48d23c40c..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/filesystem/FsValueState.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.filesystem; - -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.fs.Path; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class FsValueState { - - public static class Snapshot extends AbstractFsStateSnapshot, ValueStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - ValueStateDescriptor stateDescs, - Path filePath) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, filePath); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMemStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMemStateSnapshot.java deleted file mode 100644 index aadfe4eb0c0b9..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMemStateSnapshot.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.State; -import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; -import org.apache.flink.runtime.state.heap.StateTable; -import org.apache.flink.runtime.util.DataInputDeserializer; - -import java.io.IOException; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public abstract class AbstractMemStateSnapshot> - implements KvStateSnapshot, MigrationRestoreSnapshot { - - private static final long serialVersionUID = 1L; - - /** Key Serializer */ - protected final TypeSerializer keySerializer; - - /** Namespace Serializer */ - protected final TypeSerializer namespaceSerializer; - - /** Serializer for the state value */ - protected final TypeSerializer stateSerializer; - - /** StateDescriptor, for sanity checks */ - protected final SD stateDesc; - - /** The serialized data of the state key/value pairs */ - private final byte[] data; - - private transient boolean closed; - - /** - * Creates a new heap memory state snapshot. - * - * @param keySerializer The serializer for the keys. - * @param namespaceSerializer The serializer for the namespace. - * @param stateSerializer The serializer for the elements in the state HashMap - * @param stateDesc The state identifier - * @param data The serialized data of the state key/value pairs - */ - public AbstractMemStateSnapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - SD stateDesc, - byte[] data) { - this.keySerializer = keySerializer; - this.namespaceSerializer = namespaceSerializer; - this.stateSerializer = stateSerializer; - this.stateDesc = stateDesc; - this.data = data; - } - - @Override - @SuppressWarnings("unchecked") - public StateTable deserialize( - String stateName, - HeapKeyedStateBackend stateBackend) throws IOException { - - final DataInputDeserializer inView = new DataInputDeserializer(data, 0, data.length); - AbstractMigrationRestoreStrategy restoreStrategy = - new AbstractMigrationRestoreStrategy(keySerializer, namespaceSerializer, stateSerializer) { - @Override - protected DataInputView openDataInputView() throws IOException { - return inView; - } - }; - return restoreStrategy.deserialize(stateName, stateBackend); - } - - /** - * Discarding the heap state is a no-op. - */ - @Override - public void discardState() {} - - @Override - public long getStateSize() { - return data.length; - } - - @Override - public void close() { - closed = true; - } - - public TypeSerializer getKeySerializer() { - return keySerializer; - } - - public TypeSerializer getNamespaceSerializer() { - return namespaceSerializer; - } - - public TypeSerializer getStateSerializer() { - return stateSerializer; - } - - public byte[] getData() { - return data; - } - - @Override - public String toString() { - return "AbstractMemStateSnapshot{" + - "keySerializer=" + keySerializer + - ", namespaceSerializer=" + namespaceSerializer + - ", stateSerializer=" + stateSerializer + - ", stateDesc=" + stateDesc + - '}'; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMigrationRestoreStrategy.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMigrationRestoreStrategy.java deleted file mode 100644 index f58070e564b92..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/AbstractMigrationRestoreStrategy.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; -import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; -import org.apache.flink.runtime.state.heap.StateTable; -import org.apache.flink.util.Preconditions; - -import java.io.IOException; - -/** - * This class outlines the general strategy to restore from migration states. - * - * @param type of key. - * @param type of namespace. - * @param type of state. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -public abstract class AbstractMigrationRestoreStrategy implements MigrationRestoreSnapshot { - - /** - * Key Serializer - */ - protected final TypeSerializer keySerializer; - - /** - * Namespace Serializer - */ - protected final TypeSerializer namespaceSerializer; - - /** - * Serializer for the state value - */ - protected final TypeSerializer stateSerializer; - - public AbstractMigrationRestoreStrategy( - TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer) { - - this.keySerializer = Preconditions.checkNotNull(keySerializer); - this.namespaceSerializer = Preconditions.checkNotNull(namespaceSerializer); - this.stateSerializer = Preconditions.checkNotNull(stateSerializer); - } - - @Override - public StateTable deserialize(String stateName, HeapKeyedStateBackend stateBackend) throws IOException { - - Preconditions.checkNotNull(stateName, "State name is null. Cannot deserialize snapshot."); - Preconditions.checkNotNull(stateBackend, "State backend is null. Cannot deserialize snapshot."); - - final KeyGroupRange keyGroupRange = stateBackend.getKeyGroupRange(); - Preconditions.checkState(1 == keyGroupRange.getNumberOfKeyGroups(), - "Unexpected number of key-groups for restoring from Flink 1.1"); - - TypeSerializer patchedNamespaceSerializer = this.namespaceSerializer; - - if (patchedNamespaceSerializer instanceof VoidSerializer) { - patchedNamespaceSerializer = (TypeSerializer) VoidNamespaceSerializer.INSTANCE; - } - - RegisteredKeyedBackendStateMetaInfo registeredKeyedBackendStateMetaInfo = - new RegisteredKeyedBackendStateMetaInfo<>( - StateDescriptor.Type.UNKNOWN, - stateName, - patchedNamespaceSerializer, - stateSerializer); - - final StateTable stateTable = stateBackend.newStateTable(registeredKeyedBackendStateMetaInfo); - final DataInputView inView = openDataInputView(); - final int keyGroup = keyGroupRange.getStartKeyGroup(); - final int numNamespaces = inView.readInt(); - - for (int i = 0; i < numNamespaces; i++) { - N namespace = namespaceSerializer.deserialize(inView); - if (null == namespace) { - namespace = (N) VoidNamespace.INSTANCE; - } - final int numKV = inView.readInt(); - for (int j = 0; j < numKV; j++) { - K key = keySerializer.deserialize(inView); - S value = stateSerializer.deserialize(inView); - stateTable.put(key, keyGroup, namespace, value); - } - } - return stateTable; - } - - /** - * Different state handles require different code to end up with a {@link DataInputView}. - */ - protected abstract DataInputView openDataInputView() throws IOException; -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/ByteStreamStateHandle.java deleted file mode 100644 index c7fbab63659b0..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/ByteStreamStateHandle.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.migration.runtime.state.AbstractCloseableHandle; -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.runtime.state.StreamStateHandle; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.Serializable; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public final class ByteStreamStateHandle extends AbstractCloseableHandle implements StreamStateHandle { - - private static final long serialVersionUID = -5280226231200217594L; - - /** the state data */ - private final byte[] data; - - /** - * Creates a new ByteStreamStateHandle containing the given data. - * - * @param data The state data. - */ - public ByteStreamStateHandle(byte[] data) { - this.data = data; - } - - @Override - public InputStream getState(ClassLoader userCodeClassLoader) throws Exception { - ensureNotClosed(); - - ByteArrayInputStream stream = new ByteArrayInputStream(data); - registerCloseable(stream); - - return stream; - } - - @Override - public void discardState() {} - - @Override - public long getStateSize() { - return data.length; - } - - @Override - public StateHandle toSerializableHandle() { - SerializedStateHandle serializableHandle = new SerializedStateHandle(data); - - // forward the closed status - if (isClosed()) { - try { - serializableHandle.close(); - } catch (IOException e) { - // should not happen on a fresh handle, but forward anyways - throw new RuntimeException(e); - } - } - - return serializableHandle; - } - - public byte[] getData() { - return data; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemFoldingState.java deleted file mode 100644 index ad820e4d9234b..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemFoldingState.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.FoldingState; -import org.apache.flink.api.common.state.FoldingStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class MemFoldingState { - - public static class Snapshot extends AbstractMemStateSnapshot, FoldingStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - FoldingStateDescriptor stateDescs, byte[] data) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemListState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemListState.java deleted file mode 100644 index d76cda09f9639..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemListState.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; - -import java.util.ArrayList; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class MemListState { - - public static class Snapshot extends AbstractMemStateSnapshot, ListState, ListStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer> stateSerializer, - ListStateDescriptor stateDescs, byte[] data) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); - } - } - -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemReducingState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemReducingState.java deleted file mode 100644 index c39111c5e5767..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemReducingState.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.ReducingState; -import org.apache.flink.api.common.state.ReducingStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; - -/** - * Heap-backed partitioned {@link ReducingState} that is - * snapshotted into a serialized memory copy. - * - * @param The type of the key. - * @param The type of the namespace. - * @param The type of the values in the list state. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class MemReducingState { - - public static class Snapshot extends AbstractMemStateSnapshot, ReducingStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - ReducingStateDescriptor stateDescs, byte[] data) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); - } - }} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemValueState.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemValueState.java deleted file mode 100644 index 940d4895569d7..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/MemValueState.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; - -/** - * Heap-backed key/value state that is snapshotted into a serialized memory copy. - * - * @param The type of the key. - * @param The type of the namespace. - * @param The type of the value. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@SuppressWarnings("deprecation") -public class MemValueState { - - public static class Snapshot extends AbstractMemStateSnapshot, ValueStateDescriptor> { - private static final long serialVersionUID = 1L; - - public Snapshot(TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, - TypeSerializer stateSerializer, - ValueStateDescriptor stateDescs, byte[] data) { - super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/SerializedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/SerializedStateHandle.java deleted file mode 100644 index 49d772e3019e1..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/memory/SerializedStateHandle.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.runtime.state.memory; - -import org.apache.flink.migration.runtime.state.AbstractCloseableHandle; -import org.apache.flink.migration.runtime.state.StateHandle; -import org.apache.flink.migration.util.MigrationInstantiationUtil; -import org.apache.flink.util.InstantiationUtil; - -import java.io.IOException; -import java.io.Serializable; - -/** - * A state handle that represents its state in serialized form as bytes. - * - * @param The type of state represented by this state handle. - */ -@SuppressWarnings("deprecation") -public class SerializedStateHandle extends AbstractCloseableHandle implements StateHandle { - - private static final long serialVersionUID = 4145685722538475769L; - - /** The serialized data */ - private final byte[] serializedData; - - /** - * Creates a new serialized state handle, eagerly serializing the given state object. - * - * @param value The state object. - * @throws IOException Thrown, if the serialization fails. - */ - public SerializedStateHandle(T value) throws IOException { - this.serializedData = value == null ? null : InstantiationUtil.serializeObject(value); - } - - /** - * Creates a new serialized state handle, based in the given already serialized data. - * - * @param serializedData The serialized data. - */ - public SerializedStateHandle(byte[] serializedData) { - this.serializedData = serializedData; - } - - @Override - public T getState(ClassLoader classLoader) throws Exception { - if (classLoader == null) { - throw new NullPointerException(); - } - - ensureNotClosed(); - return serializedData == null ? null : MigrationInstantiationUtil.deserializeObject(serializedData, classLoader); - } - - /** - * Gets the size of the serialized state. - * @return The size of the serialized state. - */ - public int getSizeOfSerializedState() { - return serializedData.length; - } - - /** - * Discarding heap-memory backed state is a no-op, so this method does nothing. - */ - @Override - public void discardState() {} - - @Override - public long getStateSize() { - return serializedData.length; - } - - public byte[] getSerializedData() { - return serializedData; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationKeyGroupStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationKeyGroupStateHandle.java deleted file mode 100644 index 3f1ff552d9c58..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationKeyGroupStateHandle.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.state; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.runtime.state.KeyGroupRangeOffsets; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.util.Migration; - -/** - * This class is just a KeyGroupsStateHandle that is tagged as migration, to figure out which restore logic to apply, - * e.g. when restoring backend data from a state handle. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Internal -@Deprecated -public class MigrationKeyGroupStateHandle extends KeyGroupsStateHandle implements Migration { - - private static final long serialVersionUID = -8554427169776881697L; - - /** - * @param groupRangeOffsets range of key-group ids that in the state of this handle - * @param streamStateHandle handle to the actual state of the key-groups - */ - public MigrationKeyGroupStateHandle(KeyGroupRangeOffsets groupRangeOffsets, StreamStateHandle streamStateHandle) { - super(groupRangeOffsets, streamStateHandle); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationStreamStateHandle.java deleted file mode 100644 index 220191605202c..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/state/MigrationStreamStateHandle.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.state; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataInputStreamWrapper; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.util.Migration; - -import java.io.IOException; - -/** - * This class is just a StreamStateHandle that is tagged as migration, to figure out which restore logic to apply, e.g. - * when restoring backend data from a state handle. - * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Internal -@Deprecated -public class MigrationStreamStateHandle implements StreamStateHandle, Migration { - - private static final long serialVersionUID = -2332113722532150112L; - private final StreamStateHandle delegate; - - public MigrationStreamStateHandle(StreamStateHandle delegate) { - this.delegate = delegate; - } - - @Override - public FSDataInputStream openInputStream() throws IOException { - return new MigrationFSInputStream(delegate.openInputStream()); - } - - @Override - public void discardState() throws Exception { - delegate.discardState(); - } - - @Override - public long getStateSize() { - return delegate.getStateSize(); - } - - static class MigrationFSInputStream extends FSDataInputStreamWrapper implements Migration { - - public MigrationFSInputStream(FSDataInputStream inputStream) { - super(inputStream); - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskState.java b/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskState.java deleted file mode 100644 index b044ffbb3d753..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskState.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.streaming.runtime.tasks; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.StateHandle; - -import java.io.Closeable; -import java.io.IOException; -import java.io.Serializable; -import java.util.HashMap; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@Internal -@SuppressWarnings("deprecation") -public class StreamTaskState implements Serializable, Closeable { - - private static final long serialVersionUID = 1L; - - private StateHandle operatorState; - - private StateHandle functionState; - - private HashMap> kvStates; - - // ------------------------------------------------------------------------ - - public StateHandle getOperatorState() { - return operatorState; - } - - public void setOperatorState(StateHandle operatorState) { - this.operatorState = operatorState; - } - - public StateHandle getFunctionState() { - return functionState; - } - - public void setFunctionState(StateHandle functionState) { - this.functionState = functionState; - } - - public HashMap> getKvStates() { - return kvStates; - } - - public void setKvStates(HashMap> kvStates) { - this.kvStates = kvStates; - } - - // ------------------------------------------------------------------------ - - /** - * Checks if this state object actually contains any state, or if all of the state - * fields are null. - * - * @return True, if all state is null, false if at least one state is not null. - */ - public boolean isEmpty() { - return operatorState == null & functionState == null & kvStates == null; - } - - @Override - public void close() throws IOException { - - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskStateList.java b/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskStateList.java deleted file mode 100644 index 7643039c1f852..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/migration/streaming/runtime/tasks/StreamTaskStateList.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.streaming.runtime.tasks; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.StateHandle; - -import java.io.IOException; -import java.util.HashMap; - -/** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. - */ -@Deprecated -@Internal -@SuppressWarnings("deprecation") -public class StreamTaskStateList implements StateHandle { - - private static final long serialVersionUID = 1L; - - /** The states for all operator. */ - private final StreamTaskState[] states; - - public StreamTaskStateList(StreamTaskState[] states) throws Exception { - this.states = states; - } - - public boolean isEmpty() { - for (StreamTaskState state : states) { - if (state != null) { - return false; - } - } - return true; - } - - @Override - public StreamTaskState[] getState(ClassLoader userCodeClassLoader) { - return states; - } - - @Override - public void discardState() throws Exception { - } - - @Override - public long getStateSize() throws Exception { - long sumStateSize = 0; - - if (states != null) { - for (StreamTaskState state : states) { - if (state != null) { - StateHandle operatorState = state.getOperatorState(); - StateHandle functionState = state.getFunctionState(); - HashMap> kvStates = state.getKvStates(); - - if (operatorState != null) { - sumStateSize += operatorState.getStateSize(); - } - - if (functionState != null) { - sumStateSize += functionState.getStateSize(); - } - - if (kvStates != null) { - for (KvStateSnapshot kvState : kvStates.values()) { - if (kvState != null) { - sumStateSize += kvState.getStateSize(); - } - } - } - } - } - } - - // State size as sum of all state sizes - return sumStateSize; - } - - @Override - public void close() throws IOException { - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/JobException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/JobException.java index d5a5bb968acb4..d923af96f77e4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/JobException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/JobException.java @@ -18,10 +18,12 @@ package org.apache.flink.runtime; +import org.apache.flink.util.FlinkException; + /** * Indicates that a job has failed. */ -public class JobException extends Exception { +public class JobException extends FlinkException { private static final long serialVersionUID = 1275864691743020176L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/StoppingException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/StoppingException.java index 6bb71ce6cfac5..3644219abfcef 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/StoppingException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/StoppingException.java @@ -18,10 +18,12 @@ package org.apache.flink.runtime; +import org.apache.flink.util.FlinkException; + /** * Indicates that a job is not stoppable. */ -public class StoppingException extends Exception { +public class StoppingException extends FlinkException { private static final long serialVersionUID = -721315728140810694L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobCache.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobCache.java index 32bd8fd83ed32..c50a8887f8558 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobCache.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobCache.java @@ -18,19 +18,26 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; -import java.net.URL; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Timer; +import java.util.TimerTask; import java.util.concurrent.atomic.AtomicBoolean; import static org.apache.flink.util.Preconditions.checkArgument; @@ -39,13 +46,13 @@ /** * The BLOB cache implements a local cache for content-addressable BLOBs. * - *

When requesting BLOBs through the {@link BlobCache#getURL} methods, the + *

When requesting BLOBs through the {@link BlobCache#getFile} methods, the * BLOB cache will first attempt to serve the file from its local cache. Only if * the local cache does not contain the desired BLOB, the BLOB cache will try to * download it from a distributed file system (if available) or the BLOB * server.

*/ -public final class BlobCache implements BlobService { +public class BlobCache extends TimerTask implements BlobService { /** The log object used for debugging. */ private static final Logger LOG = LoggerFactory.getLogger(BlobCache.class); @@ -69,6 +76,32 @@ public final class BlobCache implements BlobService { /** Configuration for the blob client like ssl parameters required to connect to the blob server */ private final Configuration blobClientConfig; + // -------------------------------------------------------------------------------------------- + + /** + * Job reference counters with a time-to-live (TTL). + */ + private static class RefCount { + /** + * Number of references to a job. + */ + public int references = 0; + + /** + * Timestamp in milliseconds when any job data should be cleaned up (no cleanup for + * non-positive values). + */ + public long keepUntil = -1; + } + + /** Map to store the number of references to a specific job */ + private final Map jobRefCounters = new HashMap<>(); + + /** Time interval (ms) to run the cleanup task; also used as the default TTL. */ + private final long cleanupInterval; + + private final Timer cleanupTimer; + /** * Instantiates a new BLOB cache. * @@ -92,7 +125,7 @@ public BlobCache( // configure and create the storage directory String storageDirectory = blobClientConfig.getString(BlobServerOptions.STORAGE_DIRECTORY); - this.storageDir = BlobUtils.initStorageDirectory(storageDirectory); + this.storageDir = BlobUtils.initLocalStorageDirectory(storageDirectory); LOG.info("Created BLOB cache storage directory " + storageDir); // configure the number of fetch retries @@ -106,49 +139,149 @@ public BlobCache( this.numFetchRetries = 0; } + // Initializing the clean up task + this.cleanupTimer = new Timer(true); + + cleanupInterval = blobClientConfig.getLong(BlobServerOptions.CLEANUP_INTERVAL) * 1000; + this.cleanupTimer.schedule(this, cleanupInterval, cleanupInterval); + // Add shutdown hook to delete storage directory shutdownHook = BlobUtils.addShutdownHook(this, LOG); } /** - * Returns the URL for the BLOB with the given key. The method will first attempt to serve - * the BLOB from its local cache. If the BLOB is not in the cache, the method will try to download it - * from this cache's BLOB server. + * Registers use of job-related BLOBs. + *

+ * Using any other method to access BLOBs, e.g. {@link #getFile}, is only valid within calls + * to {@link #registerJob(JobID)} and {@link #releaseJob(JobID)}. + * + * @param jobId + * ID of the job this blob belongs to * - * @param requiredBlob The key of the desired BLOB. - * @return URL referring to the local storage location of the BLOB. - * @throws IOException Thrown if an I/O error occurs while downloading the BLOBs from the BLOB server. + * @see #releaseJob(JobID) */ - public URL getURL(final BlobKey requiredBlob) throws IOException { + public void registerJob(JobID jobId) { + synchronized (jobRefCounters) { + RefCount ref = jobRefCounters.get(jobId); + if (ref == null) { + ref = new RefCount(); + jobRefCounters.put(jobId, ref); + } + ++ref.references; + } + } + + /** + * Unregisters use of job-related BLOBs and allow them to be released. + * + * @param jobId + * ID of the job this blob belongs to + * + * @see #registerJob(JobID) + */ + public void releaseJob(JobID jobId) { + synchronized (jobRefCounters) { + RefCount ref = jobRefCounters.get(jobId); + + if (ref == null) { + LOG.warn("improper use of releaseJob() without a matching number of registerJob() calls"); + return; + } + + --ref.references; + if (ref.references == 0) { + ref.keepUntil = System.currentTimeMillis() + cleanupInterval; + } + } + } + + /** + * Returns local copy of the (job-unrelated) file for the BLOB with the given key. + *

+ * The method will first attempt to serve the BLOB from its local cache. If the BLOB is not in + * the cache, the method will try to download it from this cache's BLOB server. + * + * @param key + * The key of the desired BLOB. + * + * @return file referring to the local storage location of the BLOB. + * + * @throws IOException + * Thrown if an I/O error occurs while downloading the BLOBs from the BLOB server. + */ + @Override + public File getFile(BlobKey key) throws IOException { + return getFileInternal(null, key); + } + + /** + * Returns local copy of the file for the BLOB with the given key. + *

+ * The method will first attempt to serve the BLOB from its local cache. If the BLOB is not in + * the cache, the method will try to download it from this cache's BLOB server. + * + * @param jobId + * ID of the job this blob belongs to + * @param key + * The key of the desired BLOB. + * + * @return file referring to the local storage location of the BLOB. + * + * @throws IOException + * Thrown if an I/O error occurs while downloading the BLOBs from the BLOB server. + */ + @Override + public File getFile(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + return getFileInternal(jobId, key); + } + + /** + * Returns local copy of the file for the BLOB with the given key. + *

+ * The method will first attempt to serve the BLOB from its local cache. If the BLOB is not in + * the cache, the method will try to download it from this cache's BLOB server. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param requiredBlob + * The key of the desired BLOB. + * + * @return file referring to the local storage location of the BLOB. + * + * @throws IOException + * Thrown if an I/O error occurs while downloading the BLOBs from the BLOB server. + */ + private File getFileInternal(@Nullable JobID jobId, BlobKey requiredBlob) throws IOException { checkArgument(requiredBlob != null, "BLOB key cannot be null."); - final File localJarFile = BlobUtils.getStorageLocation(storageDir, requiredBlob); + final File localJarFile = BlobUtils.getStorageLocation(storageDir, jobId, requiredBlob); if (localJarFile.exists()) { - return localJarFile.toURI().toURL(); + return localJarFile; } // first try the distributed blob store (if available) try { - blobView.get(requiredBlob, localJarFile); + blobView.get(jobId, requiredBlob, localJarFile); } catch (Exception e) { LOG.info("Failed to copy from blob store. Downloading from BLOB server instead.", e); } if (localJarFile.exists()) { - return localJarFile.toURI().toURL(); + return localJarFile; } // fallback: download from the BlobServer final byte[] buf = new byte[BlobServerProtocol.BUFFER_SIZE]; - LOG.info("Downloading {} from {}", requiredBlob, serverAddress); + LOG.info("Downloading {}/{} from {}", jobId, requiredBlob, serverAddress); // loop over retries int attempt = 0; while (true) { try ( final BlobClient bc = new BlobClient(serverAddress, blobClientConfig); - final InputStream is = bc.get(requiredBlob); + final InputStream is = bc.getInternal(jobId, requiredBlob); final OutputStream os = new FileOutputStream(localJarFile) ) { while (true) { @@ -160,10 +293,10 @@ public URL getURL(final BlobKey requiredBlob) throws IOException { } // success, we finished - return localJarFile.toURI().toURL(); + return localJarFile; } catch (Throwable t) { - String message = "Failed to fetch BLOB " + requiredBlob + " from " + serverAddress + + String message = "Failed to fetch BLOB " + jobId + "/" + requiredBlob + " from " + serverAddress + " and store it under " + localJarFile.getAbsolutePath(); if (attempt < numFetchRetries) { if (LOG.isDebugEnabled()) { @@ -179,40 +312,110 @@ public URL getURL(final BlobKey requiredBlob) throws IOException { // retry ++attempt; - LOG.info("Downloading {} from {} (retry {})", requiredBlob, serverAddress, attempt); + LOG.info("Downloading {}/{} from {} (retry {})", jobId, requiredBlob, serverAddress, attempt); } } // end loop over retries } /** - * Deletes the file associated with the given key from the BLOB cache. - * @param key referring to the file to be deleted + * Deletes the (job-unrelated) file associated with the blob key in this BLOB cache. + * + * @param key + * blob key associated with the file to be deleted + * + * @throws IOException */ - public void delete(BlobKey key) throws IOException{ - final File localFile = BlobUtils.getStorageLocation(storageDir, key); + @Override + public void delete(BlobKey key) throws IOException { + deleteInternal(null, key); + } + /** + * Deletes the file associated with the blob key in this BLOB cache. + * + * @param jobId + * ID of the job this blob belongs to + * @param key + * blob key associated with the file to be deleted + * + * @throws IOException + */ + @Override + public void delete(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + deleteInternal(jobId, key); + } + + /** + * Deletes the file associated with the blob key in this BLOB cache. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param key + * blob key associated with the file to be deleted + * + * @throws IOException + */ + private void deleteInternal(@Nullable JobID jobId, BlobKey key) throws IOException{ + final File localFile = BlobUtils.getStorageLocation(storageDir, jobId, key); if (!localFile.delete() && localFile.exists()) { LOG.warn("Failed to delete locally cached BLOB {} at {}", key, localFile.getAbsolutePath()); } } /** - * Deletes the file associated with the given key from the BLOB cache and + * Deletes the (job-unrelated) file associated with the given key from the BLOB cache and * BLOB server. * - * @param key referring to the file to be deleted + * @param key + * referring to the file to be deleted + * * @throws IOException - * thrown if an I/O error occurs while transferring the request to - * the BLOB server or if the BLOB server cannot delete the file + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file */ public void deleteGlobal(BlobKey key) throws IOException { + deleteGlobalInternal(null, key); + } + + /** + * Deletes the file associated with the given key from the BLOB cache and BLOB server. + * + * @param jobId + * ID of the job this blob belongs to + * @param key + * referring to the file to be deleted + * + * @throws IOException + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file + */ + public void deleteGlobal(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + deleteGlobalInternal(jobId, key); + } + + /** + * Deletes the file associated with the given key from the BLOB cache and + * BLOB server. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param key + * referring to the file to be deleted + * + * @throws IOException + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file + */ + private void deleteGlobalInternal(@Nullable JobID jobId, BlobKey key) throws IOException { // delete locally - delete(key); + deleteInternal(jobId, key); // then delete on the BLOB server // (don't use the distributed storage directly - this way the blob // server is aware of the delete operation, too) try (BlobClient bc = createClient()) { - bc.delete(key); + bc.deleteInternal(jobId, key); } } @@ -221,8 +424,40 @@ public int getPort() { return serverAddress.getPort(); } + /** + * Cleans up BLOBs which are not referenced anymore. + */ + @Override + public void run() { + synchronized (jobRefCounters) { + Iterator> entryIter = jobRefCounters.entrySet().iterator(); + final long currentTimeMillis = System.currentTimeMillis(); + + while (entryIter.hasNext()) { + Map.Entry entry = entryIter.next(); + RefCount ref = entry.getValue(); + + if (ref.references <= 0 && ref.keepUntil > 0 && currentTimeMillis >= ref.keepUntil) { + JobID jobId = entry.getKey(); + + final File localFile = + new File(BlobUtils.getStorageLocationPath(storageDir.getAbsolutePath(), jobId)); + try { + FileUtils.deleteDirectory(localFile); + // let's only remove this directory from cleanup if the cleanup was successful + entryIter.remove(); + } catch (Throwable t) { + LOG.warn("Failed to locally delete job directory " + localFile.getAbsolutePath(), t); + } + } + } + } + } + @Override public void close() throws IOException { + cleanupTimer.cancel(); + if (shutdownRequested.compareAndSet(false, true)) { LOG.info("Shutting down BlobCache"); @@ -249,8 +484,19 @@ public BlobClient createClient() throws IOException { return new BlobClient(serverAddress, blobClientConfig); } - public File getStorageDir() { - return this.storageDir; + /** + * Returns a file handle to the file associated with the given blob key on the blob + * server. + * + *

This is only called from the {@link BlobServerConnection} + * + * @param jobId ID of the job this blob belongs to (or null if job-unrelated) + * @param key identifying the file + * @return file handle to the file + */ + @VisibleForTesting + public File getStorageLocation(JobID jobId, BlobKey key) { + return BlobUtils.getStorageLocation(storageDir, jobId, key); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java index 0882ec3905820..8f1487ae28c5b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; @@ -29,6 +30,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocket; @@ -46,7 +48,8 @@ import java.util.List; import static org.apache.flink.runtime.blob.BlobServerProtocol.BUFFER_SIZE; -import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_ADDRESSABLE; +import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_FOR_JOB; +import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_NO_JOB; import static org.apache.flink.runtime.blob.BlobServerProtocol.DELETE_OPERATION; import static org.apache.flink.runtime.blob.BlobServerProtocol.GET_OPERATION; import static org.apache.flink.runtime.blob.BlobServerProtocol.PUT_OPERATION; @@ -55,7 +58,7 @@ import static org.apache.flink.runtime.blob.BlobUtils.readFully; import static org.apache.flink.runtime.blob.BlobUtils.readLength; import static org.apache.flink.runtime.blob.BlobUtils.writeLength; -import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; /** * The BLOB client can communicate with the BLOB server and either upload (PUT), download (GET), @@ -75,6 +78,7 @@ public final class BlobClient implements Closeable { * the network address of the BLOB server * @param clientConfig * additional configuration like SSL parameters required to connect to the blob server + * * @throws IOException * thrown if the connection to the BLOB server could not be established */ @@ -130,22 +134,65 @@ public boolean isClosed() { // -------------------------------------------------------------------------------------------- /** - * Downloads the BLOB identified by the given BLOB key from the BLOB server. If no such BLOB exists on the server, a - * {@link FileNotFoundException} is thrown. - * + * Downloads the (job-unrelated) BLOB identified by the given BLOB key from the BLOB server. + * * @param blobKey - * the BLOB key identifying the BLOB to download + * blob key associated with the requested file + * * @return an input stream to read the retrieved data from + * + * @throws FileNotFoundException + * if there is no such file; * @throws IOException - * thrown if an I/O error occurs during the download + * if an I/O error occurs during the download */ public InputStream get(BlobKey blobKey) throws IOException { + return getInternal(null, blobKey); + } + + /** + * Downloads the BLOB identified by the given BLOB key from the BLOB server. + * + * @param jobId + * ID of the job this blob belongs to + * @param blobKey + * blob key associated with the requested file + * + * @return an input stream to read the retrieved data from + * + * @throws FileNotFoundException + * if there is no such file; + * @throws IOException + * if an I/O error occurs during the download + */ + public InputStream get(JobID jobId, BlobKey blobKey) throws IOException { + checkNotNull(jobId); + return getInternal(jobId, blobKey); + } + + /** + * Downloads the BLOB identified by the given BLOB key from the BLOB server. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param blobKey + * blob key associated with the requested file + * + * @return an input stream to read the retrieved data from + * + * @throws FileNotFoundException + * if there is no such file; + * @throws IOException + * if an I/O error occurs during the download + */ + InputStream getInternal(@Nullable JobID jobId, BlobKey blobKey) throws IOException { if (this.socket.isClosed()) { throw new IllegalStateException("BLOB Client is not connected. " + "Client has been shut down or encountered an error before."); } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("GET content addressable BLOB %s from %s", blobKey, socket.getLocalSocketAddress())); + LOG.debug("GET BLOB {}/{} from {}.", jobId, blobKey, + socket.getLocalSocketAddress()); } try { @@ -153,8 +200,8 @@ public InputStream get(BlobKey blobKey) throws IOException { InputStream is = this.socket.getInputStream(); // Send GET header - sendGetHeader(os, null, blobKey); - receiveAndCheckResponse(is); + sendGetHeader(os, jobId, blobKey); + receiveAndCheckGetResponse(is); return new BlobInputStream(is, blobKey); } @@ -169,29 +216,40 @@ public InputStream get(BlobKey blobKey) throws IOException { * * @param outputStream * the output stream to write the header data to - * @param jobID - * the job ID identifying the BLOB to download or null to indicate the BLOB key should be used - * to identify the BLOB on the server instead + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) * @param blobKey - * the BLOB key to identify the BLOB to download if either the job ID or the regular key are - * null + * blob key associated with the requested file + * * @throws IOException * thrown if an I/O error occurs while writing the header data to the output stream */ - private void sendGetHeader(OutputStream outputStream, JobID jobID, BlobKey blobKey) throws IOException { - checkArgument(jobID == null); + private static void sendGetHeader(OutputStream outputStream, @Nullable JobID jobId, BlobKey blobKey) throws IOException { + checkNotNull(blobKey); // Signal type of operation outputStream.write(GET_OPERATION); - // Check if GET should be done in content-addressable manner - if (jobID == null) { - outputStream.write(CONTENT_ADDRESSABLE); - blobKey.writeToOutputStream(outputStream); + // Send job ID and key + if (jobId == null) { + outputStream.write(CONTENT_NO_JOB); + } else { + outputStream.write(CONTENT_FOR_JOB); + outputStream.write(jobId.getBytes()); } + blobKey.writeToOutputStream(outputStream); } - private void receiveAndCheckResponse(InputStream is) throws IOException { + /** + * Reads the response from the input stream and throws in case of errors + * + * @param is + * stream to read from + * + * @throws IOException + * if the response is an error or reading the response failed + */ + private static void receiveAndCheckGetResponse(InputStream is) throws IOException { int response = is.read(); if (response < 0) { throw new EOFException("Premature end of response"); @@ -211,82 +269,111 @@ else if (response != RETURN_OKAY) { // -------------------------------------------------------------------------------------------- /** - * Uploads the data of the given byte array to the BLOB server in a content-addressable manner. + * Uploads the data of the given byte array for the given job to the BLOB server. * + * @param jobId + * the ID of the job the BLOB belongs to (or null if job-unrelated) * @param value - * the buffer to upload + * the buffer to upload + * * @return the computed BLOB key identifying the BLOB on the server + * * @throws IOException - * thrown if an I/O error occurs while uploading the data to the BLOB server + * thrown if an I/O error occurs while uploading the data to the BLOB server */ - public BlobKey put(byte[] value) throws IOException { - return put(value, 0, value.length); + @VisibleForTesting + public BlobKey put(@Nullable JobID jobId, byte[] value) throws IOException { + return put(jobId, value, 0, value.length); } /** - * Uploads data from the given byte array to the BLOB server in a content-addressable manner. + * Uploads data from the given byte array for the given job to the BLOB server. * + * @param jobId + * the ID of the job the BLOB belongs to (or null if job-unrelated) * @param value - * the buffer to upload data from + * the buffer to upload data from * @param offset - * the read offset within the buffer + * the read offset within the buffer * @param len - * the number of bytes to upload from the buffer + * the number of bytes to upload from the buffer + * * @return the computed BLOB key identifying the BLOB on the server + * * @throws IOException - * thrown if an I/O error occurs while uploading the data to the BLOB server + * thrown if an I/O error occurs while uploading the data to the BLOB server */ - public BlobKey put(byte[] value, int offset, int len) throws IOException { - return putBuffer(null, value, offset, len); + @VisibleForTesting + public BlobKey put(@Nullable JobID jobId, byte[] value, int offset, int len) throws IOException { + return putBuffer(jobId, value, offset, len); } /** - * Uploads the data from the given input stream to the BLOB server in a content-addressable manner. + * Uploads the (job-unrelated) data from the given input stream to the BLOB server. * * @param inputStream - * the input stream to read the data from + * the input stream to read the data from + * * @return the computed BLOB key identifying the BLOB on the server + * * @throws IOException - * thrown if an I/O error occurs while reading the data from the input stream or uploading the data to the - * BLOB server + * thrown if an I/O error occurs while reading the data from the input stream or uploading the + * data to the BLOB server */ public BlobKey put(InputStream inputStream) throws IOException { return putInputStream(null, inputStream); } + /** + * Uploads the data from the given input stream for the given job to the BLOB server. + * + * @param jobId + * ID of the job this blob belongs to + * @param inputStream + * the input stream to read the data from + * + * @return the computed BLOB key identifying the BLOB on the server + * + * @throws IOException + * thrown if an I/O error occurs while reading the data from the input stream or uploading the + * data to the BLOB server + */ + public BlobKey put(JobID jobId, InputStream inputStream) throws IOException { + checkNotNull(jobId); + return putInputStream(jobId, inputStream); + } + /** * Uploads data from the given byte buffer to the BLOB server. * * @param jobId - * the ID of the job the BLOB belongs to or null to store the BLOB in a content-addressable - * manner + * the ID of the job the BLOB belongs to (or null if job-unrelated) * @param value - * the buffer to read the data from + * the buffer to read the data from * @param offset - * the read offset within the buffer + * the read offset within the buffer * @param len - * the number of bytes to read from the buffer - * @return the computed BLOB key if the BLOB has been stored in a content-addressable manner, null - * otherwise + * the number of bytes to read from the buffer + * + * @return the computed BLOB key of the uploaded BLOB + * * @throws IOException - * thrown if an I/O error occurs while uploading the data to the BLOB server + * thrown if an I/O error occurs while uploading the data to the BLOB server */ - private BlobKey putBuffer(JobID jobId, byte[] value, int offset, int len) throws IOException { + private BlobKey putBuffer(@Nullable JobID jobId, byte[] value, int offset, int len) throws IOException { if (this.socket.isClosed()) { throw new IllegalStateException("BLOB Client is not connected. " + "Client has been shut down or encountered an error before."); } + checkNotNull(value); if (LOG.isDebugEnabled()) { - if (jobId == null) { - LOG.debug(String.format("PUT content addressable BLOB buffer (%d bytes) to %s", - len, socket.getLocalSocketAddress())); - } + LOG.debug("PUT BLOB buffer (" + len + " bytes) to " + socket.getLocalSocketAddress() + "."); } try { final OutputStream os = this.socket.getOutputStream(); - final MessageDigest md = jobId == null ? BlobUtils.createMessageDigest() : null; + final MessageDigest md = BlobUtils.createMessageDigest(); // Send the PUT header sendPutHeader(os, jobId); @@ -295,15 +382,15 @@ private BlobKey putBuffer(JobID jobId, byte[] value, int offset, int len) throws int remainingBytes = len; while (remainingBytes > 0) { + // want a common code path for byte[] and InputStream at the BlobServer + // -> since for InputStream we don't know a total size beforehand, send lengths iteratively final int bytesToSend = Math.min(BUFFER_SIZE, remainingBytes); writeLength(bytesToSend, os); os.write(value, offset, bytesToSend); - // Update the message digest if necessary - if (md != null) { - md.update(value, offset, bytesToSend); - } + // Update the message digest + md.update(value, offset, bytesToSend); remainingBytes -= bytesToSend; offset += bytesToSend; @@ -313,7 +400,7 @@ private BlobKey putBuffer(JobID jobId, byte[] value, int offset, int len) throws // Receive blob key and compare final InputStream is = this.socket.getInputStream(); - return receivePutResponseAndCompare(is, md); + return receiveAndCheckPutResponse(is, md); } catch (Throwable t) { BlobUtils.closeSilently(socket, LOG); @@ -325,37 +412,36 @@ private BlobKey putBuffer(JobID jobId, byte[] value, int offset, int len) throws * Uploads data from the given input stream to the BLOB server. * * @param jobId - * the ID of the job the BLOB belongs to or null to store the BLOB in a content-addressable - * manner + * the ID of the job the BLOB belongs to (or null if job-unrelated) * @param inputStream - * the input stream to read the data from - * @return he computed BLOB key if the BLOB has been stored in a content-addressable manner, null - * otherwise + * the input stream to read the data from + * + * @return the computed BLOB key of the uploaded BLOB + * * @throws IOException - * thrown if an I/O error occurs while uploading the data to the BLOB server + * thrown if an I/O error occurs while uploading the data to the BLOB server */ - private BlobKey putInputStream(JobID jobId, InputStream inputStream) throws IOException { + private BlobKey putInputStream(@Nullable JobID jobId, InputStream inputStream) throws IOException { if (this.socket.isClosed()) { throw new IllegalStateException("BLOB Client is not connected. " + "Client has been shut down or encountered an error before."); } + checkNotNull(inputStream); if (LOG.isDebugEnabled()) { - if (jobId == null) { - LOG.debug(String.format("PUT content addressable BLOB stream to %s", - socket.getLocalSocketAddress())); - } + LOG.debug("PUT BLOB stream to {}.", socket.getLocalSocketAddress()); } try { final OutputStream os = this.socket.getOutputStream(); - final MessageDigest md = jobId == null ? BlobUtils.createMessageDigest() : null; + final MessageDigest md = BlobUtils.createMessageDigest(); final byte[] xferBuf = new byte[BUFFER_SIZE]; // Send the PUT header sendPutHeader(os, jobId); while (true) { + // since we don't know a total size here, send lengths iteratively final int read = inputStream.read(xferBuf); if (read < 0) { // we are done. send a -1 and be done @@ -365,15 +451,13 @@ private BlobKey putInputStream(JobID jobId, InputStream inputStream) throws IOEx if (read > 0) { writeLength(read, os); os.write(xferBuf, 0, read); - if (md != null) { - md.update(xferBuf, 0, read); - } + md.update(xferBuf, 0, read); } } // Receive blob key and compare final InputStream is = this.socket.getInputStream(); - return receivePutResponseAndCompare(is, md); + return receiveAndCheckPutResponse(is, md); } catch (Throwable t) { BlobUtils.closeSilently(socket, LOG); @@ -381,16 +465,25 @@ private BlobKey putInputStream(JobID jobId, InputStream inputStream) throws IOEx } } - private BlobKey receivePutResponseAndCompare(InputStream is, MessageDigest md) throws IOException { + /** + * Reads the response from the input stream and throws in case of errors + * + * @param is + * stream to read from + * @param md + * message digest to check the response against + * + * @throws IOException + * if the response is an error, the message digest does not match or reading the response + * failed + */ + private static BlobKey receiveAndCheckPutResponse(InputStream is, MessageDigest md) + throws IOException { int response = is.read(); if (response < 0) { throw new EOFException("Premature end of response"); } else if (response == RETURN_OKAY) { - if (md == null) { - // not content addressable - return null; - } BlobKey remoteKey = BlobKey.readFromInputStream(is); BlobKey localKey = new BlobKey(md.digest()); @@ -412,24 +505,24 @@ else if (response == RETURN_ERROR) { /** * Constructs and writes the header data for a PUT request to the given output stream. - * NOTE: If the jobId and key are null, we send the data to the content addressable section. * * @param outputStream - * the output stream to write the PUT header data to - * @param jobID - * the ID of job the BLOB belongs to or null to indicate the upload of a - * content-addressable BLOB + * the output stream to write the PUT header data to + * @param jobId + * the ID of job the BLOB belongs to (or null if job-unrelated) + * * @throws IOException - * thrown if an I/O error occurs while writing the header data to the output stream + * thrown if an I/O error occurs while writing the header data to the output stream */ - private void sendPutHeader(OutputStream outputStream, JobID jobID) throws IOException { - checkArgument(jobID == null); - + private static void sendPutHeader(OutputStream outputStream, @Nullable JobID jobId) throws IOException { // Signal type of operation outputStream.write(PUT_OPERATION); - - // Check if PUT should be done in content-addressable manner - outputStream.write(CONTENT_ADDRESSABLE); + if (jobId == null) { + outputStream.write(CONTENT_NO_JOB); + } else { + outputStream.write(CONTENT_FOR_JOB); + outputStream.write(jobId.getBytes()); + } } // -------------------------------------------------------------------------------------------- @@ -437,16 +530,50 @@ private void sendPutHeader(OutputStream outputStream, JobID jobID) throws IOExce // -------------------------------------------------------------------------------------------- /** - * Deletes the BLOB identified by the given BLOB key from the BLOB server. + * Deletes the (job-unrelated) BLOB identified by the given BLOB key from the BLOB server. + * + * @param key + * the key to identify the BLOB * - * @param blobKey - * the key to identify the BLOB * @throws IOException - * thrown if an I/O error occurs while transferring the request to - * the BLOB server or if the BLOB server cannot delete the file + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file */ - public void delete(BlobKey blobKey) throws IOException { - checkArgument(blobKey != null, "BLOB key must not be null."); + public void delete(BlobKey key) throws IOException { + deleteInternal(null, key); + } + + /** + * Deletes the BLOB identified by the given BLOB key and job ID from the BLOB server. + * + * @param jobId + * the ID of job the BLOB belongs to + * @param key + * the key to identify the BLOB + * + * @throws IOException + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file + */ + public void delete(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + deleteInternal(jobId, key); + } + + /** + * Deletes the BLOB identified by the given BLOB key and job ID from the BLOB server. + * + * @param jobId + * the ID of job the BLOB belongs to (or null if job-unrelated) + * @param key + * the key to identify the BLOB + * + * @throws IOException + * thrown if an I/O error occurs while transferring the request to the BLOB server or if the + * BLOB server cannot delete the file + */ + public void deleteInternal(@Nullable JobID jobId, BlobKey key) throws IOException { + checkNotNull(key); try { final OutputStream outputStream = this.socket.getOutputStream(); @@ -456,20 +583,16 @@ public void delete(BlobKey blobKey) throws IOException { outputStream.write(DELETE_OPERATION); // delete blob key - outputStream.write(CONTENT_ADDRESSABLE); - blobKey.writeToOutputStream(outputStream); - - int response = inputStream.read(); - if (response < 0) { - throw new EOFException("Premature end of response"); - } - if (response == RETURN_ERROR) { - Throwable cause = readExceptionFromStream(inputStream); - throw new IOException("Server side error: " + cause.getMessage(), cause); - } - else if (response != RETURN_OKAY) { - throw new IOException("Unrecognized response"); + if (jobId == null) { + outputStream.write(CONTENT_NO_JOB); + } else { + outputStream.write(CONTENT_FOR_JOB); + outputStream.write(jobId.getBytes()); } + key.writeToOutputStream(outputStream); + + // the response is the same as for a GET request + receiveAndCheckGetResponse(inputStream); } catch (Throwable t) { BlobUtils.closeSilently(socket, LOG); @@ -480,15 +603,20 @@ else if (response != RETURN_OKAY) { /** * Uploads the JAR files to a {@link BlobServer} at the given address. * - * @param serverAddress Server address of the {@link BlobServer} - * @param clientConfig Any additional configuration for the blob client - * @param jars List of JAR files to upload - * @throws IOException Thrown if the upload fails + * @param serverAddress + * Server address of the {@link BlobServer} + * @param clientConfig + * Any additional configuration for the blob client + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param jars + * List of JAR files to upload + * + * @throws IOException + * if the upload fails */ - public static List uploadJarFiles( - InetSocketAddress serverAddress, - Configuration clientConfig, - List jars) throws IOException { + public static List uploadJarFiles(InetSocketAddress serverAddress, + Configuration clientConfig, JobID jobId, List jars) throws IOException {checkNotNull(jobId); if (jars.isEmpty()) { return Collections.emptyList(); } else { @@ -500,7 +628,7 @@ public static List uploadJarFiles( FSDataInputStream is = null; try { is = fs.open(jar); - final BlobKey key = blobClient.put(is); + final BlobKey key = blobClient.putInputStream(jobId, is); blobKeys.add(key); } finally { if (is != null) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java index ecb452701dcda..bfcf881dbea75 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; @@ -28,20 +30,20 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import javax.net.ssl.SSLContext; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; -import java.net.URL; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -59,7 +61,7 @@ public class BlobServer extends Thread implements BlobService { private static final Logger LOG = LoggerFactory.getLogger(BlobServer.class); /** Counter to generate unique names for temporary files. */ - private final AtomicInteger tempFileCounter = new AtomicInteger(0); + private final AtomicLong tempFileCounter = new AtomicLong(0); /** The server socket listening for incoming connections. */ private final ServerSocket serverSocket; @@ -111,7 +113,7 @@ public BlobServer(Configuration config, BlobStore blobStore) throws IOException // configure and create the storage directory String storageDirectory = config.getString(BlobServerOptions.STORAGE_DIRECTORY); - this.storageDir = BlobUtils.initStorageDirectory(storageDirectory); + this.storageDir = BlobUtils.initLocalStorageDirectory(storageDirectory); LOG.info("Created BLOB server storage directory {}", storageDir); // configure the maximum number of concurrent connections @@ -190,11 +192,13 @@ public ServerSocket createSocket(int port) throws IOException { * *

This is only called from the {@link BlobServerConnection} * + * @param jobId ID of the job this blob belongs to (or null if job-unrelated) * @param key identifying the file * @return file handle to the file */ - File getStorageLocation(BlobKey key) { - return BlobUtils.getStorageLocation(storageDir, key); + @VisibleForTesting + public File getStorageLocation(JobID jobId, BlobKey key) { + return BlobUtils.getStorageLocation(storageDir, jobId, key); } /** @@ -334,34 +338,85 @@ public BlobClient createClient() throws IOException { } /** - * Method which retrieves the URL of a file associated with a blob key. The blob server looks - * the blob key up in its local storage. If the file exists, then the URL is returned. If the - * file does not exist, then a FileNotFoundException is thrown. + * Retrieves the local path of a (job-unrelated) file associated with a job and a blob key. + *

+ * The blob server looks the blob key up in its local storage. If the file exists, it is + * returned. If the file does not exist, it is retrieved from the HA blob store (if available) + * or a {@link FileNotFoundException} is thrown. + * + * @param key + * blob key associated with the requested file + * + * @return file referring to the local storage location of the BLOB * - * @param requiredBlob blob key associated with the requested file - * @return URL of the file * @throws IOException + * Thrown if the file retrieval failed. */ @Override - public URL getURL(BlobKey requiredBlob) throws IOException { + public File getFile(BlobKey key) throws IOException { + return getFileInternal(null, key); + } + + /** + * Retrieves the local path of a file associated with a job and a blob key. + *

+ * The blob server looks the blob key up in its local storage. If the file exists, it is + * returned. If the file does not exist, it is retrieved from the HA blob store (if available) + * or a {@link FileNotFoundException} is thrown. + * + * @param jobId + * ID of the job this blob belongs to + * @param key + * blob key associated with the requested file + * + * @return file referring to the local storage location of the BLOB + * + * @throws IOException + * Thrown if the file retrieval failed. + */ + @Override + public File getFile(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + return getFileInternal(jobId, key); + } + + /** + * Retrieves the local path of a file associated with a job and a blob key. + *

+ * The blob server looks the blob key up in its local storage. If the file exists, it is + * returned. If the file does not exist, it is retrieved from the HA blob store (if available) + * or a {@link FileNotFoundException} is thrown. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param requiredBlob + * blob key associated with the requested file + * + * @return file referring to the local storage location of the BLOB + * + * @throws IOException + * Thrown if the file retrieval failed. + */ + private File getFileInternal(@Nullable JobID jobId, BlobKey requiredBlob) throws IOException { checkArgument(requiredBlob != null, "BLOB key cannot be null."); - final File localFile = BlobUtils.getStorageLocation(storageDir, requiredBlob); + final File localFile = BlobUtils.getStorageLocation(storageDir, jobId, requiredBlob); if (localFile.exists()) { - return localFile.toURI().toURL(); + return localFile; } else { try { // Try the blob store - blobStore.get(requiredBlob, localFile); + blobStore.get(jobId, requiredBlob, localFile); } catch (Exception e) { - throw new IOException("Failed to copy from blob store.", e); + throw new IOException( + "Failed to copy BLOB " + requiredBlob + " from blob store to " + localFile, e); } if (localFile.exists()) { - return localFile.toURI().toURL(); + return localFile; } else { throw new FileNotFoundException("Local file " + localFile + " does not exist " + @@ -371,29 +426,94 @@ public URL getURL(BlobKey requiredBlob) throws IOException { } /** - * This method deletes the file associated to the blob key if it exists in the local storage - * of the blob server. + * Deletes the (job-unrelated) file associated with the blob key in both the local storage as + * well as in the HA store of the blob server. + * + * @param key + * blob key associated with the file to be deleted * - * @param key associated with the file to be deleted * @throws IOException */ @Override public void delete(BlobKey key) throws IOException { - final File localFile = BlobUtils.getStorageLocation(storageDir, key); + deleteInternal(null, key); + } + + /** + * Deletes the file associated with the blob key in both the local storage as well as in the HA + * store of the blob server. + * + * @param jobId + * ID of the job this blob belongs to + * @param key + * blob key associated with the file to be deleted + * + * @throws IOException + */ + @Override + public void delete(JobID jobId, BlobKey key) throws IOException { + checkNotNull(jobId); + deleteInternal(jobId, key); + } + + /** + * Deletes the file associated with the blob key in both the local storage as well as in the HA + * store of the blob server. + * + * @param jobId + * ID of the job this blob belongs to (or null if job-unrelated) + * @param key + * blob key associated with the file to be deleted + * + * @throws IOException + */ + void deleteInternal(@Nullable JobID jobId, BlobKey key) throws IOException { + final File localFile = BlobUtils.getStorageLocation(storageDir, jobId, key); readWriteLock.writeLock().lock(); try { if (!localFile.delete() && localFile.exists()) { - LOG.warn("Failed to delete locally BLOB " + key + " at " + localFile.getAbsolutePath()); + LOG.warn("Failed to locally delete BLOB " + key + " at " + localFile.getAbsolutePath()); + } + + blobStore.delete(jobId, key); + } finally { + readWriteLock.writeLock().unlock(); + } + } + + /** + * Removes all BLOBs from local and HA store belonging to the given job ID. + * + * @param jobId + * ID of the job this blob belongs to + */ + public void cleanupJob(JobID jobId) { + checkNotNull(jobId); + + final File jobDir = + new File(BlobUtils.getStorageLocationPath(storageDir.getAbsolutePath(), jobId)); + + readWriteLock.writeLock().lock(); + + try { + // delete locally + try { + FileUtils.deleteDirectory(jobDir); + } catch (IOException e) { + LOG.warn("Failed to locally delete BLOB storage directory at " + + jobDir.getAbsolutePath(), e); } - blobStore.delete(key); + // delete in HA store + blobStore.deleteAll(jobId); } finally { readWriteLock.writeLock().unlock(); } } + /** * Returns the port on which the server is listening. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerConnection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerConnection.java index 181211d74eabe..7f617f91d0ca1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerConnection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerConnection.java @@ -39,7 +39,8 @@ import java.util.concurrent.locks.ReadWriteLock; import static org.apache.flink.runtime.blob.BlobServerProtocol.BUFFER_SIZE; -import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_ADDRESSABLE; +import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_FOR_JOB; +import static org.apache.flink.runtime.blob.BlobServerProtocol.CONTENT_NO_JOB; import static org.apache.flink.runtime.blob.BlobServerProtocol.DELETE_OPERATION; import static org.apache.flink.runtime.blob.BlobServerProtocol.GET_OPERATION; import static org.apache.flink.runtime.blob.BlobServerProtocol.PUT_OPERATION; @@ -49,6 +50,7 @@ import static org.apache.flink.runtime.blob.BlobUtils.readFully; import static org.apache.flink.runtime.blob.BlobUtils.readLength; import static org.apache.flink.runtime.blob.BlobUtils.writeLength; +import static org.apache.flink.util.Preconditions.checkNotNull; /** * A BLOB connection handles a series of requests from a particular BLOB client. @@ -83,12 +85,8 @@ class BlobServerConnection extends Thread { super("BLOB connection for " + clientSocket.getRemoteSocketAddress()); setDaemon(true); - if (blobServer == null) { - throw new NullPointerException(); - } - this.clientSocket = clientSocket; - this.blobServer = blobServer; + this.blobServer = checkNotNull(blobServer); this.blobStore = blobServer.getBlobStore(); ReadWriteLock readWriteLock = blobServer.getReadWriteLock(); @@ -141,14 +139,7 @@ public void run() { LOG.error("Error while executing BLOB connection.", t); } finally { - try { - if (clientSocket != null) { - clientSocket.close(); - } - } catch (Throwable t) { - LOG.debug("Exception while closing BLOB server connection socket.", t); - } - + closeSilently(clientSocket, LOG); blobServer.unregisterConnection(this); } } @@ -167,15 +158,16 @@ public void close() { /** * Handles an incoming GET request from a BLOB client. - * + * * @param inputStream - * the input stream to read incoming data from + * the input stream to read incoming data from * @param outputStream - * the output stream to send data back to the client + * the output stream to send data back to the client * @param buf - * an auxiliary buffer for data serialization/deserialization + * an auxiliary buffer for data serialization/deserialization + * * @throws IOException - * thrown if an I/O error occurs while reading/writing data from/to the respective streams + * thrown if an I/O error occurs while reading/writing data from/to the respective streams */ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) throws IOException { /* @@ -187,25 +179,36 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) * so a local cache makes more sense. */ - File blobFile; - int contentAddressable = -1; - JobID jobId = null; - BlobKey blobKey = null; + final File blobFile; + final JobID jobId; + final BlobKey blobKey; try { - contentAddressable = inputStream.read(); + final int mode = inputStream.read(); - if (contentAddressable < 0) { + if (mode < 0) { throw new EOFException("Premature end of GET request"); } - if (contentAddressable == CONTENT_ADDRESSABLE) { - blobKey = BlobKey.readFromInputStream(inputStream); - blobFile = blobServer.getStorageLocation(blobKey); + + // Receive the job ID and key + if (mode == CONTENT_NO_JOB) { + jobId = null; + } else if (mode == CONTENT_FOR_JOB) { + byte[] jidBytes = new byte[JobID.SIZE]; + readFully(inputStream, jidBytes, 0, JobID.SIZE, "JobID"); + jobId = JobID.fromByteArray(jidBytes); + } else { + throw new IOException("Unknown type of BLOB addressing: " + mode + '.'); } - else { - throw new IOException("Unknown type of BLOB addressing: " + contentAddressable + '.'); + blobKey = BlobKey.readFromInputStream(inputStream); + + if (LOG.isDebugEnabled()) { + LOG.debug("Received GET request for BLOB {}/{} from {}.", jobId, + blobKey, clientSocket.getInetAddress()); } + blobFile = blobServer.getStorageLocation(jobId, blobKey); + // up to here, an error can give a good message } catch (Throwable t) { @@ -214,7 +217,7 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) writeErrorToStream(outputStream, t); } catch (IOException e) { - // since we are in an exception case, it means not much that we could not send the error + // since we are in an exception case, it means that we could not send the error // ignore this } clientSocket.close(); @@ -224,6 +227,7 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) readLock.lock(); try { + // copy the file to local store if it does not exist yet try { if (!blobFile.exists()) { // first we have to release the read lock in order to acquire the write lock @@ -232,9 +236,9 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) try { if (blobFile.exists()) { - LOG.debug("Blob file {} has downloaded from the BlobStore by a different connection.", blobFile); + LOG.debug("Blob file {} has been downloaded from the (distributed) blob store by a different connection.", blobFile); } else { - blobStore.get(blobKey, blobFile); + blobStore.get(jobId, blobKey, blobFile); } } finally { writeLock.unlock(); @@ -248,6 +252,7 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) } } + // enforce a 2GB max for now (otherwise the protocol's length field needs to be increased) if (blobFile.length() > Integer.MAX_VALUE) { throw new IOException("BLOB size exceeds the maximum size (2 GB)."); } @@ -259,7 +264,7 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) writeErrorToStream(outputStream, t); } catch (IOException e) { - // since we are in an exception case, it means not much that we could not send the error + // since we are in an exception case, it means that we could not send the error // ignore this } clientSocket.close(); @@ -294,59 +299,48 @@ private void get(InputStream inputStream, OutputStream outputStream, byte[] buf) /** * Handles an incoming PUT request from a BLOB client. - * - * @param inputStream The input stream to read incoming data from. - * @param outputStream The output stream to send data back to the client. - * @param buf An auxiliary buffer for data serialization/deserialization. + * + * @param inputStream + * The input stream to read incoming data from + * @param outputStream + * The output stream to send data back to the client + * @param buf + * An auxiliary buffer for data serialization/deserialization + * + * @throws IOException + * thrown if an I/O error occurs while reading/writing data from/to the respective streams */ private void put(InputStream inputStream, OutputStream outputStream, byte[] buf) throws IOException { - JobID jobID = null; - MessageDigest md = null; - File incomingFile = null; - FileOutputStream fos = null; try { - final int contentAddressable = inputStream.read(); - if (contentAddressable < 0) { + final int mode = inputStream.read(); + + if (mode < 0) { throw new EOFException("Premature end of PUT request"); } - if (contentAddressable == CONTENT_ADDRESSABLE) { - md = BlobUtils.createMessageDigest(); - } - else { + // Receive the job ID and key + final JobID jobId; + if (mode == CONTENT_NO_JOB) { + jobId = null; + } else if (mode == CONTENT_FOR_JOB) { + byte[] jidBytes = new byte[JobID.SIZE]; + readFully(inputStream, jidBytes, 0, JobID.SIZE, "JobID"); + jobId = JobID.fromByteArray(jidBytes); + } else { throw new IOException("Unknown type of BLOB addressing."); } if (LOG.isDebugEnabled()) { - LOG.debug("Received PUT request for content addressable BLOB"); + LOG.debug("Received PUT request for BLOB of job {} with from {}.", jobId, + clientSocket.getInetAddress()); } incomingFile = blobServer.createTemporaryFilename(); - fos = new FileOutputStream(incomingFile); - - while (true) { - final int bytesExpected = readLength(inputStream); - if (bytesExpected == -1) { - // done - break; - } - if (bytesExpected > BUFFER_SIZE) { - throw new IOException("Unexpected number of incoming bytes: " + bytesExpected); - } - - readFully(inputStream, buf, 0, bytesExpected, "buffer"); - fos.write(buf, 0, bytesExpected); - - if (md != null) { - md.update(buf, 0, bytesExpected); - } - } - fos.close(); + BlobKey blobKey = readFileFully(inputStream, incomingFile, buf); - BlobKey blobKey = new BlobKey(md.digest()); - File storageFile = blobServer.getStorageLocation(blobKey); + File storageFile = blobServer.getStorageLocation(jobId, blobKey); writeLock.lock(); @@ -369,13 +363,15 @@ private void put(InputStream inputStream, OutputStream outputStream, byte[] buf) // only the one moving the incoming file to its final destination is allowed to upload the // file to the blob store - blobStore.put(storageFile, blobKey); + blobStore.put(storageFile, jobId, blobKey); + } else { + LOG.warn("File upload for an existing file with key {} for job {}. This may indicate a duplicate upload or a hash collision. Ignoring newest upload.", blobKey, jobId); } } catch(IOException ioe) { // we failed to either create the local storage file or to upload it --> try to delete the local file // while still having the write lock - if (storageFile.exists() && !storageFile.delete()) { - LOG.warn("Could not delete the storage file."); + if (!storageFile.delete() && storageFile.exists()) { + LOG.warn("Could not delete the storage file with key {} and job {}.", blobKey, jobId); } throw ioe; @@ -403,43 +399,89 @@ private void put(InputStream inputStream, OutputStream outputStream, byte[] buf) clientSocket.close(); } finally { - if (fos != null) { - try { - fos.close(); - } catch (Throwable t) { - LOG.warn("Cannot close stream to BLOB staging file", t); - } - } if (incomingFile != null) { - if (!incomingFile.delete()) { + if (!incomingFile.delete() && incomingFile.exists()) { LOG.warn("Cannot delete BLOB server staging file " + incomingFile.getAbsolutePath()); } } } } + /** + * Reads a full file from inputStream into incomingFile returning its checksum. + * + * @param inputStream + * stream to read from + * @param incomingFile + * file to write to + * @param buf + * An auxiliary buffer for data serialization/deserialization + * + * @return the received file's content hash as a BLOB key + * + * @throws IOException + * thrown if an I/O error occurs while reading/writing data from/to the respective streams + */ + private static BlobKey readFileFully( + final InputStream inputStream, final File incomingFile, final byte[] buf) + throws IOException { + MessageDigest md = BlobUtils.createMessageDigest(); + + try (FileOutputStream fos = new FileOutputStream(incomingFile)) { + while (true) { + final int bytesExpected = readLength(inputStream); + if (bytesExpected == -1) { + // done + break; + } + if (bytesExpected > BUFFER_SIZE) { + throw new IOException( + "Unexpected number of incoming bytes: " + bytesExpected); + } + + readFully(inputStream, buf, 0, bytesExpected, "buffer"); + fos.write(buf, 0, bytesExpected); + + md.update(buf, 0, bytesExpected); + } + return new BlobKey(md.digest()); + } + } + /** * Handles an incoming DELETE request from a BLOB client. - * - * @param inputStream The input stream to read the request from. - * @param outputStream The output stream to write the response to. - * @throws java.io.IOException Thrown if an I/O error occurs while reading the request data from the input stream. + * + * @param inputStream + * The input stream to read the request from. + * @param outputStream + * The output stream to write the response to. + * + * @throws IOException + * Thrown if an I/O error occurs while reading the request data from the input stream. */ private void delete(InputStream inputStream, OutputStream outputStream) throws IOException { try { - int type = inputStream.read(); - if (type < 0) { + final int mode = inputStream.read(); + + if (mode < 0) { throw new EOFException("Premature end of DELETE request"); } - if (type == CONTENT_ADDRESSABLE) { - BlobKey key = BlobKey.readFromInputStream(inputStream); - blobServer.delete(key); - } - else { - throw new IOException("Unrecognized addressing type: " + type); + // Receive the job ID and key + final JobID jobId; + if (mode == CONTENT_NO_JOB) { + jobId = null; + } else if (mode == CONTENT_FOR_JOB) { + byte[] jidBytes = new byte[JobID.SIZE]; + readFully(inputStream, jidBytes, 0, JobID.SIZE, "JobID"); + jobId = JobID.fromByteArray(jidBytes); + } else { + throw new IOException("Unknown type of BLOB addressing."); } + BlobKey key = BlobKey.readFromInputStream(inputStream); + + blobServer.deleteInternal(jobId, key); outputStream.write(RETURN_OKAY); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerProtocol.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerProtocol.java index d8ac83373b573..681fc8150be3f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerProtocol.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerProtocol.java @@ -42,12 +42,20 @@ public class BlobServerProtocol { static final byte RETURN_ERROR = 1; /** - * Internal code to identify a reference via content hash as the key. + * Internal code to identify a job-unrelated reference via content hash as the key. *

* Note: previously, there was also NAME_ADDRESSABLE (code 1) and * JOB_ID_SCOPE (code 2). */ - static final byte CONTENT_ADDRESSABLE = 0; + static final byte CONTENT_NO_JOB = 0; + + /** + * Internal code to identify a job-related reference via content hash as the key. + *

+ * Note: previously, there was also NAME_ADDRESSABLE (code 1) and + * JOB_ID_SCOPE (code 2). + */ + static final byte CONTENT_FOR_JOB = 3; // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobService.java index c1447c849eeb3..0db5a589593b8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobService.java @@ -18,9 +18,11 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; + import java.io.Closeable; +import java.io.File; import java.io.IOException; -import java.net.URL; /** * A simple store and retrieve binary large objects (BLOBs). @@ -28,29 +30,49 @@ public interface BlobService extends Closeable { /** - * Returns the URL of the file associated with the provided blob key. + * Returns the path to a local copy of the (job-unrelated) file associated with the provided + * blob key. * * @param key blob key associated with the requested file - * @return The URL to the file. + * @return The path to the file. * @throws java.io.FileNotFoundException when the path does not exist; * @throws IOException if any other error occurs when retrieving the file */ - URL getURL(BlobKey key) throws IOException; + File getFile(BlobKey key) throws IOException; + /** + * Returns the path to a local copy of the file associated with the provided job ID and blob key. + * + * @param jobId ID of the job this blob belongs to + * @param key blob key associated with the requested file + * @return The path to the file. + * @throws java.io.FileNotFoundException when the path does not exist; + * @throws IOException if any other error occurs when retrieving the file + */ + File getFile(JobID jobId, BlobKey key) throws IOException; /** - * Deletes the file associated with the provided blob key. + * Deletes the (job-unrelated) file associated with the provided blob key. * * @param key associated with the file to be deleted * @throws IOException */ void delete(BlobKey key) throws IOException; + /** + * Deletes the file associated with the provided job ID and blob key. + * + * @param jobId ID of the job this blob belongs to + * @param key associated with the file to be deleted + * @throws IOException + */ + void delete(JobID jobId, BlobKey key) throws IOException; + /** * Returns the port of the blob service. * @return the port of the blob service. */ int getPort(); - + BlobClient createClient() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobStore.java index 1e8b73a43a1d8..d2ea8caaf2616 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobStore.java @@ -32,19 +32,21 @@ public interface BlobStore extends BlobView { * Copies the local file to the blob store. * * @param localFile The file to copy + * @param jobId ID of the job this blob belongs to (or null if job-unrelated) * @param blobKey The ID for the file in the blob store * @throws IOException If the copy fails */ - void put(File localFile, BlobKey blobKey) throws IOException; + void put(File localFile, JobID jobId, BlobKey blobKey) throws IOException; /** * Tries to delete a blob from storage. * *

NOTE: This also tries to delete any created directories if empty.

* + * @param jobId ID of the job this blob belongs to (or null if job-unrelated) * @param blobKey The blob ID */ - void delete(BlobKey blobKey); + void delete(JobID jobId, BlobKey blobKey); /** * Tries to delete all blobs for the given job from storage. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobUtils.java index e8f3fe575e187..dabd1bfd764cb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobUtils.java @@ -26,8 +26,11 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.util.StringUtils; + import org.slf4j.Logger; +import javax.annotation.Nullable; +import java.io.Closeable; import java.io.EOFException; import java.io.File; import java.io.IOException; @@ -61,6 +64,11 @@ public class BlobUtils { */ private static final String JOB_DIR_PREFIX = "job_"; + /** + * The prefix of all job-unrelated directories created by the BLOB server. + */ + private static final String NO_JOB_DIR_PREFIX = "no_job"; + /** * Creates a BlobStore based on the parameters set in the configuration. * @@ -116,26 +124,29 @@ private static BlobStoreService createFileSystemBlobStore(Configuration configur } /** - * Creates a storage directory for a blob service. + * Creates a local storage directory for a blob service under the given parent directory. + * + * @param basePath + * base path, i.e. parent directory, of the storage directory to use (if null or + * empty, the path in java.io.tmpdir will be used) * - * @return the storage directory used by a BLOB service + * @return a new local storage directory * * @throws IOException - * thrown if the (local or distributed) file storage cannot be created or - * is not usable + * thrown if the local file storage cannot be created or is not usable */ - static File initStorageDirectory(String storageDirectory) throws - IOException { + static File initLocalStorageDirectory(String basePath) throws IOException { File baseDir; - if (StringUtils.isNullOrWhitespaceOnly(storageDirectory)) { + if (StringUtils.isNullOrWhitespaceOnly(basePath)) { baseDir = new File(System.getProperty("java.io.tmpdir")); } else { - baseDir = new File(storageDirectory); + baseDir = new File(basePath); } File storageDir; + // NOTE: although we will be using UUIDs, there may be collisions final int MAX_ATTEMPTS = 10; for(int attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { storageDir = new File(baseDir, String.format( @@ -143,7 +154,7 @@ static File initStorageDirectory(String storageDirectory) throws // Create the storage dir if it doesn't exist. Only return it when the operation was // successful. - if (!storageDir.exists() && storageDir.mkdirs()) { + if (storageDir.mkdirs()) { return storageDir; } } @@ -153,46 +164,106 @@ static File initStorageDirectory(String storageDirectory) throws } /** - * Returns the BLOB service's directory for incoming files. The directory is created if it did - * not exist so far. + * Returns the BLOB service's directory for incoming (job-unrelated) files. The directory is + * created if it does not exist yet. + * + * @param storageDir + * storage directory used be the BLOB service * - * @return the BLOB server's directory for incoming files + * @return the BLOB service's directory for incoming files */ static File getIncomingDirectory(File storageDir) { final File incomingDir = new File(storageDir, "incoming"); - if (!incomingDir.mkdirs() && !incomingDir.exists()) { - throw new RuntimeException("Cannot create directory for incoming files " + incomingDir.getAbsolutePath()); - } + mkdirTolerateExisting(incomingDir); return incomingDir; } /** - * Returns the BLOB service's directory for cached files. The directory is created if it did - * not exist so far. + * Makes sure a given directory exists by creating it if necessary. * - * @return the BLOB server's directory for cached files + * @param dir + * directory to create */ - private static File getCacheDirectory(File storageDir) { - final File cacheDirectory = new File(storageDir, "cache"); - - if (!cacheDirectory.mkdirs() && !cacheDirectory.exists()) { - throw new RuntimeException("Could not create cache directory '" + cacheDirectory.getAbsolutePath() + "'."); + private static void mkdirTolerateExisting(final File dir) { + // note: thread-safe create should try to mkdir first and then ignore the case that the + // directory already existed + if (!dir.mkdirs() && !dir.exists()) { + throw new RuntimeException( + "Cannot create directory '" + dir.getAbsolutePath() + "'."); } - - return cacheDirectory; } /** * Returns the (designated) physical storage location of the BLOB with the given key. * + * @param storageDir + * storage directory used be the BLOB service * @param key - * the key identifying the BLOB + * the key identifying the BLOB + * @param jobId + * ID of the job for the incoming files (or null if job-unrelated) + * * @return the (designated) physical storage location of the BLOB */ - static File getStorageLocation(File storageDir, BlobKey key) { - return new File(getCacheDirectory(storageDir), BLOB_FILE_PREFIX + key.toString()); + static File getStorageLocation( + File storageDir, @Nullable JobID jobId, BlobKey key) { + File file = new File(getStorageLocationPath(storageDir.getAbsolutePath(), jobId, key)); + + mkdirTolerateExisting(file.getParentFile()); + + return file; + } + + /** + * Returns the BLOB server's storage directory for BLOBs belonging to the job with the given ID + * without creating the directory. + * + * @param storageDir + * storage directory used be the BLOB service + * @param jobId + * the ID of the job to return the storage directory for + * + * @return the storage directory for BLOBs belonging to the job with the given ID + */ + static String getStorageLocationPath(String storageDir, @Nullable JobID jobId) { + if (jobId == null) { + // format: $base/no_job + return String.format("%s/%s", storageDir, NO_JOB_DIR_PREFIX); + } else { + // format: $base/job_$jobId + return String.format("%s/%s%s", storageDir, JOB_DIR_PREFIX, jobId.toString()); + } + } + + /** + * Returns the path for the given blob key. + *

+ * The returned path can be used with the (local or HA) BLOB store file system back-end for + * recovery purposes and follows the same scheme as {@link #getStorageLocation(File, JobID, + * BlobKey)}. + * + * @param storageDir + * storage directory used be the BLOB service + * @param key + * the key identifying the BLOB + * @param jobId + * ID of the job for the incoming files + * + * @return the path to the given BLOB + */ + static String getStorageLocationPath( + String storageDir, @Nullable JobID jobId, BlobKey key) { + if (jobId == null) { + // format: $base/no_job/blob_$key + return String.format("%s/%s/%s%s", + storageDir, NO_JOB_DIR_PREFIX, BLOB_FILE_PREFIX, key.toString()); + } else { + // format: $base/job_$jobId/blob_$key + return String.format("%s/%s%s/%s%s", + storageDir, JOB_DIR_PREFIX, jobId.toString(), BLOB_FILE_PREFIX, key.toString()); + } } /** @@ -211,7 +282,7 @@ static MessageDigest createMessageDigest() { /** * Adds a shutdown hook to the JVM and returns the Thread, which has been registered. */ - static Thread addShutdownHook(final BlobService service, final Logger logger) { + static Thread addShutdownHook(final Closeable service, final Logger logger) { checkNotNull(service); checkNotNull(logger); @@ -325,35 +396,11 @@ static void closeSilently(Socket socket, Logger LOG) { try { socket.close(); } catch (Throwable t) { - if (LOG.isDebugEnabled()) { - LOG.debug("Error while closing resource after BLOB transfer.", t); - } + LOG.debug("Exception while closing BLOB server connection socket.", t); } } } - /** - * Returns the path for the given blob key. - * - *

The returned path can be used with the state backend for recovery purposes. - * - *

This follows the same scheme as {@link #getStorageLocation(File, BlobKey)} - * and is used for HA. - */ - static String getRecoveryPath(String basePath, BlobKey blobKey) { - // format: $base/cache/blob_$key - return String.format("%s/cache/%s%s", basePath, BLOB_FILE_PREFIX, blobKey.toString()); - } - - /** - * Returns the path for the given job ID. - * - *

The returned path can be used with the state backend for recovery purposes. - */ - static String getRecoveryPath(String basePath, JobID jobId) { - return String.format("%s/%s%s", basePath, JOB_DIR_PREFIX, jobId.toString()); - } - /** * Private constructor to prevent instantiation. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobView.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobView.java index 2e2e4a77841b6..8916d953d7eff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobView.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobView.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; + import java.io.File; import java.io.IOException; @@ -29,9 +31,10 @@ public interface BlobView { /** * Copies a blob to a local file. * + * @param jobId ID of the job this blob belongs to (or null if job-unrelated) * @param blobKey The blob ID * @param localFile The local file to copy to * @throws IOException If the copy fails */ - void get(BlobKey blobKey, File localFile) throws IOException; + void get(JobID jobId, BlobKey blobKey, File localFile) throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/FileSystemBlobStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/FileSystemBlobStore.java index 5f8058bee5e77..062fd8273bce5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/FileSystemBlobStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/FileSystemBlobStore.java @@ -18,13 +18,13 @@ package org.apache.flink.runtime.blob; -import com.google.common.io.Files; - import org.apache.flink.api.common.JobID; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; import org.apache.flink.util.IOUtils; +import org.apache.flink.shaded.guava18.com.google.common.io.Files; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,8 +64,8 @@ public FileSystemBlobStore(FileSystem fileSystem, String storagePath) throws IOE // - Put ------------------------------------------------------------------ @Override - public void put(File localFile, BlobKey blobKey) throws IOException { - put(localFile, BlobUtils.getRecoveryPath(basePath, blobKey)); + public void put(File localFile, JobID jobId, BlobKey blobKey) throws IOException { + put(localFile, BlobUtils.getStorageLocationPath(basePath, jobId, blobKey)); } private void put(File fromFile, String toBlobPath) throws IOException { @@ -78,8 +78,8 @@ private void put(File fromFile, String toBlobPath) throws IOException { // - Get ------------------------------------------------------------------ @Override - public void get(BlobKey blobKey, File localFile) throws IOException { - get(BlobUtils.getRecoveryPath(basePath, blobKey), localFile); + public void get(JobID jobId, BlobKey blobKey, File localFile) throws IOException { + get(BlobUtils.getStorageLocationPath(basePath, jobId, blobKey), localFile); } private void get(String fromBlobPath, File toFile) throws IOException { @@ -112,13 +112,13 @@ private void get(String fromBlobPath, File toFile) throws IOException { // - Delete --------------------------------------------------------------- @Override - public void delete(BlobKey blobKey) { - delete(BlobUtils.getRecoveryPath(basePath, blobKey)); + public void delete(JobID jobId, BlobKey blobKey) { + delete(BlobUtils.getStorageLocationPath(basePath, jobId, blobKey)); } @Override public void deleteAll(JobID jobId) { - delete(BlobUtils.getRecoveryPath(basePath, jobId)); + delete(BlobUtils.getStorageLocationPath(basePath, jobId)); } private void delete(String blobPath) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/VoidBlobStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/VoidBlobStore.java index 6e2bb53ff08d7..95be5697920db 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/VoidBlobStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/VoidBlobStore.java @@ -29,16 +29,15 @@ public class VoidBlobStore implements BlobStoreService { @Override - public void put(File localFile, BlobKey blobKey) throws IOException { + public void put(File localFile, JobID jobId, BlobKey blobKey) throws IOException { } - @Override - public void get(BlobKey blobKey, File localFile) throws IOException { + public void get(JobID jobId, BlobKey blobKey, File localFile) throws IOException { } @Override - public void delete(BlobKey blobKey) { + public void delete(JobID jobId, BlobKey blobKey) { } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 6f4186745e95d..c98d3aa1083c2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -40,7 +40,7 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.state.SharedStateRegistryFactory; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.util.Preconditions; import org.apache.flink.util.StringUtils; @@ -175,8 +175,11 @@ public class CheckpointCoordinator { @Nullable private CheckpointStatsTracker statsTracker; + /** A factory for SharedStateRegistry objects */ + private final SharedStateRegistryFactory sharedStateRegistryFactory; + /** Registry that tracks state which is shared across (incremental) checkpoints */ - private final SharedStateRegistry sharedStateRegistry; + private SharedStateRegistry sharedStateRegistry; // -------------------------------------------------------------------------------------------- @@ -193,7 +196,8 @@ public CheckpointCoordinator( CheckpointIDCounter checkpointIDCounter, CompletedCheckpointStore completedCheckpointStore, @Nullable String checkpointDirectory, - Executor executor) { + Executor executor, + SharedStateRegistryFactory sharedStateRegistryFactory) { // sanity checks checkArgument(baseInterval > 0, "Checkpoint timeout must be larger than zero"); @@ -231,7 +235,8 @@ public CheckpointCoordinator( this.completedCheckpointStore = checkNotNull(completedCheckpointStore); this.checkpointDirectory = checkpointDirectory; this.executor = checkNotNull(executor); - this.sharedStateRegistry = new SharedStateRegistry(executor); + this.sharedStateRegistryFactory = checkNotNull(sharedStateRegistryFactory); + this.sharedStateRegistry = sharedStateRegistryFactory.create(executor); this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS); this.masterHooks = new HashMap<>(); @@ -1016,7 +1021,7 @@ int getNumScheduledTasks() { * Restores the latest checkpointed state. * * @param tasks Map of job vertices to restore. State for these vertices is - * restored via {@link Execution#setInitialState(TaskStateHandles)}. + * restored via {@link Execution#setInitialState(TaskStateSnapshot)}. * @param errorIfNoCheckpoint Fail if no completed checkpoint is available to * restore from. * @param allowNonRestoredState Allow checkpoint state that cannot be mapped @@ -1044,10 +1049,23 @@ public boolean restoreLatestCheckpointedState( throw new IllegalStateException("CheckpointCoordinator is shut down"); } - // Recover the checkpoints - completedCheckpointStore.recover(sharedStateRegistry); + // We create a new shared state registry object, so that all pending async disposal requests from previous + // runs will go against the old object (were they can do no harm). + // This must happen under the checkpoint lock. + sharedStateRegistry.close(); + sharedStateRegistry = sharedStateRegistryFactory.create(executor); + + // Recover the checkpoints, TODO this could be done only when there is a new leader, not on each recovery + completedCheckpointStore.recover(); + + // Now, we re-register all (shared) states from the checkpoint store with the new registry + for (CompletedCheckpoint completedCheckpoint : completedCheckpointStore.getAllCheckpoints()) { + completedCheckpoint.registerSharedStatesAfterRestored(sharedStateRegistry); + } + + LOG.debug("Status of the shared state registry after restore: {}.", sharedStateRegistry); - // restore from the latest checkpoint + // Restore from the latest checkpoint CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint(); if (latest == null) { @@ -1102,7 +1120,7 @@ public boolean restoreLatestCheckpointedState( * mapped to any job vertex in tasks. * @param tasks Map of job vertices to restore. State for these * vertices is restored via - * {@link Execution#setInitialState(TaskStateHandles)}. + * {@link Execution#setInitialState(TaskStateSnapshot)}. * @param userClassLoader The class loader to resolve serialized classes in * legacy savepoint versions. */ @@ -1121,7 +1139,6 @@ public boolean restoreSavepoint( CompletedCheckpoint savepoint = SavepointLoader.loadAndValidateSavepoint( job, tasks, savepointPath, userClassLoader, allowNonRestored); - savepoint.registerSharedStatesAfterRestored(sharedStateRegistry); completedCheckpointStore.addCheckpoint(savepoint); // Reset the checkpoint ID counter @@ -1256,7 +1273,7 @@ private void discardSubtaskState( final JobID jobId, final ExecutionAttemptID executionAttemptID, final long checkpointId, - final SubtaskState subtaskState) { + final TaskStateSnapshot subtaskState) { if (subtaskState != null) { executor.execute(new Runnable() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java index 43d66ee719604..22244f6cb8d51 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java @@ -29,7 +29,7 @@ void acknowledgeCheckpoint( final ExecutionAttemptID executionAttemptID, final long checkpointId, final CheckpointMetrics checkpointMetrics, - final SubtaskState subtaskState); + final TaskStateSnapshot subtaskState); void declineCheckpoint( JobID jobID, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointProperties.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointProperties.java index 6df7e71c0afc5..1233b6ec419fc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointProperties.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointProperties.java @@ -39,6 +39,7 @@ public class CheckpointProperties implements Serializable { private final boolean forced; private final boolean externalize; + private final boolean savepoint; private final boolean discardSubsumed; private final boolean discardFinished; @@ -49,6 +50,7 @@ public class CheckpointProperties implements Serializable { CheckpointProperties( boolean forced, boolean externalize, + boolean savepoint, boolean discardSubsumed, boolean discardFinished, boolean discardCancelled, @@ -57,6 +59,7 @@ public class CheckpointProperties implements Serializable { this.forced = forced; this.externalize = externalize; + this.savepoint = savepoint; this.discardSubsumed = discardSubsumed; this.discardFinished = discardFinished; this.discardCancelled = discardCancelled; @@ -183,7 +186,7 @@ boolean discardOnJobSuspended() { * @return true if the properties describe a savepoint, false otherwise. */ public boolean isSavepoint() { - return this == STANDARD_SAVEPOINT; + return savepoint; } // ------------------------------------------------------------------------ @@ -201,6 +204,7 @@ public boolean equals(Object o) { CheckpointProperties that = (CheckpointProperties) o; return forced == that.forced && externalize == that.externalize && + savepoint == that.savepoint && discardSubsumed == that.discardSubsumed && discardFinished == that.discardFinished && discardCancelled == that.discardCancelled && @@ -212,6 +216,7 @@ public boolean equals(Object o) { public int hashCode() { int result = (forced ? 1 : 0); result = 31 * result + (externalize ? 1 : 0); + result = 31 * result + (savepoint ? 1 : 0); result = 31 * result + (discardSubsumed ? 1 : 0); result = 31 * result + (discardFinished ? 1 : 0); result = 31 * result + (discardCancelled ? 1 : 0); @@ -224,7 +229,8 @@ public int hashCode() { public String toString() { return "CheckpointProperties{" + "forced=" + forced + - ", externalize=" + externalizeCheckpoint() + + ", externalized=" + externalizeCheckpoint() + + ", savepoint=" + savepoint + ", discardSubsumed=" + discardSubsumed + ", discardFinished=" + discardFinished + ", discardCancelled=" + discardCancelled + @@ -236,6 +242,7 @@ public String toString() { // ------------------------------------------------------------------------ private static final CheckpointProperties STANDARD_SAVEPOINT = new CheckpointProperties( + true, true, true, false, @@ -245,6 +252,7 @@ public String toString() { false); private static final CheckpointProperties STANDARD_CHECKPOINT = new CheckpointProperties( + false, false, false, true, @@ -256,6 +264,7 @@ public String toString() { private static final CheckpointProperties EXTERNALIZED_CHECKPOINT_RETAINED = new CheckpointProperties( false, true, + false, true, true, false, // Retain on cancellation @@ -265,6 +274,7 @@ public String toString() { private static final CheckpointProperties EXTERNALIZED_CHECKPOINT_DELETED = new CheckpointProperties( false, true, + false, true, true, true, // Delete on cancellation diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index 7c3edee081669..d3f61e448acc9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -209,6 +209,8 @@ public boolean discardOnShutdown(JobStatus jobStatus) throws Exception { private void doDiscard() throws Exception { + LOG.trace("Executing discard procedure for {}.", this); + try { // collect exceptions and continue cleanup Exception exception = null; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java index 45d407e91a045..82193b5f08da9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.jobgraph.JobStatus; -import org.apache.flink.runtime.state.SharedStateRegistry; import java.util.List; @@ -33,10 +32,8 @@ public interface CompletedCheckpointStore { * *

After a call to this method, {@link #getLatestCheckpoint()} returns the latest * available checkpoint. - * - * @param sharedStateRegistry the shared state registry to register recovered states. */ - void recover(SharedStateRegistry sharedStateRegistry) throws Exception; + void recover() throws Exception; /** * Adds a {@link CompletedCheckpoint} instance to the list of completed checkpoints. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java index b15302835bb38..a5f908d7aa964 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java @@ -30,8 +30,8 @@ import java.util.Objects; /** - * Simple container class which contains the raw/managed/legacy operator state and key-group state handles for the sub - * tasks of an operator. + * Simple container class which contains the raw/managed operator state and key-group state handles from all sub + * tasks of an operator and therefore represents the complete state of a logical operator. */ public class OperatorState implements CompositeStateHandle { @@ -102,15 +102,6 @@ public int getMaxParallelism() { return maxParallelism; } - public boolean hasNonPartitionedState() { - for (OperatorSubtaskState sts : operatorSubtaskStates.values()) { - if (sts != null && sts.getLegacyOperatorState() != null) { - return true; - } - } - return false; - } - @Override public void discardState() throws Exception { for (OperatorSubtaskState operatorSubtaskState : operatorSubtaskStates.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index e2ae632a26b1b..3df9c4fc5248d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -24,14 +24,34 @@ import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; +import javax.annotation.Nonnull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; /** - * Container for the state of one parallel subtask of an operator. This is part of the {@link OperatorState}. + * This class encapsulates the state for one parallel instance of an operator. The complete state of a (logical) + * operator (e.g. a flatmap operator) consists of the union of all {@link OperatorSubtaskState}s from all + * parallel tasks that physically execute parallelized, physical instances of the operator. + * + *

The full state of the logical operator is represented by {@link OperatorState} which consists of + * {@link OperatorSubtaskState}s. + * + *

Typically, we expect all collections in this class to be of size 0 or 1, because there is up to one state handle + * produced per state type (e.g. managed-keyed, raw-operator, ...). In particular, this holds when taking a snapshot. + * The purpose of having the state handles in collections is that this class is also reused in restoring state. + * Under normal circumstances, the expected size of each collection is still 0 or 1, except for scale-down. In + * scale-down, one operator subtask can become responsible for the state of multiple previous subtasks. The collections + * can then store all the state handles that are relevant to build up the new subtask state. + * + *

There is no collection for legacy state because it is not rescalable. */ public class OperatorSubtaskState implements CompositeStateHandle { @@ -39,34 +59,29 @@ public class OperatorSubtaskState implements CompositeStateHandle { private static final long serialVersionUID = -2394696997971923995L; - /** - * Legacy (non-repartitionable) operator state. - * - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - private final StreamStateHandle legacyOperatorState; - /** * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}. */ - private final OperatorStateHandle managedOperatorState; + @Nonnull + private final Collection managedOperatorState; /** * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}. */ - private final OperatorStateHandle rawOperatorState; + @Nonnull + private final Collection rawOperatorState; /** * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}. */ - private final KeyedStateHandle managedKeyedState; + @Nonnull + private final Collection managedKeyedState; /** * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}. */ - private final KeyedStateHandle rawKeyedState; + @Nonnull + private final Collection rawKeyedState; /** * The state size. This is also part of the deserialized state handle. @@ -75,31 +90,69 @@ public class OperatorSubtaskState implements CompositeStateHandle { */ private final long stateSize; + /** + * Empty state. + */ + public OperatorSubtaskState() { + this( + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()); + } + public OperatorSubtaskState( - StreamStateHandle legacyOperatorState, - OperatorStateHandle managedOperatorState, - OperatorStateHandle rawOperatorState, - KeyedStateHandle managedKeyedState, - KeyedStateHandle rawKeyedState) { + Collection managedOperatorState, + Collection rawOperatorState, + Collection managedKeyedState, + Collection rawKeyedState) { - this.legacyOperatorState = legacyOperatorState; - this.managedOperatorState = managedOperatorState; - this.rawOperatorState = rawOperatorState; - this.managedKeyedState = managedKeyedState; - this.rawKeyedState = rawKeyedState; + this.managedOperatorState = Preconditions.checkNotNull(managedOperatorState); + this.rawOperatorState = Preconditions.checkNotNull(rawOperatorState); + this.managedKeyedState = Preconditions.checkNotNull(managedKeyedState); + this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState); try { - long calculateStateSize = getSizeNullSafe(legacyOperatorState); - calculateStateSize += getSizeNullSafe(managedOperatorState); - calculateStateSize += getSizeNullSafe(rawOperatorState); - calculateStateSize += getSizeNullSafe(managedKeyedState); - calculateStateSize += getSizeNullSafe(rawKeyedState); + long calculateStateSize = sumAllSizes(managedOperatorState); + calculateStateSize += sumAllSizes(rawOperatorState); + calculateStateSize += sumAllSizes(managedKeyedState); + calculateStateSize += sumAllSizes(rawKeyedState); stateSize = calculateStateSize; } catch (Exception e) { throw new RuntimeException("Failed to get state size.", e); } } + /** + * For convenience because the size of the collections is typically 0 or 1. Null values are translated into empty + * Collections (except for legacy state). + */ + public OperatorSubtaskState( + OperatorStateHandle managedOperatorState, + OperatorStateHandle rawOperatorState, + KeyedStateHandle managedKeyedState, + KeyedStateHandle rawKeyedState) { + + this( + singletonOrEmptyOnNull(managedOperatorState), + singletonOrEmptyOnNull(rawOperatorState), + singletonOrEmptyOnNull(managedKeyedState), + singletonOrEmptyOnNull(rawKeyedState)); + } + + private static Collection singletonOrEmptyOnNull(T element) { + return element != null ? Collections.singletonList(element) : Collections.emptyList(); + } + + private static long sumAllSizes(Collection stateObject) throws Exception { + long size = 0L; + for (StateObject object : stateObject) { + size += getSizeNullSafe(object); + } + + return size; + } + private static long getSizeNullSafe(StateObject stateObject) throws Exception { return stateObject != null ? stateObject.getStateSize() : 0L; } @@ -107,40 +160,51 @@ private static long getSizeNullSafe(StateObject stateObject) throws Exception { // -------------------------------------------------------------------------------------------- /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. + * Returns a handle to the managed operator state. */ - @Deprecated - public StreamStateHandle getLegacyOperatorState() { - return legacyOperatorState; - } - - public OperatorStateHandle getManagedOperatorState() { + @Nonnull + public Collection getManagedOperatorState() { return managedOperatorState; } - public OperatorStateHandle getRawOperatorState() { + /** + * Returns a handle to the raw operator state. + */ + @Nonnull + public Collection getRawOperatorState() { return rawOperatorState; } - public KeyedStateHandle getManagedKeyedState() { + /** + * Returns a handle to the managed keyed state. + */ + @Nonnull + public Collection getManagedKeyedState() { return managedKeyedState; } - public KeyedStateHandle getRawKeyedState() { + /** + * Returns a handle to the raw keyed state. + */ + @Nonnull + public Collection getRawKeyedState() { return rawKeyedState; } @Override public void discardState() { try { - StateUtil.bestEffortDiscardAllStateObjects( - Arrays.asList( - legacyOperatorState, - managedOperatorState, - rawOperatorState, - managedKeyedState, - rawKeyedState)); + List toDispose = + new ArrayList<>( + managedOperatorState.size() + + rawOperatorState.size() + + managedKeyedState.size() + + rawKeyedState.size()); + toDispose.addAll(managedOperatorState); + toDispose.addAll(rawOperatorState); + toDispose.addAll(managedKeyedState); + toDispose.addAll(rawKeyedState); + StateUtil.bestEffortDiscardAllStateObjects(toDispose); } catch (Exception e) { LOG.warn("Error while discarding operator states.", e); } @@ -148,12 +212,17 @@ public void discardState() { @Override public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { - if (managedKeyedState != null) { - managedKeyedState.registerSharedStates(sharedStateRegistry); - } + registerSharedState(sharedStateRegistry, managedKeyedState); + registerSharedState(sharedStateRegistry, rawKeyedState); + } - if (rawKeyedState != null) { - rawKeyedState.registerSharedStates(sharedStateRegistry); + private static void registerSharedState( + SharedStateRegistry sharedStateRegistry, + Iterable stateHandles) { + for (KeyedStateHandle stateHandle : stateHandles) { + if (stateHandle != null) { + stateHandle.registerSharedStates(sharedStateRegistry); + } } } @@ -175,56 +244,55 @@ public boolean equals(Object o) { OperatorSubtaskState that = (OperatorSubtaskState) o; - if (stateSize != that.stateSize) { - return false; - } - - if (legacyOperatorState != null ? - !legacyOperatorState.equals(that.legacyOperatorState) - : that.legacyOperatorState != null) { + if (getStateSize() != that.getStateSize()) { return false; } - if (managedOperatorState != null ? - !managedOperatorState.equals(that.managedOperatorState) - : that.managedOperatorState != null) { + if (!getManagedOperatorState().equals(that.getManagedOperatorState())) { return false; } - if (rawOperatorState != null ? - !rawOperatorState.equals(that.rawOperatorState) - : that.rawOperatorState != null) { + if (!getRawOperatorState().equals(that.getRawOperatorState())) { return false; } - if (managedKeyedState != null ? - !managedKeyedState.equals(that.managedKeyedState) - : that.managedKeyedState != null) { + if (!getManagedKeyedState().equals(that.getManagedKeyedState())) { return false; } - return rawKeyedState != null ? - rawKeyedState.equals(that.rawKeyedState) - : that.rawKeyedState == null; - + return getRawKeyedState().equals(that.getRawKeyedState()); } @Override public int hashCode() { - int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; - result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); - result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); - result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); - result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); - result = 31 * result + (int) (stateSize ^ (stateSize >>> 32)); + int result = getManagedOperatorState().hashCode(); + result = 31 * result + getRawOperatorState().hashCode(); + result = 31 * result + getManagedKeyedState().hashCode(); + result = 31 * result + getRawKeyedState().hashCode(); + result = 31 * result + (int) (getStateSize() ^ (getStateSize() >>> 32)); return result; } @Override public String toString() { return "SubtaskState{" + - "legacyState=" + legacyOperatorState + - ", operatorStateFromBackend=" + managedOperatorState + + "operatorStateFromBackend=" + managedOperatorState + ", operatorStateFromStream=" + rawOperatorState + ", keyedStateFromBackend=" + managedKeyedState + ", keyedStateFromStream=" + rawKeyedState + ", stateSize=" + stateSize + '}'; } + + public boolean hasState() { + return hasState(managedOperatorState) + || hasState(rawOperatorState) + || hasState(managedKeyedState) + || hasState(rawKeyedState); + } + + private boolean hasState(Iterable states) { + for (StateObject state : states) { + if (state != null) { + return true; + } + } + return false; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index 3472fc20abfd8..16231dd0a9663 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -25,19 +25,18 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -353,13 +352,13 @@ private CompletedCheckpoint finalizeInternal( * Acknowledges the task with the given execution attempt id and the given subtask state. * * @param executionAttemptId of the acknowledged task - * @param subtaskState of the acknowledged task + * @param operatorSubtaskStates of the acknowledged task * @param metrics Checkpoint metrics for the stats * @return TaskAcknowledgeResult of the operation */ public TaskAcknowledgeResult acknowledgeTask( ExecutionAttemptID executionAttemptId, - SubtaskState subtaskState, + TaskStateSnapshot operatorSubtaskStates, CheckpointMetrics metrics) { synchronized (lock) { @@ -383,21 +382,19 @@ public TaskAcknowledgeResult acknowledgeTask( int subtaskIndex = vertex.getParallelSubtaskIndex(); long ackTimestamp = System.currentTimeMillis(); - long stateSize = 0; - if (subtaskState != null) { - stateSize = subtaskState.getStateSize(); - - @SuppressWarnings("deprecation") - ChainedStateHandle nonPartitionedState = - subtaskState.getLegacyOperatorState(); - ChainedStateHandle partitioneableState = - subtaskState.getManagedOperatorState(); - ChainedStateHandle rawOperatorState = - subtaskState.getRawOperatorState(); - - // break task state apart into separate operator states - for (int x = 0; x < operatorIDs.size(); x++) { - OperatorID operatorID = operatorIDs.get(x); + long stateSize = 0L; + + if (operatorSubtaskStates != null) { + for (OperatorID operatorID : operatorIDs) { + + OperatorSubtaskState operatorSubtaskState = + operatorSubtaskStates.getSubtaskStateByOperatorID(operatorID); + + // if no real operatorSubtaskState was reported, we insert an empty state + if (operatorSubtaskState == null) { + operatorSubtaskState = new OperatorSubtaskState(); + } + OperatorState operatorState = operatorStates.get(operatorID); if (operatorState == null) { @@ -408,23 +405,8 @@ public TaskAcknowledgeResult acknowledgeTask( operatorStates.put(operatorID, operatorState); } - KeyedStateHandle managedKeyedState = null; - KeyedStateHandle rawKeyedState = null; - - // only the head operator retains the keyed state - if (x == operatorIDs.size() - 1) { - managedKeyedState = subtaskState.getManagedKeyedState(); - rawKeyedState = subtaskState.getRawKeyedState(); - } - - OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( - nonPartitionedState != null ? nonPartitionedState.get(x) : null, - partitioneableState != null ? partitioneableState.get(x) : null, - rawOperatorState != null ? rawOperatorState.get(x) : null, - managedKeyedState, - rawKeyedState); - operatorState.putState(subtaskIndex, operatorSubtaskState); + stateSize += operatorSubtaskState.getStateSize(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 046096fc85c08..4513ef80b32b1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -89,6 +89,10 @@ private GroupByStateNameResults groupByStateName( for (OperatorStateHandle psh : previousParallelSubtaskStates) { + if (psh == null) { + continue; + } + for (Map.Entry e : psh.getStateNameToPartitionOffsets().entrySet()) { OperatorStateHandle.StateMetaInfo metaInfo = e.getValue(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java index fbb0198a7f17c..63e7468ebca74 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java @@ -20,7 +20,7 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; -import org.apache.flink.runtime.state.SharedStateRegistry; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +57,7 @@ public StandaloneCompletedCheckpointStore(int maxNumberOfCheckpointsToRetain) { } @Override - public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { + public void recover() throws Exception { // Nothing to do } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 5712ea1d43827..cc9f9cd1bd10c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -23,15 +23,13 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -163,8 +161,6 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List subNonPartitionableState = new ArrayList<>(); - Tuple2, Collection> subKeyedState = null; List> subManagedOperatorState = new ArrayList<>(); @@ -175,17 +171,9 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List(subNonPartitionableState), - subManagedOperatorState, - subRawOperatorState, - subKeyedState != null ? subKeyedState.f0 : null, - subKeyedState != null ? subKeyedState.f1 : null); + for (int i = 0; i < operatorIDs.size(); ++i) { - currentExecutionAttempt.setInitialState(taskStateHandles); + OperatorID operatorID = operatorIDs.get(i); + + Collection rawKeyed = Collections.emptyList(); + Collection managedKeyed = Collections.emptyList(); + + // keyed state case + if (subKeyedState != null) { + managedKeyed = subKeyedState.f0; + rawKeyed = subKeyedState.f1; + } + + OperatorSubtaskState operatorSubtaskState = + new OperatorSubtaskState( + subManagedOperatorState.get(i), + subRawOperatorState.get(i), + managedKeyed, + rawKeyed + ); + + taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + } + + currentExecutionAttempt.setInitialState(taskState); } } } + private static boolean isHeadOperator(int opIdx, List operatorIDs) { + return opIdx == operatorIDs.size() - 1; + } public void checkParallelismPreconditions(List operatorStates, ExecutionJobVertex executionJobVertex) { @@ -239,18 +246,18 @@ private void reAssignSubPartitionableState( List> subManagedOperatorState, List> subRawOperatorState) { - if (newMangedOperatorStates.get(operatorIndex) != null) { - subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex)); + if (newMangedOperatorStates.get(operatorIndex) != null && !newMangedOperatorStates.get(operatorIndex).isEmpty()) { + Collection operatorStateHandles = newMangedOperatorStates.get(operatorIndex).get(subTaskIndex); + subManagedOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.emptyList()); } else { - subManagedOperatorState.add(null); + subManagedOperatorState.add(Collections.emptyList()); } - if (newRawOperatorStates.get(operatorIndex) != null) { - subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex)); + if (newRawOperatorStates.get(operatorIndex) != null && !newRawOperatorStates.get(operatorIndex).isEmpty()) { + Collection operatorStateHandles = newRawOperatorStates.get(operatorIndex).get(subTaskIndex); + subRawOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.emptyList()); } else { - subRawOperatorState.add(null); + subRawOperatorState.add(Collections.emptyList()); } - - } private Tuple2, Collection> reAssignSubKeyedStates( @@ -265,24 +272,22 @@ private Tuple2, Collection> reAss if (newParallelism == oldParallelism) { if (operatorState.getState(subTaskIndex) != null) { - KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); - KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); - subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList( - oldSubManagedKeyedState) : null; - subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList( - oldSubRawKeyedState) : null; + subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); + subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); } else { - subManagedKeyedState = null; - subRawKeyedState = null; + subManagedKeyedState = Collections.emptyList(); + subRawKeyedState = Collections.emptyList(); } } else { subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); } - if (subManagedKeyedState == null && subRawKeyedState == null) { + + if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) { return null; + } else { + return new Tuple2<>(subManagedKeyedState, subRawKeyedState); } - return new Tuple2<>(subManagedKeyedState, subRawKeyedState); } @@ -295,30 +300,12 @@ private boolean allElementsAreNull(List nonPartitionableStates) { return true; } - - private void reAssignSubNonPartitionedStates( - OperatorState operatorState, - int subTaskIndex, - int newParallelism, - int oldParallelism, - List subNonPartitionableState) { - if (oldParallelism == newParallelism) { - if (operatorState.getState(subTaskIndex) != null) { - subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState()); - } else { - subNonPartitionableState.add(null); - } - } else { - subNonPartitionableState.add(null); - } - } - private void reDistributePartitionableStates( List operatorStates, int newParallelism, List>> newManagedOperatorStates, List>> newRawOperatorStates) { - //collect the old partitionalbe state + //collect the old partitionable state List> oldManagedOperatorStates = new ArrayList<>(); List> oldRawOperatorStates = new ArrayList<>(); @@ -351,19 +338,16 @@ private void collectPartionableStates( for (int i = 0; i < operatorState.getParallelism(); i++) { OperatorSubtaskState operatorSubtaskState = operatorState.getState(i); if (operatorSubtaskState != null) { - if (operatorSubtaskState.getManagedOperatorState() != null) { - if (managedOperatorState == null) { - managedOperatorState = new ArrayList<>(); - } - managedOperatorState.add(operatorSubtaskState.getManagedOperatorState()); + + if (managedOperatorState == null) { + managedOperatorState = new ArrayList<>(); } + managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState()); - if (operatorSubtaskState.getRawOperatorState() != null) { - if (rawOperatorState == null) { - rawOperatorState = new ArrayList<>(); - } - rawOperatorState.add(operatorSubtaskState.getRawOperatorState()); + if (rawOperatorState == null) { + rawOperatorState = new ArrayList<>(); } + rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState()); } } @@ -382,21 +366,19 @@ private void collectPartionableStates( * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange */ public static List getManagedKeyedStateHandles( - OperatorState operatorState, - KeyGroupRange subtaskKeyGroupRange) { + OperatorState operatorState, + KeyGroupRange subtaskKeyGroupRange) { - List subtaskKeyedStateHandles = null; + List subtaskKeyedStateHandles = new ArrayList<>(); for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange); + if (operatorState.getState(i) != null) { - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); - } + Collection keyedStateHandles = operatorState.getState(i).getManagedKeyedState(); + extractIntersectingState( + keyedStateHandles, + subtaskKeyGroupRange, + subtaskKeyedStateHandles); } } @@ -415,22 +397,40 @@ public static List getRawKeyedStateHandles( OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) { - List subtaskKeyedStateHandles = null; + List extractedKeyedStateHandles = new ArrayList<>(); for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange); + if (operatorState.getState(i) != null) { + Collection rawKeyedState = operatorState.getState(i).getRawKeyedState(); + extractIntersectingState( + rawKeyedState, + subtaskKeyGroupRange, + extractedKeyedStateHandles); + } + } + + return extractedKeyedStateHandles; + } + + /** + * Extracts certain key group ranges from the given state handles and adds them to the collector. + */ + private static void extractIntersectingState( + Collection originalSubtaskStateHandles, + KeyGroupRange rangeToExtract, + List extractedStateCollector) { + + for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) { + + if (keyedStateHandle != null) { + + KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract); if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); + extractedStateCollector.add(intersectedKeyedStateHandle); } } } - - return subtaskKeyedStateHandles; } /** @@ -492,19 +492,6 @@ private static void checkParallelismPreconditions(OperatorState operatorState, E "is currently not supported."); } } - - //----------------------------------------parallelism preconditions----------------------------------------- - - final int oldParallelism = operatorState.getParallelism(); - final int newParallelism = executionJobVertex.getParallelism(); - - if (operatorState.hasNonPartitionedState() && (oldParallelism != newParallelism)) { - throw new IllegalStateException("Cannot restore the latest checkpoint because " + - "the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " + - "state and its parallelism changed. The operator " + executionJobVertex.getJobVertexId() + - " has parallelism " + newParallelism + " whereas the corresponding " + - "state object has a parallelism of " + oldParallelism); - } } /** @@ -554,7 +541,7 @@ public static List> applyRepartitioner( int newParallelism) { if (chainOpParallelStates == null) { - return null; + return Collections.emptyList(); } //We only redistribute if the parallelism of the operator changed from previous executions @@ -567,20 +554,23 @@ public static List> applyRepartitioner( List> repackStream = new ArrayList<>(newParallelism); for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { - Map partitionOffsets = + if (operatorStateHandle != null) { + Map partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets(); - for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { - // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning - if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { - return opStateRepartitioner.repartitionState( + for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { + + // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning + if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { + return opStateRepartitioner.repartitionState( chainOpParallelStates, newParallelism); + } } - } - repackStream.add(Collections.singletonList(operatorStateHandle)); + repackStream.add(Collections.singletonList(operatorStateHandle)); + } } return repackStream; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index 20d675b686b94..281693bc90304 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -25,14 +25,12 @@ import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; -import org.apache.flink.runtime.state.StreamStateHandle; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Arrays; -import static org.apache.flink.util.Preconditions.checkNotNull; - /** * Container for the chained state of one parallel subtask of an operator/task. This is part of the * {@link TaskState}. @@ -43,15 +41,6 @@ public class SubtaskState implements CompositeStateHandle { private static final long serialVersionUID = -2394696997971923995L; - /** - * Legacy (non-repartitionable) operator state. - * - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - private final ChainedStateHandle legacyOperatorState; - /** * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}. */ @@ -80,21 +69,18 @@ public class SubtaskState implements CompositeStateHandle { private final long stateSize; public SubtaskState( - ChainedStateHandle legacyOperatorState, ChainedStateHandle managedOperatorState, ChainedStateHandle rawOperatorState, KeyedStateHandle managedKeyedState, KeyedStateHandle rawKeyedState) { - this.legacyOperatorState = checkNotNull(legacyOperatorState, "State"); this.managedOperatorState = managedOperatorState; this.rawOperatorState = rawOperatorState; this.managedKeyedState = managedKeyedState; this.rawKeyedState = rawKeyedState; try { - long calculateStateSize = getSizeNullSafe(legacyOperatorState); - calculateStateSize += getSizeNullSafe(managedOperatorState); + long calculateStateSize = getSizeNullSafe(managedOperatorState); calculateStateSize += getSizeNullSafe(rawOperatorState); calculateStateSize += getSizeNullSafe(managedKeyedState); calculateStateSize += getSizeNullSafe(rawKeyedState); @@ -110,15 +96,6 @@ private static final long getSizeNullSafe(StateObject stateObject) throws Except // -------------------------------------------------------------------------------------------- - /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - public ChainedStateHandle getLegacyOperatorState() { - return legacyOperatorState; - } - public ChainedStateHandle getManagedOperatorState() { return managedOperatorState; } @@ -140,7 +117,6 @@ public void discardState() { try { StateUtil.bestEffortDiscardAllStateObjects( Arrays.asList( - legacyOperatorState, managedOperatorState, rawOperatorState, managedKeyedState, @@ -183,11 +159,6 @@ public boolean equals(Object o) { return false; } - if (legacyOperatorState != null ? - !legacyOperatorState.equals(that.legacyOperatorState) - : that.legacyOperatorState != null) { - return false; - } if (managedOperatorState != null ? !managedOperatorState.equals(that.managedOperatorState) : that.managedOperatorState != null) { @@ -211,8 +182,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; - result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); + int result = (managedOperatorState != null ? managedOperatorState.hashCode() : 0); result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); @@ -223,8 +193,7 @@ public int hashCode() { @Override public String toString() { return "SubtaskState{" + - "chainedStateHandle=" + legacyOperatorState + - ", operatorStateFromBackend=" + managedOperatorState + + "operatorStateFromBackend=" + managedOperatorState + ", operatorStateFromStream=" + rawOperatorState + ", keyedStateFromBackend=" + managedKeyedState + ", keyedStateFromStream=" + rawKeyedState + diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index ed847a43449d7..0f3bedbc72639 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -48,7 +48,6 @@ public class TaskState implements CompositeStateHandle { /** handles to non-partitioned states, subtaskindex -> subtaskstate */ private final Map subtaskStates; - /** parallelism of the operator when it was checkpointed */ private final int parallelism; @@ -117,15 +116,6 @@ public int getChainLength() { return chainLength; } - public boolean hasNonPartitionedState() { - for(SubtaskState sts : subtaskStates.values()) { - if (sts != null && !sts.getLegacyOperatorState().isEmpty()) { - return true; - } - } - return false; - } - @Override public void discardState() throws Exception { for (SubtaskState subtaskState : subtaskStates.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java new file mode 100644 index 0000000000000..c416f3f641c10 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.CompositeStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.StateUtil; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * This class encapsulates state handles to the snapshots of all operator instances executed within one task. A task + * can run multiple operator instances as a result of operator chaining, and all operator instances from the chain can + * register their state under their operator id. Each operator instance is a physical execution responsible for + * processing a partition of the data that goes through a logical operator. This partitioning happens to parallelize + * execution of logical operators, e.g. distributing a map function. + * + *

One instance of this class contains the information that one task will send to acknowledge a checkpoint request by + * the checkpoint coordinator. Tasks run operator instances in parallel, so the union of all + * {@link TaskStateSnapshot} that are collected by the checkpoint coordinator from all tasks represent the whole + * state of a job at the time of the checkpoint. + * + *

This class should be called TaskState once the old class with this name that we keep for backwards + * compatibility goes away. + */ +public class TaskStateSnapshot implements CompositeStateHandle { + + private static final long serialVersionUID = 1L; + + /** Mapping from an operator id to the state of one subtask of this operator */ + private final Map subtaskStatesByOperatorID; + + public TaskStateSnapshot() { + this(10); + } + + public TaskStateSnapshot(int size) { + this(new HashMap(size)); + } + + public TaskStateSnapshot(Map subtaskStatesByOperatorID) { + this.subtaskStatesByOperatorID = Preconditions.checkNotNull(subtaskStatesByOperatorID); + } + + /** + * Returns the subtask state for the given operator id (or null if not contained). + */ + public OperatorSubtaskState getSubtaskStateByOperatorID(OperatorID operatorID) { + return subtaskStatesByOperatorID.get(operatorID); + } + + /** + * Maps the given operator id to the given subtask state. Returns the subtask state of a previous mapping, if such + * a mapping existed or null otherwise. + */ + public OperatorSubtaskState putSubtaskStateByOperatorID(OperatorID operatorID, OperatorSubtaskState state) { + return subtaskStatesByOperatorID.put(operatorID, Preconditions.checkNotNull(state)); + } + + /** + * Returns the set of all mappings from operator id to the corresponding subtask state. + */ + public Set> getSubtaskStateMappings() { + return subtaskStatesByOperatorID.entrySet(); + } + + @Override + public void discardState() throws Exception { + StateUtil.bestEffortDiscardAllStateObjects(subtaskStatesByOperatorID.values()); + } + + @Override + public long getStateSize() { + long size = 0L; + + for (OperatorSubtaskState subtaskState : subtaskStatesByOperatorID.values()) { + if (subtaskState != null) { + size += subtaskState.getStateSize(); + } + } + + return size; + } + + @Override + public void registerSharedStates(SharedStateRegistry stateRegistry) { + for (OperatorSubtaskState operatorSubtaskState : subtaskStatesByOperatorID.values()) { + if (operatorSubtaskState != null) { + operatorSubtaskState.registerSharedStates(stateRegistry); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TaskStateSnapshot that = (TaskStateSnapshot) o; + + return subtaskStatesByOperatorID.equals(that.subtaskStatesByOperatorID); + } + + @Override + public int hashCode() { + return subtaskStatesByOperatorID.hashCode(); + } + + @Override + public String toString() { + return "TaskOperatorSubtaskStates{" + + "subtaskStatesByOperatorID=" + subtaskStatesByOperatorID + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java index c4cb6bca3a58d..88dd0d4a14f5e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java @@ -18,20 +18,21 @@ package org.apache.flink.runtime.checkpoint; -import org.apache.curator.framework.CuratorFramework; -import org.apache.curator.utils.ZKPaths; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.runtime.state.RetrievableStateHandle; -import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.FlinkException; + +import org.apache.curator.framework.CuratorFramework; +import org.apache.curator.utils.ZKPaths; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; + import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -138,14 +139,13 @@ public boolean requiresExternalizedCheckpoints() { * that the history of checkpoints is consistent. */ @Override - public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { + public void recover() throws Exception { LOG.info("Recovering checkpoints from ZooKeeper."); // Clear local handles in order to prevent duplicates on // recovery. The local handles should reflect the state // of ZooKeeper. completedCheckpoints.clear(); - sharedStateRegistry.clear(); // Get all there is first List, String>> initialCheckpoints; @@ -170,8 +170,6 @@ public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { try { completedCheckpoint = retrieveCompletedCheckpoint(checkpointStateHandle); if (completedCheckpoint != null) { - // Re-register all shared states in the checkpoint. - completedCheckpoint.registerSharedStatesAfterRestored(sharedStateRegistry); completedCheckpoints.add(completedCheckpoint); } } catch (Exception e) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java index c1fcf4f0d4547..12e9c5bd63782 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java @@ -18,8 +18,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; -import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0; -import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0Serializer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.util.Preconditions; import java.util.HashMap; @@ -30,15 +29,20 @@ */ public class SavepointSerializers { + /** If this flag is true, restoring a savepoint fails if it contains legacy state (<= Flink 1.1 format) */ + static boolean FAIL_WHEN_LEGACY_STATE_DETECTED = true; private static final Map> SERIALIZERS = new HashMap<>(2); static { - SERIALIZERS.put(SavepointV0.VERSION, SavepointV0Serializer.INSTANCE); SERIALIZERS.put(SavepointV1.VERSION, SavepointV1Serializer.INSTANCE); SERIALIZERS.put(SavepointV2.VERSION, SavepointV2Serializer.INSTANCE); } + private SavepointSerializers() { + throw new AssertionError(); + } + // ------------------------------------------------------------------------ /** @@ -77,4 +81,12 @@ public static SavepointSerializer getSerializer(int version) { } } + /** + * This is only visible as a temporary solution to keep the stateful job migration it cases working from binary + * savepoints that still contain legacy state (<= Flink 1.1). + */ + @VisibleForTesting + public static void setFailWhenLegacyStateDetected(boolean fail) { + FAIL_WHEN_LEGACY_STATE_DETECTED = fail; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStore.java index 7beb1b8dd6ace..586df57a27568 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStore.java @@ -36,6 +36,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java index f67d54ca15149..c26c983fb93a2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.util.Preconditions; import java.io.DataInputStream; import java.io.DataOutputStream; @@ -59,7 +60,6 @@ public class SavepointV1Serializer implements SavepointSerializer { private static final byte KEY_GROUPS_HANDLE = 3; private static final byte PARTITIONABLE_OPERATOR_STATE_HANDLE = 4; - public static final SavepointV1Serializer INSTANCE = new SavepointV1Serializer(); private SavepointV1Serializer() { @@ -130,20 +130,15 @@ public void serializeOld(SavepointV1 savepoint, DataOutputStream dos) throws IOE private static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException { - dos.writeLong(-1); - - ChainedStateHandle nonPartitionableState = subtaskState.getLegacyOperatorState(); + //backwards compatibility, do not remove + dos.writeLong(-1L); - int len = nonPartitionableState != null ? nonPartitionableState.getLength() : 0; - dos.writeInt(len); - for (int i = 0; i < len; ++i) { - StreamStateHandle stateHandle = nonPartitionableState.get(i); - serializeStreamStateHandle(stateHandle, dos); - } + //backwards compatibility (number of legacy state handles), do not remove + dos.writeInt(0); ChainedStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); - len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0; + int len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0; dos.writeInt(len); for (int i = 0; i < len; ++i) { OperatorStateHandle stateHandle = operatorStateBackend.get(i); @@ -171,12 +166,19 @@ private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws long ignoredDuration = dis.readLong(); int len = dis.readInt(); - List nonPartitionableState = new ArrayList<>(len); - for (int i = 0; i < len; ++i) { - StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis); - nonPartitionableState.add(streamStateHandle); - } + if (SavepointSerializers.FAIL_WHEN_LEGACY_STATE_DETECTED) { + Preconditions.checkState(len == 0, + "Legacy state (from Flink <= 1.1, created through the 'Checkpointed' interface) is " + + "no longer supported starting from Flink 1.4. Please rewrite your job to use " + + "'CheckpointedFunction' instead!"); + + } else { + for (int i = 0; i < len; ++i) { + // absorb bytes from stream and ignore result + deserializeStreamStateHandle(dis); + } + } len = dis.readInt(); List operatorStateBackend = new ArrayList<>(len); @@ -196,9 +198,6 @@ private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis); - ChainedStateHandle nonPartitionableStateChain = - new ChainedStateHandle<>(nonPartitionableState); - ChainedStateHandle operatorStateBackendChain = new ChainedStateHandle<>(operatorStateBackend); @@ -206,7 +205,6 @@ private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws new ChainedStateHandle<>(operatorStateStream); return new SubtaskState( - nonPartitionableStateChain, operatorStateBackendChain, operatorStateStreamChain, keyedStateBackend, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java index bd364a28958b9..9e406dfe44c64 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java @@ -29,7 +29,6 @@ import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.Preconditions; import java.util.Collection; @@ -207,9 +206,6 @@ public static Savepoint convertToOperatorStateSavepointV2( continue; } - @SuppressWarnings("deprecation") - ChainedStateHandle nonPartitionedState = - subtaskState.getLegacyOperatorState(); ChainedStateHandle partitioneableState = subtaskState.getManagedOperatorState(); ChainedStateHandle rawOperatorState = @@ -240,7 +236,6 @@ public static Savepoint convertToOperatorStateSavepointV2( } OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( - nonPartitionedState != null ? nonPartitionedState.get(operatorIndex) : null, partitioneableState != null ? partitioneableState.get(operatorIndex) : null, rawOperatorState != null ? rawOperatorState.get(operatorIndex) : null, managedKeyedState, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java index 4cbbfcfba8b5b..5636a52aba71e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.util.Preconditions; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -240,19 +241,26 @@ private MasterState deserializeMasterState(DataInputStream dis) throws IOExcepti // task state (de)serialization methods // ------------------------------------------------------------------------ + private static T extractSingleton(Collection collection) { + if (collection == null || collection.isEmpty()) { + return null; + } + + if (collection.size() == 1) { + return collection.iterator().next(); + } else { + throw new IllegalStateException("Expected singleton collection, but found size: " + collection.size()); + } + } + private static void serializeSubtaskState(OperatorSubtaskState subtaskState, DataOutputStream dos) throws IOException { dos.writeLong(-1); - StreamStateHandle nonPartitionableState = subtaskState.getLegacyOperatorState(); - - int len = nonPartitionableState != null ? 1 : 0; + int len = 0; dos.writeInt(len); - if (len == 1) { - serializeStreamStateHandle(nonPartitionableState, dos); - } - OperatorStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); + OperatorStateHandle operatorStateBackend = extractSingleton(subtaskState.getManagedOperatorState()); len = operatorStateBackend != null ? 1 : 0; dos.writeInt(len); @@ -260,7 +268,7 @@ private static void serializeSubtaskState(OperatorSubtaskState subtaskState, Dat serializeOperatorStateHandle(operatorStateBackend, dos); } - OperatorStateHandle operatorStateFromStream = subtaskState.getRawOperatorState(); + OperatorStateHandle operatorStateFromStream = extractSingleton(subtaskState.getRawOperatorState()); len = operatorStateFromStream != null ? 1 : 0; dos.writeInt(len); @@ -268,19 +276,31 @@ private static void serializeSubtaskState(OperatorSubtaskState subtaskState, Dat serializeOperatorStateHandle(operatorStateFromStream, dos); } - KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); + KeyedStateHandle keyedStateBackend = extractSingleton(subtaskState.getManagedKeyedState()); serializeKeyedStateHandle(keyedStateBackend, dos); - KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState(); + KeyedStateHandle keyedStateStream = extractSingleton(subtaskState.getRawKeyedState()); serializeKeyedStateHandle(keyedStateStream, dos); } private static OperatorSubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException { - // Duration field has been removed from SubtaskState + // Duration field has been removed from SubtaskState, do not remove long ignoredDuration = dis.readLong(); + // for compatibility, do not remove int len = dis.readInt(); - StreamStateHandle nonPartitionableState = len == 0 ? null : deserializeStreamStateHandle(dis); + + if (SavepointSerializers.FAIL_WHEN_LEGACY_STATE_DETECTED) { + Preconditions.checkState(len == 0, + "Legacy state (from Flink <= 1.1, created through the 'Checkpointed' interface) is " + + "no longer supported starting from Flink 1.4. Please rewrite your job to use " + + "'CheckpointedFunction' instead!"); + } else { + for (int i = 0; i < len; ++i) { + // absorb bytes from stream and ignore result + deserializeStreamStateHandle(dis); + } + } len = dis.readInt(); OperatorStateHandle operatorStateBackend = len == 0 ? null : deserializeOperatorStateHandle(dis); @@ -293,7 +313,6 @@ private static OperatorSubtaskState deserializeSubtaskState(DataInputStream dis) KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis); return new OperatorSubtaskState( - nonPartitionableState, operatorStateBackend, operatorStateStream, keyedStateBackend, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClient.java index 19f0e2c635bf3..425461cd58bb7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClient.java @@ -234,7 +234,7 @@ public static ClassLoader retrieveClassLoader( int pos = 0; for (BlobKey blobKey : props.requiredJarFiles()) { try { - allURLs[pos++] = blobClient.getURL(blobKey); + allURLs[pos++] = blobClient.getFile(jobID, blobKey).toURI().toURL(); } catch (Exception e) { try { blobClient.close(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorConnectionTimeoutException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorConnectionTimeoutException.java index 72a56585f98ca..74a4e1f9cf783 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorConnectionTimeoutException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorConnectionTimeoutException.java @@ -18,11 +18,13 @@ package org.apache.flink.runtime.client; +import org.apache.flink.util.FlinkException; + /** * Exception which is thrown when the {@link JobClientActor} wants to submit a job to * the job manager but has not found one after a given timeout interval. */ -public class JobClientActorConnectionTimeoutException extends Exception { +public class JobClientActorConnectionTimeoutException extends FlinkException { private static final long serialVersionUID = 2287747430528388637L; public JobClientActorConnectionTimeoutException(String msg) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorRegistrationTimeoutException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorRegistrationTimeoutException.java index e57d1b4f40e19..499c9e40902a5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorRegistrationTimeoutException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorRegistrationTimeoutException.java @@ -18,11 +18,13 @@ package org.apache.flink.runtime.client; +import org.apache.flink.util.FlinkException; + /** * Exception which is thrown by the {@link JobClientActor} if it has not heard back from the job * manager after it has attempted to register for a job within a given timeout interval. */ -public class JobClientActorRegistrationTimeoutException extends Exception { +public class JobClientActorRegistrationTimeoutException extends FlinkException { private static final long serialVersionUID = 8762463142030454853L; public JobClientActorRegistrationTimeoutException(String msg) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorSubmissionTimeoutException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorSubmissionTimeoutException.java index 2d394621abd53..a56e38990164c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorSubmissionTimeoutException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobClientActorSubmissionTimeoutException.java @@ -18,11 +18,13 @@ package org.apache.flink.runtime.client; +import org.apache.flink.util.FlinkException; + /** * Exception which is thrown by the {@link JobClientActor} if it has not heard back from the job * manager after it has submitted a job to it within a given timeout interval. */ -public class JobClientActorSubmissionTimeoutException extends Exception { +public class JobClientActorSubmissionTimeoutException extends FlinkException { private static final long serialVersionUID = 8762463142030454853L; public JobClientActorSubmissionTimeoutException(String msg) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobExecutionException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobExecutionException.java index 7c6a4afe64eff..47eaaf53eda60 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobExecutionException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/client/JobExecutionException.java @@ -19,13 +19,14 @@ package org.apache.flink.runtime.client; import org.apache.flink.api.common.JobID; +import org.apache.flink.util.FlinkException; /** * This exception is the base exception for all exceptions that denote any failure during * the execution of a job. The JobExecutionException and its subclasses are thrown by * the {@link JobClient}. */ -public class JobExecutionException extends Exception { +public class JobExecutionException extends FlinkException { private static final long serialVersionUID = 2818087325120827525L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/clusterframework/BootstrapTools.java b/flink-runtime/src/main/java/org/apache/flink/runtime/clusterframework/BootstrapTools.java index d24a3d07cdb6e..b86054f5ca7d9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/clusterframework/BootstrapTools.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/clusterframework/BootstrapTools.java @@ -371,7 +371,12 @@ public static String getTaskManagerShellCommand( Class mainClass) { final Map startCommandValues = new HashMap<>(); - startCommandValues.put("java", "$JAVA_HOME/bin/java"); + if (System.getProperty("os.name").toLowerCase().startsWith("windows")){ + startCommandValues.put("java", "%JAVA_HOME%/bin/java"); + } + else { + startCommandValues.put("java", "$JAVA_HOME/bin/java"); + } ArrayList params = new ArrayList<>(); params.add(String.format("-Xms%dm", tmParams.taskManagerHeapSizeMB())); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/FutureUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/FutureUtils.java index 5c6439d3800f6..b982c8e86f6c7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/FutureUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/concurrent/FutureUtils.java @@ -26,13 +26,15 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import java.util.function.Supplier; import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; @@ -48,6 +50,7 @@ public class FutureUtils { // retrying operations // ------------------------------------------------------------------------ + /** * Retry the given operation the given number of times in case of a failure. * @@ -58,35 +61,135 @@ public class FutureUtils { * @return Future containing either the result of the operation or a {@link RetryException} */ public static CompletableFuture retry( - final Callable> operation, - final int retries, - final Executor executor) { + final Supplier> operation, + final int retries, + final Executor executor) { + + final CompletableFuture resultFuture = new CompletableFuture<>(); - CompletableFuture operationResultFuture; + retryOperation(resultFuture, operation, retries, executor); + + return resultFuture; + } - try { - operationResultFuture = operation.call(); - } catch (Exception e) { - return FutureUtils.completedExceptionally(new RetryException("Could not execute the provided operation.", e)); + /** + * Helper method which retries the provided operation in case of a failure. + * + * @param resultFuture to complete + * @param operation to retry + * @param retries until giving up + * @param executor to run the futures + * @param type of the future's result + */ + private static void retryOperation( + final CompletableFuture resultFuture, + final Supplier> operation, + final int retries, + final Executor executor) { + + if (!resultFuture.isDone()) { + final CompletableFuture operationFuture = operation.get(); + + operationFuture.whenCompleteAsync( + (t, throwable) -> { + if (throwable != null) { + if (throwable instanceof CancellationException) { + resultFuture.completeExceptionally(new RetryException("Operation future was cancelled.", throwable)); + } else { + if (retries > 0) { + retryOperation( + resultFuture, + operation, + retries - 1, + executor); + } else { + resultFuture.completeExceptionally(new RetryException("Could not complete the operation. Number of retries " + + "has been exhausted.", throwable)); + } + } + } else { + resultFuture.complete(t); + } + }, + executor); + + resultFuture.whenComplete( + (t, throwable) -> operationFuture.cancel(false)); } + } + + /** + * Retry the given operation with the given delay in between failures. + * + * @param operation to retry + * @param retries number of retries + * @param retryDelay delay between retries + * @param scheduledExecutor executor to be used for the retry operation + * @param type of the result + * @return Future which retries the given operation a given amount of times and delays the retry in case of failures + */ + public static CompletableFuture retryWithDelay( + final Supplier> operation, + final int retries, + final Time retryDelay, + final ScheduledExecutor scheduledExecutor) { + + final CompletableFuture resultFuture = new CompletableFuture<>(); + + retryOperationWithDelay( + resultFuture, + operation, + retries, + retryDelay, + scheduledExecutor); + + return resultFuture; + } - return operationResultFuture.handleAsync( - (t, throwable) -> { - if (throwable != null) { - if (retries > 0) { - return retry(operation, retries - 1, executor); + private static void retryOperationWithDelay( + final CompletableFuture resultFuture, + final Supplier> operation, + final int retries, + final Time retryDelay, + final ScheduledExecutor scheduledExecutor) { + + if (!resultFuture.isDone()) { + final CompletableFuture operationResultFuture = operation.get(); + + operationResultFuture.whenCompleteAsync( + (t, throwable) -> { + if (throwable != null) { + if (throwable instanceof CancellationException) { + resultFuture.completeExceptionally(new RetryException("Operation future was cancelled.", throwable)); + } else { + if (retries > 0) { + final ScheduledFuture scheduledFuture = scheduledExecutor.schedule( + () -> retryOperationWithDelay(resultFuture, operation, retries - 1, retryDelay, scheduledExecutor), + retryDelay.toMilliseconds(), + TimeUnit.MILLISECONDS); + + resultFuture.whenComplete( + (innerT, innerThrowable) -> scheduledFuture.cancel(false)); + } else { + resultFuture.completeExceptionally(new RetryException("Could not complete the operation. Number of retries " + + "has been exhausted.", throwable)); + } + } } else { - return FutureUtils.completedExceptionally(new RetryException("Could not complete the operation. Number of retries " + - "has been exhausted.", throwable)); + resultFuture.complete(t); } - } else { - return CompletableFuture.completedFuture(t); - } - }, - executor) - .thenCompose(value -> value); + }, + scheduledExecutor); + + resultFuture.whenComplete( + (t, throwable) -> operationResultFuture.cancel(false)); + } } + /** + * Exception with which the returned future is completed if the {@link #retry(Supplier, int, Executor)} + * operation fails. + */ public static class RetryException extends Exception { private static final long serialVersionUID = 3613470781274141862L; @@ -109,14 +212,14 @@ public RetryException(Throwable cause) { // ------------------------------------------------------------------------ /** - * Creates a future that is complete once multiple other futures completed. + * Creates a future that is complete once multiple other futures completed. * The future fails (completes exceptionally) once one of the futures in the * conjunction fails. Upon successful completion, the future returns the * collection of the futures' results. * *

The ConjunctFuture gives access to how many Futures in the conjunction have already - * completed successfully, via {@link ConjunctFuture#getNumFuturesCompleted()}. - * + * completed successfully, via {@link ConjunctFuture#getNumFuturesCompleted()}. + * * @param futures The futures that make up the conjunction. No null entries are allowed. * @return The ConjunctFuture that completes once all given futures are complete (or one fails). */ @@ -158,7 +261,7 @@ public static ConjunctFuture waitForAll(CollectionThe advantage of using the ConjunctFuture over chaining all the futures (such as via * {@link CompletableFuture#thenCombine(CompletionStage, BiFunction)} )}) is that ConjunctFuture * also tracks how many of the Futures are already complete. @@ -183,16 +286,16 @@ public abstract static class ConjunctFuture extends CompletableFuture { */ private static class ResultConjunctFuture extends ConjunctFuture> { - /** The total number of futures in the conjunction */ + /** The total number of futures in the conjunction. */ private final int numTotal; - /** The next free index in the results arrays */ + /** The next free index in the results arrays. */ private final AtomicInteger nextIndex = new AtomicInteger(0); - /** The number of futures in the conjunction that are already complete */ + /** The number of futures in the conjunction that are already complete. */ private final AtomicInteger numCompleted = new AtomicInteger(0); - /** The set of collected results so far */ + /** The set of collected results so far. */ private volatile T[] results; /** The function that is attached to all futures in the conjunction. Once a future @@ -215,7 +318,7 @@ final void handleCompletedFuture(T value, Throwable throwable) { @SuppressWarnings("unchecked") ResultConjunctFuture(int numTotal) { this.numTotal = numTotal; - results = (T[])new Object[numTotal]; + results = (T[]) new Object[numTotal]; } @Override @@ -235,13 +338,13 @@ public int getNumFuturesCompleted() { */ private static final class WaitingConjunctFuture extends ConjunctFuture { - /** Number of completed futures */ + /** Number of completed futures. */ private final AtomicInteger numCompleted = new AtomicInteger(0); - /** Total number of futures to wait on */ + /** Total number of futures to wait on. */ private final int numTotal; - /** Method which increments the atomic completion counter and completes or fails the WaitingFutureImpl */ + /** Method which increments the atomic completion counter and completes or fails the WaitingFutureImpl. */ private void handleCompletedFuture(Object ignored, Throwable throwable) { if (throwable == null) { if (numTotal == numCompleted.incrementAndGet()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 0578b787290d9..1fa5eb5b51484 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -18,11 +18,11 @@ package org.apache.flink.runtime.deployment; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; import org.apache.flink.runtime.executiongraph.TaskInformation; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; @@ -64,7 +64,7 @@ public final class TaskDeploymentDescriptor implements Serializable { private final int targetSlotNumber; /** State handles for the sub task. */ - private final TaskStateHandles taskStateHandles; + private final TaskStateSnapshot taskStateHandles; public TaskDeploymentDescriptor( SerializedValue serializedJobInformation, @@ -74,7 +74,7 @@ public TaskDeploymentDescriptor( int subtaskIndex, int attemptNumber, int targetSlotNumber, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors) { @@ -153,7 +153,7 @@ public Collection getInputGates() { return inputGates; } - public TaskStateHandles getTaskStateHandles() { + public TaskStateSnapshot getTaskStateHandles() { return taskStateHandles; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java index 9fc1fc4098179..8977415a8f07d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java @@ -23,7 +23,6 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.client.JobSubmissionException; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.concurrent.FutureUtils; @@ -35,11 +34,15 @@ import org.apache.flink.runtime.jobmanager.SubmittedJobGraph; import org.apache.flink.runtime.jobmanager.SubmittedJobGraphStore; import org.apache.flink.runtime.jobmaster.JobManagerRunner; +import org.apache.flink.runtime.jobmaster.JobManagerServices; +import org.apache.flink.runtime.leaderelection.LeaderContender; +import org.apache.flink.runtime.leaderelection.LeaderElectionService; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.rpc.FatalErrorHandler; -import org.apache.flink.runtime.rpc.RpcEndpoint; +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.runtime.rpc.RpcUtils; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FlinkException; import org.apache.flink.util.Preconditions; @@ -48,6 +51,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CompletableFuture; /** @@ -56,7 +60,7 @@ * the jobs and to recover them in case of a master failure. Furthermore, it knows * about the state of the Flink session cluster. */ -public abstract class Dispatcher extends RpcEndpoint implements DispatcherGateway { +public abstract class Dispatcher extends FencedRpcEndpoint implements DispatcherGateway, LeaderContender { public static final String DISPATCHER_NAME = "dispatcher"; @@ -66,7 +70,7 @@ public abstract class Dispatcher extends RpcEndpoint implements DispatcherGatewa private final RunningJobsRegistry runningJobsRegistry; private final HighAvailabilityServices highAvailabilityServices; - private final BlobServer blobServer; + private final JobManagerServices jobManagerServices; private final HeartbeatServices heartbeatServices; private final MetricRegistry metricRegistry; @@ -74,6 +78,8 @@ public abstract class Dispatcher extends RpcEndpoint implements DispatcherGatewa private final Map jobManagerRunners; + private final LeaderElectionService leaderElectionService; + protected Dispatcher( RpcService rpcService, String endpointId, @@ -83,11 +89,13 @@ protected Dispatcher( HeartbeatServices heartbeatServices, MetricRegistry metricRegistry, FatalErrorHandler fatalErrorHandler) throws Exception { - super(rpcService, endpointId); + super(rpcService, endpointId, DispatcherId.generate()); this.configuration = Preconditions.checkNotNull(configuration); this.highAvailabilityServices = Preconditions.checkNotNull(highAvailabilityServices); - this.blobServer = Preconditions.checkNotNull(blobServer); + this.jobManagerServices = JobManagerServices.fromConfiguration( + configuration, + Preconditions.checkNotNull(blobServer)); this.heartbeatServices = Preconditions.checkNotNull(heartbeatServices); this.metricRegistry = Preconditions.checkNotNull(metricRegistry); this.fatalErrorHandler = Preconditions.checkNotNull(fatalErrorHandler); @@ -96,6 +104,8 @@ protected Dispatcher( this.runningJobsRegistry = highAvailabilityServices.getRunningJobsRegistry(); jobManagerRunners = new HashMap<>(16); + + leaderElectionService = highAvailabilityServices.getDispatcherLeaderElectionService(); } //------------------------------------------------------ @@ -104,13 +114,15 @@ protected Dispatcher( @Override public void postStop() throws Exception { - Exception exception = null; - // stop all currently running JobManagerRunners - for (JobManagerRunner jobManagerRunner : jobManagerRunners.values()) { - jobManagerRunner.shutdown(); - } + Throwable exception = null; - jobManagerRunners.clear(); + clearState(); + + try { + jobManagerServices.shutdown(); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } try { submittedJobGraphStore.stop(); @@ -118,6 +130,12 @@ public void postStop() throws Exception { exception = ExceptionUtils.firstOrSuppressed(e, exception); } + try { + leaderElectionService.stop(); + } catch (Exception e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } + try { super.postStop(); } catch (Exception e) { @@ -129,12 +147,20 @@ public void postStop() throws Exception { } } + @Override + public void start() throws Exception { + super.start(); + + leaderElectionService.start(this); + } + //------------------------------------------------------ // RPCs //------------------------------------------------------ @Override public CompletableFuture submitJob(JobGraph jobGraph, Time timeout) { + final JobID jobId = jobGraph.getJobID(); log.info("Submitting job {} ({}).", jobGraph.getJobID(), jobGraph.getName()); @@ -167,8 +193,8 @@ public CompletableFuture submitJob(JobGraph jobGraph, Time timeout) configuration, getRpcService(), highAvailabilityServices, - blobServer, heartbeatServices, + jobManagerServices, metricRegistry, new DispatcherOnCompleteActions(jobGraph.getJobID()), fatalErrorHandler); @@ -225,18 +251,140 @@ private void removeJob(JobID jobId, boolean cleanupHA) throws Exception { // TODO: remove job related files from blob server } + /** + * Clears the state of the dispatcher. + * + *

The state are all currently running jobs. + */ + private void clearState() throws Exception { + Exception exception = null; + + // stop all currently running JobManager since they run in the same process + for (JobManagerRunner jobManagerRunner : jobManagerRunners.values()) { + try { + jobManagerRunner.shutdown(); + } catch (Exception e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } + } + + jobManagerRunners.clear(); + + if (exception != null) { + throw exception; + } + } + + /** + * Recovers all jobs persisted via the submitted job graph store. + */ + private void recoverJobs() { + log.info("Recovering all persisted jobs."); + + getRpcService().execute( + () -> { + final Collection jobIds; + + try { + jobIds = submittedJobGraphStore.getJobIds(); + } catch (Exception e) { + log.error("Could not recover job ids from the submitted job graph store. Aborting recovery.", e); + return; + } + + for (JobID jobId : jobIds) { + try { + SubmittedJobGraph submittedJobGraph = submittedJobGraphStore.recoverJobGraph(jobId); + + runAsync(() -> submitJob(submittedJobGraph.getJobGraph(), RpcUtils.INF_TIMEOUT)); + } catch (Exception e) { + log.error("Could not recover the job graph for " + jobId + '.', e); + } + } + }); + } + + private void onFatalError(Throwable throwable) { + log.error("Fatal error occurred in dispatcher {}.", getAddress(), throwable); + fatalErrorHandler.onFatalError(throwable); + } + protected abstract JobManagerRunner createJobManagerRunner( ResourceID resourceId, JobGraph jobGraph, Configuration configuration, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, - BlobService blobService, HeartbeatServices heartbeatServices, + JobManagerServices jobManagerServices, MetricRegistry metricRegistry, OnCompletionActions onCompleteActions, FatalErrorHandler fatalErrorHandler) throws Exception; + //------------------------------------------------------ + // Leader contender + //------------------------------------------------------ + + /** + * Callback method when current resourceManager is granted leadership. + * + * @param newLeaderSessionID unique leadershipID + */ + @Override + public void grantLeadership(final UUID newLeaderSessionID) { + runAsyncWithoutFencing( + () -> { + final DispatcherId dispatcherId = new DispatcherId(newLeaderSessionID); + + log.info("Dispatcher {} was granted leadership with fencing token {}", getAddress(), dispatcherId); + + // clear the state if we've been the leader before + if (getFencingToken() != null) { + try { + clearState(); + } catch (Exception e) { + log.warn("Could not properly clear the Dispatcher state while granting leadership.", e); + } + } + + setFencingToken(dispatcherId); + + // confirming the leader session ID might be blocking, + getRpcService().execute( + () -> leaderElectionService.confirmLeaderSessionID(newLeaderSessionID)); + + recoverJobs(); + }); + } + + /** + * Callback method when current resourceManager loses leadership. + */ + @Override + public void revokeLeadership() { + runAsyncWithoutFencing( + () -> { + log.info("Dispatcher {} was revoked leadership.", getAddress()); + try { + clearState(); + } catch (Exception e) { + log.warn("Could not properly clear the Dispatcher state while revoking leadership.", e); + } + + setFencingToken(DispatcherId.generate()); + }); + } + + /** + * Handles error occurring in the leader election service. + * + * @param exception Exception being thrown in the leader election service + */ + @Override + public void handleError(final Exception exception) { + onFatalError(new DispatcherException("Received an error from the LeaderElectionService.", exception)); + } + //------------------------------------------------------ // Utility classes //------------------------------------------------------ @@ -253,48 +401,40 @@ private DispatcherOnCompleteActions(JobID jobId) { public void jobFinished(JobExecutionResult result) { log.info("Job {} finished.", jobId); - runAsync(new Runnable() { - @Override - public void run() { + runAsync(() -> { try { removeJob(jobId, true); } catch (Exception e) { log.warn("Could not properly remove job {} from the dispatcher.", jobId, e); } - } - }); + }); } @Override public void jobFailed(Throwable cause) { log.info("Job {} failed.", jobId); - runAsync(new Runnable() { - @Override - public void run() { + runAsync(() -> { try { removeJob(jobId, true); } catch (Exception e) { log.warn("Could not properly remove job {} from the dispatcher.", jobId, e); } - } - }); + }); } @Override public void jobFinishedByOther() { log.info("Job {} was finished by other JobManager.", jobId); - runAsync(new Runnable() { - @Override - public void run() { + runAsync( + () -> { try { removeJob(jobId, false); } catch (Exception e) { log.warn("Could not properly remove job {} from the dispatcher.", jobId, e); } - } - }); + }); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherException.java new file mode 100644 index 0000000000000..cf4a49300a5ca --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.dispatcher; + +import org.apache.flink.util.FlinkException; + +/** + * Base class for {@link Dispatcher} related exceptions. + */ +public class DispatcherException extends FlinkException { + private static final long serialVersionUID = 3781733042984381286L; + + public DispatcherException(String message) { + super(message); + } + + public DispatcherException(Throwable cause) { + super(cause); + } + + public DispatcherException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherGateway.java index 33b8a42a364d4..09254c3ee4756 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherGateway.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.messages.Acknowledge; -import org.apache.flink.runtime.rpc.RpcGateway; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; import java.util.Collection; @@ -31,7 +31,7 @@ /** * Gateway for the Dispatcher component. */ -public interface DispatcherGateway extends RpcGateway { +public interface DispatcherGateway extends FencedRpcGateway { /** * Submit a job to the dispatcher. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherId.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherId.java new file mode 100644 index 0000000000000..e5630904ce3ed --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/DispatcherId.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.dispatcher; + +import org.apache.flink.util.AbstractID; + +import java.util.UUID; + +/** + * Fencing token of the {@link Dispatcher}. + */ +public class DispatcherId extends AbstractID { + + private static final long serialVersionUID = -1654056277003743966L; + + public DispatcherId(byte[] bytes) { + super(bytes); + } + + public DispatcherId(long lowerPart, long upperPart) { + super(lowerPart, upperPart); + } + + public DispatcherId(AbstractID id) { + super(id); + } + + public DispatcherId() {} + + public DispatcherId(UUID uuid) { + this(uuid.getLeastSignificantBits(), uuid.getMostSignificantBits()); + } + + public UUID toUUID() { + return new UUID(getUpperPart(), getLowerPart()); + } + + public static DispatcherId generate() { + return new DispatcherId(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/StandaloneDispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/StandaloneDispatcher.java index 54d698ef6c2d9..d6d82b1bd42b9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/StandaloneDispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/StandaloneDispatcher.java @@ -20,13 +20,13 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobmanager.OnCompletionActions; import org.apache.flink.runtime.jobmaster.JobManagerRunner; +import org.apache.flink.runtime.jobmaster.JobManagerServices; import org.apache.flink.runtime.jobmaster.JobMaster; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.rpc.FatalErrorHandler; @@ -65,8 +65,8 @@ protected JobManagerRunner createJobManagerRunner( Configuration configuration, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, - BlobService blobService, HeartbeatServices heartbeatServices, + JobManagerServices jobManagerServices, MetricRegistry metricRegistry, OnCompletionActions onCompleteActions, FatalErrorHandler fatalErrorHandler) throws Exception { @@ -77,8 +77,8 @@ protected JobManagerRunner createJobManagerRunner( configuration, rpcService, highAvailabilityServices, - blobService, heartbeatServices, + jobManagerServices, metricRegistry, onCompleteActions, fatalErrorHandler); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java old mode 100644 new mode 100755 index 2538f209c0ca0..861355fd3490b --- a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.GlobalConfiguration; import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.core.fs.FileSystem; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.clusterframework.BootstrapTools; @@ -46,7 +47,9 @@ import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import scala.concurrent.duration.FiniteDuration; @@ -69,6 +72,8 @@ public abstract class ClusterEntrypoint implements FatalErrorHandler { private final Configuration configuration; + private final CompletableFuture terminationFuture; + @GuardedBy("lock") private MetricRegistry metricRegistry = null; @@ -86,12 +91,19 @@ public abstract class ClusterEntrypoint implements FatalErrorHandler { protected ClusterEntrypoint(Configuration configuration) { this.configuration = Preconditions.checkNotNull(configuration); + this.terminationFuture = new CompletableFuture<>(); + } + + public CompletableFuture getTerminationFuture() { + return terminationFuture; } protected void startCluster() { LOG.info("Starting {}.", getClass().getSimpleName()); try { + installDefaultFileSystem(configuration); + SecurityContext securityContext = installSecurityContext(configuration); securityContext.runSecured(new Callable() { @@ -115,6 +127,17 @@ public Void call() throws Exception { } } + protected void installDefaultFileSystem(Configuration configuration) throws Exception { + LOG.info("Install default filesystem."); + + try { + FileSystem.setDefaultScheme(configuration); + } catch (IOException e) { + throw new IOException("Error while setting the default " + + "filesystem scheme from configuration.", e); + } + } + protected SecurityContext installSecurityContext(Configuration configuration) throws Exception { LOG.info("Install security context."); @@ -184,9 +207,18 @@ protected MetricRegistry createMetricRegistry(Configuration configuration) { } protected void shutDown(boolean cleanupHaData) throws FlinkException { + LOG.info("Stopping {}.", getClass().getSimpleName()); + Throwable exception = null; synchronized (lock) { + + try { + stopClusterComponents(cleanupHaData); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } + if (metricRegistry != null) { try { metricRegistry.shutdown(); @@ -222,6 +254,8 @@ protected void shutDown(boolean cleanupHaData) throws FlinkException { exception = ExceptionUtils.firstOrSuppressed(t, exception); } } + + terminationFuture.complete(true); } if (exception != null) { @@ -244,6 +278,9 @@ protected abstract void startClusterComponents( HeartbeatServices heartbeatServices, MetricRegistry metricRegistry) throws Exception; + protected void stopClusterComponents(boolean cleanupHaData) throws Exception { + } + protected static ClusterConfiguration parseArguments(String[] args) { ParameterTool parameterTool = ParameterTool.fromArgs(args); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/JobClusterEntrypoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/JobClusterEntrypoint.java index 87281865e9684..124c6c62fcaba 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/JobClusterEntrypoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/JobClusterEntrypoint.java @@ -22,13 +22,13 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobmanager.OnCompletionActions; import org.apache.flink.runtime.jobmaster.JobManagerRunner; +import org.apache.flink.runtime.jobmaster.JobManagerServices; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.ResourceManager; import org.apache.flink.runtime.rpc.FatalErrorHandler; @@ -44,6 +44,8 @@ public abstract class JobClusterEntrypoint extends ClusterEntrypoint { private ResourceManager resourceManager; + private JobManagerServices jobManagerServices; + private JobManagerRunner jobManagerRunner; public JobClusterEntrypoint(Configuration configuration) { @@ -68,12 +70,14 @@ protected void startClusterComponents( metricRegistry, this); + jobManagerServices = JobManagerServices.fromConfiguration(configuration, blobServer); + jobManagerRunner = createJobManagerRunner( configuration, ResourceID.generate(), rpcService, highAvailabilityServices, - blobServer, + jobManagerServices, heartbeatServices, metricRegistry, this); @@ -90,7 +94,7 @@ protected JobManagerRunner createJobManagerRunner( ResourceID resourceId, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, - BlobService blobService, + JobManagerServices jobManagerServices, HeartbeatServices heartbeatServices, MetricRegistry metricRegistry, FatalErrorHandler fatalErrorHandler) throws Exception { @@ -103,15 +107,15 @@ protected JobManagerRunner createJobManagerRunner( configuration, rpcService, highAvailabilityServices, - blobService, heartbeatServices, + jobManagerServices, metricRegistry, new TerminatingOnCompleteActions(jobGraph.getJobID()), fatalErrorHandler); } @Override - protected void shutDown(boolean cleanupHaData) throws FlinkException { + protected void stopClusterComponents(boolean cleanupHaData) throws Exception { Throwable exception = null; if (jobManagerRunner != null) { @@ -122,22 +126,24 @@ protected void shutDown(boolean cleanupHaData) throws FlinkException { } } - if (resourceManager != null) { + if (jobManagerServices != null) { try { - resourceManager.shutDown(); + jobManagerServices.shutdown(); } catch (Throwable t) { exception = ExceptionUtils.firstOrSuppressed(t, exception); } } - try { - super.shutDown(cleanupHaData); - } catch (Throwable t) { - exception = ExceptionUtils.firstOrSuppressed(t, exception); + if (resourceManager != null) { + try { + resourceManager.shutDown(); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } } if (exception != null) { - throw new FlinkException("Could not properly shut down the session cluster entry point.", exception); + throw new FlinkException("Could not properly shut down the job cluster entry point.", exception); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/SessionClusterEntrypoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/SessionClusterEntrypoint.java index 4013e8313c3d9..cea1688f13d3e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/SessionClusterEntrypoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/SessionClusterEntrypoint.java @@ -80,7 +80,7 @@ protected void startClusterComponents( } @Override - protected void shutDown(boolean cleanupHaData) throws FlinkException { + protected void stopClusterComponents(boolean cleanupHaData) throws Exception { Throwable exception = null; if (dispatcher != null) { @@ -99,12 +99,6 @@ protected void shutDown(boolean cleanupHaData) throws FlinkException { } } - try { - super.shutDown(cleanupHaData); - } catch (Throwable t) { - exception = ExceptionUtils.firstOrSuppressed(t, exception); - } - if (exception != null) { throw new FlinkException("Could not properly shut down the session cluster entry point.", exception); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 9e9f7c4c719c4..203ee8547cf42 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.execution; -import java.util.Map; -import java.util.concurrent.Future; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; @@ -28,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; @@ -41,6 +39,9 @@ import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; +import java.util.Map; +import java.util.concurrent.Future; + /** * The Environment gives the code executed in a task access to the task's properties * (such as name, parallelism), the configurations, the data stream readers and writers, @@ -175,7 +176,7 @@ public interface Environment { * @param checkpointMetrics metrics for this checkpoint * @param subtaskState All state handles for the checkpointed state */ - void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState); + void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState); /** * Declines a checkpoint. This tells the checkpoint coordinator that this task will diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManager.java index 0387725db0a0b..c8fc4e4c3a53b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManager.java @@ -23,9 +23,12 @@ import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.util.ExceptionUtils; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + import java.io.IOException; import java.net.URL; import java.util.Arrays; @@ -33,72 +36,52 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.Timer; -import java.util.TimerTask; +import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; /** - * For each job graph that is submitted to the system the library cache manager maintains - * a set of libraries (typically JAR files) which the job requires to run. The library cache manager - * caches library files in order to avoid unnecessary retransmission of data. It is based on a singleton - * programming pattern, so there exists at most one library manager at a time. - *

- * All files registered via {@link #registerJob(JobID, Collection, Collection)} are reference-counted - * and are removed by a timer-based cleanup task if their reference counter is zero. + * Provides facilities to download a set of libraries (typically JAR files) for a job from a + * {@link BlobService} and create a class loader with references to them. */ -public final class BlobLibraryCacheManager extends TimerTask implements LibraryCacheManager { +public class BlobLibraryCacheManager implements LibraryCacheManager { + + private static final Logger LOG = LoggerFactory.getLogger(BlobLibraryCacheManager.class); + + private static final ExecutionAttemptID JOB_ATTEMPT_ID = new ExecutionAttemptID(-1, -1); - private static Logger LOG = LoggerFactory.getLogger(BlobLibraryCacheManager.class); - - private static ExecutionAttemptID JOB_ATTEMPT_ID = new ExecutionAttemptID(-1, -1); - // -------------------------------------------------------------------------------------------- - + /** The global lock to synchronize operations */ private final Object lockObject = new Object(); /** Registered entries per job */ - private final Map cacheEntries = new HashMap(); - - /** Map to store the number of reference to a specific file */ - private final Map blobKeyReferenceCounters = new HashMap(); + private final Map cacheEntries = new HashMap<>(); /** The blob service to download libraries */ private final BlobService blobService; - - private final Timer cleanupTimer; - + // -------------------------------------------------------------------------------------------- - /** - * Creates the blob library cache manager. - * - * @param blobService blob file retrieval service to use - * @param cleanupInterval cleanup interval in milliseconds - */ - public BlobLibraryCacheManager(BlobService blobService, long cleanupInterval) { + public BlobLibraryCacheManager(BlobService blobService) { this.blobService = checkNotNull(blobService); - - // Initializing the clean up task - this.cleanupTimer = new Timer(true); - this.cleanupTimer.schedule(this, cleanupInterval, cleanupInterval); } - // -------------------------------------------------------------------------------------------- - @Override public void registerJob(JobID id, Collection requiredJarFiles, Collection requiredClasspaths) - throws IOException { + throws IOException { registerTask(id, JOB_ATTEMPT_ID, requiredJarFiles, requiredClasspaths); } - + @Override - public void registerTask(JobID jobId, ExecutionAttemptID task, Collection requiredJarFiles, - Collection requiredClasspaths) throws IOException { + public void registerTask( + JobID jobId, + ExecutionAttemptID task, + @Nullable Collection requiredJarFiles, + @Nullable Collection requiredClasspaths) throws IOException { + checkNotNull(jobId, "The JobId must not be null."); checkNotNull(task, "The task execution id must not be null."); @@ -113,43 +96,31 @@ public void registerTask(JobID jobId, ExecutionAttemptID task, Collection> entryIter = blobKeyReferenceCounters.entrySet().iterator(); - - while (entryIter.hasNext()) { - Map.Entry entry = entryIter.next(); - BlobKey key = entry.getKey(); - int references = entry.getValue(); - - try { - if (references <= 0) { - blobService.delete(key); - entryIter.remove(); - } - } catch (Throwable t) { - LOG.warn("Could not delete file with blob key" + key, t); - } - } - } - } - - public int getNumberOfReferenceHolders(JobID jobId) { + int getNumberOfReferenceHolders(JobID jobId) { synchronized (lockObject) { LibraryCacheEntry entry = cacheEntries.get(jobId); return entry == null ? 0 : entry.getNumberOfReferenceHolders(); } } - - int getNumberOfCachedLibraries() { - return blobKeyReferenceCounters.size(); - } - - private URL registerReferenceToBlobKeyAndGetURL(BlobKey key) throws IOException { - // it is important that we fetch the URL before increasing the counter. - // in case the URL cannot be created (failed to fetch the BLOB), we have no stale counter - try { - URL url = blobService.getURL(key); - Integer references = blobKeyReferenceCounters.get(key); - int newReferences = references == null ? 1 : references + 1; - blobKeyReferenceCounters.put(key, newReferences); - - return url; - } - catch (IOException e) { - throw new IOException("Cannot get library with hash " + key, e); - } + /** + * Returns the number of registered jobs that this library cache manager handles. + * + * @return number of jobs (irrespective of the actual number of tasks per job) + */ + int getNumberOfManagedJobs() { + // no synchronisation necessary + return cacheEntries.size(); } - - private void unregisterReferenceToBlobKey(BlobKey key) { - Integer references = blobKeyReferenceCounters.get(key); - if (references != null) { - int newReferences = Math.max(references - 1, 0); - blobKeyReferenceCounters.put(key, newReferences); - } - else { - // make sure we have an entry in any case, that the cleanup timer removes any - // present libraries - blobKeyReferenceCounters.put(key, 0); + + @Override + public void shutdown() { + synchronized (lockObject) { + for (LibraryCacheEntry entry : cacheEntries.values()) { + entry.releaseClassLoader(); + } } } - // -------------------------------------------------------------------------------------------- /** * An entry in the per-job library cache. Tracks which execution attempts * still reference the libraries. Once none reference it any more, the - * libraries can be cleaned up. + * class loaders can be cleaned up. */ private static class LibraryCacheEntry { - + private final FlinkUserCodeClassLoader classLoader; - + private final Set referenceHolders; - + /** + * Set of BLOB keys used for a previous job/task registration. + * + *

The purpose of this is to make sure, future registrations do not differ in content as + * this is a contract of the {@link BlobLibraryCacheManager}. + */ private final Set libraries; - - - public LibraryCacheEntry(Collection libraries, URL[] libraryURLs, ExecutionAttemptID initialReference) { + + /** + * Set of class path URLs used for a previous job/task registration. + * + *

The purpose of this is to make sure, future registrations do not differ in content as + * this is a contract of the {@link BlobLibraryCacheManager}. + */ + private final Set classPaths; + + /** + * Creates a cache entry for a flink class loader with the given libraryURLs. + * + * @param requiredLibraries + * BLOB keys required by the class loader (stored for ensuring consistency among different + * job/task registrations) + * @param requiredClasspaths + * class paths required by the class loader (stored for ensuring consistency among + * different job/task registrations) + * @param libraryURLs + * complete list of URLs to use for the class loader (includes references to the + * requiredLibraries and requiredClasspaths) + * @param initialReference + * reference holder ID + */ + LibraryCacheEntry( + Collection requiredLibraries, + Collection requiredClasspaths, + URL[] libraryURLs, + ExecutionAttemptID initialReference) { + this.classLoader = new FlinkUserCodeClassLoader(libraryURLs); - this.libraries = new HashSet<>(libraries); + // NOTE: do not store the class paths, i.e. URLs, into a set for performance reasons + // see http://findbugs.sourceforge.net/bugDescriptions.html#DMI_COLLECTION_OF_URLS + // -> alternatively, compare their string representation + this.classPaths = new HashSet<>(requiredClasspaths.size()); + for (URL url : requiredClasspaths) { + classPaths.add(url.toString()); + } + this.libraries = new HashSet<>(requiredLibraries); this.referenceHolders = new HashSet<>(); this.referenceHolders.add(initialReference); } - - + public ClassLoader getClassLoader() { return classLoader; } - + public Set getLibraries() { return libraries; } - - public void register(ExecutionAttemptID task, Collection keys) { - if (!libraries.containsAll(keys)) { + + public void register( + ExecutionAttemptID task, Collection requiredLibraries, + Collection requiredClasspaths) { + + // Make sure the previous registration referred to the same libraries and class paths. + // NOTE: the original collections may contain duplicates and may not already be Set + // collections with fast checks whether an item is contained in it. + + // lazy construction of a new set for faster comparisons + if (libraries.size() != requiredLibraries.size() || + !new HashSet<>(requiredLibraries).containsAll(libraries)) { + throw new IllegalStateException( - "The library registration references a different set of libraries than previous registrations for this job."); + "The library registration references a different set of library BLOBs than" + + " previous registrations for this job:\nold:" + libraries.toString() + + "\nnew:" + requiredLibraries.toString()); } - + + // lazy construction of a new set with String representations of the URLs + if (classPaths.size() != requiredClasspaths.size() || + !requiredClasspaths.stream().map(URL::toString).collect(Collectors.toSet()) + .containsAll(classPaths)) { + + throw new IllegalStateException( + "The library registration references a different set of library BLOBs than" + + " previous registrations for this job:\nold:" + + classPaths.toString() + + "\nnew:" + requiredClasspaths.toString()); + } + this.referenceHolders.add(task); } - + public boolean unregister(ExecutionAttemptID task) { referenceHolders.remove(task); return referenceHolders.isEmpty(); } - - public int getNumberOfReferenceHolders() { + + int getNumberOfReferenceHolders() { return referenceHolders.size(); } @@ -343,5 +319,4 @@ void releaseClassLoader() { } } } - } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/FallbackLibraryCacheManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/FallbackLibraryCacheManager.java index 8e14e5867263a..41eeb1826b7fa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/FallbackLibraryCacheManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/librarycache/FallbackLibraryCacheManager.java @@ -28,7 +28,7 @@ import java.util.Collection; public class FallbackLibraryCacheManager implements LibraryCacheManager { - + private static Logger LOG = LoggerFactory.getLogger(FallbackLibraryCacheManager.class); @Override @@ -40,10 +40,10 @@ public ClassLoader getClassLoader(JobID id) { public void registerJob(JobID id, Collection requiredJarFiles, Collection requiredClasspaths) { LOG.warn("FallbackLibraryCacheManager cannot download files associated with blob keys."); } - + @Override public void registerTask(JobID id, ExecutionAttemptID execution, Collection requiredJarFiles, - Collection requiredClasspaths) { + Collection requiredClasspaths) { LOG.warn("FallbackLibraryCacheManager cannot download files associated with blob keys."); } @@ -51,7 +51,7 @@ public void registerTask(JobID id, ExecutionAttemptID execution, Collection requiredJarFiles, Collection requiredClasspaths) - throws IOException; - + throws IOException; + /** - * Registers a job task execution with its required jar files and classpaths. The jar files are identified by their blob keys. + * Registers a job task execution with its required jar files and classpaths. The jar files are + * identified by their blob keys and downloaded for use by a {@link ClassLoader}. * * @param id job ID * @param requiredJarFiles collection of blob keys identifying the required jar files * @param requiredClasspaths collection of classpaths that are added to the user code class loader - * @throws IOException + * + * @throws IOException if any error occurs when retrieving the required jar files * * @see #unregisterTask(JobID, ExecutionAttemptID) counterpart of this method */ void registerTask(JobID id, ExecutionAttemptID execution, Collection requiredJarFiles, - Collection requiredClasspaths) throws IOException; + Collection requiredClasspaths) throws IOException; /** * Unregisters a job task execution from the library cache manager. @@ -88,9 +93,7 @@ void registerTask(JobID id, ExecutionAttemptID execution, Collection re void unregisterJob(JobID id); /** - * Shutdown method - * - * @throws IOException + * Shutdown method which may release created class loaders. */ - void shutdown() throws IOException; + void shutdown(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index bd5bc7f5da6cc..2074820b41dc3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; @@ -41,7 +42,6 @@ import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.messages.StackTraceSampleResponse; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.ExceptionUtils; @@ -133,7 +133,7 @@ public class Execution implements AccessExecution, Archiveable hook : masterHooks) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java index 5ee7a9f97aa6c..e6d49d261def3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java @@ -26,6 +26,7 @@ import org.apache.flink.api.common.accumulators.AccumulatorHelper; import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.core.io.InputSplit; import org.apache.flink.core.io.InputSplitAssigner; import org.apache.flink.core.io.InputSplitSource; @@ -39,7 +40,6 @@ import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 0ff71e799e884..9aac133060bab 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -22,7 +22,9 @@ import org.apache.flink.api.common.Archiveable; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.PartialInputChannelDeploymentDescriptor; @@ -38,11 +40,9 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.EvictingBoundedList; import org.apache.flink.util.ExceptionUtils; @@ -457,7 +457,7 @@ public Iterable getPreferredLocations() { */ public Iterable getPreferredLocationsBasedOnState() { TaskManagerLocation priorLocation; - if (currentExecution.getTaskStateHandles() != null && (priorLocation = getLatestPriorLocation()) != null) { + if (currentExecution.getTaskStateSnapshot() != null && (priorLocation = getLatestPriorLocation()) != null) { return Collections.singleton(priorLocation); } else { @@ -719,7 +719,7 @@ void notifyStateTransition(Execution execution, ExecutionState newState, Throwab TaskDeploymentDescriptor createDeploymentDescriptor( ExecutionAttemptID executionId, SimpleSlot targetSlot, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, int attemptNumber) throws ExecutionGraphException { // Produced intermediate results diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServices.java index b44905e3d94c5..defe5cce331db 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServices.java @@ -72,6 +72,12 @@ public interface HighAvailabilityServices extends AutoCloseable { */ LeaderRetrievalService getResourceManagerLeaderRetriever(); + /** + * Gets the leader retriever for the dispatcher. This leader retrieval service + * is not always accessible. + */ + LeaderRetrievalService getDispatcherLeaderRetriever(); + /** * Gets the leader retriever for the job JobMaster which is responsible for the given job * @@ -99,6 +105,13 @@ public interface HighAvailabilityServices extends AutoCloseable { */ LeaderElectionService getResourceManagerLeaderElectionService(); + /** + * Gets the leader election service for the cluster's dispatcher. + * + * @return Leader election service for the dispatcher leader election + */ + LeaderElectionService getDispatcherLeaderElectionService(); + /** * Gets the leader election service for the given job. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServicesUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServicesUtils.java index 2ebfd20245662..7a89ed8fdf898 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServicesUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/HighAvailabilityServicesUtils.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.runtime.blob.BlobStoreService; import org.apache.flink.runtime.blob.BlobUtils; +import org.apache.flink.runtime.dispatcher.Dispatcher; import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices; import org.apache.flink.runtime.highavailability.nonha.standalone.StandaloneHaServices; import org.apache.flink.runtime.highavailability.zookeeper.ZooKeeperHaServices; @@ -87,8 +88,17 @@ public static HighAvailabilityServices createHighAvailabilityServices( ResourceManager.RESOURCE_MANAGER_NAME, addressResolution, configuration); + final String dispatcherRpcUrl = AkkaRpcServiceUtils.getRpcUrl( + hostnamePort.f0, + hostnamePort.f1, + Dispatcher.DISPATCHER_NAME, + addressResolution, + configuration); - return new StandaloneHaServices(resourceManagerRpcUrl, jobManagerRpcUrl); + return new StandaloneHaServices( + resourceManagerRpcUrl, + dispatcherRpcUrl, + jobManagerRpcUrl); case ZOOKEEPER: BlobStoreService blobStoreService = BlobUtils.createBlobStoreFromConfig(configuration); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/embedded/EmbeddedHaServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/embedded/EmbeddedHaServices.java index 76eb681e3e8e0..4c30f87fba7fc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/embedded/EmbeddedHaServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/embedded/EmbeddedHaServices.java @@ -45,11 +45,14 @@ public class EmbeddedHaServices extends AbstractNonHaServices { private final EmbeddedLeaderService resourceManagerLeaderService; + private final EmbeddedLeaderService dispatcherLeaderService; + private final HashMap jobManagerLeaderServices; public EmbeddedHaServices(Executor executor) { this.executor = Preconditions.checkNotNull(executor); this.resourceManagerLeaderService = new EmbeddedLeaderService(executor); + this.dispatcherLeaderService = new EmbeddedLeaderService(executor); this.jobManagerLeaderServices = new HashMap<>(); } @@ -62,11 +65,21 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { return resourceManagerLeaderService.createLeaderRetrievalService(); } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + return dispatcherLeaderService.createLeaderRetrievalService(); + } + @Override public LeaderElectionService getResourceManagerLeaderElectionService() { return resourceManagerLeaderService.createLeaderElectionService(); } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + return dispatcherLeaderService.createLeaderElectionService(); + } + @Override public LeaderRetrievalService getJobManagerLeaderRetriever(JobID jobID) { checkNotNull(jobID); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServices.java index b3c6ee51b37bd..617b3512df611 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServices.java @@ -45,6 +45,9 @@ public class StandaloneHaServices extends AbstractNonHaServices { /** The fix address of the ResourceManager */ private final String resourceManagerAddress; + /** The fix address of the Dispatcher */ + private final String dispatcherAddress; + /** The fix address of the JobManager */ private final String jobManagerAddress; @@ -53,8 +56,12 @@ public class StandaloneHaServices extends AbstractNonHaServices { * * @param resourceManagerAddress The fix address of the ResourceManager */ - public StandaloneHaServices(String resourceManagerAddress, String jobManagerAddress) { + public StandaloneHaServices( + String resourceManagerAddress, + String dispatcherAddress, + String jobManagerAddress) { this.resourceManagerAddress = checkNotNull(resourceManagerAddress, "resourceManagerAddress"); + this.dispatcherAddress = checkNotNull(dispatcherAddress, "dispatcherAddress"); this.jobManagerAddress = checkNotNull(jobManagerAddress, "jobManagerAddress"); } @@ -72,6 +79,15 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + synchronized (lock) { + checkNotShutdown(); + + return new StandaloneLeaderRetrievalService(dispatcherAddress, DEFAULT_LEADER_ID); + } + } + @Override public LeaderElectionService getResourceManagerLeaderElectionService() { synchronized (lock) { @@ -81,6 +97,15 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { } } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + synchronized (lock) { + checkNotShutdown(); + + return new StandaloneLeaderElectionService(); + } + } + @Override public LeaderRetrievalService getJobManagerLeaderRetriever(JobID jobID) { synchronized (lock) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/zookeeper/ZooKeeperHaServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/zookeeper/ZooKeeperHaServices.java index 9dabfa218284e..04ab6d3cc96e9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/zookeeper/ZooKeeperHaServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/highavailability/zookeeper/ZooKeeperHaServices.java @@ -82,6 +82,8 @@ public class ZooKeeperHaServices implements HighAvailabilityServices { private static final String RESOURCE_MANAGER_LEADER_PATH = "/resource_manager_lock"; + private static final String DISPATCHER_LEADER_PATH = "/dispatcher_lock"; + private static final String JOB_MANAGER_LEADER_PATH = "/job_manager_lock"; // ------------------------------------------------------------------------ @@ -124,6 +126,11 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { return ZooKeeperUtils.createLeaderRetrievalService(client, configuration, RESOURCE_MANAGER_LEADER_PATH); } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + return ZooKeeperUtils.createLeaderRetrievalService(client, configuration, DISPATCHER_LEADER_PATH); + } + @Override public LeaderRetrievalService getJobManagerLeaderRetriever(JobID jobID) { return ZooKeeperUtils.createLeaderRetrievalService(client, configuration, getPathForJobManager(jobID)); @@ -139,6 +146,11 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { return ZooKeeperUtils.createLeaderElectionService(client, configuration, RESOURCE_MANAGER_LEADER_PATH); } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + return ZooKeeperUtils.createLeaderElectionService(client, configuration, DISPATCHER_LEADER_PATH); + } + @Override public LeaderElectionService getJobManagerLeaderElectionService(JobID jobID) { return ZooKeeperUtils.createLeaderElectionService(client, configuration, getPathForJobManager(jobID)); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPool.java index de2b3e5f2aafb..6397043562c8a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPool.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; import org.apache.flink.runtime.jobmanager.slots.SlotAndLocality; import org.apache.flink.runtime.jobmanager.slots.SlotOwner; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; import org.apache.flink.runtime.resourcemanager.SlotRequest; @@ -53,7 +54,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -118,11 +118,8 @@ public class SlotPool extends RpcEndpoint implements SlotPoolGateway { private final Clock clock; - /** the leader id of job manager */ - private UUID jobManagerLeaderId; - - /** The leader id of resource manager */ - private UUID resourceManagerLeaderId; + /** the fencing token of the job manager */ + private JobMasterId jobMasterId; /** The gateway to communicate with resource manager */ private ResourceManagerGateway resourceManagerGateway; @@ -158,6 +155,10 @@ public SlotPool( this.waitingForResourceManager = new HashMap<>(); this.providerAndOwner = new ProviderAndOwner(getSelfGateway(SlotPoolGateway.class), slotRequestTimeout); + + this.jobMasterId = null; + this.resourceManagerGateway = null; + this.jobManagerAddress = null; } // ------------------------------------------------------------------------ @@ -172,11 +173,11 @@ public void start() { /** * Start the slot pool to accept RPC calls. * - * @param newJobManagerLeaderId The necessary leader id for running the job. + * @param jobMasterId The necessary leader id for running the job. * @param newJobManagerAddress for the slot requests which are sent to the resource manager */ - public void start(UUID newJobManagerLeaderId, String newJobManagerAddress) throws Exception { - this.jobManagerLeaderId = checkNotNull(newJobManagerLeaderId); + public void start(JobMasterId jobMasterId, String newJobManagerAddress) throws Exception { + this.jobMasterId = checkNotNull(jobMasterId); this.jobManagerAddress = checkNotNull(newJobManagerAddress); // TODO - start should not throw an exception @@ -198,8 +199,7 @@ public void suspend() { stop(); // do not accept any requests - jobManagerLeaderId = null; - resourceManagerLeaderId = null; + jobMasterId = null; resourceManagerGateway = null; // Clear (but not release!) the available slots. The TaskManagers should re-register them @@ -240,8 +240,7 @@ public SlotProvider getSlotProvider() { // ------------------------------------------------------------------------ @Override - public void connectToResourceManager(UUID resourceManagerLeaderId, ResourceManagerGateway resourceManagerGateway) { - this.resourceManagerLeaderId = checkNotNull(resourceManagerLeaderId); + public void connectToResourceManager(ResourceManagerGateway resourceManagerGateway) { this.resourceManagerGateway = checkNotNull(resourceManagerGateway); // work on all slots waiting for this connection @@ -255,7 +254,6 @@ public void connectToResourceManager(UUID resourceManagerLeaderId, ResourceManag @Override public void disconnectResourceManager() { - this.resourceManagerLeaderId = null; this.resourceManagerGateway = null; } @@ -319,7 +317,7 @@ private void requestSlotFromResourceManager( pendingRequests.put(allocationID, new PendingRequest(allocationID, future, resources)); CompletableFuture rmResponse = resourceManagerGateway.requestSlot( - jobManagerLeaderId, resourceManagerLeaderId, + jobMasterId, new SlotRequest(jobId, allocationID, resources, jobManagerAddress), resourceManagerRequestsTimeout); @@ -613,7 +611,7 @@ public void registerTaskManager(final ResourceID resourceID) { * @param resourceID The id of the TaskManager */ @Override - public void releaseTaskManager(final ResourceID resourceID) { + public CompletableFuture releaseTaskManager(final ResourceID resourceID) { if (registeredTaskManagers.remove(resourceID)) { availableSlots.removeAllForTaskManager(resourceID); @@ -622,6 +620,8 @@ public void releaseTaskManager(final ResourceID resourceID) { slot.releaseSlot(); } } + + return CompletableFuture.completedFuture(Acknowledge.get()); } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPoolGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPoolGateway.java index 8d4f2a514f51c..06c4b120cca3d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPoolGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/instance/SlotPoolGateway.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile; import org.apache.flink.runtime.jobmanager.scheduler.ScheduledUnit; import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; +import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; @@ -32,7 +33,6 @@ import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import java.util.Collection; -import java.util.UUID; import java.util.concurrent.CompletableFuture; /** @@ -54,10 +54,9 @@ public interface SlotPoolGateway extends RpcGateway { * Connects the SlotPool to the given ResourceManager. After this method is called, the * SlotPool will be able to request resources from the given ResourceManager. * - * @param resourceManagerLeaderId The leader session ID of the resource manager. * @param resourceManagerGateway The RPC gateway for the resource manager. */ - void connectToResourceManager(UUID resourceManagerLeaderId, ResourceManagerGateway resourceManagerGateway); + void connectToResourceManager(ResourceManagerGateway resourceManagerGateway); /** * Disconnects the slot pool from its current Resource Manager. After this call, the pool will not @@ -74,7 +73,7 @@ public interface SlotPoolGateway extends RpcGateway { void registerTaskManager(ResourceID resourceID); - void releaseTaskManager(ResourceID resourceID); + CompletableFuture releaseTaskManager(ResourceID resourceID); CompletableFuture offerSlot(AllocatedSlot slot); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java index eddba8db46d42..8816e32cb27e6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network; -import com.google.common.collect.Maps; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -26,6 +25,8 @@ import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.shaded.guava18.com.google.common.collect.Maps; + import java.util.Map; /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java index d2dc46b5520b6..9ef170a47795f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java @@ -18,11 +18,12 @@ package org.apache.flink.runtime.io.network.api; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.shaded.guava18.com.google.common.collect.HashMultimap; +import org.apache.flink.shaded.guava18.com.google.common.collect.Multimap; + /** * The event handler manages {@link EventListener} instances and allows to * to publish events to them. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java index e8727d20441b4..4036e2924d202 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java @@ -20,6 +20,7 @@ import org.apache.flink.runtime.util.FatalExitExceptionHandler; +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; @@ -32,7 +33,6 @@ import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java index 7db7ac4fe996e..e3097ba1e2eb4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java @@ -31,10 +31,10 @@ import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.shaded.guava18.com.google.common.collect.Maps; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; -import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java index 881eae8312ed0..ff0f1307dbfc7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java @@ -26,13 +26,13 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; +import org.apache.flink.shaded.guava18.com.google.common.collect.Sets; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; -import com.google.common.collect.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionManager.java index f681548263491..92fb2a03201db 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionManager.java @@ -18,11 +18,13 @@ package org.apache.flink.runtime.io.network.partition; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Table; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + +import org.apache.flink.shaded.guava18.com.google.common.collect.HashBasedTable; +import org.apache.flink.shaded.guava18.com.google.common.collect.ImmutableList; +import org.apache.flink.shaded.guava18.com.google.common.collect.Table; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index 55c78af07d487..87443d261cf2f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -18,11 +18,12 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.shaded.guava18.com.google.common.collect.Maps; +import org.apache.flink.shaded.guava18.com.google.common.collect.Sets; + import java.io.IOException; import java.util.ArrayDeque; import java.util.Map; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java index 1c68515f83965..c12687506524c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java @@ -535,7 +535,8 @@ public void uploadUserJars( InetSocketAddress blobServerAddress, Configuration blobClientConfig) throws IOException { if (!userJars.isEmpty()) { - List blobKeys = BlobClient.uploadJarFiles(blobServerAddress, blobClientConfig, userJars); + // TODO: make use of job-related BLOBs after adapting the BlobLibraryCacheManager + List blobKeys = BlobClient.uploadJarFiles(blobServerAddress, blobClientConfig, jobID, userJars); for (BlobKey blobKey : blobKeys) { if (!userJarBlobKeys.contains(blobKey)) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java index 0930011896353..00db01ffd2e04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java @@ -21,7 +21,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; /** * This interface must be implemented by any invokable that has recoverable state and participates @@ -35,7 +35,7 @@ public interface StatefulTask { * * @param taskStateHandles All state handle for the task. */ - void setInitialState(TaskStateHandles taskStateHandles) throws Exception; + void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception; /** * This method is called to trigger a checkpoint, asynchronously by the checkpoint @@ -43,8 +43,8 @@ public interface StatefulTask { * *

This method is called for tasks that start the checkpoints by injecting the initial barriers, * i.e., the source tasks. In contrast, checkpoints on downstream operators, which are the result of - * receiving checkpoint barriers, invoke the {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointMetrics)} - * method. + * receiving checkpoint barriers, invoke the + * {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointOptions, CheckpointMetrics)} method. * * @param checkpointMetaData Meta data for about this checkpoint * @param checkpointOptions Options for performing this checkpoint diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerException.java index bc2759d0ed0c1..1650c83290d3c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerException.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerException.java @@ -18,10 +18,12 @@ package org.apache.flink.runtime.jobmaster; +import org.apache.flink.util.FlinkException; + /** * Base exception thrown by the {@link JobMaster}. */ -public class JobManagerException extends Exception { +public class JobManagerException extends FlinkException { private static final long serialVersionUID = -7290962952242188064L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerRunner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerRunner.java index 5838cf27461b8..6f5a082bf40cd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerRunner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerRunner.java @@ -20,10 +20,11 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.concurrent.FlinkFutureException; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; @@ -33,17 +34,19 @@ import org.apache.flink.runtime.jobmanager.OnCompletionActions; import org.apache.flink.runtime.leaderelection.LeaderContender; import org.apache.flink.runtime.leaderelection.LeaderElectionService; +import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.metrics.MetricRegistryConfiguration; import org.apache.flink.runtime.metrics.groups.JobManagerMetricGroup; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -58,22 +61,22 @@ public class JobManagerRunner implements LeaderContender, OnCompletionActions, F // ------------------------------------------------------------------------ - /** Lock to ensure that this runner can deal with leader election event and job completion notifies simultaneously */ + /** Lock to ensure that this runner can deal with leader election event and job completion notifies simultaneously. */ private final Object lock = new Object(); - /** The job graph needs to run */ + /** The job graph needs to run. */ private final JobGraph jobGraph; - /** The listener to notify once the job completes - either successfully or unsuccessfully */ + /** The listener to notify once the job completes - either successfully or unsuccessfully. */ private final OnCompletionActions toNotifyOnComplete; - /** The handler to call in case of fatal (unrecoverable) errors */ + /** The handler to call in case of fatal (unrecoverable) errors. */ private final FatalErrorHandler errorHandler; - /** Used to check whether a job needs to be run */ + /** Used to check whether a job needs to be run. */ private final RunningJobsRegistry runningJobsRegistry; - /** Leader election for this job */ + /** Leader election for this job. */ private final LeaderElectionService leaderElectionService; private final JobManagerServices jobManagerServices; @@ -82,66 +85,20 @@ public class JobManagerRunner implements LeaderContender, OnCompletionActions, F private final JobManagerMetricGroup jobManagerMetricGroup; - /** flag marking the runner as shut down */ + private final Time timeout; + + /** flag marking the runner as shut down. */ private volatile boolean shutdown; // ------------------------------------------------------------------------ - public JobManagerRunner( - final ResourceID resourceId, - final JobGraph jobGraph, - final Configuration configuration, - final RpcService rpcService, - final HighAvailabilityServices haServices, - final BlobService blobService, - final HeartbeatServices heartbeatServices, - final OnCompletionActions toNotifyOnComplete, - final FatalErrorHandler errorHandler) throws Exception { - this( - resourceId, - jobGraph, - configuration, - rpcService, - haServices, - blobService, - heartbeatServices, - new MetricRegistry(MetricRegistryConfiguration.fromConfiguration(configuration)), - toNotifyOnComplete, - errorHandler); - } - - public JobManagerRunner( - final ResourceID resourceId, - final JobGraph jobGraph, - final Configuration configuration, - final RpcService rpcService, - final HighAvailabilityServices haServices, - final BlobService blobService, - final HeartbeatServices heartbeatServices, - final MetricRegistry metricRegistry, - final OnCompletionActions toNotifyOnComplete, - final FatalErrorHandler errorHandler) throws Exception { - this( - resourceId, - jobGraph, - configuration, - rpcService, - haServices, - heartbeatServices, - JobManagerServices.fromConfiguration(configuration, blobService), - metricRegistry, - toNotifyOnComplete, - errorHandler); - } - /** - * - *

Exceptions that occur while creating the JobManager or JobManagerRunner are directly + * Exceptions that occur while creating the JobManager or JobManagerRunner are directly * thrown and not reported to the given {@code FatalErrorHandler}. - * + * *

This JobManagerRunner assumes that it owns the given {@code JobManagerServices}. * It will shut them down on error and on calls to {@link #shutdown()}. - * + * * @throws Exception Thrown if the runner cannot be set up, because either one of the * required services could not be started, ot the Job could not be initialized. */ @@ -199,6 +156,7 @@ public JobManagerRunner( haServices, heartbeatServices, jobManagerServices.executorService, + jobManagerServices.blobServer, jobManagerServices.libraryCacheManager, jobManagerServices.restartStrategyFactory, jobManagerServices.rpcAskTimeout, @@ -206,15 +164,11 @@ public JobManagerRunner( this, this, userCodeLoader); + + this.timeout = jobManagerServices.rpcAskTimeout; } catch (Throwable t) { // clean up everything - try { - jobManagerServices.shutdown(); - } catch (Throwable tt) { - log.error("Error while shutting down JobManager services", tt); - } - if (jobManagerMetrics != null) { jobManagerMetrics.close(); } @@ -237,40 +191,37 @@ public void start() throws Exception { } } - public void shutdown() { - shutdownInternally(); + public void shutdown() throws Exception { + shutdownInternally().get(); } - private void shutdownInternally() { + private CompletableFuture shutdownInternally() { synchronized (lock) { shutdown = true; - if (leaderElectionService != null) { - try { - leaderElectionService.stop(); - } catch (Throwable t) { - log.error("Could not properly shutdown the leader election service", t); - } - } - - try { - jobManager.shutDown(); - } catch (Throwable t) { - log.error("Error shutting down JobManager", t); - } - - try { - jobManagerServices.shutdown(); - } catch (Throwable t) { - log.error("Error shutting down JobManager services", t); - } - - // make all registered metrics go away - try { - jobManagerMetricGroup.close(); - } catch (Throwable t) { - log.error("Error while unregistering metrics", t); - } + jobManager.shutDown(); + + return jobManager.getTerminationFuture() + .thenAccept( + ignored -> { + Throwable exception = null; + try { + leaderElectionService.stop(); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } + + // make all registered metrics go away + try { + jobManagerMetricGroup.close(); + } catch (Throwable t) { + exception = ExceptionUtils.firstOrSuppressed(t, exception); + } + + if (exception != null) { + throw new FlinkFutureException("Could not properly shut down the JobManagerRunner.", exception); + } + }); } } @@ -279,7 +230,7 @@ private void shutdownInternally() { //---------------------------------------------------------------------------------------------- /** - * Job completion notification triggered by JobManager + * Job completion notification triggered by JobManager. */ @Override public void jobFinished(JobExecutionResult result) { @@ -295,7 +246,7 @@ public void jobFinished(JobExecutionResult result) { } /** - * Job completion notification triggered by JobManager + * Job completion notification triggered by JobManager. */ @Override public void jobFailed(Throwable cause) { @@ -311,7 +262,7 @@ public void jobFailed(Throwable cause) { } /** - * Job completion notification triggered by self + * Job completion notification triggered by self. */ @Override public void jobFinishedByOther() { @@ -326,7 +277,7 @@ public void jobFinishedByOther() { } /** - * Job completion notification triggered by JobManager or self + * Job completion notification triggered by JobManager or self. */ @Override public void onFatalError(Throwable exception) { @@ -353,7 +304,7 @@ public void onFatalError(Throwable exception) { /** * Marks this runner's job as not running. Other JobManager will not recover the job * after this call. - * + * *

This method never throws an exception. */ private void unregisterJobFromHighAvailability() { @@ -407,14 +358,22 @@ public void grantLeadership(final UUID leaderSessionID) { // This will eventually be noticed, but can not be ruled out from the beginning. if (leaderElectionService.hasLeadership()) { try { - // Now set the running status is after getting leader ship and + // Now set the running status is after getting leader ship and // set finished status after job in terminated status. // So if finding the job is running, it means someone has already run the job, need recover. if (schedulingStatus == JobSchedulingStatus.PENDING) { runningJobsRegistry.setJobRunning(jobGraph.getJobID()); } - jobManager.start(leaderSessionID); + CompletableFuture startingFuture = jobManager.start(new JobMasterId(leaderSessionID), timeout); + + startingFuture.whenCompleteAsync( + (Acknowledge ack, Throwable throwable) -> { + if (throwable != null) { + onFatalError(new Exception("Could not start the job manager.", throwable)); + } + }, + jobManagerServices.executorService); } catch (Exception e) { onFatalError(new Exception("Could not start the job manager.", e)); } @@ -433,7 +392,15 @@ public void revokeLeadership() { log.info("JobManager for job {} ({}) was revoked leadership at {}.", jobGraph.getName(), jobGraph.getJobID(), getAddress()); - jobManager.getSelfGateway(JobMasterGateway.class).suspendExecution(new Exception("JobManager is no longer the leader.")); + CompletableFuture suspendFuture = jobManager.suspend(new Exception("JobManager is no longer the leader."), timeout); + + suspendFuture.whenCompleteAsync( + (Acknowledge ack, Throwable throwable) -> { + if (throwable != null) { + onFatalError(new Exception("Could not start the job manager.", throwable)); + } + }, + jobManagerServices.executorService); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerServices.java index e14f5aff0a824..57aeaff5c7844 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobManagerServices.java @@ -19,11 +19,10 @@ package org.apache.flink.runtime.jobmaster; import org.apache.flink.api.common.time.Time; -import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.runtime.akka.AkkaUtils; -import org.apache.flink.runtime.blob.BlobService; +import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.executiongraph.restart.RestartStrategyFactory; import org.apache.flink.runtime.util.ExecutorThreadFactory; @@ -45,6 +44,7 @@ public class JobManagerServices { public final ScheduledExecutorService executorService; + public final BlobServer blobServer; public final BlobLibraryCacheManager libraryCacheManager; public final RestartStrategyFactory restartStrategyFactory; @@ -53,11 +53,13 @@ public class JobManagerServices { public JobManagerServices( ScheduledExecutorService executorService, + BlobServer blobServer, BlobLibraryCacheManager libraryCacheManager, RestartStrategyFactory restartStrategyFactory, Time rpcAskTimeout) { this.executorService = checkNotNull(executorService); + this.blobServer = checkNotNull(blobServer); this.libraryCacheManager = checkNotNull(libraryCacheManager); this.restartStrategyFactory = checkNotNull(restartStrategyFactory); this.rpcAskTimeout = checkNotNull(rpcAskTimeout); @@ -80,8 +82,9 @@ public void shutdown() throws Exception { firstException = t; } + libraryCacheManager.shutdown(); try { - libraryCacheManager.shutdown(); + blobServer.close(); } catch (Throwable t) { if (firstException == null) { @@ -103,16 +106,12 @@ public void shutdown() throws Exception { public static JobManagerServices fromConfiguration( Configuration config, - BlobService blobService) throws Exception { + BlobServer blobServer) throws Exception { Preconditions.checkNotNull(config); - Preconditions.checkNotNull(blobService); + Preconditions.checkNotNull(blobServer); - final long cleanupInterval = config.getLong( - ConfigConstants.LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL, - ConfigConstants.DEFAULT_LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL) * 1000; - - final BlobLibraryCacheManager libraryCacheManager = new BlobLibraryCacheManager(blobService, cleanupInterval); + final BlobLibraryCacheManager libraryCacheManager = new BlobLibraryCacheManager(blobServer); final FiniteDuration timeout; try { @@ -127,6 +126,7 @@ public static JobManagerServices fromConfiguration( return new JobManagerServices( futureExecutor, + blobServer, libraryCacheManager, RestartStrategyFactory.createRestartStrategyFactory(config), Time.of(timeout.length(), timeout.unit())); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java index 31036f6c8663d..80d6e4f496b41 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java @@ -28,10 +28,11 @@ import org.apache.flink.core.io.InputSplitAssigner; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; +import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.checkpoint.CheckpointCoordinator; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; @@ -52,7 +53,6 @@ import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.heartbeat.HeartbeatTarget; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; -import org.apache.flink.runtime.highavailability.LeaderIdMismatchException; import org.apache.flink.runtime.instance.Slot; import org.apache.flink.runtime.instance.SlotPool; import org.apache.flink.runtime.instance.SlotPoolGateway; @@ -80,8 +80,9 @@ import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.registration.RetryingRegistration; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.rpc.FatalErrorHandler; -import org.apache.flink.runtime.rpc.RpcEndpoint; +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.akka.AkkaRpcServiceUtils; import org.apache.flink.runtime.state.KeyGroupRange; @@ -89,13 +90,15 @@ import org.apache.flink.runtime.taskexecutor.slot.SlotOffer; import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; -import org.apache.flink.util.SerializedThrowable; +import org.apache.flink.util.FlinkException; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; +import org.apache.flink.util.SerializedThrowable; import org.slf4j.Logger; import javax.annotation.Nullable; + import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -107,7 +110,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -122,15 +124,12 @@ * given task * */ -public class JobMaster extends RpcEndpoint implements JobMasterGateway { +public class JobMaster extends FencedRpcEndpoint implements JobMasterGateway { /** Default names for Flink's distributed components */ public static final String JOB_MANAGER_NAME = "jobmanager"; public static final String ARCHIVE_NAME = "archive"; - private static final AtomicReferenceFieldUpdater LEADER_ID_UPDATER = - AtomicReferenceFieldUpdater.newUpdater(JobMaster.class, UUID.class, "leaderSessionID"); - // ------------------------------------------------------------------------ private final JobMasterGateway selfGateway; @@ -148,7 +147,10 @@ public class JobMaster extends RpcEndpoint implements JobMasterGateway { /** Service to contend for and retrieve the leadership of JM and RM */ private final HighAvailabilityServices highAvailabilityServices; - /** Blob cache manager used across jobs */ + /** Blob server used across jobs */ + private final BlobServer blobServer; + + /** Blob library cache manager used across jobs */ private final BlobLibraryCacheManager libraryCacheManager; /** The metrics for the JobManager itself */ @@ -179,8 +181,6 @@ public class JobMaster extends RpcEndpoint implements JobMasterGateway { private final SlotPoolGateway slotPoolGateway; - private volatile UUID leaderSessionID; - // --------- ResourceManager -------- /** Leader retriever service used to locate ResourceManager's address */ @@ -203,6 +203,7 @@ public JobMaster( HighAvailabilityServices highAvailabilityService, HeartbeatServices heartbeatServices, ScheduledExecutorService executor, + BlobServer blobServer, BlobLibraryCacheManager libraryCacheManager, RestartStrategyFactory restartStrategyFactory, Time rpcAskTimeout, @@ -211,7 +212,7 @@ public JobMaster( FatalErrorHandler errorHandler, ClassLoader userCodeLoader) throws Exception { - super(rpcService, AkkaRpcServiceUtils.createRandomName(JobMaster.JOB_MANAGER_NAME)); + super(rpcService, AkkaRpcServiceUtils.createRandomName(JobMaster.JOB_MANAGER_NAME), JobMasterId.INITIAL_JOB_MASTER_ID); selfGateway = getSelfGateway(JobMasterGateway.class); @@ -220,6 +221,7 @@ public JobMaster( this.configuration = checkNotNull(configuration); this.rpcTimeout = rpcAskTimeout; this.highAvailabilityServices = checkNotNull(highAvailabilityService); + this.blobServer = checkNotNull(blobServer); this.libraryCacheManager = checkNotNull(libraryCacheManager); this.executor = checkNotNull(executor); this.jobCompletionActions = checkNotNull(jobCompletionActions); @@ -303,19 +305,36 @@ public void start() { /** * Start the rpc service and begin to run the job. * - * @param leaderSessionID The necessary leader id for running the job. + * @param newJobMasterId The necessary fencing token to run the job + * @param timeout for the operation + * @return Future acknowledge if the job could be started. Otherwise the future contains an exception */ - public void start(final UUID leaderSessionID) throws Exception { - if (LEADER_ID_UPDATER.compareAndSet(this, null, leaderSessionID)) { - // make sure we receive RPC and async calls - super.start(); + public CompletableFuture start(final JobMasterId newJobMasterId, final Time timeout) throws Exception { + // make sure we receive RPC and async calls + super.start(); - log.info("JobManager started as leader {} for job {}", leaderSessionID, jobGraph.getJobID()); - selfGateway.startJobExecution(); - } - else { - log.warn("Job already started with leader ID {}, ignoring this start request.", leaderSessionID); - } + return callAsyncWithoutFencing(() -> startJobExecution(newJobMasterId), timeout); + } + + /** + * Suspending job, all the running tasks will be cancelled, and communication with other components + * will be disposed. + * + *

Mostly job is suspended because of the leadership has been revoked, one can be restart this job by + * calling the {@link #start(JobMasterId, Time)} method once we take the leadership back again. + * + *

This method is executed asynchronously + * + * @param cause The reason of why this job been suspended. + * @param timeout for this operation + * @return Future acknowledge indicating that the job has been suspended. Otherwise the future contains an exception + */ + public CompletableFuture suspend(final Throwable cause, final Time timeout) { + CompletableFuture suspendFuture = callAsyncWithoutFencing(() -> suspendExecution(cause), timeout); + + stop(); + + return suspendFuture; } /** @@ -336,98 +355,6 @@ public void postStop() throws Exception { // RPC methods //---------------------------------------------------------------------------------------------- - //-- job starting and stopping ----------------------------------------------------------------- - - @Override - public void startJobExecution() { - // double check that the leader status did not change - if (leaderSessionID == null) { - log.info("Aborting job startup - JobManager lost leader status"); - return; - } - - log.info("Starting execution of job {} ({})", jobGraph.getName(), jobGraph.getJobID()); - - try { - // start the slot pool make sure the slot pool now accepts messages for this leader - log.debug("Staring SlotPool component"); - slotPool.start(leaderSessionID, getAddress()); - } catch (Exception e) { - log.error("Faild to start job {} ({})", jobGraph.getName(), jobGraph.getJobID(), e); - - handleFatalError(new Exception("Could not start job execution: Failed to start the slot pool.", e)); - } - - try { - // job is ready to go, try to establish connection with resource manager - // - activate leader retrieval for the resource manager - // - on notification of the leader, the connection will be established and - // the slot pool will start requesting slots - resourceManagerLeaderRetriever.start(new ResourceManagerLeaderListener()); - } - catch (Throwable t) { - log.error("Failed to start job {} ({})", jobGraph.getName(), jobGraph.getJobID(), t); - - handleFatalError(new Exception( - "Could not start job execution: Failed to start leader service for Resource Manager", t)); - - return; - } - - // start scheduling job in another thread - executor.execute(new Runnable() { - @Override - public void run() { - try { - executionGraph.scheduleForExecution(); - } - catch (Throwable t) { - executionGraph.failGlobal(t); - } - } - }); - } - - /** - * Suspending job, all the running tasks will be cancelled, and communication with other components - * will be disposed. - * - *

Mostly job is suspended because of the leadership has been revoked, one can be restart this job by - * calling the {@link #start(UUID)} method once we take the leadership back again. - * - * @param cause The reason of why this job been suspended. - */ - @Override - public void suspendExecution(final Throwable cause) { - if (leaderSessionID == null) { - log.debug("Job has already been suspended or shutdown."); - return; - } - - // not leader any more - should not accept any leader messages any more - leaderSessionID = null; - - try { - resourceManagerLeaderRetriever.stop(); - } catch (Throwable t) { - log.warn("Failed to stop resource manager leader retriever when suspending.", t); - } - - // tell the execution graph (JobManager is still processing messages here) - executionGraph.suspend(cause); - - // receive no more messages until started again, should be called before we clear self leader id - stop(); - - // the slot pool stops receiving messages and clears its pooled slots - slotPoolGateway.suspend(); - - // disconnect from resource manager: - closeResourceManagerConnection(new Exception("Execution was suspended.", cause)); - } - - //---------------------------------------------------------------------------------------------- - /** * Updates the task execution state for a given task. * @@ -436,17 +363,10 @@ public void suspendExecution(final Throwable cause) { */ @Override public CompletableFuture updateTaskExecutionState( - final UUID leaderSessionID, final TaskExecutionState taskExecutionState) { checkNotNull(taskExecutionState, "taskExecutionState"); - try { - validateLeaderSessionId(leaderSessionID); - } catch (LeaderIdMismatchException e) { - return FutureUtils.completedExceptionally(e); - } - if (executionGraph.updateState(taskExecutionState)) { return CompletableFuture.completedFuture(Acknowledge.get()); } else { @@ -458,16 +378,9 @@ public CompletableFuture updateTaskExecutionState( @Override public CompletableFuture requestNextInputSplit( - final UUID leaderSessionID, final JobVertexID vertexID, final ExecutionAttemptID executionAttempt) { - try { - validateLeaderSessionId(leaderSessionID); - } catch (LeaderIdMismatchException e) { - return FutureUtils.completedExceptionally(e); - } - final Execution execution = executionGraph.getRegisteredExecutions().get(executionAttempt); if (execution == null) { // can happen when JobManager had already unregistered this execution upon on task failure, @@ -514,16 +427,9 @@ public CompletableFuture requestNextInputSplit( @Override public CompletableFuture requestPartitionState( - final UUID leaderSessionID, final IntermediateDataSetID intermediateResultId, final ResultPartitionID resultPartitionId) { - try { - validateLeaderSessionId(leaderSessionID); - } catch (LeaderIdMismatchException e) { - return FutureUtils.completedExceptionally(e); - } - final Execution execution = executionGraph.getRegisteredExecutions().get(resultPartitionId.getProducerId()); if (execution != null) { return CompletableFuture.completedFuture(execution.getState()); @@ -553,12 +459,9 @@ public CompletableFuture requestPartitionState( @Override public CompletableFuture scheduleOrUpdateConsumers( - final UUID leaderSessionID, final ResultPartitionID partitionID, final Time timeout) { try { - validateLeaderSessionId(leaderSessionID); - executionGraph.scheduleOrUpdateConsumers(partitionID); return CompletableFuture.completedFuture(Acknowledge.get()); } catch (Exception e) { @@ -586,7 +489,7 @@ public void acknowledgeCheckpoint( final ExecutionAttemptID executionAttemptID, final long checkpointId, final CheckpointMetrics checkpointMetrics, - final SubtaskState checkpointState) { + final TaskStateSnapshot checkpointState) { final CheckpointCoordinator checkpointCoordinator = executionGraph.getCheckpointCoordinator(); final AcknowledgeCheckpoint ackMessage = @@ -697,7 +600,7 @@ public void notifyKvStateUnregistered( @Override public CompletableFuture requestClassloadingProps() { return CompletableFuture.completedFuture( - new ClassloadingProps(libraryCacheManager.getBlobServerPort(), + new ClassloadingProps(blobServer.getPort(), executionGraph.getRequiredJarFiles(), executionGraph.getRequiredClasspaths())); } @@ -706,15 +609,8 @@ public CompletableFuture requestClassloadingProps() { public CompletableFuture> offerSlots( final ResourceID taskManagerId, final Iterable slots, - final UUID leaderId, final Time timeout) { - try { - validateLeaderSessionId(leaderId); - } catch (LeaderIdMismatchException e) { - return FutureUtils.completedExceptionally(e); - } - Tuple2 taskManager = registeredTaskManagers.get(taskManagerId); if (taskManager == null) { @@ -727,7 +623,7 @@ public CompletableFuture> offerSlots( final ArrayList> slotsAndOffers = new ArrayList<>(); - final RpcTaskManagerGateway rpcTaskManagerGateway = new RpcTaskManagerGateway(taskExecutorGateway, leaderId); + final RpcTaskManagerGateway rpcTaskManagerGateway = new RpcTaskManagerGateway(taskExecutorGateway, getFencingToken()); for (SlotOffer slotOffer : slots) { final AllocatedSlot slot = new AllocatedSlot( @@ -748,15 +644,8 @@ public CompletableFuture> offerSlots( public void failSlot( final ResourceID taskManagerId, final AllocationID allocationId, - final UUID leaderId, final Exception cause) { - try { - validateLeaderSessionId(leaderSessionID); - } catch (LeaderIdMismatchException e) { - log.warn("Cannot fail slot " + allocationId + '.', e); - } - if (registeredTaskManagers.containsKey(taskManagerId)) { slotPoolGateway.failAllocation(allocationId, cause); } else { @@ -769,22 +658,13 @@ public void failSlot( public CompletableFuture registerTaskManager( final String taskManagerRpcAddress, final TaskManagerLocation taskManagerLocation, - final UUID leaderId, final Time timeout) { - if (!Objects.equals(leaderSessionID, leaderId)) { - log.warn("Discard registration from TaskExecutor {} at ({}) because the expected " + - "leader session ID {} did not equal the received leader session ID {}.", - taskManagerLocation.getResourceID(), taskManagerRpcAddress, leaderSessionID, leaderId); - return FutureUtils.completedExceptionally( - new Exception("Leader id not match, expected: " + - leaderSessionID + ", actual: " + leaderId)); - } final ResourceID taskManagerId = taskManagerLocation.getResourceID(); if (registeredTaskManagers.containsKey(taskManagerId)) { final RegistrationResponse response = new JMTMRegistrationSuccess( - resourceId, libraryCacheManager.getBlobServerPort()); + resourceId, blobServer.getPort()); return CompletableFuture.completedFuture(response); } else { return getRpcService() @@ -795,13 +675,6 @@ public CompletableFuture registerTaskManager( return new RegistrationResponse.Decline(throwable.getMessage()); } - if (!Objects.equals(leaderSessionID, leaderId)) { - log.warn("Discard registration from TaskExecutor {} at ({}) because the expected " + - "leader session ID {} did not equal the received leader session ID {}.", - taskManagerId, taskManagerRpcAddress, leaderSessionID, leaderId); - return new RegistrationResponse.Decline("Invalid leader session id"); - } - slotPoolGateway.registerTaskManager(taskManagerId); registeredTaskManagers.put(taskManagerId, Tuple2.of(taskManagerLocation, taskExecutorGateway)); @@ -818,7 +691,7 @@ public void requestHeartbeat(ResourceID resourceID, Void payload) { } }); - return new JMTMRegistrationSuccess(resourceId, libraryCacheManager.getBlobServerPort()); + return new JMTMRegistrationSuccess(resourceId, blobServer.getPort()); }, getMainThreadExecutor()); } @@ -826,18 +699,11 @@ public void requestHeartbeat(ResourceID resourceID, Void payload) { @Override public void disconnectResourceManager( - final UUID jobManagerLeaderId, - final UUID resourceManagerLeaderId, + final ResourceManagerId resourceManagerId, final Exception cause) { - try { - validateLeaderSessionId(jobManagerLeaderId); - } catch (LeaderIdMismatchException e) { - log.warn("Cannot disconnect resource manager " + resourceManagerLeaderId + '.', e); - } - if (resourceManagerConnection != null - && resourceManagerConnection.getTargetLeaderId().equals(resourceManagerLeaderId)) { + && resourceManagerConnection.getTargetLeaderId().equals(resourceManagerId)) { closeResourceManagerConnection(cause); } } @@ -856,21 +722,111 @@ public void heartbeatFromResourceManager(final ResourceID resourceID) { // Internal methods //---------------------------------------------------------------------------------------------- - private void handleFatalError(final Throwable cause) { - runAsync(new Runnable() { - @Override - public void run() { - log.error("Fatal error occurred on JobManager, cause: {}", cause.getMessage(), cause); + //-- job starting and stopping ----------------------------------------------------------------- + + private Acknowledge startJobExecution(JobMasterId newJobMasterId) throws Exception { + validateRunsInMainThread(); + + Preconditions.checkNotNull(newJobMasterId, "The new JobMasterId must not be null."); + + if (Objects.equals(getFencingToken(), newJobMasterId)) { + log.info("Already started the job execution with JobMasterId {}.", newJobMasterId); + + return Acknowledge.get(); + } + + if (!Objects.equals(getFencingToken(), JobMasterId.INITIAL_JOB_MASTER_ID)) { + log.info("Restarting old job with JobMasterId {}. The new JobMasterId is {}.", getFencingToken(), newJobMasterId); + + // first we have to suspend the current execution + suspendExecution(new FlinkException("Old job with JobMasterId " + getFencingToken() + + " is restarted with a new JobMasterId " + newJobMasterId + '.')); + } + // set new leader id + setFencingToken(newJobMasterId); + + log.info("Starting execution of job {} ({})", jobGraph.getName(), jobGraph.getJobID()); + + try { + // start the slot pool make sure the slot pool now accepts messages for this leader + log.debug("Staring SlotPool component"); + slotPool.start(getFencingToken(), getAddress()); + + // job is ready to go, try to establish connection with resource manager + // - activate leader retrieval for the resource manager + // - on notification of the leader, the connection will be established and + // the slot pool will start requesting slots + resourceManagerLeaderRetriever.start(new ResourceManagerLeaderListener()); + } + catch (Throwable t) { + log.error("Failed to start job {} ({})", jobGraph.getName(), jobGraph.getJobID(), t); + + throw new Exception("Could not start job execution: Failed to start JobMaster services.", t); + } + + // start scheduling job in another thread + executor.execute( + () -> { try { - shutDown(); - } catch (Exception e) { - cause.addSuppressed(e); + executionGraph.scheduleForExecution(); } + catch (Throwable t) { + executionGraph.failGlobal(t); + } + }); - errorHandler.onFatalError(cause); - } - }); + return Acknowledge.get(); + } + + /** + * Suspending job, all the running tasks will be cancelled, and communication with other components + * will be disposed. + * + *

Mostly job is suspended because of the leadership has been revoked, one can be restart this job by + * calling the {@link #start(JobMasterId, Time)} method once we take the leadership back again. + * + * @param cause The reason of why this job been suspended. + */ + private Acknowledge suspendExecution(final Throwable cause) { + validateRunsInMainThread(); + + if (getFencingToken() == null) { + log.debug("Job has already been suspended or shutdown."); + return Acknowledge.get(); + } + + // not leader anymore --> set the JobMasterId to the initial id + setFencingToken(JobMasterId.INITIAL_JOB_MASTER_ID); + + try { + resourceManagerLeaderRetriever.stop(); + } catch (Throwable t) { + log.warn("Failed to stop resource manager leader retriever when suspending.", t); + } + + // tell the execution graph (JobManager is still processing messages here) + executionGraph.suspend(cause); + + // the slot pool stops receiving messages and clears its pooled slots + slotPoolGateway.suspend(); + + // disconnect from resource manager: + closeResourceManagerConnection(new Exception("Execution was suspended.", cause)); + + return Acknowledge.get(); + } + + //---------------------------------------------------------------------------------------------- + + private void handleFatalError(final Throwable cause) { + + try { + log.error("Fatal error occurred on JobManager.", cause); + } catch (Throwable ignore) {} + + // The fatal error handler implementation should make sure that this call is non-blocking + errorHandler.onFatalError(cause); } private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, final Throwable error) { @@ -890,7 +846,7 @@ private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, fina Map accumulatorResults = executionGraph.getAccumulators(); JobExecutionResult result = new JobExecutionResult(jobID, 0L, accumulatorResults); - jobCompletionActions.jobFinished(result); + executor.execute(() -> jobCompletionActions.jobFinished(result)); } catch (Exception e) { log.error("Cannot fetch final accumulators for job {} ({})", jobName, jobID, e); @@ -900,7 +856,7 @@ private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, fina "The job is registered as 'FINISHED (successful), but this notification describes " + "a failure, since the resulting accumulators could not be fetched.", e); - jobCompletionActions.jobFailed(exception); + executor.execute(() ->jobCompletionActions.jobFailed(exception)); } break; @@ -908,7 +864,7 @@ private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, fina final JobExecutionException exception = new JobExecutionException( jobID, "Job was cancelled.", new Exception("The job was cancelled")); - jobCompletionActions.jobFailed(exception); + executor.execute(() -> jobCompletionActions.jobFailed(exception)); break; } @@ -916,7 +872,7 @@ private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, fina final Throwable unpackedError = SerializedThrowable.get(error, userCodeLoader); final JobExecutionException exception = new JobExecutionException( jobID, "Job execution failed.", unpackedError); - jobCompletionActions.jobFailed(exception); + executor.execute(() -> jobCompletionActions.jobFailed(exception)); break; } @@ -927,11 +883,11 @@ private void jobStatusChanged(final JobStatus newJobStatus, long timestamp, fina } } - private void notifyOfNewResourceManagerLeader(final String resourceManagerAddress, final UUID resourceManagerLeaderId) { + private void notifyOfNewResourceManagerLeader(final String resourceManagerAddress, final ResourceManagerId resourceManagerId) { if (resourceManagerConnection != null) { if (resourceManagerAddress != null) { - if (resourceManagerAddress.equals(resourceManagerConnection.getTargetAddress()) - && resourceManagerLeaderId.equals(resourceManagerConnection.getTargetLeaderId())) { + if (Objects.equals(resourceManagerAddress, resourceManagerConnection.getTargetAddress()) + && Objects.equals(resourceManagerId, resourceManagerConnection.getTargetLeaderId())) { // both address and leader id are not changed, we can keep the old connection return; } @@ -955,9 +911,9 @@ private void notifyOfNewResourceManagerLeader(final String resourceManagerAddres jobGraph.getJobID(), resourceId, getAddress(), - leaderSessionID, + getFencingToken(), resourceManagerAddress, - resourceManagerLeaderId, + resourceManagerId, executor); resourceManagerConnection.start(); @@ -965,17 +921,17 @@ private void notifyOfNewResourceManagerLeader(final String resourceManagerAddres } private void establishResourceManagerConnection(final JobMasterRegistrationSuccess success) { - final UUID resourceManagerLeaderId = success.getResourceManagerLeaderId(); + final ResourceManagerId resourceManagerId = success.getResourceManagerId(); // verify the response with current connection if (resourceManagerConnection != null - && resourceManagerConnection.getTargetLeaderId().equals(resourceManagerLeaderId)) { + && Objects.equals(resourceManagerConnection.getTargetLeaderId(), resourceManagerId)) { - log.info("JobManager successfully registered at ResourceManager, leader id: {}.", resourceManagerLeaderId); + log.info("JobManager successfully registered at ResourceManager, leader id: {}.", resourceManagerId); final ResourceManagerGateway resourceManagerGateway = resourceManagerConnection.getTargetGateway(); - slotPoolGateway.connectToResourceManager(resourceManagerLeaderId, resourceManagerGateway); + slotPoolGateway.connectToResourceManager(resourceManagerGateway); resourceManagerHeartbeatManager.monitorTarget(success.getResourceManagerResourceId(), new HeartbeatTarget() { @Override @@ -988,6 +944,9 @@ public void requestHeartbeat(ResourceID resourceID, Void payload) { // request heartbeat will never be called on the job manager side } }); + } else { + log.debug("Ignoring resource manager connection to {} because its a duplicate or outdated.", resourceManagerId); + } } @@ -1007,12 +966,6 @@ private void closeResourceManagerConnection(Exception cause) { slotPoolGateway.disconnectResourceManager(); } - private void validateLeaderSessionId(UUID leaderSessionID) throws LeaderIdMismatchException { - if (this.leaderSessionID == null || !this.leaderSessionID.equals(leaderSessionID)) { - throw new LeaderIdMismatchException(this.leaderSessionID, leaderSessionID); - } - } - //---------------------------------------------------------------------------------------------- // Utility classes //---------------------------------------------------------------------------------------------- @@ -1021,12 +974,10 @@ private class ResourceManagerLeaderListener implements LeaderRetrievalListener { @Override public void notifyLeaderAddress(final String leaderAddress, final UUID leaderSessionID) { - runAsync(new Runnable() { - @Override - public void run() { - notifyOfNewResourceManagerLeader(leaderAddress, leaderSessionID); - } - }); + runAsync( + () -> notifyOfNewResourceManagerLeader( + leaderAddress, + leaderSessionID != null ? new ResourceManagerId(leaderSessionID) : null)); } @Override @@ -1038,7 +989,7 @@ public void handleError(final Exception exception) { //---------------------------------------------------------------------------------------------- private class ResourceManagerConnection - extends RegisteredRpcConnection + extends RegisteredRpcConnection { private final JobID jobID; @@ -1046,7 +997,7 @@ private class ResourceManagerConnection private final String jobManagerRpcAddress; - private final UUID jobManagerLeaderID; + private final JobMasterId jobMasterId; private ResourceID resourceManagerResourceID; @@ -1055,33 +1006,32 @@ private class ResourceManagerConnection final JobID jobID, final ResourceID jobManagerResourceID, final String jobManagerRpcAddress, - final UUID jobManagerLeaderID, + final JobMasterId jobMasterId, final String resourceManagerAddress, - final UUID resourceManagerLeaderID, + final ResourceManagerId resourceManagerId, final Executor executor) { - super(log, resourceManagerAddress, resourceManagerLeaderID, executor); + super(log, resourceManagerAddress, resourceManagerId, executor); this.jobID = checkNotNull(jobID); this.jobManagerResourceID = checkNotNull(jobManagerResourceID); this.jobManagerRpcAddress = checkNotNull(jobManagerRpcAddress); - this.jobManagerLeaderID = checkNotNull(jobManagerLeaderID); + this.jobMasterId = checkNotNull(jobMasterId); } @Override - protected RetryingRegistration generateRegistration() { - return new RetryingRegistration( + protected RetryingRegistration generateRegistration() { + return new RetryingRegistration( log, getRpcService(), "ResourceManager", ResourceManagerGateway.class, getTargetAddress(), getTargetLeaderId()) { @Override protected CompletableFuture invokeRegistration( - ResourceManagerGateway gateway, UUID leaderId, long timeoutMillis) throws Exception + ResourceManagerGateway gateway, ResourceManagerId fencingToken, long timeoutMillis) throws Exception { Time timeout = Time.milliseconds(timeoutMillis); return gateway.registerJobManager( - leaderId, - jobManagerLeaderID, + jobMasterId, jobManagerResourceID, jobManagerRpcAddress, jobID, @@ -1127,12 +1077,7 @@ public void jobStatusChanges( final Throwable error) { // run in rpc thread to avoid concurrency - runAsync(new Runnable() { - @Override - public void run() { - jobStatusChanged(newJobStatus, timestamp, error); - } - }); + runAsync(() -> jobStatusChanged(newJobStatus, timestamp, error)); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterGateway.java index b396cd68df0b1..965d88d272e3c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterGateway.java @@ -36,6 +36,8 @@ import org.apache.flink.runtime.query.KvStateLocation; import org.apache.flink.runtime.query.KvStateServerAddress; import org.apache.flink.runtime.registration.RegistrationResponse; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.state.KeyGroupRange; @@ -44,46 +46,31 @@ import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import java.util.Collection; -import java.util.UUID; import java.util.concurrent.CompletableFuture; /** * {@link JobMaster} rpc gateway interface */ -public interface JobMasterGateway extends CheckpointCoordinatorGateway { - - // ------------------------------------------------------------------------ - // Job start and stop methods - // ------------------------------------------------------------------------ - - void startJobExecution(); - - void suspendExecution(Throwable cause); - - // ------------------------------------------------------------------------ +public interface JobMasterGateway extends CheckpointCoordinatorGateway, FencedRpcGateway { /** * Updates the task execution state for a given task. * - * @param leaderSessionID The leader id of JobManager * @param taskExecutionState New task execution state for a given task * @return Future flag of the task execution state update result */ CompletableFuture updateTaskExecutionState( - final UUID leaderSessionID, final TaskExecutionState taskExecutionState); /** * Requesting next input split for the {@link ExecutionJobVertex}. The next input split is sent back to the sender * as a {@link SerializedInputSplit} message. * - * @param leaderSessionID The leader id of JobManager * @param vertexID The job vertex id * @param executionAttempt The execution attempt id * @return The future of the input split. If there is no further input split, will return an empty object. */ CompletableFuture requestNextInputSplit( - final UUID leaderSessionID, final JobVertexID vertexID, final ExecutionAttemptID executionAttempt); @@ -91,13 +78,11 @@ CompletableFuture requestNextInputSplit( * Requests the current state of the partition. * The state of a partition is currently bound to the state of the producing execution. * - * @param leaderSessionID The leader id of JobManager * @param intermediateResultId The execution attempt ID of the task requesting the partition state. * @param partitionId The partition ID of the partition to request the state of. * @return The future of the partition state */ CompletableFuture requestPartitionState( - final UUID leaderSessionID, final IntermediateDataSetID intermediateResultId, final ResultPartitionID partitionId); @@ -110,13 +95,11 @@ CompletableFuture requestPartitionState( *

* The JobManager then can decide when to schedule the partition consumers of the given session. * - * @param leaderSessionID The leader id of JobManager * @param partitionID The partition which has already produced data * @param timeout before the rpc call fails * @return Future acknowledge of the schedule or update operation */ CompletableFuture scheduleOrUpdateConsumers( - final UUID leaderSessionID, final ResultPartitionID partitionID, @RpcTimeout final Time timeout); @@ -132,13 +115,11 @@ CompletableFuture scheduleOrUpdateConsumers( /** * Disconnects the resource manager from the job manager because of the given cause. * - * @param jobManagerLeaderId identifying the job manager leader id - * @param resourceManagerLeaderId identifying the resource manager leader id + * @param resourceManagerId identifying the resource manager leader id * @param cause of the disconnect */ void disconnectResourceManager( - final UUID jobManagerLeaderId, - final UUID resourceManagerLeaderId, + final ResourceManagerId resourceManagerId, final Exception cause); /** @@ -183,14 +164,12 @@ void notifyKvStateUnregistered( * * @param taskManagerId identifying the task manager * @param slots to offer to the job manager - * @param leaderId identifying the job leader * @param timeout for the rpc call * @return Future set of accepted slots. */ CompletableFuture> offerSlots( final ResourceID taskManagerId, final Iterable slots, - final UUID leaderId, @RpcTimeout final Time timeout); /** @@ -198,12 +177,10 @@ CompletableFuture> offerSlots( * * @param taskManagerId identifying the task manager * @param allocationId identifying the slot to fail - * @param leaderId identifying the job leader * @param cause of the failing */ void failSlot(final ResourceID taskManagerId, final AllocationID allocationId, - final UUID leaderId, final Exception cause); /** @@ -211,14 +188,12 @@ void failSlot(final ResourceID taskManagerId, * * @param taskManagerRpcAddress the rpc address of the task manager * @param taskManagerLocation location of the task manager - * @param leaderId identifying the job leader * @param timeout for the rpc call * @return Future registration response indicating whether the registration was successful or not */ CompletableFuture registerTaskManager( final String taskManagerRpcAddress, final TaskManagerLocation taskManagerLocation, - final UUID leaderId, @RpcTimeout final Time timeout); /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterId.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterId.java new file mode 100644 index 0000000000000..ffd53b31b6864 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterId.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobmaster; + +import org.apache.flink.util.AbstractID; + +import java.util.UUID; + +/** + * The {@link JobMaster} fencing token. + */ +public class JobMasterId extends AbstractID { + + private static final long serialVersionUID = -933276753644003754L; + + public static final JobMasterId INITIAL_JOB_MASTER_ID = new JobMasterId(0L, 0L); + + public JobMasterId(byte[] bytes) { + super(bytes); + } + + public JobMasterId(long lowerPart, long upperPart) { + super(lowerPart, upperPart); + } + + public JobMasterId(AbstractID id) { + super(id); + } + + public JobMasterId() { + } + + public JobMasterId(UUID uuid) { + this(uuid.getLeastSignificantBits(), uuid.getMostSignificantBits()); + } + + public UUID toUUID() { + return new UUID(getUpperPart(), getLowerPart()); + } + + public static JobMasterId generate() { + return new JobMasterId(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterRegistrationSuccess.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterRegistrationSuccess.java index a7a622465eccb..94ecfd2e92574 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterRegistrationSuccess.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMasterRegistrationSuccess.java @@ -20,8 +20,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.registration.RegistrationResponse; - -import java.util.UUID; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -34,16 +33,16 @@ public class JobMasterRegistrationSuccess extends RegistrationResponse.Success { private final long heartbeatInterval; - private final UUID resourceManagerLeaderId; + private final ResourceManagerId resourceManagerId; private final ResourceID resourceManagerResourceId; public JobMasterRegistrationSuccess( final long heartbeatInterval, - final UUID resourceManagerLeaderId, + final ResourceManagerId resourceManagerId, final ResourceID resourceManagerResourceId) { this.heartbeatInterval = heartbeatInterval; - this.resourceManagerLeaderId = checkNotNull(resourceManagerLeaderId); + this.resourceManagerId = checkNotNull(resourceManagerId); this.resourceManagerResourceId = checkNotNull(resourceManagerResourceId); } @@ -56,8 +55,8 @@ public long getHeartbeatInterval() { return heartbeatInterval; } - public UUID getResourceManagerLeaderId() { - return resourceManagerLeaderId; + public ResourceManagerId getResourceManagerId() { + return resourceManagerId; } public ResourceID getResourceManagerResourceId() { @@ -68,7 +67,7 @@ public ResourceID getResourceManagerResourceId() { public String toString() { return "JobMasterRegistrationSuccess{" + "heartbeatInterval=" + heartbeatInterval + - ", resourceManagerLeaderId=" + resourceManagerLeaderId + + ", resourceManagerLeaderId=" + resourceManagerId + ", resourceManagerResourceId=" + resourceManagerResourceId + '}'; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/RpcTaskManagerGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/RpcTaskManagerGateway.java index e93c907d36b6b..8967aae9a2f8b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/RpcTaskManagerGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/RpcTaskManagerGateway.java @@ -33,7 +33,6 @@ import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.util.Preconditions; -import java.util.UUID; import java.util.concurrent.CompletableFuture; /** @@ -43,11 +42,11 @@ public class RpcTaskManagerGateway implements TaskManagerGateway { private final TaskExecutorGateway taskExecutorGateway; - private final UUID leaderId; + private final JobMasterId jobMasterId; - public RpcTaskManagerGateway(TaskExecutorGateway taskExecutorGateway, UUID leaderId) { + public RpcTaskManagerGateway(TaskExecutorGateway taskExecutorGateway, JobMasterId jobMasterId) { this.taskExecutorGateway = Preconditions.checkNotNull(taskExecutorGateway); - this.leaderId = Preconditions.checkNotNull(leaderId); + this.jobMasterId = Preconditions.checkNotNull(jobMasterId); } @Override @@ -87,7 +86,7 @@ public CompletableFuture requestStackTraceSample( @Override public CompletableFuture submitTask(TaskDeploymentDescriptor tdd, Time timeout) { - return taskExecutorGateway.submitTask(tdd, leaderId, timeout); + return taskExecutorGateway.submitTask(tdd, jobMasterId, timeout); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java index 9721c2cd6f306..65e3019951fcc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; /** @@ -36,7 +36,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements private static final long serialVersionUID = -7606214777192401493L; - private final SubtaskState subtaskState; + private final TaskStateSnapshot subtaskState; private final CheckpointMetrics checkpointMetrics; @@ -47,7 +47,7 @@ public AcknowledgeCheckpoint( ExecutionAttemptID taskExecutionId, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState) { + TaskStateSnapshot subtaskState) { super(job, taskExecutionId, checkpointId); @@ -64,7 +64,7 @@ public AcknowledgeCheckpoint(JobID jobId, ExecutionAttemptID taskExecutionId, lo // properties // ------------------------------------------------------------------------ - public SubtaskState getSubtaskState() { + public TaskStateSnapshot getSubtaskState() { return subtaskState; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java index 2e36e9e8428d7..2fe0587cd4bcc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniCluster.java @@ -38,6 +38,7 @@ import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.metrics.MetricRegistryConfiguration; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.ResourceManagerRunner; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; @@ -95,15 +96,13 @@ public class MiniCluster { @GuardedBy("lock") private ResourceManagerRunner[] resourceManagerRunners; - @GuardedBy("lock") - private TaskExecutor[] taskManagers; + private volatile TaskExecutor[] taskManagers; @GuardedBy("lock") private MiniClusterJobDispatcher jobDispatcher; /** Flag marking the mini cluster as started/running */ - @GuardedBy("lock") - private boolean running; + private volatile boolean running; // ------------------------------------------------------------------------ @@ -150,6 +149,8 @@ public MiniCluster(Configuration config) { @Deprecated public MiniCluster(Configuration config, boolean singleRpcService) { this(createConfig(config, singleRpcService)); + + running = false; } // ------------------------------------------------------------------------ @@ -352,6 +353,8 @@ private void shutdownInternally() throws Exception { if (tm != null) { try { tm.shutDown(); + // wait for the TaskManager to properly terminate + tm.getTerminationFuture().get(); } catch (Throwable t) { exception = firstOrSuppressed(t, exception); } @@ -419,14 +422,14 @@ public void waitUntilTaskManagerRegistrationsComplete() throws Exception { final LeaderAddressAndId addressAndId = addressAndIdFuture.get(); final ResourceManagerGateway resourceManager = - commonRpcService.connect(addressAndId.leaderAddress(), ResourceManagerGateway.class).get(); + commonRpcService.connect(addressAndId.leaderAddress(), new ResourceManagerId(addressAndId.leaderId()), ResourceManagerGateway.class).get(); final int numTaskManagersToWaitFor = taskManagers.length; // poll and wait until enough TaskManagers are available while (true) { int numTaskManagersAvailable = - resourceManager.getNumberOfRegisteredTaskManagers(addressAndId.leaderId()).get(); + resourceManager.getNumberOfRegisteredTaskManagers().get(); if (numTaskManagersAvailable >= numTaskManagersToWaitFor) { break; @@ -645,17 +648,18 @@ private TerminatingFatalErrorHandler(int index) { @Override public void onFatalError(Throwable exception) { - LOG.error("TaskManager #{} failed.", index, exception); + // first check if we are still running + if (running) { + LOG.error("TaskManager #{} failed.", index, exception); - try { - synchronized (lock) { - // note: if not running (after shutdown) taskManagers may be null! - if (running && taskManagers[index] != null) { - taskManagers[index].shutDown(); - } + // let's check if there are still TaskManagers because there could be a concurrent + // shut down operation taking place + TaskExecutor[] currentTaskManagers = taskManagers; + + if (currentTaskManagers != null) { + // the shutDown is asynchronous + currentTaskManagers[index].shutDown(); } - } catch (Exception e) { - LOG.error("TaskManager #{} could not be properly terminated.", index, e); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniClusterJobDispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniClusterJobDispatcher.java index 2bb94f2dec142..60d9a6692609a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniClusterJobDispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/minicluster/MiniClusterJobDispatcher.java @@ -33,6 +33,8 @@ import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -156,7 +158,7 @@ public MiniClusterJobDispatcher( * Shuts down the mini cluster dispatcher. If a job is currently running, that job will be * terminally failed. */ - public void shutdown() { + public void shutdown() throws Exception { synchronized (lock) { if (!shutdown) { shutdown = true; @@ -166,14 +168,31 @@ public void shutdown() { // in this shutdown code we copy the references to the stack first, // to avoid concurrent modification + Throwable exception = null; + JobManagerRunner[] runners = this.runners; if (runners != null) { this.runners = null; for (JobManagerRunner runner : runners) { - runner.shutdown(); + try { + runner.shutdown(); + } catch (Throwable e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } } } + + // shut down the JobManagerServices + try { + jobManagerServices.shutdown(); + } catch (Throwable throwable) { + exception = ExceptionUtils.firstOrSuppressed(throwable, exception); + } + + if (exception != null) { + throw new FlinkException("Could not properly terminate all JobManagerRunners.", exception); + } } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java index 60099d2a35b1a..1a84e831959fe 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.util.Preconditions; +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; @@ -39,7 +40,6 @@ import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; import akka.dispatch.Futures; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.nio.channels.ClosedChannelException; import java.util.ArrayDeque; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java index 2889e2e21aedf..7cf2148273afe 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequest; import org.apache.flink.util.Preconditions; +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; @@ -35,7 +36,6 @@ import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RegisteredRpcConnection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RegisteredRpcConnection.java index da46e1c10d963..a585f0dff6881 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RegisteredRpcConnection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RegisteredRpcConnection.java @@ -23,7 +23,7 @@ import org.slf4j.Logger; -import java.util.UUID; +import java.io.Serializable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -39,16 +39,17 @@ * The RPC connection can be closed, for example when the target where it tries to register * at looses leader status. * - * @param The type of the gateway to connect to. - * @param The type of the successful registration responses. + * @param The type of the fencing token + * @param The type of the gateway to connect to. + * @param The type of the successful registration responses. */ -public abstract class RegisteredRpcConnection { +public abstract class RegisteredRpcConnection { /** The logger for all log messages of this class. */ protected final Logger log; - /** The target component leaderID, for example the ResourceManager leaderID. */ - private final UUID targetLeaderId; + /** The fencing token fo the remote component. */ + private final F fencingToken; /** The target component Address, for example the ResourceManager Address. */ private final String targetAddress; @@ -57,20 +58,20 @@ public abstract class RegisteredRpcConnection pendingRegistration; + private RetryingRegistration pendingRegistration; /** The gateway to register, it's null until the registration is completed. */ - private volatile Gateway targetGateway; + private volatile G targetGateway; /** Flag indicating that the RPC connection is closed. */ private volatile boolean closed; // ------------------------------------------------------------------------ - public RegisteredRpcConnection(Logger log, String targetAddress, UUID targetLeaderId, Executor executor) { + public RegisteredRpcConnection(Logger log, String targetAddress, F fencingToken, Executor executor) { this.log = checkNotNull(log); this.targetAddress = checkNotNull(targetAddress); - this.targetLeaderId = checkNotNull(targetLeaderId); + this.fencingToken = checkNotNull(fencingToken); this.executor = checkNotNull(executor); } @@ -86,10 +87,10 @@ public void start() { pendingRegistration = checkNotNull(generateRegistration()); pendingRegistration.startRegistration(); - CompletableFuture> future = pendingRegistration.getFuture(); + CompletableFuture> future = pendingRegistration.getFuture(); future.whenCompleteAsync( - (Tuple2 result, Throwable failure) -> { + (Tuple2 result, Throwable failure) -> { // this future should only ever fail if there is a bug, not if the registration is declined if (failure != null) { onRegistrationFailure(failure); @@ -103,12 +104,12 @@ public void start() { /** * This method generate a specific Registration, for example TaskExecutor Registration at the ResourceManager. */ - protected abstract RetryingRegistration generateRegistration(); + protected abstract RetryingRegistration generateRegistration(); /** * This method handle the Registration Response. */ - protected abstract void onRegistrationSuccess(Success success); + protected abstract void onRegistrationSuccess(S success); /** * This method handle the Registration failure. @@ -135,8 +136,8 @@ public boolean isClosed() { // Properties // ------------------------------------------------------------------------ - public UUID getTargetLeaderId() { - return targetLeaderId; + public F getTargetLeaderId() { + return fencingToken; } public String getTargetAddress() { @@ -146,7 +147,7 @@ public String getTargetAddress() { /** * Gets the RegisteredGateway. This returns null until the registration is completed. */ - public Gateway getTargetGateway() { + public G getTargetGateway() { return targetGateway; } @@ -158,7 +159,7 @@ public boolean isConnected() { @Override public String toString() { - String connectionInfo = "(ADDRESS: " + targetAddress + " LEADERID: " + targetLeaderId + ")"; + String connectionInfo = "(ADDRESS: " + targetAddress + " FENCINGTOKEN: " + fencingToken + ")"; if (isConnected()) { connectionInfo = "RPC connection to " + targetGateway.getClass().getSimpleName() + " " + connectionInfo; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RetryingRegistration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RetryingRegistration.java index 6a18ffd78b5ed..ce4a798b0d35c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RetryingRegistration.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/registration/RetryingRegistration.java @@ -19,12 +19,13 @@ package org.apache.flink.runtime.registration; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcService; import org.slf4j.Logger; -import java.util.UUID; +import java.io.Serializable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -43,10 +44,11 @@ * The registration can be canceled, for example when the target where it tries to register * at looses leader status. * - * @param The type of the gateway to connect to. - * @param The type of the successful registration responses. + * @param The type of the fencing token + * @param The type of the gateway to connect to. + * @param The type of the successful registration responses. */ -public abstract class RetryingRegistration { +public abstract class RetryingRegistration { // ------------------------------------------------------------------------ // default configuration values @@ -74,13 +76,13 @@ public abstract class RetryingRegistration targetType; + private final Class targetType; private final String targetAddress; - private final UUID leaderId; + private final F fencingToken; - private final CompletableFuture> completionFuture; + private final CompletableFuture> completionFuture; private final long initialRegistrationTimeout; @@ -98,10 +100,10 @@ public RetryingRegistration( Logger log, RpcService rpcService, String targetName, - Class targetType, + Class targetType, String targetAddress, - UUID leaderId) { - this(log, rpcService, targetName, targetType, targetAddress, leaderId, + F fencingToken) { + this(log, rpcService, targetName, targetType, targetAddress, fencingToken, INITIAL_REGISTRATION_TIMEOUT_MILLIS, MAX_REGISTRATION_TIMEOUT_MILLIS, ERROR_REGISTRATION_DELAY_MILLIS, REFUSED_REGISTRATION_DELAY_MILLIS); } @@ -110,9 +112,9 @@ public RetryingRegistration( Logger log, RpcService rpcService, String targetName, - Class targetType, + Class targetType, String targetAddress, - UUID leaderId, + F fencingToken, long initialRegistrationTimeout, long maxRegistrationTimeout, long delayOnError, @@ -128,7 +130,7 @@ public RetryingRegistration( this.targetName = checkNotNull(targetName); this.targetType = checkNotNull(targetType); this.targetAddress = checkNotNull(targetAddress); - this.leaderId = checkNotNull(leaderId); + this.fencingToken = checkNotNull(fencingToken); this.initialRegistrationTimeout = initialRegistrationTimeout; this.maxRegistrationTimeout = maxRegistrationTimeout; this.delayOnError = delayOnError; @@ -141,7 +143,7 @@ public RetryingRegistration( // completion and cancellation // ------------------------------------------------------------------------ - public CompletableFuture> getFuture() { + public CompletableFuture> getFuture() { return completionFuture; } @@ -165,7 +167,7 @@ public boolean isCanceled() { // ------------------------------------------------------------------------ protected abstract CompletableFuture invokeRegistration( - Gateway gateway, UUID leaderId, long timeoutMillis) throws Exception; + G gateway, F fencingToken, long timeoutMillis) throws Exception; /** * This method resolves the target address to a callable gateway and starts the @@ -175,11 +177,20 @@ protected abstract CompletableFuture invokeRegistration( public void startRegistration() { try { // trigger resolution of the resource manager address to a callable gateway - CompletableFuture resourceManagerFuture = rpcService.connect(targetAddress, targetType); + final CompletableFuture resourceManagerFuture; + + if (FencedRpcGateway.class.isAssignableFrom(targetType)) { + resourceManagerFuture = (CompletableFuture) rpcService.connect( + targetAddress, + fencingToken, + targetType.asSubclass(FencedRpcGateway.class)); + } else { + resourceManagerFuture = rpcService.connect(targetAddress, targetType); + } // upon success, start the registration attempts CompletableFuture resourceManagerAcceptFuture = resourceManagerFuture.thenAcceptAsync( - (Gateway result) -> { + (G result) -> { log.info("Resolved {} address, beginning registration", targetName); register(result, 1, initialRegistrationTimeout); }, @@ -206,7 +217,7 @@ public void startRegistration() { * depending on the result. */ @SuppressWarnings("unchecked") - private void register(final Gateway gateway, final int attempt, final long timeoutMillis) { + private void register(final G gateway, final int attempt, final long timeoutMillis) { // eager check for canceling to avoid some unnecessary work if (canceled) { return; @@ -214,7 +225,7 @@ private void register(final Gateway gateway, final int attempt, final long timeo try { log.info("Registration at {} attempt {} (timeout={}ms)", targetName, attempt, timeoutMillis); - CompletableFuture registrationFuture = invokeRegistration(gateway, leaderId, timeoutMillis); + CompletableFuture registrationFuture = invokeRegistration(gateway, fencingToken, timeoutMillis); // if the registration was successful, let the TaskExecutor know CompletableFuture registrationAcceptFuture = registrationFuture.thenAcceptAsync( @@ -222,7 +233,7 @@ private void register(final Gateway gateway, final int attempt, final long timeo if (!isCanceled()) { if (result instanceof RegistrationResponse.Success) { // registration successful! - Success success = (Success) result; + S success = (S) result; completionFuture.complete(Tuple2.of(gateway, success)); } else { @@ -274,7 +285,7 @@ private void register(final Gateway gateway, final int attempt, final long timeo } } - private void registerLater(final Gateway gateway, final int attempt, final long timeoutMillis, long delay) { + private void registerLater(final G gateway, final int attempt, final long timeoutMillis, long delay) { rpcService.scheduleRunnable(new Runnable() { @Override public void run() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdActions.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdActions.java index 4ca62090a098b..565cd82918739 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdActions.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdActions.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.resourcemanager; import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobmaster.JobMasterId; import java.util.UUID; @@ -31,9 +32,9 @@ public interface JobLeaderIdActions { * Callback when a monitored job leader lost its leadership. * * @param jobId identifying the job whose leader lost leadership - * @param oldJobLeaderId of the job manager which lost leadership + * @param oldJobMasterId of the job manager which lost leadership */ - void jobLeaderLostLeadership(JobID jobId, UUID oldJobLeaderId); + void jobLeaderLostLeadership(JobID jobId, JobMasterId oldJobMasterId); /** * Notify a job timeout. The job is identified by the given JobID. In order to check diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdService.java index aaa72d9fc78ad..da0a7fd5b5b73 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdService.java @@ -22,14 +22,17 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; + import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -49,17 +52,17 @@ public class JobLeaderIdService { private static final Logger LOG = LoggerFactory.getLogger(JobLeaderIdService.class); - /** High availability services to use by this service */ + /** High availability services to use by this service. */ private final HighAvailabilityServices highAvailabilityServices; private final ScheduledExecutor scheduledExecutor; private final Time jobTimeout; - /** Map of currently monitored jobs */ + /** Map of currently monitored jobs. */ private final Map jobLeaderIdListeners; - /** Actions to call when the job leader changes */ + /** Actions to call when the job leader changes. */ private JobLeaderIdActions jobLeaderIdActions; public JobLeaderIdService( @@ -178,14 +181,14 @@ public boolean containsJob(JobID jobId) { return jobLeaderIdListeners.containsKey(jobId); } - public CompletableFuture getLeaderId(JobID jobId) throws Exception { + public CompletableFuture getLeaderId(JobID jobId) throws Exception { if (!jobLeaderIdListeners.containsKey(jobId)) { addJob(jobId); } JobLeaderIdListener listener = jobLeaderIdListeners.get(jobId); - return listener.getLeaderIdFuture(); + return listener.getLeaderIdFuture().thenApply((UUID id) -> id != null ? new JobMasterId(id) : null); } public boolean isValidTimeout(JobID jobId, UUID timeoutId) { @@ -216,15 +219,14 @@ private final class JobLeaderIdListener implements LeaderRetrievalListener { private volatile CompletableFuture leaderIdFuture; private volatile boolean running = true; - /** Null if no timeout has been scheduled; otherwise non null */ + /** Null if no timeout has been scheduled; otherwise non null. */ @Nullable private volatile ScheduledFuture timeoutFuture; - /** Null if no timeout has been scheduled; otherwise non null */ + /** Null if no timeout has been scheduled; otherwise non null. */ @Nullable private volatile UUID timeoutId; - private JobLeaderIdListener( JobID jobId, JobLeaderIdActions listenerJobLeaderIdActions, @@ -279,7 +281,7 @@ public void notifyLeaderAddress(String leaderAddress, UUID leaderSessionId) { if (previousJobLeaderId != null && !previousJobLeaderId.equals(leaderSessionId)) { // we had a previous job leader, so notify about his lost leadership - listenerJobLeaderIdActions.jobLeaderLostLeadership(jobId, previousJobLeaderId); + listenerJobLeaderIdActions.jobLeaderLostLeadership(jobId, new JobMasterId(previousJobLeaderId)); if (null == leaderSessionId) { // No current leader active ==> Set a timeout for the job diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManager.java index c2b0590e49b81..87cf7d10fbcaa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManager.java @@ -33,29 +33,26 @@ import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.heartbeat.HeartbeatTarget; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; -import org.apache.flink.runtime.highavailability.LeaderIdMismatchException; import org.apache.flink.runtime.instance.InstanceID; +import org.apache.flink.runtime.jobmaster.JobMaster; +import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.JobMasterRegistrationSuccess; import org.apache.flink.runtime.leaderelection.LeaderContender; import org.apache.flink.runtime.leaderelection.LeaderElectionService; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; +import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.resourcemanager.exceptions.ResourceManagerException; import org.apache.flink.runtime.resourcemanager.registration.JobManagerRegistration; import org.apache.flink.runtime.resourcemanager.registration.WorkerRegistration; -import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.resourcemanager.slotmanager.ResourceManagerActions; +import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManagerException; import org.apache.flink.runtime.rpc.FatalErrorHandler; -import org.apache.flink.runtime.rpc.RpcEndpoint; +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.runtime.jobmaster.JobMaster; -import org.apache.flink.runtime.jobmaster.JobMasterGateway; -import org.apache.flink.runtime.registration.RegistrationResponse; - -import org.apache.flink.runtime.rpc.exceptions.LeaderSessionIDException; import org.apache.flink.runtime.taskexecutor.SlotReport; -import org.apache.flink.runtime.taskexecutor.TaskExecutor; import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskexecutor.TaskExecutorRegistrationSuccess; import org.apache.flink.util.ExceptionUtils; @@ -76,22 +73,22 @@ * ResourceManager implementation. The resource manager is responsible for resource de-/allocation * and bookkeeping. * - * It offers the following methods as part of its rpc interface to interact with him remotely: + *

It offers the following methods as part of its rpc interface to interact with him remotely: *

    - *
  • {@link #registerJobManager(UUID, UUID, ResourceID, String, JobID, Time)} registers a {@link JobMaster} at the resource manager
  • - *
  • {@link #requestSlot(UUID, UUID, SlotRequest, Time)} requests a slot from the resource manager
  • + *
  • {@link #registerJobManager(JobMasterId, ResourceID, String, JobID, Time)} registers a {@link JobMaster} at the resource manager
  • + *
  • {@link #requestSlot(JobMasterId, SlotRequest, Time)} requests a slot from the resource manager
  • *
*/ public abstract class ResourceManager - extends RpcEndpoint + extends FencedRpcEndpoint implements ResourceManagerGateway, LeaderContender { public static final String RESOURCE_MANAGER_NAME = "resourcemanager"; - /** Unique id of the resource manager */ + /** Unique id of the resource manager. */ private final ResourceID resourceId; - /** Configuration of the resource manager */ + /** Configuration of the resource manager. */ private final ResourceManagerConfiguration resourceManagerConfiguration; /** All currently registered JobMasterGateways scoped by JobID. */ @@ -100,7 +97,7 @@ public abstract class ResourceManager /** All currently registered JobMasterGateways scoped by ResourceID. */ private final Map jmResourceIdRegistrations; - /** Service to retrieve the job leader ids */ + /** Service to retrieve the job leader ids. */ private final JobLeaderIdService jobLeaderIdService; /** All currently registered TaskExecutors with there framework specific worker information. */ @@ -115,21 +112,18 @@ public abstract class ResourceManager /** The heartbeat manager with job managers. */ private final HeartbeatManager jobManagerHeartbeatManager; - /** Registry to use for metrics */ + /** Registry to use for metrics. */ private final MetricRegistry metricRegistry; - /** Fatal error handler */ + /** Fatal error handler. */ private final FatalErrorHandler fatalErrorHandler; - /** The slot manager maintains the available slots */ + /** The slot manager maintains the available slots. */ private final SlotManager slotManager; /** The service to elect a ResourceManager leader. */ private LeaderElectionService leaderElectionService; - /** ResourceManager's leader session id which is updated on leader election. */ - private volatile UUID leaderSessionId; - /** All registered listeners for status updates of the ResourceManager. */ private ConcurrentMap infoMessageListeners; @@ -145,7 +139,7 @@ public ResourceManager( JobLeaderIdService jobLeaderIdService, FatalErrorHandler fatalErrorHandler) { - super(rpcService, resourceManagerEndpointId); + super(rpcService, resourceManagerEndpointId, ResourceManagerId.generate()); this.resourceId = checkNotNull(resourceId); this.resourceManagerConfiguration = checkNotNull(resourceManagerConfiguration); @@ -170,7 +164,6 @@ public ResourceManager( this.jobManagerRegistrations = new HashMap<>(4); this.jmResourceIdRegistrations = new HashMap<>(4); this.taskExecutors = new HashMap<>(8); - this.leaderSessionId = null; infoMessageListeners = new ConcurrentHashMap<>(8); } @@ -247,147 +240,108 @@ public void postStop() throws Exception { @Override public CompletableFuture registerJobManager( - final UUID resourceManagerLeaderId, - final UUID jobManagerLeaderId, + final JobMasterId jobMasterId, final ResourceID jobManagerResourceId, final String jobManagerAddress, final JobID jobId, final Time timeout) { - checkNotNull(resourceManagerLeaderId); - checkNotNull(jobManagerLeaderId); + checkNotNull(jobMasterId); checkNotNull(jobManagerResourceId); checkNotNull(jobManagerAddress); checkNotNull(jobId); - if (isValid(resourceManagerLeaderId)) { - if (!jobLeaderIdService.containsJob(jobId)) { - try { - jobLeaderIdService.addJob(jobId); - } catch (Exception e) { - ResourceManagerException exception = new ResourceManagerException("Could not add the job " + - jobId + " to the job id leader service.", e); + if (!jobLeaderIdService.containsJob(jobId)) { + try { + jobLeaderIdService.addJob(jobId); + } catch (Exception e) { + ResourceManagerException exception = new ResourceManagerException("Could not add the job " + + jobId + " to the job id leader service.", e); - onFatalErrorAsync(exception); + onFatalError(exception); - log.error("Could not add job {} to job leader id service.", jobId, e); - return FutureUtils.completedExceptionally(exception); - } + log.error("Could not add job {} to job leader id service.", jobId, e); + return FutureUtils.completedExceptionally(exception); } + } - log.info("Registering job manager {}@{} for job {}.", jobManagerLeaderId, jobManagerAddress, jobId); - - CompletableFuture jobLeaderIdFuture; + log.info("Registering job manager {}@{} for job {}.", jobMasterId, jobManagerAddress, jobId); - try { - jobLeaderIdFuture = jobLeaderIdService.getLeaderId(jobId); - } catch (Exception e) { - // we cannot check the job leader id so let's fail - // TODO: Maybe it's also ok to skip this check in case that we cannot check the leader id - ResourceManagerException exception = new ResourceManagerException("Cannot obtain the " + - "job leader id future to verify the correct job leader.", e); + CompletableFuture jobMasterIdFuture; - onFatalErrorAsync(exception); + try { + jobMasterIdFuture = jobLeaderIdService.getLeaderId(jobId); + } catch (Exception e) { + // we cannot check the job leader id so let's fail + // TODO: Maybe it's also ok to skip this check in case that we cannot check the leader id + ResourceManagerException exception = new ResourceManagerException("Cannot obtain the " + + "job leader id future to verify the correct job leader.", e); - log.debug("Could not obtain the job leader id future to verify the correct job leader."); - return FutureUtils.completedExceptionally(exception); - } + onFatalError(exception); - CompletableFuture jobMasterGatewayFuture = getRpcService().connect(jobManagerAddress, JobMasterGateway.class); - - CompletableFuture registrationResponseFuture = jobMasterGatewayFuture.thenCombineAsync( - jobLeaderIdFuture, - (JobMasterGateway jobMasterGateway, UUID jobLeaderId) -> { - if (isValid(resourceManagerLeaderId)) { - if (Objects.equals(jobLeaderId, jobManagerLeaderId)) { - return registerJobMasterInternal( - jobMasterGateway, - jobLeaderId, - jobId, - jobManagerAddress, - jobManagerResourceId); - } else { - log.debug("The job manager leader id {} did not match the job " + - "leader id {}.", jobManagerLeaderId, jobLeaderId); - return new RegistrationResponse.Decline("Job manager leader id did not match."); - } - } else { - log.debug("The resource manager leader id changed {}. Discarding job " + - "manager registration from {}.", getLeaderSessionId(), jobManagerAddress); - return new RegistrationResponse.Decline("Resource manager leader id changed."); - } - }, - getMainThreadExecutor()); + log.debug("Could not obtain the job leader id future to verify the correct job leader."); + return FutureUtils.completedExceptionally(exception); + } - // handle exceptions which might have occurred in one of the futures inputs of combine - return registrationResponseFuture.handleAsync( - (RegistrationResponse registrationResponse, Throwable throwable) -> { - if (throwable != null) { - if (log.isDebugEnabled()) { - log.debug("Registration of job manager {}@{} failed.", jobManagerLeaderId, jobManagerAddress, throwable); - } else { - log.info("Registration of job manager {}@{} failed.", jobManagerLeaderId, jobManagerAddress); - } + CompletableFuture jobMasterGatewayFuture = getRpcService().connect(jobManagerAddress, jobMasterId, JobMasterGateway.class); - return new RegistrationResponse.Decline(throwable.getMessage()); + CompletableFuture registrationResponseFuture = jobMasterGatewayFuture.thenCombineAsync( + jobMasterIdFuture, + (JobMasterGateway jobMasterGateway, JobMasterId currentJobMasterId) -> { + if (Objects.equals(currentJobMasterId, jobMasterId)) { + return registerJobMasterInternal( + jobMasterGateway, + jobId, + jobManagerAddress, + jobManagerResourceId); + } else { + log.debug("The current JobMaster leader id {} did not match the received " + + "JobMaster id {}.", jobMasterId, currentJobMasterId); + return new RegistrationResponse.Decline("Job manager leader id did not match."); + } + }, + getMainThreadExecutor()); + + // handle exceptions which might have occurred in one of the futures inputs of combine + return registrationResponseFuture.handleAsync( + (RegistrationResponse registrationResponse, Throwable throwable) -> { + if (throwable != null) { + if (log.isDebugEnabled()) { + log.debug("Registration of job manager {}@{} failed.", jobMasterId, jobManagerAddress, throwable); } else { - return registrationResponse; + log.info("Registration of job manager {}@{} failed.", jobMasterId, jobManagerAddress); } - }, - getRpcService().getExecutor()); - } else { - log.debug("Discard register job manager message from {}, because the leader id " + - "{} did not match the expected leader id {}.", jobManagerAddress, - resourceManagerLeaderId, leaderSessionId); - return CompletableFuture.completedFuture( - new RegistrationResponse.Decline("Resource manager leader id did not match.")); - } + return new RegistrationResponse.Decline(throwable.getMessage()); + } else { + return registrationResponse; + } + }, + getRpcService().getExecutor()); } - /** - * Register a {@link TaskExecutor} at the resource manager - * - * @param resourceManagerLeaderId The fencing token for the ResourceManager leader - * @param taskExecutorAddress The address of the TaskExecutor that registers - * @param taskExecutorResourceId The resource ID of the TaskExecutor that registers - * - * @return The response by the ResourceManager. - */ @Override public CompletableFuture registerTaskExecutor( - final UUID resourceManagerLeaderId, final String taskExecutorAddress, final ResourceID taskExecutorResourceId, final SlotReport slotReport, final Time timeout) { - if (Objects.equals(leaderSessionId, resourceManagerLeaderId)) { - CompletableFuture taskExecutorGatewayFuture = getRpcService().connect(taskExecutorAddress, TaskExecutorGateway.class); + CompletableFuture taskExecutorGatewayFuture = getRpcService().connect(taskExecutorAddress, TaskExecutorGateway.class); - return taskExecutorGatewayFuture.handleAsync( - (TaskExecutorGateway taskExecutorGateway, Throwable throwable) -> { - if (throwable != null) { - return new RegistrationResponse.Decline(throwable.getMessage()); - } else { - return registerTaskExecutorInternal( - taskExecutorGateway, - taskExecutorAddress, - taskExecutorResourceId, - slotReport); - } - }, - getMainThreadExecutor()); - } else { - log.warn("Discard registration from TaskExecutor {} at ({}) because the expected leader session ID {} did " + - "not equal the received leader session ID {}", - taskExecutorResourceId, taskExecutorAddress, leaderSessionId, resourceManagerLeaderId); - - return CompletableFuture.completedFuture( - new RegistrationResponse.Decline("Discard registration because the leader id " + - resourceManagerLeaderId + " does not match the expected leader id " + - leaderSessionId + '.')); - } + return taskExecutorGatewayFuture.handleAsync( + (TaskExecutorGateway taskExecutorGateway, Throwable throwable) -> { + if (throwable != null) { + return new RegistrationResponse.Decline(throwable.getMessage()); + } else { + return registerTaskExecutorInternal( + taskExecutorGateway, + taskExecutorAddress, + taskExecutorResourceId, + slotReport); + } + }, + getMainThreadExecutor()); } @Override @@ -410,28 +364,17 @@ public void disconnectJobManager(final JobID jobId, final Exception cause) { closeJobManagerConnection(jobId, cause); } - /** - * Requests a slot from the resource manager. - * - * @param slotRequest Slot request - * @return Slot assignment - */ @Override public CompletableFuture requestSlot( - UUID jobMasterLeaderID, - UUID resourceManagerLeaderID, + JobMasterId jobMasterId, SlotRequest slotRequest, final Time timeout) { - if (!Objects.equals(resourceManagerLeaderID, leaderSessionId)) { - return FutureUtils.completedExceptionally(new LeaderSessionIDException(resourceManagerLeaderID, leaderSessionId)); - } - JobID jobId = slotRequest.getJobId(); JobManagerRegistration jobManagerRegistration = jobManagerRegistrations.get(jobId); if (null != jobManagerRegistration) { - if (Objects.equals(jobMasterLeaderID, jobManagerRegistration.getLeaderID())) { + if (Objects.equals(jobMasterId, jobManagerRegistration.getJobMasterId())) { log.info("Request slot with profile {} for job {} with allocation id {}.", slotRequest.getResourceProfile(), slotRequest.getJobId(), @@ -445,7 +388,8 @@ public CompletableFuture requestSlot( return CompletableFuture.completedFuture(Acknowledge.get()); } else { - return FutureUtils.completedExceptionally(new LeaderSessionIDException(jobMasterLeaderID, jobManagerRegistration.getLeaderID())); + return FutureUtils.completedExceptionally(new ResourceManagerException("The job leader's id " + + jobManagerRegistration.getJobMasterId() + " does not match the received id " + jobMasterId + '.')); } } else { @@ -453,51 +397,38 @@ public CompletableFuture requestSlot( } } - /** - * Notification from a TaskExecutor that a slot has become available - * @param resourceManagerLeaderId TaskExecutor's resource manager leader id - * @param instanceID TaskExecutor's instance id - * @param slotId The slot id of the available slot - */ @Override public void notifySlotAvailable( - final UUID resourceManagerLeaderId, final InstanceID instanceID, final SlotID slotId, final AllocationID allocationId) { - if (Objects.equals(resourceManagerLeaderId, leaderSessionId)) { - final ResourceID resourceId = slotId.getResourceID(); - WorkerRegistration registration = taskExecutors.get(resourceId); + final ResourceID resourceId = slotId.getResourceID(); + WorkerRegistration registration = taskExecutors.get(resourceId); - if (registration != null) { - InstanceID registrationId = registration.getInstanceID(); + if (registration != null) { + InstanceID registrationId = registration.getInstanceID(); - if (Objects.equals(registrationId, instanceID)) { - slotManager.freeSlot(slotId, allocationId); - } else { - log.debug("Invalid registration id for slot available message. This indicates an" + - " outdated request."); - } + if (Objects.equals(registrationId, instanceID)) { + slotManager.freeSlot(slotId, allocationId); } else { - log.debug("Could not find registration for resource id {}. Discarding the slot available" + - "message {}.", resourceId, slotId); + log.debug("Invalid registration id for slot available message. This indicates an" + + " outdated request."); } } else { - log.debug("Discarding notify slot available message for slot {}, because the " + - "leader id {} did not match the expected leader id {}.", slotId, - resourceManagerLeaderId, leaderSessionId); + log.debug("Could not find registration for resource id {}. Discarding the slot available" + + "message {}.", resourceId, slotId); } } /** - * Registers an info message listener + * Registers an info message listener. * * @param address address of infoMessage listener to register to this resource manager */ @Override public void registerInfoMessageListener(final String address) { - if(infoMessageListeners.containsKey(address)) { + if (infoMessageListeners.containsKey(address)) { log.warn("Receive a duplicate registration from info message listener on ({})", address); } else { CompletableFuture infoMessageListenerRpcGatewayFuture = getRpcService() @@ -517,7 +448,7 @@ public void registerInfoMessageListener(final String address) { } /** - * Unregisters an info message listener + * Unregisters an info message listener. * * @param address of the info message listener to unregister from this resource manager * @@ -528,7 +459,7 @@ public void unRegisterInfoMessageListener(final String address) { } /** - * Cleanup application and shut down cluster + * Cleanup application and shut down cluster. * * @param finalStatus of the Flink application * @param optionalDiagnostics for the Flink application @@ -545,27 +476,8 @@ public void shutDownCluster(final ApplicationStatus finalStatus, final String op } @Override - public CompletableFuture getNumberOfRegisteredTaskManagers(UUID requestLeaderSessionId) { - if (Objects.equals(leaderSessionId, requestLeaderSessionId)) { - return CompletableFuture.completedFuture(taskExecutors.size()); - } - else { - return FutureUtils.completedExceptionally(new LeaderIdMismatchException(leaderSessionId, requestLeaderSessionId)); - } - } - - // ------------------------------------------------------------------------ - // Testing methods - // ------------------------------------------------------------------------ - - /** - * Gets the leader session id of current resourceManager. - * - * @return return the leaderSessionId of current resourceManager, this returns null until the current resourceManager is granted leadership. - */ - @VisibleForTesting - UUID getLeaderSessionId() { - return leaderSessionId; + public CompletableFuture getNumberOfRegisteredTaskManagers() { + return CompletableFuture.completedFuture(taskExecutors.size()); } // ------------------------------------------------------------------------ @@ -576,7 +488,6 @@ UUID getLeaderSessionId() { * Registers a new JobMaster. * * @param jobMasterGateway to communicate with the registering JobMaster - * @param jobLeaderId leader id of the JobMaster * @param jobId of the job for which the JobMaster is responsible * @param jobManagerAddress address of the JobMaster * @param jobManagerResourceId ResourceID of the JobMaster @@ -584,16 +495,15 @@ UUID getLeaderSessionId() { */ private RegistrationResponse registerJobMasterInternal( final JobMasterGateway jobMasterGateway, - UUID jobLeaderId, JobID jobId, String jobManagerAddress, ResourceID jobManagerResourceId) { if (jobManagerRegistrations.containsKey(jobId)) { JobManagerRegistration oldJobManagerRegistration = jobManagerRegistrations.get(jobId); - if (oldJobManagerRegistration.getLeaderID().equals(jobLeaderId)) { + if (Objects.equals(oldJobManagerRegistration.getJobMasterId(), jobMasterGateway.getFencingToken())) { // same registration - log.debug("Job manager {}@{} was already registered.", jobLeaderId, jobManagerAddress); + log.debug("Job manager {}@{} was already registered.", jobMasterGateway.getFencingToken(), jobManagerAddress); } else { // tell old job manager that he is no longer the job leader disconnectJobManager( @@ -603,7 +513,6 @@ private RegistrationResponse registerJobMasterInternal( JobManagerRegistration jobManagerRegistration = new JobManagerRegistration( jobId, jobManagerResourceId, - jobLeaderId, jobMasterGateway); jobManagerRegistrations.put(jobId, jobManagerRegistration); jmResourceIdRegistrations.put(jobManagerResourceId, jobManagerRegistration); @@ -613,13 +522,12 @@ private RegistrationResponse registerJobMasterInternal( JobManagerRegistration jobManagerRegistration = new JobManagerRegistration( jobId, jobManagerResourceId, - jobLeaderId, jobMasterGateway); jobManagerRegistrations.put(jobId, jobManagerRegistration); jmResourceIdRegistrations.put(jobManagerResourceId, jobManagerRegistration); } - log.info("Registered job manager {}@{} for job {}.", jobLeaderId, jobManagerAddress, jobId); + log.info("Registered job manager {}@{} for job {}.", jobMasterGateway.getFencingToken(), jobManagerAddress, jobId); jobManagerHeartbeatManager.monitorTarget(jobManagerResourceId, new HeartbeatTarget() { @Override @@ -635,7 +543,7 @@ public void requestHeartbeat(ResourceID resourceID, Void payload) { return new JobMasterRegistrationSuccess( resourceManagerConfiguration.getHeartbeatInterval().toMilliseconds(), - getLeaderSessionId(), + getFencingToken(), resourceId); } @@ -706,8 +614,6 @@ private void clearState() { } catch (Exception e) { onFatalError(new ResourceManagerException("Could not properly clear the job leader id service.", e)); } - - leaderSessionId = null; } /** @@ -723,10 +629,10 @@ protected void closeJobManagerConnection(JobID jobId, Exception cause) { if (jobManagerRegistration != null) { final ResourceID jobManagerResourceId = jobManagerRegistration.getJobManagerResourceID(); final JobMasterGateway jobMasterGateway = jobManagerRegistration.getJobManagerGateway(); - final UUID jobManagerLeaderId = jobManagerRegistration.getLeaderID(); + final JobMasterId jobMasterId = jobManagerRegistration.getJobMasterId(); log.info("Disconnect job manager {}@{} for job {} from the resource manager.", - jobManagerLeaderId, + jobMasterId, jobMasterGateway.getAddress(), jobId); @@ -735,7 +641,7 @@ protected void closeJobManagerConnection(JobID jobId, Exception cause) { jmResourceIdRegistrations.remove(jobManagerResourceId); // tell the job manager about the disconnect - jobMasterGateway.disconnectResourceManager(jobManagerLeaderId, getLeaderSessionId(), cause); + jobMasterGateway.disconnectResourceManager(getFencingToken(), cause); } else { log.debug("There was no registered job manager for job {}.", jobId); } @@ -765,17 +671,6 @@ protected void closeTaskManagerConnection(final ResourceID resourceID, final Exc } } - /** - * Checks whether the given resource manager leader id is matching the current leader id and - * not null. - * - * @param resourceManagerLeaderId to check - * @return True if the given leader id matches the actual leader id and is not null; otherwise false - */ - protected boolean isValid(UUID resourceManagerLeaderId) { - return Objects.equals(resourceManagerLeaderId, leaderSessionId); - } - protected void removeJob(JobID jobId) { try { jobLeaderIdService.removeJob(jobId); @@ -788,17 +683,17 @@ protected void removeJob(JobID jobId) { } } - protected void jobLeaderLostLeadership(JobID jobId, UUID oldJobLeaderId) { + protected void jobLeaderLostLeadership(JobID jobId, JobMasterId oldJobMasterId) { if (jobManagerRegistrations.containsKey(jobId)) { JobManagerRegistration jobManagerRegistration = jobManagerRegistrations.get(jobId); - if (Objects.equals(jobManagerRegistration.getLeaderID(), oldJobLeaderId)) { + if (Objects.equals(jobManagerRegistration.getJobMasterId(), oldJobMasterId)) { disconnectJobManager(jobId, new Exception("Job leader lost leadership.")); } else { log.debug("Discarding job leader lost leadership, because a new job leader was found for job {}. ", jobId); } } else { - log.debug("Discard job leader lost leadership for outdated leader {} for job {}.", oldJobLeaderId, jobId); + log.debug("Discard job leader lost leadership for outdated leader {} for job {}.", oldJobMasterId, jobId); } } @@ -825,28 +720,15 @@ public void run() { /** * Notifies the ResourceManager that a fatal error has occurred and it cannot proceed. - * This method should be used when asynchronous threads want to notify the - * ResourceManager of a fatal error. - * - * @param t The exception describing the fatal error - */ - protected void onFatalErrorAsync(final Throwable t) { - runAsync(new Runnable() { - @Override - public void run() { - onFatalError(t); - } - }); - } - - /** - * Notifies the ResourceManager that a fatal error has occurred and it cannot proceed. - * This method must only be called from within the ResourceManager's main thread. * * @param t The exception describing the fatal error */ protected void onFatalError(Throwable t) { - log.error("Fatal error occurred.", t); + try { + log.error("Fatal error occurred in ResourceManager.", t); + } catch (Throwable ignored) {} + + // The fatal error handler implementation should make sure that this call is non-blocking fatalErrorHandler.onFatalError(t); } @@ -855,35 +737,32 @@ protected void onFatalError(Throwable t) { // ------------------------------------------------------------------------ /** - * Callback method when current resourceManager is granted leadership + * Callback method when current resourceManager is granted leadership. * * @param newLeaderSessionID unique leadershipID */ @Override public void grantLeadership(final UUID newLeaderSessionID) { - runAsync(new Runnable() { - @Override - public void run() { - log.info("ResourceManager {} was granted leadership with leader session ID {}", getAddress(), newLeaderSessionID); + runAsyncWithoutFencing( + () -> { + final ResourceManagerId newResourceManagerId = new ResourceManagerId(newLeaderSessionID); + + log.info("ResourceManager {} was granted leadership with fencing token {}", getAddress(), newResourceManagerId); // clear the state if we've been the leader before - if (leaderSessionId != null) { + if (getFencingToken() != null) { clearState(); } - leaderSessionId = newLeaderSessionID; + setFencingToken(newResourceManagerId); - slotManager.start(leaderSessionId, getMainThreadExecutor(), new ResourceManagerActionsImpl()); + slotManager.start(getFencingToken(), getMainThreadExecutor(), new ResourceManagerActionsImpl()); - getRpcService().execute(new Runnable() { - @Override - public void run() { + getRpcService().execute( + () -> // confirming the leader session ID might be blocking, - leaderElectionService.confirmLeaderSessionID(newLeaderSessionID); - } - }); - } - }); + leaderElectionService.confirmLeaderSessionID(newLeaderSessionID)); + }); } /** @@ -891,28 +770,28 @@ public void run() { */ @Override public void revokeLeadership() { - runAsync(new Runnable() { - @Override - public void run() { - log.info("ResourceManager {} was revoked leadership.", getAddress()); + runAsyncWithoutFencing( + () -> { + final ResourceManagerId newResourceManagerId = ResourceManagerId.generate(); + + log.info("ResourceManager {} was revoked leadership. Setting fencing token to {}.", getAddress(), newResourceManagerId); clearState(); - slotManager.suspend(); + setFencingToken(newResourceManagerId); - leaderSessionId = null; - } - }); + slotManager.suspend(); + }); } /** - * Handles error occurring in the leader election service + * Handles error occurring in the leader election service. * * @param exception Exception being thrown in the leader election service */ @Override public void handleError(final Exception exception) { - onFatalErrorAsync(new ResourceManagerException("Received an error from the LeaderElectionService.", exception)); + onFatalError(new ResourceManagerException("Received an error from the LeaderElectionService.", exception)); } // ------------------------------------------------------------------------ @@ -930,7 +809,7 @@ public void handleError(final Exception exception) { * The framework specific code for shutting down the application. This should report the * application's final status and shut down the resource manager cleanly. * - * This method also needs to make sure all pending containers that are not registered + *

This method also needs to make sure all pending containers that are not registered * yet are returned. * * @param finalStatus The application status to report. @@ -941,12 +820,18 @@ public void handleError(final Exception exception) { /** * Allocates a resource using the resource profile. + * * @param resourceProfile The resource description */ @VisibleForTesting public abstract void startNewWorker(ResourceProfile resourceProfile); - public abstract void stopWorker(InstanceID instanceId); + /** + * Deallocates a resource. + * + * @param resourceID The resource ID + */ + public abstract void stopWorker(ResourceID resourceID); /** * Callback when a worker was started. @@ -962,12 +847,36 @@ private class ResourceManagerActionsImpl implements ResourceManagerActions { @Override public void releaseResource(InstanceID instanceId) { - stopWorker(instanceId); + runAsync(new Runnable() { + @Override + public void run() { + ResourceID resourceID = null; + + for (Map.Entry> entry : taskExecutors.entrySet()) { + if (entry.getValue().getInstanceID().equals(instanceId)) { + resourceID = entry.getKey(); + break; + } + } + + if (resourceID != null) { + stopWorker(resourceID); + } + else { + log.warn("Ignoring request to release TaskManager with instance ID {} (not found).", instanceId); + } + } + }); } @Override public void allocateResource(ResourceProfile resourceProfile) throws ResourceManagerException { - startNewWorker(resourceProfile); + runAsync(new Runnable() { + @Override + public void run() { + startNewWorker(resourceProfile); + } + }); } @Override @@ -979,11 +888,11 @@ public void notifyAllocationFailure(JobID jobId, AllocationID allocationId, Exce private class JobLeaderIdActionsImpl implements JobLeaderIdActions { @Override - public void jobLeaderLostLeadership(final JobID jobId, final UUID oldJobLeaderId) { + public void jobLeaderLostLeadership(final JobID jobId, final JobMasterId oldJobMasterId) { runAsync(new Runnable() { @Override public void run() { - ResourceManager.this.jobLeaderLostLeadership(jobId, oldJobLeaderId); + ResourceManager.this.jobLeaderLostLeadership(jobId, oldJobMasterId); } }); } @@ -1002,7 +911,7 @@ public void run() { @Override public void handleError(Throwable error) { - onFatalErrorAsync(error); + onFatalError(error); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerGateway.java index 1ba68932b45a5..a957716aeb66e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerGateway.java @@ -25,82 +25,75 @@ import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.SlotID; import org.apache.flink.runtime.instance.InstanceID; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.messages.Acknowledge; -import org.apache.flink.runtime.rpc.RpcGateway; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; import org.apache.flink.runtime.jobmaster.JobMaster; import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.taskexecutor.SlotReport; +import org.apache.flink.runtime.taskexecutor.TaskExecutor; -import java.util.UUID; import java.util.concurrent.CompletableFuture; /** * The {@link ResourceManager}'s RPC gateway interface. */ -public interface ResourceManagerGateway extends RpcGateway { +public interface ResourceManagerGateway extends FencedRpcGateway { /** * Register a {@link JobMaster} at the resource manager. * - * @param resourceManagerLeaderId The fencing token for the ResourceManager leader - * @param jobMasterLeaderId The fencing token for the JobMaster leader - * @param jobMasterResourceId The resource ID of the JobMaster that registers - * @param jobMasterAddress The address of the JobMaster that registers - * @param jobID The Job ID of the JobMaster that registers - * @param timeout Timeout for the future to complete + * @param jobMasterId The fencing token for the JobMaster leader + * @param jobMasterResourceId The resource ID of the JobMaster that registers + * @param jobMasterAddress The address of the JobMaster that registers + * @param jobId The Job ID of the JobMaster that registers + * @param timeout Timeout for the future to complete * @return Future registration response */ CompletableFuture registerJobManager( - UUID resourceManagerLeaderId, - UUID jobMasterLeaderId, + JobMasterId jobMasterId, ResourceID jobMasterResourceId, String jobMasterAddress, - JobID jobID, + JobID jobId, @RpcTimeout Time timeout); /** * Requests a slot from the resource manager. * - * @param resourceManagerLeaderID leader if of the ResourceMaster - * @param jobMasterLeaderID leader if of the JobMaster + * @param jobMasterId id of the JobMaster * @param slotRequest The slot to request * @return The confirmation that the slot gets allocated */ CompletableFuture requestSlot( - UUID resourceManagerLeaderID, - UUID jobMasterLeaderID, + JobMasterId jobMasterId, SlotRequest slotRequest, @RpcTimeout Time timeout); /** - * Register a {@link org.apache.flink.runtime.taskexecutor.TaskExecutor} at the resource manager. + * Register a {@link TaskExecutor} at the resource manager. * - * @param resourceManagerLeaderId The fencing token for the ResourceManager leader - * @param taskExecutorAddress The address of the TaskExecutor that registers - * @param resourceID The resource ID of the TaskExecutor that registers - * @param slotReport The slot report containing free and allocated task slots - * @param timeout The timeout for the response. + * @param taskExecutorAddress The address of the TaskExecutor that registers + * @param resourceId The resource ID of the TaskExecutor that registers + * @param slotReport The slot report containing free and allocated task slots + * @param timeout The timeout for the response. * * @return The future to the response by the ResourceManager. */ CompletableFuture registerTaskExecutor( - UUID resourceManagerLeaderId, String taskExecutorAddress, - ResourceID resourceID, + ResourceID resourceId, SlotReport slotReport, @RpcTimeout Time timeout); /** * Sent by the TaskExecutor to notify the ResourceManager that a slot has become available. * - * @param resourceManagerLeaderId The ResourceManager leader id * @param instanceId TaskExecutor's instance id * @param slotID The SlotID of the freed slot * @param oldAllocationId to which the slot has been allocated */ void notifySlotAvailable( - UUID resourceManagerLeaderId, InstanceID instanceId, SlotID slotID, AllocationID oldAllocationId); @@ -130,10 +123,9 @@ void notifySlotAvailable( /** * Gets the currently registered number of TaskManagers. * - * @param leaderSessionId The leader session ID with which to address the ResourceManager. * @return The future to the number of registered TaskManagers. */ - CompletableFuture getNumberOfRegisteredTaskManagers(UUID leaderSessionId); + CompletableFuture getNumberOfRegisteredTaskManagers(); /** * Sends the heartbeat to resource manager from task manager diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerId.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerId.java new file mode 100644 index 0000000000000..3594e88e12249 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerId.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.resourcemanager; + +import org.apache.flink.util.AbstractID; + +import java.util.UUID; + +/** + * Fencing token for the {@link ResourceManager}. + */ +public class ResourceManagerId extends AbstractID { + + private static final long serialVersionUID = -6042820142662137374L; + + public ResourceManagerId(byte[] bytes) { + super(bytes); + } + + public ResourceManagerId(long lowerPart, long upperPart) { + super(lowerPart, upperPart); + } + + public ResourceManagerId(AbstractID id) { + super(id); + } + + public ResourceManagerId() { + } + + public ResourceManagerId(UUID uuid) { + this(uuid.getLeastSignificantBits(), uuid.getMostSignificantBits()); + } + + public UUID toUUID() { + return new UUID(getUpperPart(), getLowerPart()); + } + + public static ResourceManagerId generate() { + return new ResourceManagerId(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerRunner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerRunner.java index d0c411ceea6bb..ed6e18c60adff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerRunner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/ResourceManagerRunner.java @@ -20,16 +20,18 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.concurrent.FlinkFutureException; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.CompletableFuture; + /** * Simple {@link StandaloneResourceManager} runner. It instantiates the resource manager's services * and handles fatal errors by shutting the resource manager down. @@ -91,27 +93,23 @@ public void start() throws Exception { } public void shutDown() throws Exception { - shutDownInternally(); + // wait for the completion + shutDownInternally().get(); } - private void shutDownInternally() throws Exception { - Exception exception = null; + private CompletableFuture shutDownInternally() { synchronized (lock) { - try { - resourceManager.shutDown(); - } catch (Exception e) { - exception = ExceptionUtils.firstOrSuppressed(e, exception); - } - - try { - resourceManagerRuntimeServices.shutDown(); - } catch (Exception e) { - exception = ExceptionUtils.firstOrSuppressed(e, exception); - } - - if (exception != null) { - ExceptionUtils.rethrow(exception, "Error while shutting down the resource manager runner."); - } + resourceManager.shutDown(); + + return resourceManager.getTerminationFuture() + .thenAccept( + ignored -> { + try { + resourceManagerRuntimeServices.shutDown(); + } catch (Exception e) { + throw new FlinkFutureException("Could not properly shut down the resource manager runtime services.", e); + } + }); } } @@ -123,10 +121,13 @@ private void shutDownInternally() throws Exception { public void onFatalError(Throwable exception) { LOG.error("Encountered fatal error.", exception); - try { - shutDownInternally(); - } catch (Exception e) { - LOG.error("Could not properly shut down the resource manager.", e); - } + CompletableFuture shutdownFuture = shutDownInternally(); + + shutdownFuture.whenComplete( + (Void ignored, Throwable throwable) -> { + if (throwable != null) { + LOG.error("Could not properly shut down the resource manager runner.", throwable); + } + }); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/StandaloneResourceManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/StandaloneResourceManager.java index a921a29183ba8..ac2c7453070e9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/StandaloneResourceManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/StandaloneResourceManager.java @@ -23,7 +23,6 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; -import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.exceptions.ResourceManagerException; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; @@ -76,7 +75,7 @@ public void startNewWorker(ResourceProfile resourceProfile) { } @Override - public void stopWorker(InstanceID instanceId) { + public void stopWorker(ResourceID resourceID) { } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/registration/JobManagerRegistration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/registration/JobManagerRegistration.java index df3a39fad2856..dca2db674edfe 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/registration/JobManagerRegistration.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/registration/JobManagerRegistration.java @@ -21,10 +21,9 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.util.Preconditions; -import java.util.UUID; - /** * Container for JobManager related registration information, such as the leader id or the job id. */ @@ -33,18 +32,14 @@ public class JobManagerRegistration { private final ResourceID jobManagerResourceID; - private final UUID leaderID; - private final JobMasterGateway jobManagerGateway; public JobManagerRegistration( JobID jobID, ResourceID jobManagerResourceID, - UUID leaderID, JobMasterGateway jobManagerGateway) { this.jobID = Preconditions.checkNotNull(jobID); this.jobManagerResourceID = Preconditions.checkNotNull(jobManagerResourceID); - this.leaderID = Preconditions.checkNotNull(leaderID); this.jobManagerGateway = Preconditions.checkNotNull(jobManagerGateway); } @@ -56,8 +51,8 @@ public ResourceID getJobManagerResourceID() { return jobManagerResourceID; } - public UUID getLeaderID() { - return leaderID; + public JobMasterId getJobMasterId() { + return jobManagerGateway.getFencingToken(); } public JobMasterGateway getJobManagerGateway() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManager.java index 3bda409f1a40a..d8eb47c5ce50b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManager.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.SlotRequest; import org.apache.flink.runtime.resourcemanager.exceptions.ResourceManagerException; import org.apache.flink.runtime.resourcemanager.registration.TaskExecutorConnection; @@ -35,6 +36,7 @@ import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskexecutor.exceptions.SlotAllocationException; import org.apache.flink.runtime.taskexecutor.exceptions.SlotOccupiedException; +import org.apache.flink.util.AbstractID; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +47,6 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; -import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -64,7 +65,7 @@ * slots are currently not used) and pending slot requests time out triggering their release and * failure, respectively. */ -public class SlotManager implements AutoCloseable { +public class SlotManager implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(SlotManager.class); /** Scheduled executor for timeouts */ @@ -94,8 +95,8 @@ public class SlotManager implements AutoCloseable { /** Map of pending/unfulfilled slot allocation requests */ private final HashMap pendingSlotRequests; - /** Leader id of the containing component */ - private UUID leaderId; + /** ResourceManager's id */ + private ResourceManagerId resourceManagerId; /** Executor for future callbacks which have to be "synchronized" */ private Executor mainThreadExecutor; @@ -126,7 +127,7 @@ public SlotManager( fulfilledSlotRequests = new HashMap<>(16); pendingSlotRequests = new HashMap<>(16); - leaderId = null; + resourceManagerId = null; resourceManagerActions = null; mainThreadExecutor = null; taskManagerTimeoutCheck = null; @@ -142,13 +143,14 @@ public SlotManager( /** * Starts the slot manager with the given leader id and resource manager actions. * - * @param newLeaderId to use for communication with the task managers + * @param newResourceManagerId to use for communication with the task managers + * @param newMainThreadExecutor to use to run code in the ResourceManager's main thread * @param newResourceManagerActions to use for resource (de-)allocations */ - public void start(UUID newLeaderId, Executor newMainThreadExecutor, ResourceManagerActions newResourceManagerActions) { + public void start(ResourceManagerId newResourceManagerId, Executor newMainThreadExecutor, ResourceManagerActions newResourceManagerActions) { LOG.info("Starting the SlotManager."); - leaderId = Preconditions.checkNotNull(newLeaderId); + this.resourceManagerId = Preconditions.checkNotNull(newResourceManagerId); mainThreadExecutor = Preconditions.checkNotNull(newMainThreadExecutor); resourceManagerActions = Preconditions.checkNotNull(newResourceManagerActions); @@ -204,7 +206,7 @@ public void suspend() { unregisterTaskManager(registeredTaskManager); } - leaderId = null; + resourceManagerId = null; resourceManagerActions = null; started = false; } @@ -643,7 +645,7 @@ private void allocateSlot(TaskManagerSlot taskManagerSlot, PendingSlotRequest pe pendingSlotRequest.getJobId(), allocationId, pendingSlotRequest.getTargetAddress(), - leaderId, + resourceManagerId, taskManagerRequestTimeout); requestFuture.whenComplete( @@ -836,10 +838,13 @@ private void checkTaskManagerTimeouts() { while (taskManagerRegistrationIterator.hasNext()) { TaskManagerRegistration taskManagerRegistration = taskManagerRegistrationIterator.next().getValue(); + LOG.debug("Evaluating TaskManager {} for idleness.", taskManagerRegistration.getInstanceId()); if (anySlotUsed(taskManagerRegistration.getSlots())) { taskManagerRegistration.markUsed(); } else if (currentTime - taskManagerRegistration.getIdleSince() >= taskManagerTimeout.toMilliseconds()) { + LOG.info("Removing idle TaskManager {} from the SlotManager.", taskManagerRegistration.getInstanceId()); + taskManagerRegistrationIterator.remove(); internalUnregisterTaskManager(taskManagerRegistration); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/TaskManagerRegistration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/TaskManagerRegistration.java index 7d3764c792c7d..f19f9bf15efff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/TaskManagerRegistration.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/resourcemanager/slotmanager/TaskManagerRegistration.java @@ -68,7 +68,9 @@ public boolean isIdle() { } public void markIdle() { - idleSince = System.currentTimeMillis(); + if (!isIdle()) { + idleSince = System.currentTimeMillis(); + } } public void markUsed() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/HttpMethodWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/HttpMethodWrapper.java new file mode 100644 index 0000000000000..8987d7560879c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/HttpMethodWrapper.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpMethod; + +/** + * This class wraps netty's {@link HttpMethod}s into an enum, allowing us to use them in switches. + */ +public enum HttpMethodWrapper { + GET(HttpMethod.GET), + POST(HttpMethod.POST); + + private HttpMethod nettyHttpMethod; + + HttpMethodWrapper(HttpMethod nettyHttpMethod) { + this.nettyHttpMethod = nettyHttpMethod; + } + + public HttpMethod getNettyHttpMethod() { + return nettyHttpMethod; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java new file mode 100644 index 0000000000000..ea266be789508 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.rest.handler.PipelineErrorHandler; +import org.apache.flink.runtime.rest.messages.ErrorResponseBody; +import org.apache.flink.runtime.rest.messages.MessageHeaders; +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.util.RestClientException; +import org.apache.flink.runtime.rest.util.RestMapperUtils; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultFullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpResponse; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpClientCodec; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpObjectAggregator; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; + +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; + +/** + * This client is the counter-part to the {@link RestServerEndpoint}. + */ +public class RestClient { + private static final Logger LOG = LoggerFactory.getLogger(RestClient.class); + + private static final ObjectMapper objectMapper = RestMapperUtils.getStrictObjectMapper(); + + // used to open connections to a rest server endpoint + private final Executor executor; + + private Bootstrap bootstrap; + + public RestClient(RestClientConfiguration configuration, Executor executor) { + Preconditions.checkNotNull(configuration); + this.executor = Preconditions.checkNotNull(executor); + + SSLEngine sslEngine = configuration.getSslEngine(); + ChannelInitializer initializer = new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel socketChannel) throws Exception { + // SSL should be the first handler in the pipeline + if (sslEngine != null) { + socketChannel.pipeline().addLast("ssl", new SslHandler(sslEngine)); + } + + socketChannel.pipeline() + .addLast(new HttpClientCodec()) + .addLast(new HttpObjectAggregator(1024 * 1024)) + .addLast(new ClientHandler()) + .addLast(new PipelineErrorHandler(LOG)); + } + }; + NioEventLoopGroup group = new NioEventLoopGroup(1); + + bootstrap = new Bootstrap(); + bootstrap + .group(group) + .channel(NioSocketChannel.class) + .handler(initializer); + + LOG.info("Rest client endpoint started."); + } + + public void shutdown(Time timeout) { + LOG.info("Shutting down rest endpoint."); + CompletableFuture groupFuture = new CompletableFuture<>(); + if (bootstrap != null) { + if (bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(0, timeout.toMilliseconds(), TimeUnit.MILLISECONDS) + .addListener(ignored -> groupFuture.complete(null)); + } + } + + try { + groupFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + LOG.info("Rest endpoint shutdown complete."); + } catch (Exception e) { + LOG.warn("Rest endpoint shutdown failed.", e); + } + } + + public , U extends MessageParameters, R extends RequestBody, P extends ResponseBody> CompletableFuture

sendRequest(String targetAddress, int targetPort, M messageHeaders, U messageParameters, R request) throws IOException { + Preconditions.checkNotNull(targetAddress); + Preconditions.checkArgument(0 <= targetPort && targetPort < 65536, "The target port " + targetPort + " is not in the range (0, 65536]."); + Preconditions.checkNotNull(messageHeaders); + Preconditions.checkNotNull(request); + Preconditions.checkNotNull(messageParameters); + Preconditions.checkState(messageParameters.isResolved(), "Message parameters were not resolved."); + + String targetUrl = MessageParameters.resolveUrl(messageHeaders.getTargetRestEndpointURL(), messageParameters); + + LOG.debug("Sending request of class {} to {}", request.getClass(), targetUrl); + // serialize payload + StringWriter sw = new StringWriter(); + objectMapper.writeValue(sw, request); + ByteBuf payload = Unpooled.wrappedBuffer(sw.toString().getBytes(ConfigConstants.DEFAULT_CHARSET)); + + // create request and set headers + FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, messageHeaders.getHttpMethod().getNettyHttpMethod(), targetUrl, payload); + httpRequest.headers() + .add(HttpHeaders.Names.CONTENT_LENGTH, payload.capacity()) + .add(HttpHeaders.Names.CONTENT_TYPE, "application/json; charset=" + ConfigConstants.DEFAULT_CHARSET.name()) + .set(HttpHeaders.Names.HOST, targetAddress + ':' + targetPort) + .set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.CLOSE); + + return submitRequest(targetAddress, targetPort, httpRequest, messageHeaders.getResponseClass()); + } + + private

CompletableFuture

submitRequest(String targetAddress, int targetPort, FullHttpRequest httpRequest, Class

responseClass) { + return CompletableFuture.supplyAsync(() -> bootstrap.connect(targetAddress, targetPort), executor) + .thenApply((channel) -> { + try { + return channel.sync(); + } catch (InterruptedException e) { + throw new FlinkRuntimeException(e); + } + }) + .thenApply((ChannelFuture::channel)) + .thenCompose(channel -> { + ClientHandler handler = channel.pipeline().get(ClientHandler.class); + CompletableFuture future = handler.getJsonFuture(); + channel.writeAndFlush(httpRequest); + return future; + }).thenComposeAsync( + (JsonResponse rawResponse) -> parseResponse(rawResponse, responseClass), + executor + ); + } + + private static

CompletableFuture

parseResponse(JsonResponse rawResponse, Class

responseClass) { + CompletableFuture

responseFuture = new CompletableFuture<>(); + try { + P response = objectMapper.treeToValue(rawResponse.getJson(), responseClass); + responseFuture.complete(response); + } catch (JsonProcessingException jpe) { + // the received response did not matched the expected response type + + // lets see if it is an ErrorResponse instead + try { + ErrorResponseBody error = objectMapper.treeToValue(rawResponse.getJson(), ErrorResponseBody.class); + responseFuture.completeExceptionally(new RestClientException(error.errors.toString(), rawResponse.getHttpResponseStatus())); + } catch (JsonProcessingException jpe2) { + // if this fails it is either the expected type or response type was wrong, most likely caused + // by a client/search MessageHeaders mismatch + LOG.error("Received response was neither of the expected type ({}) nor an error. Response={}", responseClass, rawResponse, jpe2); + responseFuture.completeExceptionally( + new RestClientException( + "Response was neither of the expected type(" + responseClass + ") nor an error.", + jpe2, + rawResponse.getHttpResponseStatus())); + } + } + return responseFuture; + } + + private static class ClientHandler extends SimpleChannelInboundHandler { + + private final CompletableFuture jsonFuture = new CompletableFuture<>(); + + CompletableFuture getJsonFuture() { + return jsonFuture; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof FullHttpResponse) { + readRawResponse((FullHttpResponse) msg); + } else { + LOG.error("Implementation error: Received a response that wasn't a FullHttpResponse."); + if (msg instanceof HttpResponse) { + jsonFuture.completeExceptionally( + new RestClientException( + "Implementation error: Received a response that wasn't a FullHttpResponse.", + ((HttpResponse) msg).getStatus())); + } else { + jsonFuture.completeExceptionally( + new RestClientException( + "Implementation error: Received a response that wasn't a FullHttpResponse.", + HttpResponseStatus.INTERNAL_SERVER_ERROR)); + } + + } + ctx.close(); + } + + private void readRawResponse(FullHttpResponse msg) { + ByteBuf content = msg.content(); + + JsonNode rawResponse; + try { + InputStream in = new ByteBufInputStream(content); + rawResponse = objectMapper.readTree(in); + LOG.debug("Received response {}.", rawResponse); + } catch (JsonParseException je) { + LOG.error("Response was not valid JSON.", je); + jsonFuture.completeExceptionally(new RestClientException("Response was not valid JSON.", je, msg.getStatus())); + return; + } catch (IOException ioe) { + LOG.error("Response could not be read.", ioe); + jsonFuture.completeExceptionally(new RestClientException("Response could not be read.", ioe, msg.getStatus())); + return; + } + jsonFuture.complete(new JsonResponse(rawResponse, msg.getStatus())); + } + } + + private static final class JsonResponse { + private final JsonNode json; + private final HttpResponseStatus httpResponseStatus; + + private JsonResponse(JsonNode json, HttpResponseStatus httpResponseStatus) { + this.json = Preconditions.checkNotNull(json); + this.httpResponseStatus = Preconditions.checkNotNull(httpResponseStatus); + } + + public JsonNode getJson() { + return json; + } + + public HttpResponseStatus getHttpResponseStatus() { + return httpResponseStatus; + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClientConfiguration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClientConfiguration.java new file mode 100644 index 0000000000000..7bf030718261c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClientConfiguration.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.runtime.net.SSLUtils; +import org.apache.flink.util.ConfigurationException; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +/** + * A configuration object for {@link RestClient}s. + */ +public final class RestClientConfiguration { + + @Nullable + private final SSLEngine sslEngine; + + private RestClientConfiguration(@Nullable SSLEngine sslEngine) { + this.sslEngine = sslEngine; + } + + /** + * Returns the {@link SSLEngine} that the REST client endpoint should use. + * + * @return SSLEngine that the REST client endpoint should use, or null if SSL was disabled + */ + + public SSLEngine getSslEngine() { + return sslEngine; + } + + /** + * Creates and returns a new {@link RestClientConfiguration} from the given {@link Configuration}. + * + * @param config configuration from which the REST client endpoint configuration should be created from + * @return REST client endpoint configuration + * @throws ConfigurationException if SSL was configured incorrectly + */ + + public static RestClientConfiguration fromConfiguration(Configuration config) throws ConfigurationException { + Preconditions.checkNotNull(config); + + SSLEngine sslEngine = null; + boolean enableSSL = config.getBoolean(SecurityOptions.SSL_ENABLED); + if (enableSSL) { + try { + SSLContext sslContext = SSLUtils.createSSLServerContext(config); + if (sslContext != null) { + sslEngine = sslContext.createSSLEngine(); + SSLUtils.setSSLVerAndCipherSuites(sslEngine, config); + sslEngine.setUseClientMode(false); + } + } catch (Exception e) { + throw new ConfigurationException("Failed to initialize SSLContext for the web frontend", e); + } + } + + return new RestClientConfiguration(sslEngine); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java new file mode 100644 index 0000000000000..4a3ba89123afc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.rest.handler.AbstractRestHandler; +import org.apache.flink.runtime.rest.handler.PipelineErrorHandler; +import org.apache.flink.runtime.rest.handler.RouterHandler; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpObjectAggregator; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpServerCodec; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; + +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +/** + * An abstract class for netty-based REST server endpoints. + */ +public abstract class RestServerEndpoint { + protected final Logger log = LoggerFactory.getLogger(getClass()); + + private final String configuredAddress; + private final int configuredPort; + private final SSLEngine sslEngine; + + private ServerBootstrap bootstrap; + private Channel serverChannel; + + public RestServerEndpoint(RestServerEndpointConfiguration configuration) { + Preconditions.checkNotNull(configuration); + this.configuredAddress = configuration.getEndpointBindAddress(); + this.configuredPort = configuration.getEndpointBindPort(); + this.sslEngine = configuration.getSslEngine(); + } + + /** + * This method is called at the beginning of {@link #start()} to setup all handlers that the REST server endpoint + * implementation requires. + */ + protected abstract Collection> initializeHandlers(); + + /** + * Starts this REST server endpoint. + */ + public void start() { + log.info("Starting rest endpoint."); + + final Router router = new Router(); + + initializeHandlers().forEach(handler -> registerHandler(router, handler)); + + ChannelInitializer initializer = new ChannelInitializer() { + + @Override + protected void initChannel(SocketChannel ch) { + Handler handler = new RouterHandler(router); + + // SSL should be the first handler in the pipeline + if (sslEngine != null) { + ch.pipeline().addLast("ssl", new SslHandler(sslEngine)); + } + + ch.pipeline() + .addLast(new HttpServerCodec()) + .addLast(new HttpObjectAggregator(1024 * 1024 * 10)) + .addLast(handler.name(), handler) + .addLast(new PipelineErrorHandler(log)); + } + }; + + NioEventLoopGroup bossGroup = new NioEventLoopGroup(1); + NioEventLoopGroup workerGroup = new NioEventLoopGroup(); + + bootstrap = new ServerBootstrap(); + bootstrap + .group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(initializer); + + final ChannelFuture channel; + if (configuredAddress == null) { + channel = bootstrap.bind(configuredPort); + } else { + channel = bootstrap.bind(configuredAddress, configuredPort); + } + serverChannel = channel.syncUninterruptibly().channel(); + + InetSocketAddress bindAddress = (InetSocketAddress) serverChannel.localAddress(); + String address = bindAddress.getAddress().getHostAddress(); + int port = bindAddress.getPort(); + + log.info("Rest endpoint listening at {}" + ':' + "{}", address, port); + } + + /** + * Returns the address on which this endpoint is accepting requests. + * + * @return address on which this endpoint is accepting requests + */ + public InetSocketAddress getServerAddress() { + Channel server = this.serverChannel; + if (server != null) { + try { + return ((InetSocketAddress) server.localAddress()); + } catch (Exception e) { + log.error("Cannot access local server address", e); + } + } + + return null; + } + + /** + * Stops this REST server endpoint. + */ + public void shutdown(Time timeout) { + log.info("Shutting down rest endpoint."); + + CompletableFuture channelFuture = new CompletableFuture<>(); + if (this.serverChannel != null) { + this.serverChannel.close().addListener(ignored -> channelFuture.complete(null)); + } + CompletableFuture groupFuture = new CompletableFuture<>(); + CompletableFuture childGroupFuture = new CompletableFuture<>(); + + channelFuture.thenRun(() -> { + if (bootstrap != null) { + if (bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(0, timeout.toMilliseconds(), TimeUnit.MILLISECONDS) + .addListener(ignored -> groupFuture.complete(null)); + } + if (bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(0, timeout.toMilliseconds(), TimeUnit.MILLISECONDS) + .addListener(ignored -> childGroupFuture.complete(null)); + } + } else { + // complete the group futures since there is nothing to stop + groupFuture.complete(null); + childGroupFuture.complete(null); + } + }); + + try { + CompletableFuture.allOf(groupFuture, childGroupFuture).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + log.info("Rest endpoint shutdown complete."); + } catch (Exception e) { + log.warn("Rest endpoint shutdown failed.", e); + } + } + + private static void registerHandler(Router router, AbstractRestHandler handler) { + switch (handler.getMessageHeaders().getHttpMethod()) { + case GET: + router.GET(handler.getMessageHeaders().getTargetRestEndpointURL(), handler); + break; + case POST: + router.POST(handler.getMessageHeaders().getTargetRestEndpointURL(), handler); + break; + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpointConfiguration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpointConfiguration.java new file mode 100644 index 0000000000000..f342a0160b87c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpointConfiguration.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.runtime.net.SSLUtils; +import org.apache.flink.util.ConfigurationException; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +/** + * A configuration object for {@link RestServerEndpoint}s. + */ +public final class RestServerEndpointConfiguration { + + @Nullable + private final String restBindAddress; + private final int restBindPort; + @Nullable + private final SSLEngine sslEngine; + + private RestServerEndpointConfiguration(@Nullable String restBindAddress, int restBindPort, @Nullable SSLEngine sslEngine) { + this.restBindAddress = restBindAddress; + + Preconditions.checkArgument(0 <= restBindPort && restBindPort < 65536, "The bing rest port " + restBindPort + " is out of range (0, 65536["); + this.restBindPort = restBindPort; + this.sslEngine = sslEngine; + } + + /** + * Returns the address that the REST server endpoint should bind itself to. + * + * @return address that the REST server endpoint should bind itself to + */ + public String getEndpointBindAddress() { + return restBindAddress; + } + + /** + * Returns the port that the REST server endpoint should listen on. + * + * @return port that the REST server endpoint should listen on + */ + public int getEndpointBindPort() { + return restBindPort; + } + + /** + * Returns the {@link SSLEngine} that the REST server endpoint should use. + * + * @return SSLEngine that the REST server endpoint should use, or null if SSL was disabled + */ + public SSLEngine getSslEngine() { + return sslEngine; + } + + /** + * Creates and returns a new {@link RestServerEndpointConfiguration} from the given {@link Configuration}. + * + * @param config configuration from which the REST server endpoint configuration should be created from + * @return REST server endpoint configuration + * @throws ConfigurationException if SSL was configured incorrectly + */ + public static RestServerEndpointConfiguration fromConfiguration(Configuration config) throws ConfigurationException { + Preconditions.checkNotNull(config); + String address = config.getString(RestOptions.REST_ADDRESS); + + int port = config.getInteger(RestOptions.REST_PORT); + + SSLEngine sslEngine = null; + boolean enableSSL = config.getBoolean(SecurityOptions.SSL_ENABLED); + if (enableSSL) { + try { + SSLContext sslContext = SSLUtils.createSSLServerContext(config); + if (sslContext != null) { + sslEngine = sslContext.createSSLEngine(); + SSLUtils.setSSLVerAndCipherSuites(sslEngine, config); + sslEngine.setUseClientMode(false); + } + } catch (Exception e) { + throw new ConfigurationException("Failed to initialize SSLContext for REST server endpoint.", e); + } + } + + return new RestServerEndpointConfiguration(address, port, sslEngine); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractRestHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractRestHandler.java new file mode 100644 index 0000000000000..2f2f9aa034e16 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/AbstractRestHandler.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.concurrent.FutureUtils; +import org.apache.flink.runtime.rest.messages.ErrorResponseBody; +import org.apache.flink.runtime.rest.messages.MessageHeaders; +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.util.RestMapperUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.DefaultHttpResponse; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.FullHttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponse; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.LastHttpContent; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Routed; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.io.StringWriter; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders.Names.CONNECTION; +import static org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; +import static org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +/** + * Super class for netty-based handlers that work with {@link RequestBody}s and {@link ResponseBody}s. + * + *

Subclasses must be thread-safe. + * + * @param type of incoming requests + * @param

type of outgoing responses + */ +@ChannelHandler.Sharable +public abstract class AbstractRestHandler extends SimpleChannelInboundHandler { + protected final Logger log = LoggerFactory.getLogger(getClass()); + + private static final ObjectMapper mapper = RestMapperUtils.getStrictObjectMapper(); + + private final MessageHeaders messageHeaders; + + protected AbstractRestHandler(MessageHeaders messageHeaders) { + this.messageHeaders = messageHeaders; + } + + public MessageHeaders getMessageHeaders() { + return messageHeaders; + } + + @Override + protected void channelRead0(final ChannelHandlerContext ctx, Routed routed) throws Exception { + if (log.isDebugEnabled()) { + log.debug("Received request " + routed.request().getUri() + '.'); + } + + final HttpRequest httpRequest = routed.request(); + + try { + if (!(httpRequest instanceof FullHttpRequest)) { + // The RestServerEndpoint defines a HttpObjectAggregator in the pipeline that always returns + // FullHttpRequests. + log.error("Implementation error: Received a request that wasn't a FullHttpRequest."); + sendErrorResponse(new ErrorResponseBody("Bad request received."), HttpResponseStatus.BAD_REQUEST, ctx, httpRequest); + return; + } + + ByteBuf msgContent = ((FullHttpRequest) httpRequest).content(); + + R request; + if (msgContent.capacity() == 0) { + try { + request = mapper.readValue("{}", messageHeaders.getRequestClass()); + } catch (JsonParseException | JsonMappingException je) { + log.error("Implementation error: Get request bodies must have a no-argument constructor.", je); + sendErrorResponse(new ErrorResponseBody("Internal server error."), HttpResponseStatus.INTERNAL_SERVER_ERROR, ctx, httpRequest); + return; + } + } else { + try { + ByteBufInputStream in = new ByteBufInputStream(msgContent); + request = mapper.readValue(in, messageHeaders.getRequestClass()); + } catch (JsonParseException | JsonMappingException je) { + log.error("Failed to read request.", je); + sendErrorResponse(new ErrorResponseBody(String.format("Request did not match expected format %s.", messageHeaders.getRequestClass().getSimpleName())), HttpResponseStatus.BAD_REQUEST, ctx, httpRequest); + return; + } + } + + CompletableFuture

response; + try { + HandlerRequest handlerRequest = new HandlerRequest<>(request, messageHeaders.getUnresolvedMessageParameters(), routed.pathParams(), routed.queryParams()); + response = handleRequest(handlerRequest); + } catch (Exception e) { + response = FutureUtils.completedExceptionally(e); + } + + response.whenComplete((P resp, Throwable error) -> { + if (error != null) { + if (error instanceof RestHandlerException) { + RestHandlerException rhe = (RestHandlerException) error; + sendErrorResponse(new ErrorResponseBody(rhe.getErrorMessage()), rhe.getHttpResponseStatus(), ctx, httpRequest); + } else { + log.error("Implementation error: Unhandled exception.", error); + sendErrorResponse(new ErrorResponseBody("Internal server error."), HttpResponseStatus.INTERNAL_SERVER_ERROR, ctx, httpRequest); + } + } else { + sendResponse(messageHeaders.getResponseStatusCode(), resp, ctx, httpRequest); + } + }); + } catch (Exception e) { + log.error("Request processing failed.", e); + sendErrorResponse(new ErrorResponseBody("Internal server error."), HttpResponseStatus.INTERNAL_SERVER_ERROR, ctx, httpRequest); + } + } + + /** + * This method is called for every incoming request and returns a {@link CompletableFuture} containing a the response. + * + *

Implementations may decide whether to throw {@link RestHandlerException}s or fail the returned + * {@link CompletableFuture} with a {@link RestHandlerException}. + * + *

Failing the future with another exception type or throwing unchecked exceptions is regarded as an + * implementation error as it does not allow us to provide a meaningful HTTP status code. In this case a + * {@link HttpResponseStatus#INTERNAL_SERVER_ERROR} will be returned. + * + * @param request request that should be handled + * @return future containing a handler response + * @throws RestHandlerException if the handling failed + */ + protected abstract CompletableFuture

handleRequest(@Nonnull HandlerRequest request) throws RestHandlerException; + + private static

void sendResponse(HttpResponseStatus statusCode, P response, ChannelHandlerContext ctx, HttpRequest httpRequest) { + StringWriter sw = new StringWriter(); + try { + mapper.writeValue(sw, response); + } catch (IOException ioe) { + sendErrorResponse(new ErrorResponseBody("Internal server error. Could not map response to JSON."), HttpResponseStatus.INTERNAL_SERVER_ERROR, ctx, httpRequest); + return; + } + sendResponse(ctx, httpRequest, statusCode, sw.toString()); + } + + static void sendErrorResponse(ErrorResponseBody error, HttpResponseStatus statusCode, ChannelHandlerContext ctx, HttpRequest httpRequest) { + + StringWriter sw = new StringWriter(); + try { + mapper.writeValue(sw, error); + } catch (IOException e) { + // this should never happen + sendResponse(ctx, httpRequest, HttpResponseStatus.INTERNAL_SERVER_ERROR, "Internal server error. Could not map error response to JSON."); + } + sendResponse(ctx, httpRequest, statusCode, sw.toString()); + } + + private static void sendResponse(@Nonnull ChannelHandlerContext ctx, @Nonnull HttpRequest httpRequest, @Nonnull HttpResponseStatus statusCode, @Nonnull String message) { + HttpResponse response = new DefaultHttpResponse(HTTP_1_1, statusCode); + + response.headers().set(CONTENT_TYPE, "application/json"); + + if (HttpHeaders.isKeepAlive(httpRequest)) { + response.headers().set(CONNECTION, HttpHeaders.Values.KEEP_ALIVE); + } + + byte[] buf = message.getBytes(ConfigConstants.DEFAULT_CHARSET); + ByteBuf b = Unpooled.copiedBuffer(buf); + HttpHeaders.setContentLength(response, buf.length); + + // write the initial line and the header. + ctx.write(response); + + ctx.write(b); + + ChannelFuture lastContentFuture = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); + + // close the connection, if no keep-alive is needed + if (!HttpHeaders.isKeepAlive(httpRequest)) { + lastContentFuture.addListener(ChannelFutureListener.CLOSE); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java new file mode 100644 index 0000000000000..6a9bce99aa75a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/HandlerRequest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.MessagePathParameter; +import org.apache.flink.runtime.rest.messages.MessageQueryParameter; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.util.Preconditions; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; + +/** + * Simple container for the request to a handler, that contains the {@link RequestBody} and path/query parameters. + * + * @param type of the contained request body + * @param type of the contained message parameters + */ +public class HandlerRequest { + + private final R requestBody; + private final Map>, MessagePathParameter> pathParameters = new HashMap<>(2); + private final Map>, MessageQueryParameter> queryParameters = new HashMap<>(2); + + public HandlerRequest(R requestBody, M messageParameters, Map receivedPathParameters, Map> receivedQueryParameters) { + this.requestBody = Preconditions.checkNotNull(requestBody); + Preconditions.checkNotNull(messageParameters); + Preconditions.checkNotNull(receivedQueryParameters); + Preconditions.checkNotNull(receivedPathParameters); + + for (MessagePathParameter pathParameter : messageParameters.getPathParameters()) { + String value = receivedPathParameters.get(pathParameter.getKey()); + if (value != null) { + pathParameter.resolveFromString(value); + + @SuppressWarnings("unchecked") + Class> clazz = (Class>) pathParameter.getClass(); + pathParameters.put(clazz, pathParameter); + } + } + + for (MessageQueryParameter queryParameter : messageParameters.getQueryParameters()) { + List values = receivedQueryParameters.get(queryParameter.getKey()); + if (values != null && !values.isEmpty()) { + StringJoiner joiner = new StringJoiner(","); + values.forEach(joiner::add); + queryParameter.resolveFromString(joiner.toString()); + + @SuppressWarnings("unchecked") + Class> clazz = (Class>) queryParameter.getClass(); + queryParameters.put(clazz, queryParameter); + } + + } + } + + /** + * Returns the request body. + * + * @return request body + */ + public R getRequestBody() { + return requestBody; + } + + /** + * Returns the value of the {@link MessagePathParameter} for the given class. + * + * @param parameterClass class of the parameter + * @param the value type that the parameter contains + * @param type of the path parameter + * @return path parameter value for the given class + * @throws IllegalStateException if no value is defined for the given parameter class + */ + public > X getPathParameter(Class parameterClass) { + @SuppressWarnings("unchecked") + PP pathParameter = (PP) pathParameters.get(parameterClass); + Preconditions.checkState(pathParameter != null, "No parameter could be found for the given class."); + return pathParameter.getValue(); + } + + /** + * Returns the value of the {@link MessageQueryParameter} for the given class. + * + * @param parameterClass class of the parameter + * @param the value type that the parameter contains + * @param type of the query parameter + * @return query parameter value for the given class, or an empty list if no parameter value exists for the given class + */ + public > List getQueryParameter(Class parameterClass) { + @SuppressWarnings("unchecked") + QP queryParameter = (QP) queryParameters.get(parameterClass); + if (queryParameter == null) { + return Collections.emptyList(); + } else { + return queryParameter.getValue(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/PipelineErrorHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/PipelineErrorHandler.java new file mode 100644 index 0000000000000..14e643cdbcbca --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/PipelineErrorHandler.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.runtime.rest.messages.ErrorResponseBody; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +import org.slf4j.Logger; + +/** + * This is the last handler in the pipeline. It logs all error messages. + */ +@ChannelHandler.Sharable +public class PipelineErrorHandler extends SimpleChannelInboundHandler { + + /** The logger to which the handler writes the log statements. */ + private final Logger logger; + + public PipelineErrorHandler(Logger logger) { + this.logger = logger; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpRequest message) { + // we can't deal with this message. No one in the pipeline handled it. Log it. + logger.warn("Unknown message received: {}", message); + AbstractRestHandler.sendErrorResponse(new ErrorResponseBody("Bad request received."), HttpResponseStatus.BAD_REQUEST, ctx, message); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + logger.warn("Unhandled exception", cause); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RestHandlerException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RestHandlerException.java new file mode 100644 index 0000000000000..4cbb542ada7d7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RestHandlerException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +/** + * An exception that is thrown if the failure of a REST operation was detected by a handler. + */ +public class RestHandlerException extends Exception { + private static final long serialVersionUID = -1358206297964070876L; + + private final String errorMessage; + private final int responseCode; + + public RestHandlerException(String errorMessage, HttpResponseStatus httpResponseStatus) { + this.errorMessage = errorMessage; + this.responseCode = httpResponseStatus.code(); + } + + public String getErrorMessage() { + return errorMessage; + } + + public HttpResponseStatus getHttpResponseStatus() { + return HttpResponseStatus.valueOf(responseCode); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RouterHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RouterHandler.java new file mode 100644 index 0000000000000..72b779bd4835e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/handler/RouterHandler.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.handler; + +import org.apache.flink.runtime.rest.messages.ErrorResponseBody; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpRequest; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is an extension of {@link Handler} that replaces the standard error response to be identical with those + * sent by the {@link AbstractRestHandler}. + */ +public class RouterHandler extends Handler { + private static final Logger LOG = LoggerFactory.getLogger(RouterHandler.class); + + public RouterHandler(Router router) { + super(router); + } + + @Override + protected void respondNotFound(ChannelHandlerContext ctx, HttpRequest request) { + AbstractRestHandler.sendErrorResponse(new ErrorResponseBody("Not found."), HttpResponseStatus.NOT_FOUND, ctx, request); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ErrorResponseBody.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ErrorResponseBody.java new file mode 100644 index 0000000000000..0a7d69e32ab65 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ErrorResponseBody.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.List; + +/** + * Generic response body for communicating errors on the server. + */ +public final class ErrorResponseBody implements ResponseBody { + + static final String FIELD_NAME_ERRORS = "errors"; + + @JsonProperty(FIELD_NAME_ERRORS) + public final List errors; + + public ErrorResponseBody(String error) { + this(Collections.singletonList(error)); + } + + @JsonCreator + public ErrorResponseBody( + @JsonProperty(FIELD_NAME_ERRORS) List errors) { + + this.errors = errors; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageHeaders.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageHeaders.java new file mode 100644 index 0000000000000..254c231b00697 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageHeaders.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import org.apache.flink.runtime.rest.HttpMethodWrapper; + +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +/** + * This class links {@link RequestBody}s to {@link ResponseBody}s types and contains meta-data required for their http headers. + * + *

Implementations must be state-less. + * + * @param request message type + * @param

response message type + * @param message parameters type + */ +public interface MessageHeaders { + + /** + * Returns the class of the request message. + * + * @return class of the request message + */ + Class getRequestClass(); + + /** + * Returns the {@link HttpMethodWrapper} to be used for the request. + * + * @return http method to be used for the request + */ + HttpMethodWrapper getHttpMethod(); + + /** + * Returns the generalized endpoint url that this request should be sent to, for example {@code /job/:jobid}. + * + * @return endpoint url that this request should be sent to + */ + String getTargetRestEndpointURL(); + + /** + * Returns the class of the response message. + * + * @return class of the response message + */ + Class

getResponseClass(); + + /** + * Returns the http status code for the response. + * + * @return http status code of the response + */ + HttpResponseStatus getResponseStatusCode(); + + /** + * Returns a new {@link MessageParameters} object. + * + * @return new message parameters object + */ + M getUnresolvedMessageParameters(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameter.java new file mode 100644 index 0000000000000..e681e38dec77d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameter.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import org.apache.flink.util.Preconditions; + +/** + * This class represents a single path/query parameter that can be used for a request. Every parameter has an associated + * key, and a one-time settable value. + * + *

Parameters are either mandatory or optional, indicating whether the parameter must be resolved for the request. + * + *

All parameters support symmetric conversion from their actual type and string via {@link #convertFromString(String)} + * and {@link #convertToString(Object)}. The conversion from {@code X} to string is required on the client to assemble the + * URL, whereas the conversion from string to {@code X} is required on the server to provide properly typed parameters + * to the handlers. + * + * @see MessagePathParameter + * @see MessageQueryParameter + */ +public abstract class MessageParameter { + private boolean resolved = false; + + private final MessageParameterRequisiteness requisiteness; + + private final String key; + private X value; + + MessageParameter(String key, MessageParameterRequisiteness requisiteness) { + this.key = Preconditions.checkNotNull(key); + this.requisiteness = Preconditions.checkNotNull(requisiteness); + } + + /** + * Returns whether this parameter has been resolved. + * + * @return true, if this parameter was resolved, false otherwise + */ + public final boolean isResolved() { + return resolved; + } + + /** + * Resolves this parameter for the given value. + * + * @param value value to resolve this parameter with + */ + public final void resolve(X value) { + Preconditions.checkState(!resolved, "This parameter was already resolved."); + this.value = Preconditions.checkNotNull(value); + this.resolved = true; + } + + /** + * Resolves this parameter for the given string value representation. + * + * @param value string representation of value to resolve this parameter with + */ + public final void resolveFromString(String value) { + resolve(convertFromString(value)); + } + + /** + * Converts the given string to a valid value of this parameter. + * + * @param value string representation of parameter value + * @return parameter value + */ + protected abstract X convertFromString(String value); + + /** + * Converts the given value to its string representation. + * + * @param value parameter value + * @return string representation of typed value + */ + protected abstract String convertToString(X value); + + /** + * Returns the key of this parameter, e.g. "jobid". + * + * @return key of this parameter + */ + public final String getKey() { + return key; + } + + /** + * Returns the resolved value of this parameter, or {@code null} if it isn't resolved yet. + * + * @return resolved value, or null if it wasn't resolved yet + */ + public final X getValue() { + return value; + } + + /** + * Returns the resolved value of this parameter as a string, or {@code null} if it isn't resolved yet. + * + * @return resolved value, or null if it wasn't resolved yet + */ + final String getValueAsString() { + return value == null + ? null + : convertToString(value); + } + + /** + * Returns whether this parameter must be resolved for the request. + * + * @return true if the parameter is mandatory, false otherwise + */ + public final boolean isMandatory() { + return requisiteness == MessageParameterRequisiteness.MANDATORY; + } + + /** + * Enum for indicating whether a parameter is mandatory or optional. + */ + protected enum MessageParameterRequisiteness { + MANDATORY, + OPTIONAL + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java new file mode 100644 index 0000000000000..96243c1f93ad8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageParameters.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import org.apache.flink.util.Preconditions; + +import java.util.Collection; + +/** + * This class defines the path/query {@link MessageParameter}s that can be used for a request. + */ +public abstract class MessageParameters { + + /** + * Returns the collection of {@link MessagePathParameter} that the request supports. The collection should not be + * modifiable. + * + * @return collection of all supported message path parameters + */ + public abstract Collection> getPathParameters(); + + /** + * Returns the collection of {@link MessageQueryParameter} that the request supports. The collection should not be + * modifiable. + * + * @return collection of all supported message query parameters + */ + public abstract Collection> getQueryParameters(); + + /** + * Returns whether all mandatory parameters have been resolved. + * + * @return true, if all mandatory parameters have been resolved, false otherwise + */ + public final boolean isResolved() { + return getPathParameters().stream().filter(MessageParameter::isMandatory).allMatch(MessageParameter::isResolved) + && getQueryParameters().stream().filter(MessageParameter::isMandatory).allMatch(MessageParameter::isResolved); + } + + /** + * Resolves the given URL (e.g "jobs/:jobid") using the given path/query parameters. + * + *

This method will fail with an {@link IllegalStateException} if any mandatory parameter was not resolved. + * + *

Unresolved optional parameters will be ignored. + * + * @param genericUrl URL to resolve + * @param parameters message parameters parameters + * @return resolved url, e.g "/jobs/1234?state=running" + * @throws IllegalStateException if any mandatory parameter was not resolved + */ + public static String resolveUrl(String genericUrl, MessageParameters parameters) { + Preconditions.checkState(parameters.isResolved(), "Not all mandatory message parameters were resolved."); + StringBuilder path = new StringBuilder(genericUrl); + StringBuilder queryParameters = new StringBuilder(); + + for (MessageParameter pathParameter : parameters.getPathParameters()) { + if (pathParameter.isResolved()) { + int start = path.indexOf(':' + pathParameter.getKey()); + + final String pathValue = Preconditions.checkNotNull(pathParameter.getValueAsString()); + + // only replace path parameters if they are present + if (start != -1) { + path.replace(start, start + pathParameter.getKey().length() + 1, pathValue); + } + } + } + boolean isFirstQueryParameter = true; + for (MessageQueryParameter queryParameter : parameters.getQueryParameters()) { + if (parameters.isResolved()) { + if (isFirstQueryParameter) { + queryParameters.append('?'); + isFirstQueryParameter = false; + } else { + queryParameters.append('&'); + } + queryParameters.append(queryParameter.getKey()); + queryParameters.append('='); + queryParameters.append(queryParameter.getValueAsString()); + } + } + path.append(queryParameters); + + return path.toString(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Shutdown.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessagePathParameter.java similarity index 65% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Shutdown.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessagePathParameter.java index c596d1248bf03..5cbab072215b4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Shutdown.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessagePathParameter.java @@ -16,21 +16,14 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; - -import org.apache.flink.runtime.rpc.akka.AkkaRpcService; +package org.apache.flink.runtime.rest.messages; /** - * Shut down message used to trigger the shut down of an AkkaRpcActor. This - * message is only intended for internal use by the {@link AkkaRpcService}. + * This class represents path parameters of a request. For example, the URL "/jobs/:jobid" has a + * "jobid" path parameter that is later replaced with an actual value. */ -public final class Shutdown { - - private static Shutdown instance = new Shutdown(); - - public static Shutdown getInstance() { - return instance; +public abstract class MessagePathParameter extends MessageParameter { + protected MessagePathParameter(String key) { + super(key, MessageParameterRequisiteness.MANDATORY); } - - private Shutdown() {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageQueryParameter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageQueryParameter.java new file mode 100644 index 0000000000000..506a14b8da505 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/MessageQueryParameter.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import java.util.ArrayList; +import java.util.List; + +/** + * This class represents query parameters of a request. For example, the URL "/jobs?state=running" has a + * "state" query parameter, with "running" being its value string representation. + * + *

Query parameters may both occur multiple times or be of the form "key=value1,value2,value3". If a query parameter + * is specified multiple times the individual values are concatenated with {@code ,} and passed as a single value to + * {@link #convertToString(List)}. + */ +public abstract class MessageQueryParameter extends MessageParameter> { + protected MessageQueryParameter(String key, MessageParameterRequisiteness requisiteness) { + super(key, requisiteness); + } + + @Override + public List convertFromString(String values) { + String[] splitValues = values.split(","); + List list = new ArrayList<>(); + for (String value : splitValues) { + list.add(convertValueFromString(value)); + } + return list; + } + + /** + * Converts the given string to a valid value of this parameter. + * + * @param value string representation of parameter value + * @return parameter value + */ + public abstract X convertValueFromString(String value); + + @Override + public String convertToString(List values) { + StringBuilder sb = new StringBuilder(); + boolean first = true; + for (X value : values) { + if (first) { + sb.append(convertStringToValue(value)); + first = false; + } else { + sb.append(","); + sb.append(convertStringToValue(value)); + } + } + return sb.toString(); + } + + /** + * Converts the given value to its string representation. + * + * @param value parameter value + * @return string representation of typed value + */ + public abstract String convertStringToValue(X value); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java new file mode 100644 index 0000000000000..ca55b17532ba3 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/RequestBody.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +/** + * Marker interface for all requests of the REST API. This class represents the http body of a request. + * + *

Subclass instances are converted to JSON using jackson-databind. Subclasses must have a constructor that accepts + * all fields of the JSON request, that should be annotated with {@code @JsonCreator}. + * + *

All fields that should part of the JSON request must be accessible either by being public or having a getter. + * + *

When adding methods that are prefixed with {@code get} make sure to annotate them with {@code @JsonIgnore}. + */ +public interface RequestBody { +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java new file mode 100644 index 0000000000000..d4e94d1d6abdc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/ResponseBody.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +/** + * Marker interface for all responses of the REST API. This class represents the http body of a response. + * + *

Subclass instances are converted to JSON using jackson-databind. Subclasses must have a constructor that accepts + * all fields of the JSON response, that should be annotated with {@code @JsonCreator}. + * + *

All fields that should part of the JSON response must be accessible either by being public or having a getter. + * + *

When adding methods that are prefixed with {@code get} make sure to annotate them with {@code @JsonIgnore}. + */ +public interface ResponseBody { +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestClientException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestClientException.java new file mode 100644 index 0000000000000..2333614bfc16f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestClientException.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.util; + +import org.apache.flink.util.FlinkException; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +/** + * An exception that is thrown if the failure of a REST operation was detected on the client. + */ +public class RestClientException extends FlinkException { + + private static final long serialVersionUID = 937914622022344423L; + + private final int responseCode; + + public RestClientException(String message, HttpResponseStatus responseStatus) { + super(message); + + Preconditions.checkNotNull(responseStatus); + responseCode = responseStatus.code(); + } + + public RestClientException(String message, Throwable cause, HttpResponseStatus responseStatus) { + super(message, cause); + + responseCode = responseStatus.code(); + } + + public HttpResponseStatus getHttpResponseStatus() { + return HttpResponseStatus.valueOf(responseCode); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestMapperUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestMapperUtils.java new file mode 100644 index 0000000000000..647a7087741b1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/util/RestMapperUtils.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.util; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +/** + * This class contains utilities for mapping requests and responses to/from JSON. + */ +public class RestMapperUtils { + private static final ObjectMapper objectMapper; + + static { + objectMapper = new ObjectMapper(); + objectMapper.enable( + DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES, + DeserializationFeature.FAIL_ON_NULL_FOR_PRIMITIVES, + DeserializationFeature.FAIL_ON_READING_DUP_TREE_KEY, + DeserializationFeature.FAIL_ON_MISSING_CREATOR_PROPERTIES); + objectMapper.disable( + SerializationFeature.FAIL_ON_EMPTY_BEANS); + } + + /** + * Returns a preconfigured {@link ObjectMapper}. + * + * @return preconfigured object mapper + */ + public static ObjectMapper getStrictObjectMapper() { + return objectMapper; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FatalErrorHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FatalErrorHandler.java index 7721117a240b9..dbccaa87d8732 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FatalErrorHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FatalErrorHandler.java @@ -18,7 +18,18 @@ package org.apache.flink.runtime.rpc; +/** + * Handler for fatal errors. + */ public interface FatalErrorHandler { + /** + * Being called when a fatal error occurs. + * + *

IMPORTANT: This call should never be blocking since it might be called from within + * the main thread of an {@link RpcEndpoint}. + * + * @param exception cause + */ void onFatalError(Throwable exception); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedMainThreadExecutable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedMainThreadExecutable.java new file mode 100644 index 0000000000000..16cacc84dd237 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedMainThreadExecutable.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc; + +import org.apache.flink.api.common.time.Time; + +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; + +/** + * Extended {@link MainThreadExecutable} interface which allows to run unfenced runnables + * in the main thread. + */ +public interface FencedMainThreadExecutable extends MainThreadExecutable { + + /** + * Run the given runnable in the main thread without attaching a fencing token. + * + * @param runnable to run in the main thread without validating the fencing token. + */ + void runAsyncWithoutFencing(Runnable runnable); + + /** + * Run the given callable in the main thread without attaching a fencing token. + * + * @param callable to run in the main thread without validating the fencing token. + * @param timeout for the operation + * @param type of the callable result + * @return Future containing the callable result + */ + CompletableFuture callAsyncWithoutFencing(Callable callable, Time timeout); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcEndpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcEndpoint.java new file mode 100644 index 0000000000000..81bae2924e674 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcEndpoint.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.UUID; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; + +/** + * Base class for fenced {@link RpcEndpoint}. A fenced rpc endpoint expects all rpc messages + * being enriched with fencing tokens. Furthermore, the rpc endpoint has its own fencing token + * assigned. The rpc is then only executed if the attached fencing token equals the endpoint's own + * token. + * + * @param type of the fencing token + */ +public class FencedRpcEndpoint extends RpcEndpoint { + + private volatile F fencingToken; + private volatile MainThreadExecutor fencedMainThreadExecutor; + + protected FencedRpcEndpoint(RpcService rpcService, String endpointId, F initialFencingToken) { + super(rpcService, endpointId); + + this.fencingToken = Preconditions.checkNotNull(initialFencingToken); + this.fencedMainThreadExecutor = new MainThreadExecutor( + getRpcService().fenceRpcServer( + rpcServer, + initialFencingToken)); + } + + protected FencedRpcEndpoint(RpcService rpcService, F initialFencingToken) { + this(rpcService, UUID.randomUUID().toString(), initialFencingToken); + } + + public F getFencingToken() { + return fencingToken; + } + + protected void setFencingToken(F newFencingToken) { + // this method should only be called from within the main thread + validateRunsInMainThread(); + + this.fencingToken = newFencingToken; + + // setting a new fencing token entails that we need a new MainThreadExecutor + // which is bound to the new fencing token + MainThreadExecutable mainThreadExecutable = getRpcService().fenceRpcServer( + rpcServer, + newFencingToken); + + this.fencedMainThreadExecutor = new MainThreadExecutor(mainThreadExecutable); + } + + /** + * Returns a main thread executor which is bound to the currently valid fencing token. + * This means that runnables which are executed with this executor fail after the fencing + * token has changed. This allows to scope operations by the fencing token. + * + * @return MainThreadExecutor bound to the current fencing token + */ + @Override + protected MainThreadExecutor getMainThreadExecutor() { + return fencedMainThreadExecutor; + } + + /** + * Run the given runnable in the main thread of the RpcEndpoint without checking the fencing + * token. This allows to run operations outside of the fencing token scope. + * + * @param runnable to execute in the main thread of the rpc endpoint without checking the fencing token. + */ + protected void runAsyncWithoutFencing(Runnable runnable) { + if (rpcServer instanceof FencedMainThreadExecutable) { + ((FencedMainThreadExecutable) rpcServer).runAsyncWithoutFencing(runnable); + } else { + throw new RuntimeException("FencedRpcEndpoint has not been started with a FencedMainThreadExecutable RpcServer."); + } + } + + /** + * Run the given callable in the main thread of the RpcEndpoint without checking the fencing + * token. This allows to run operations outside of the fencing token scope. + * + * @param callable to run in the main thread of the rpc endpoint without checkint the fencing token. + * @param timeout for the operation. + * @return Future containing the callable result. + */ + protected CompletableFuture callAsyncWithoutFencing(Callable callable, Time timeout) { + if (rpcServer instanceof FencedMainThreadExecutable) { + return ((FencedMainThreadExecutable) rpcServer).callAsyncWithoutFencing(callable, timeout); + } else { + throw new RuntimeException("FencedRpcEndpoint has not been started with a FencedMainThreadExecutable RpcServer."); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/KvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcGateway.java similarity index 66% rename from flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/KvStateSnapshot.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcGateway.java index 687d41536a82e..fab638f912c1b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/KvStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/FencedRpcGateway.java @@ -16,17 +16,22 @@ * limitations under the License. */ -package org.apache.flink.migration.runtime.state; +package org.apache.flink.runtime.rpc; -import org.apache.flink.api.common.state.State; -import org.apache.flink.api.common.state.StateDescriptor; +import java.io.Serializable; /** - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. + * Fenced {@link RpcGateway}. This gateway allows to have access to the associated + * fencing token. + * + * @param type of the fencing token */ -@Deprecated -@SuppressWarnings("deprecation") -public interface KvStateSnapshot> - extends StateObject { +public interface FencedRpcGateway extends RpcGateway { + /** + * Get the current fencing token. + * + * @return current fencing token + */ + F getFencingToken(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java index 980ae48dfe991..563674add8685 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java @@ -24,6 +24,7 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; + import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -66,14 +67,14 @@ public abstract class RpcEndpoint implements RpcGateway { private final String endpointId; /** Interface to access the underlying rpc server */ - private final RpcServer rpcServer; + protected final RpcServer rpcServer; + + /** A reference to the endpoint's main thread, if the current method is called by the main thread */ + final AtomicReference currentMainThread = new AtomicReference<>(null); /** The main thread executor to be used to execute future callbacks in the main thread * of the executing rpc server. */ - private final Executor mainThreadExecutor; - - /** A reference to the endpoint's main thread, if the current method is called by the main thread */ - final AtomicReference currentMainThread = new AtomicReference<>(null); + private final MainThreadExecutor mainThreadExecutor; /** * Initializes the RPC endpoint. @@ -208,7 +209,7 @@ public String getHostname() { * * @return Main thread execution context */ - protected Executor getMainThreadExecutor() { + protected MainThreadExecutor getMainThreadExecutor() { return mainThreadExecutor; } @@ -310,7 +311,7 @@ public void validateRunsInMainThread() { /** * Executor which executes runnables in the main thread context. */ - private static class MainThreadExecutor implements Executor { + protected static class MainThreadExecutor implements Executor { private final MainThreadExecutable gateway; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcService.java index 3b5a5e2c886f8..9b2e318888e63 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcService.java @@ -21,6 +21,7 @@ import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.rpc.exceptions.RpcConnectionException; +import java.io.Serializable; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -61,7 +62,27 @@ public interface RpcService { * @return Future containing the rpc gateway or an {@link RpcConnectionException} if the * connection attempt failed */ - CompletableFuture connect(String address, Class clazz); + CompletableFuture connect( + String address, + Class clazz); + + /** + * Connect to ta remote fenced rpc server under the provided address. Returns a fenced rpc gateway + * which can be used to communicate with the rpc server. If the connection failed, then the + * returned future is failed with a {@link RpcConnectionException}. + * + * @param address Address of the remote rpc server + * @param fencingToken Fencing token to be used when communicating with the server + * @param clazz Class of the rpc gateway to return + * @param Type of the fencing token + * @param Type of the rpc gateway to return + * @return Future containing the fenced rpc gateway or an {@link RpcConnectionException} if the + * connection attempt failed + */ + > CompletableFuture connect( + String address, + F fencingToken, + Class clazz); /** * Start a rpc server which forwards the remote procedure calls to the provided rpc endpoint. @@ -72,6 +93,21 @@ public interface RpcService { */ RpcServer startServer(C rpcEndpoint); + + /** + * Fence the given RpcServer with the given fencing token. + * + *

Fencing the RpcServer means that we fix the fencing token to the provided value. + * All RPCs will then be enriched with this fencing token. This expects that the receiving + * RPC endpoint extends {@link FencedRpcEndpoint}. + * + * @param rpcServer to fence with the given fencing token + * @param fencingToken to fence the RpcServer with + * @param type of the fencing token + * @return Fenced RpcServer + */ + RpcServer fenceRpcServer(RpcServer rpcServer, F fencingToken); + /** * Stop the underlying rpc server of the provided self gateway. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcUtils.java index 9738970a61f5d..a644efd9853d7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcUtils.java @@ -18,13 +18,21 @@ package org.apache.flink.runtime.rpc; +import org.apache.flink.api.common.time.Time; + import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; /** * Utility functions for Flink's RPC implementation */ public class RpcUtils { + + public static final Time INF_TIMEOUT = Time.milliseconds(Long.MAX_VALUE); + /** * Extracts all {@link RpcGateway} interfaces implemented by the given clazz. * @@ -47,6 +55,20 @@ public static Set> extractImplementedRpcGateways(Cla return interfaces; } + /** + * Shuts the given {@link RpcEndpoint} down and awaits its termination. + * + * @param rpcEndpoint to terminate + * @param timeout for this operation + * @throws ExecutionException if a problem occurs + * @throws InterruptedException if the operation has been interrupted + * @throws TimeoutException if a timeout occurred + */ + public static void terminateRpcEndpoint(RpcEndpoint rpcEndpoint, Time timeout) throws ExecutionException, InterruptedException, TimeoutException { + rpcEndpoint.shutDown(); + rpcEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + // We don't want this class to be instantiable private RpcUtils() {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java index 0521f2e5717ea..fc785cb7c8d33 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java @@ -22,17 +22,18 @@ import akka.pattern.Patterns; import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.concurrent.FutureUtils; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.MainThreadExecutable; import org.apache.flink.runtime.rpc.RpcServer; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; import org.apache.flink.runtime.rpc.StartStoppable; -import org.apache.flink.runtime.rpc.akka.messages.CallAsync; -import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation; +import org.apache.flink.runtime.rpc.messages.CallAsync; +import org.apache.flink.runtime.rpc.messages.LocalRpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.Processing; -import org.apache.flink.runtime.rpc.akka.messages.RemoteRpcInvocation; -import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation; -import org.apache.flink.runtime.rpc.akka.messages.RunAsync; +import org.apache.flink.runtime.rpc.messages.RemoteRpcInvocation; +import org.apache.flink.runtime.rpc.messages.RpcInvocation; +import org.apache.flink.runtime.rpc.messages.RunAsync; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +47,7 @@ import java.util.Objects; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkArgument; @@ -72,7 +74,7 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaGateway, RpcServer private final ActorRef rpcEndpoint; // whether the actor ref is local and thus no message serialization is needed - private final boolean isLocal; + protected final boolean isLocal; // default timeout for asks private final Time timeout; @@ -112,53 +114,13 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl declaringClass.equals(MainThreadExecutable.class) || declaringClass.equals(RpcServer.class)) { result = method.invoke(this, args); + } else if (declaringClass.equals(FencedRpcGateway.class)) { + throw new UnsupportedOperationException("AkkaInvocationHandler does not support the call FencedRpcGateway#" + + method.getName() + ". This indicates that you retrieved a FencedRpcGateway without specifying a " + + "fencing token. Please use RpcService#connect(RpcService, F, Time) with F being the fencing token to " + + "retrieve a properly FencedRpcGateway."); } else { - String methodName = method.getName(); - Class[] parameterTypes = method.getParameterTypes(); - Annotation[][] parameterAnnotations = method.getParameterAnnotations(); - Time futureTimeout = extractRpcTimeout(parameterAnnotations, args, timeout); - - RpcInvocation rpcInvocation; - - if (isLocal) { - rpcInvocation = new LocalRpcInvocation( - methodName, - parameterTypes, - args); - } else { - try { - RemoteRpcInvocation remoteRpcInvocation = new RemoteRpcInvocation( - methodName, - parameterTypes, - args); - - if (remoteRpcInvocation.getSize() > maximumFramesize) { - throw new IOException("The rpc invocation size exceeds the maximum akka framesize."); - } else { - rpcInvocation = remoteRpcInvocation; - } - } catch (IOException e) { - LOG.warn("Could not create remote rpc invocation message. Failing rpc invocation because...", e); - throw e; - } - } - - Class returnType = method.getReturnType(); - - if (Objects.equals(returnType, Void.TYPE)) { - rpcEndpoint.tell(rpcInvocation, ActorRef.noSender()); - - result = null; - } else if (Objects.equals(returnType,CompletableFuture.class)) { - // execute an asynchronous call - result = FutureUtils.toJava(Patterns.ask(rpcEndpoint, rpcInvocation, futureTimeout.toMilliseconds())); - } else { - // execute a synchronous call - CompletableFuture futureResult = FutureUtils.toJava( - Patterns.ask(rpcEndpoint, rpcInvocation, futureTimeout.toMilliseconds())); - - result = futureResult.get(futureTimeout.getSize(), futureTimeout.getUnit()); - } + result = invokeRpc(method, args); } return result; @@ -171,7 +133,7 @@ public ActorRef getRpcEndpoint() { @Override public void runAsync(Runnable runnable) { - scheduleRunAsync(runnable, 0); + scheduleRunAsync(runnable, 0L); } @Override @@ -181,7 +143,7 @@ public void scheduleRunAsync(Runnable runnable, long delayMillis) { if (isLocal) { long atTimeNanos = delayMillis == 0 ? 0 : System.nanoTime() + (delayMillis * 1_000_000); - rpcEndpoint.tell(new RunAsync(runnable, atTimeNanos), ActorRef.noSender()); + tell(new RunAsync(runnable, atTimeNanos)); } else { throw new RuntimeException("Trying to send a Runnable to a remote actor at " + rpcEndpoint.path() + ". This is not supported."); @@ -192,9 +154,9 @@ public void scheduleRunAsync(Runnable runnable, long delayMillis) { public CompletableFuture callAsync(Callable callable, Time callTimeout) { if(isLocal) { @SuppressWarnings("unchecked") - scala.concurrent.Future resultFuture = (scala.concurrent.Future) Patterns.ask(rpcEndpoint, new CallAsync(callable), callTimeout.toMilliseconds()); + CompletableFuture resultFuture = (CompletableFuture) ask(new CallAsync(callable), callTimeout); - return FutureUtils.toJava(resultFuture); + return resultFuture; } else { throw new RuntimeException("Trying to send a Callable to a remote actor at " + rpcEndpoint.path() + ". This is not supported."); @@ -211,6 +173,88 @@ public void stop() { rpcEndpoint.tell(Processing.STOP, ActorRef.noSender()); } + // ------------------------------------------------------------------------ + // Private methods + // ------------------------------------------------------------------------ + + /** + * Invokes a RPC method by sending the RPC invocation details to the rpc endpoint. + * + * @param method to call + * @param args of the method call + * @return result of the RPC + * @throws Exception if the RPC invocation fails + */ + private Object invokeRpc(Method method, Object[] args) throws Exception { + String methodName = method.getName(); + Class[] parameterTypes = method.getParameterTypes(); + Annotation[][] parameterAnnotations = method.getParameterAnnotations(); + Time futureTimeout = extractRpcTimeout(parameterAnnotations, args, timeout); + + final RpcInvocation rpcInvocation = createRpcInvocationMessage(methodName, parameterTypes, args); + + Class returnType = method.getReturnType(); + + final Object result; + + if (Objects.equals(returnType, Void.TYPE)) { + tell(rpcInvocation); + + result = null; + } else if (Objects.equals(returnType,CompletableFuture.class)) { + // execute an asynchronous call + result = ask(rpcInvocation, futureTimeout); + } else { + // execute a synchronous call + CompletableFuture futureResult = ask(rpcInvocation, futureTimeout); + + result = futureResult.get(futureTimeout.getSize(), futureTimeout.getUnit()); + } + + return result; + } + + /** + * Create the RpcInvocation message for the given RPC. + * + * @param methodName of the RPC + * @param parameterTypes of the RPC + * @param args of the RPC + * @return RpcInvocation message which encapsulates the RPC details + * @throws IOException if we cannot serialize the RPC invocation parameters + */ + protected RpcInvocation createRpcInvocationMessage( + final String methodName, + final Class[] parameterTypes, + final Object[] args) throws IOException { + final RpcInvocation rpcInvocation; + + if (isLocal) { + rpcInvocation = new LocalRpcInvocation( + methodName, + parameterTypes, + args); + } else { + try { + RemoteRpcInvocation remoteRpcInvocation = new RemoteRpcInvocation( + methodName, + parameterTypes, + args); + + if (remoteRpcInvocation.getSize() > maximumFramesize) { + throw new IOException("The rpc invocation size exceeds the maximum akka framesize."); + } else { + rpcInvocation = remoteRpcInvocation; + } + } catch (IOException e) { + LOG.warn("Could not create remote rpc invocation message. Failing rpc invocation because...", e); + throw e; + } + } + + return rpcInvocation; + } + // ------------------------------------------------------------------------ // Helper methods // ------------------------------------------------------------------------ @@ -262,6 +306,28 @@ private static boolean isRpcTimeout(Annotation[] annotations) { return false; } + /** + * Sends the message to the RPC endpoint. + * + * @param message to send to the RPC endpoint. + */ + protected void tell(Object message) { + rpcEndpoint.tell(message, ActorRef.noSender()); + } + + /** + * Sends the message to the RPC endpoint and returns a future containing + * its response. + * + * @param message to send to the RPC endpoint + * @param timeout time to wait until the response future is failed with a {@link TimeoutException} + * @return Response future + */ + protected CompletableFuture ask(Object message, Time timeout) { + return FutureUtils.toJava( + Patterns.ask(rpcEndpoint, message, timeout.toMilliseconds())); + } + @Override public String getAddress() { return address; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java index f557447dee9db..f6c2e8be6faf3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java @@ -18,30 +18,26 @@ package org.apache.flink.runtime.rpc.akka; -import akka.actor.ActorRef; -import akka.actor.Status; -import akka.actor.UntypedActor; -import akka.japi.Procedure; -import akka.pattern.Patterns; import org.apache.flink.runtime.rpc.MainThreadValidatorUtil; import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.akka.exceptions.AkkaRpcException; -import org.apache.flink.runtime.rpc.akka.messages.CallAsync; -import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation; +import org.apache.flink.runtime.rpc.akka.exceptions.AkkaUnknownMessageException; import org.apache.flink.runtime.rpc.akka.messages.Processing; -import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation; -import org.apache.flink.runtime.rpc.akka.messages.RunAsync; - -import org.apache.flink.runtime.rpc.akka.messages.Shutdown; import org.apache.flink.runtime.rpc.exceptions.RpcConnectionException; +import org.apache.flink.runtime.rpc.messages.CallAsync; +import org.apache.flink.runtime.rpc.messages.LocalRpcInvocation; +import org.apache.flink.runtime.rpc.messages.RpcInvocation; +import org.apache.flink.runtime.rpc.messages.RunAsync; import org.apache.flink.util.ExceptionUtils; + +import akka.actor.ActorRef; +import akka.actor.Status; +import akka.actor.UntypedActor; +import akka.pattern.Patterns; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.concurrent.duration.FiniteDuration; -import scala.concurrent.impl.Promise; - import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -49,6 +45,9 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import scala.concurrent.duration.FiniteDuration; +import scala.concurrent.impl.Promise; + import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -70,88 +69,94 @@ */ class AkkaRpcActor extends UntypedActor { - private static final Logger LOG = LoggerFactory.getLogger(AkkaRpcActor.class); + protected final Logger log = LoggerFactory.getLogger(getClass()); - /** the endpoint to invoke the methods on */ - private final T rpcEndpoint; + /** the endpoint to invoke the methods on. */ + protected final T rpcEndpoint; - /** the helper that tracks whether calls come from the main thread */ + /** the helper that tracks whether calls come from the main thread. */ private final MainThreadValidatorUtil mainThreadValidator; private final CompletableFuture terminationFuture; - /** Throwable which might have been thrown by the postStop method */ - private Throwable shutdownThrowable; - AkkaRpcActor(final T rpcEndpoint, final CompletableFuture terminationFuture) { this.rpcEndpoint = checkNotNull(rpcEndpoint, "rpc endpoint"); this.mainThreadValidator = new MainThreadValidatorUtil(rpcEndpoint); this.terminationFuture = checkNotNull(terminationFuture); - - this.shutdownThrowable = null; } @Override public void postStop() throws Exception { - super.postStop(); + mainThreadValidator.enterMainThread(); - // IMPORTANT: This only works if we don't use a restarting supervisor strategy. Otherwise - // we would complete the future and let the actor system restart the actor with a completed - // future. - // Complete the termination future so that others know that we've stopped. + try { + Throwable shutdownThrowable = null; - if (shutdownThrowable != null) { - terminationFuture.completeExceptionally(shutdownThrowable); - } else { - terminationFuture.complete(null); + try { + rpcEndpoint.postStop(); + } catch (Throwable throwable) { + shutdownThrowable = throwable; + } + + super.postStop(); + + // IMPORTANT: This only works if we don't use a restarting supervisor strategy. Otherwise + // we would complete the future and let the actor system restart the actor with a completed + // future. + // Complete the termination future so that others know that we've stopped. + + if (shutdownThrowable != null) { + terminationFuture.completeExceptionally(shutdownThrowable); + } else { + terminationFuture.complete(null); + } + } finally { + mainThreadValidator.exitMainThread(); } } @Override public void onReceive(final Object message) { if (message.equals(Processing.START)) { - getContext().become(new Procedure() { - @Override - public void apply(Object msg) throws Exception { + getContext().become( + (Object msg) -> { if (msg.equals(Processing.STOP)) { getContext().unbecome(); } else { - handleMessage(msg); + mainThreadValidator.enterMainThread(); + + try { + handleMessage(msg); + } finally { + mainThreadValidator.exitMainThread(); + } } - } - }); + }); } else { - LOG.info("The rpc endpoint {} has not been started yet. Discarding message {} until processing is started.", + log.info("The rpc endpoint {} has not been started yet. Discarding message {} until processing is started.", rpcEndpoint.getClass().getName(), message.getClass().getName()); - if (!getSender().equals(ActorRef.noSender())) { - // fail a possible future if we have a sender - getSender().tell(new Status.Failure(new AkkaRpcException("Discard message, because " + - "the rpc endpoint has not been started yet.")), getSelf()); - } + sendErrorIfSender(new AkkaRpcException("Discard message, because " + + "the rpc endpoint has not been started yet.")); } } - private void handleMessage(Object message) { - mainThreadValidator.enterMainThread(); - try { - if (message instanceof RunAsync) { - handleRunAsync((RunAsync) message); - } else if (message instanceof CallAsync) { - handleCallAsync((CallAsync) message); - } else if (message instanceof RpcInvocation) { - handleRpcInvocation((RpcInvocation) message); - } else if (message instanceof Shutdown) { - triggerShutdown(); - } else { - LOG.warn( - "Received message of unknown type {} with value {}. Dropping this message!", - message.getClass().getName(), - message); - } - } finally { - mainThreadValidator.exitMainThread(); + protected void handleMessage(Object message) { + if (message instanceof RunAsync) { + handleRunAsync((RunAsync) message); + } else if (message instanceof CallAsync) { + handleCallAsync((CallAsync) message); + } else if (message instanceof RpcInvocation) { + handleRpcInvocation((RpcInvocation) message); + } else { + log.warn( + "Received message of unknown type {} with value {}. Dropping this message!", + message.getClass().getName(), + message); + + sendErrorIfSender(new AkkaUnknownMessageException("Received unknown message " + message + + " of type " + message.getClass().getSimpleName() + '.')); } } @@ -170,18 +175,18 @@ private void handleRpcInvocation(RpcInvocation rpcInvocation) { Class[] parameterTypes = rpcInvocation.getParameterTypes(); rpcMethod = lookupRpcMethod(methodName, parameterTypes); - } catch(ClassNotFoundException e) { - LOG.error("Could not load method arguments.", e); + } catch (ClassNotFoundException e) { + log.error("Could not load method arguments.", e); RpcConnectionException rpcException = new RpcConnectionException("Could not load method arguments.", e); getSender().tell(new Status.Failure(rpcException), getSelf()); } catch (IOException e) { - LOG.error("Could not deserialize rpc invocation message.", e); + log.error("Could not deserialize rpc invocation message.", e); RpcConnectionException rpcException = new RpcConnectionException("Could not deserialize rpc invocation message.", e); getSender().tell(new Status.Failure(rpcException), getSelf()); } catch (final NoSuchMethodException e) { - LOG.error("Could not find rpc method for rpc invocation.", e); + log.error("Could not find rpc method for rpc invocation.", e); RpcConnectionException rpcException = new RpcConnectionException("Could not find rpc method for rpc invocation.", e); getSender().tell(new Status.Failure(rpcException), getSelf()); @@ -189,6 +194,9 @@ private void handleRpcInvocation(RpcInvocation rpcInvocation) { if (rpcMethod != null) { try { + // this supports declaration of anonymous classes + rpcMethod.setAccessible(true); + if (rpcMethod.getReturnType().equals(Void.TYPE)) { // No return value to send back rpcMethod.invoke(rpcEndpoint, rpcInvocation.getArgs()); @@ -199,7 +207,7 @@ private void handleRpcInvocation(RpcInvocation rpcInvocation) { result = rpcMethod.invoke(rpcEndpoint, rpcInvocation.getArgs()); } catch (InvocationTargetException e) { - LOG.trace("Reporting back error thrown in remote procedure {}", rpcMethod, e); + log.trace("Reporting back error thrown in remote procedure {}", rpcMethod, e); // tell the sender about the failure getSender().tell(new Status.Failure(e.getTargetException()), getSelf()); @@ -226,7 +234,7 @@ private void handleRpcInvocation(RpcInvocation rpcInvocation) { } } } catch (Throwable e) { - LOG.error("Error while executing remote procedure call {}.", rpcMethod, e); + log.error("Error while executing remote procedure call {}.", rpcMethod, e); // tell the sender about the failure getSender().tell(new Status.Failure(e), getSelf()); } @@ -246,9 +254,9 @@ private void handleCallAsync(CallAsync callAsync) { "prior to sending the message. The " + callAsync.getClass().getName() + " is only supported with local communication."; - LOG.warn(result); + log.warn(result); - getSender().tell(new Status.Failure(new Exception(result)), getSelf()); + getSender().tell(new Status.Failure(new AkkaRpcException(result)), getSelf()); } else { try { Object result = callAsync.getCallable().call(); @@ -268,14 +276,14 @@ private void handleCallAsync(CallAsync callAsync) { */ private void handleRunAsync(RunAsync runAsync) { if (runAsync.getRunnable() == null) { - LOG.warn("Received a {} message with an empty runnable field. This indicates " + + log.warn("Received a {} message with an empty runnable field. This indicates " + "that this message has been serialized prior to sending the message. The " + "{} is only supported with local communication.", runAsync.getClass().getName(), runAsync.getClass().getName()); } else { - final long timeToRun = runAsync.getTimeNanos(); + final long timeToRun = runAsync.getTimeNanos(); final long delayNanos; if (timeToRun == 0 || (delayNanos = timeToRun - System.nanoTime()) <= 0) { @@ -283,12 +291,12 @@ private void handleRunAsync(RunAsync runAsync) { try { runAsync.getRunnable().run(); } catch (Throwable t) { - LOG.error("Caught exception while executing runnable in main thread.", t); + log.error("Caught exception while executing runnable in main thread.", t); ExceptionUtils.rethrowIfFatalErrorOrOOM(t); } } else { - // schedule for later. send a new message after the delay, which will then be immediately executed + // schedule for later. send a new message after the delay, which will then be immediately executed FiniteDuration delay = new FiniteDuration(delayNanos, TimeUnit.NANOSECONDS); RunAsync message = new RunAsync(runAsync.getRunnable(), timeToRun); @@ -298,17 +306,6 @@ private void handleRunAsync(RunAsync runAsync) { } } - private void triggerShutdown() { - try { - rpcEndpoint.postStop(); - } catch (Throwable throwable) { - shutdownThrowable = throwable; - } - - // now stop the actor which will stop processing of any further messages - getContext().system().stop(getSelf()); - } - /** * Look up the rpc method on the given {@link RpcEndpoint} instance. * @@ -321,4 +318,15 @@ private void triggerShutdown() { private Method lookupRpcMethod(final String methodName, final Class[] parameterTypes) throws NoSuchMethodException { return rpcEndpoint.getClass().getMethod(methodName, parameterTypes); } + + /** + * Send throwable to sender if the sender is specified. + * + * @param throwable to send to the sender + */ + protected void sendErrorIfSender(Throwable throwable) { + if (!getSender().equals(ActorRef.noSender())) { + getSender().tell(new Status.Failure(throwable), getSelf()); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java index ab851f6c8065b..07b334d8d604e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java @@ -30,15 +30,18 @@ import akka.dispatch.Mapper; import akka.pattern.Patterns; import org.apache.flink.api.common.time.Time; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.concurrent.ScheduledExecutor; +import org.apache.flink.runtime.rpc.FencedMainThreadExecutable; +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; +import org.apache.flink.runtime.rpc.FencedRpcGateway; import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.RpcServer; import org.apache.flink.runtime.rpc.RpcUtils; -import org.apache.flink.runtime.rpc.akka.messages.Shutdown; import org.apache.flink.runtime.rpc.exceptions.RpcConnectionException; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -49,6 +52,8 @@ import javax.annotation.Nonnull; import javax.annotation.concurrent.ThreadSafe; + +import java.io.Serializable; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Proxy; import java.util.HashSet; @@ -61,6 +66,7 @@ import java.util.concurrent.RunnableScheduledFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -131,60 +137,44 @@ public int getPort() { // this method does not mutate state and is thus thread-safe @Override - public CompletableFuture connect(final String address, final Class clazz) { - checkState(!stopped, "RpcService is stopped"); - - LOG.debug("Try to connect to remote RPC endpoint with address {}. Returning a {} gateway.", - address, clazz.getName()); - - final ActorSelection actorSel = actorSystem.actorSelection(address); - - final scala.concurrent.Future identify = Patterns.ask(actorSel, new Identify(42), timeout.toMilliseconds()); - final scala.concurrent.Future resultFuture = identify.map(new Mapper(){ - @Override - public C checkedApply(Object obj) throws Exception { - - ActorIdentity actorIdentity = (ActorIdentity) obj; - - if (actorIdentity.getRef() == null) { - throw new RpcConnectionException("Could not connect to rpc endpoint under address " + address + '.'); - } else { - ActorRef actorRef = actorIdentity.getRef(); - - final String address = AkkaUtils.getAkkaURL(actorSystem, actorRef); - final String hostname; - Option host = actorRef.path().address().host(); - if (host.isEmpty()) { - hostname = "localhost"; - } else { - hostname = host.get(); - } - - InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler( - address, - hostname, - actorRef, - timeout, - maximumFramesize, - null); - - // Rather than using the System ClassLoader directly, we derive the ClassLoader - // from this class . That works better in cases where Flink runs embedded and all Flink - // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader - ClassLoader classLoader = AkkaRpcService.this.getClass().getClassLoader(); + public CompletableFuture connect( + final String address, + final Class clazz) { - @SuppressWarnings("unchecked") - C proxy = (C) Proxy.newProxyInstance( - classLoader, - new Class[]{clazz}, - akkaInvocationHandler); - - return proxy; - } - } - }, actorSystem.dispatcher()); + return connectInternal( + address, + clazz, + (ActorRef actorRef) -> { + Tuple2 addressHostname = extractAddressHostname(actorRef); + + return new AkkaInvocationHandler( + addressHostname.f0, + addressHostname.f1, + actorRef, + timeout, + maximumFramesize, + null); + }); + } - return FutureUtils.toJava(resultFuture); + // this method does not mutate state and is thus thread-safe + @Override + public > CompletableFuture connect(String address, F fencingToken, Class clazz) { + return connectInternal( + address, + clazz, + (ActorRef actorRef) -> { + Tuple2 addressHostname = extractAddressHostname(actorRef); + + return new FencedAkkaInvocationHandler<>( + addressHostname.f0, + addressHostname.f1, + actorRef, + timeout, + maximumFramesize, + null, + () -> fencingToken); + }); } @Override @@ -192,7 +182,14 @@ public RpcServer startServer(C rpcEndpoint) checkNotNull(rpcEndpoint, "rpc endpoint"); CompletableFuture terminationFuture = new CompletableFuture<>(); - Props akkaRpcActorProps = Props.create(AkkaRpcActor.class, rpcEndpoint, terminationFuture); + final Props akkaRpcActorProps; + + if (rpcEndpoint instanceof FencedRpcEndpoint) { + akkaRpcActorProps = Props.create(FencedAkkaRpcActor.class, rpcEndpoint, terminationFuture); + } else { + akkaRpcActorProps = Props.create(AkkaRpcActor.class, rpcEndpoint, terminationFuture); + } + ActorRef actorRef; synchronized (lock) { @@ -212,24 +209,40 @@ public RpcServer startServer(C rpcEndpoint) hostname = host.get(); } - InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler( - address, - hostname, - actorRef, - timeout, - maximumFramesize, - terminationFuture); + Set> implementedRpcGateways = new HashSet<>(RpcUtils.extractImplementedRpcGateways(rpcEndpoint.getClass())); + + implementedRpcGateways.add(RpcServer.class); + implementedRpcGateways.add(AkkaGateway.class); + + final InvocationHandler akkaInvocationHandler; + + if (rpcEndpoint instanceof FencedRpcEndpoint) { + // a FencedRpcEndpoint needs a FencedAkkaInvocationHandler + akkaInvocationHandler = new FencedAkkaInvocationHandler<>( + address, + hostname, + actorRef, + timeout, + maximumFramesize, + terminationFuture, + ((FencedRpcEndpoint) rpcEndpoint)::getFencingToken); + + implementedRpcGateways.add(FencedMainThreadExecutable.class); + } else { + akkaInvocationHandler = new AkkaInvocationHandler( + address, + hostname, + actorRef, + timeout, + maximumFramesize, + terminationFuture); + } // Rather than using the System ClassLoader directly, we derive the ClassLoader // from this class . That works better in cases where Flink runs embedded and all Flink // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader ClassLoader classLoader = getClass().getClassLoader(); - Set> implementedRpcGateways = RpcUtils.extractImplementedRpcGateways(rpcEndpoint.getClass()); - - implementedRpcGateways.add(RpcServer.class); - implementedRpcGateways.add(AkkaGateway.class); - @SuppressWarnings("unchecked") RpcServer server = (RpcServer) Proxy.newProxyInstance( classLoader, @@ -239,6 +252,33 @@ public RpcServer startServer(C rpcEndpoint) return server; } + @Override + public RpcServer fenceRpcServer(RpcServer rpcServer, F fencingToken) { + if (rpcServer instanceof AkkaGateway) { + + InvocationHandler fencedInvocationHandler = new FencedAkkaInvocationHandler<>( + rpcServer.getAddress(), + rpcServer.getHostname(), + ((AkkaGateway) rpcServer).getRpcEndpoint(), + timeout, + maximumFramesize, + null, + () -> fencingToken); + + // Rather than using the System ClassLoader directly, we derive the ClassLoader + // from this class . That works better in cases where Flink runs embedded and all Flink + // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader + ClassLoader classLoader = getClass().getClassLoader(); + + return (RpcServer) Proxy.newProxyInstance( + classLoader, + new Class[]{RpcServer.class, AkkaGateway.class}, + fencedInvocationHandler); + } else { + throw new RuntimeException("The given RpcServer must implement the AkkaGateway in order to fence it."); + } + } + @Override public void stopServer(RpcServer selfGateway) { if (selfGateway instanceof AkkaGateway) { @@ -256,7 +296,7 @@ public void stopServer(RpcServer selfGateway) { if (fromThisService) { ActorRef selfActorRef = akkaClient.getRpcEndpoint(); LOG.info("Trigger shut down of RPC endpoint {}.", selfActorRef.path()); - selfActorRef.tell(Shutdown.getInstance(), ActorRef.noSender()); + actorSystem.stop(selfActorRef); } else { LOG.debug("RPC endpoint {} already stopped or from different RPC service"); } @@ -273,11 +313,14 @@ public void stopService() { } stopped = true; + actorSystem.shutdown(); actors.clear(); } actorSystem.awaitTermination(); + + LOG.info("Stopped Akka RPC service."); } @Override @@ -317,6 +360,67 @@ public CompletableFuture execute(Callable callable) { return FutureUtils.toJava(scalaFuture); } + // --------------------------------------------------------------------------------------- + // Private helper methods + // --------------------------------------------------------------------------------------- + + private Tuple2 extractAddressHostname(ActorRef actorRef) { + final String actorAddress = AkkaUtils.getAkkaURL(actorSystem, actorRef); + final String hostname; + Option host = actorRef.path().address().host(); + if (host.isEmpty()) { + hostname = "localhost"; + } else { + hostname = host.get(); + } + + return Tuple2.of(actorAddress, hostname); + } + + private CompletableFuture connectInternal( + final String address, + final Class clazz, + Function invocationHandlerFactory) { + checkState(!stopped, "RpcService is stopped"); + + LOG.debug("Try to connect to remote RPC endpoint with address {}. Returning a {} gateway.", + address, clazz.getName()); + + final ActorSelection actorSel = actorSystem.actorSelection(address); + + final Future identify = Patterns.ask(actorSel, new Identify(42), timeout.toMilliseconds()); + final Future resultFuture = identify.map(new Mapper(){ + @Override + public C checkedApply(Object obj) throws Exception { + + ActorIdentity actorIdentity = (ActorIdentity) obj; + + if (actorIdentity.getRef() == null) { + throw new RpcConnectionException("Could not connect to rpc endpoint under address " + address + '.'); + } else { + ActorRef actorRef = actorIdentity.getRef(); + + InvocationHandler invocationHandler = invocationHandlerFactory.apply(actorRef); + + // Rather than using the System ClassLoader directly, we derive the ClassLoader + // from this class . That works better in cases where Flink runs embedded and all Flink + // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader + ClassLoader classLoader = AkkaRpcService.this.getClass().getClassLoader(); + + @SuppressWarnings("unchecked") + C proxy = (C) Proxy.newProxyInstance( + classLoader, + new Class[]{clazz}, + invocationHandler); + + return proxy; + } + } + }, actorSystem.dispatcher()); + + return FutureUtils.toJava(resultFuture); + } + /** * Helper class to expose the internal scheduling logic via a {@link ScheduledExecutor}. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaInvocationHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaInvocationHandler.java new file mode 100644 index 0000000000000..9d2c2950d98f9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaInvocationHandler.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.concurrent.FutureUtils; +import org.apache.flink.runtime.rpc.FencedMainThreadExecutable; +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; +import org.apache.flink.runtime.rpc.FencedRpcGateway; +import org.apache.flink.runtime.rpc.messages.CallAsync; +import org.apache.flink.runtime.rpc.messages.FencedMessage; +import org.apache.flink.runtime.rpc.messages.LocalFencedMessage; +import org.apache.flink.runtime.rpc.messages.RemoteFencedMessage; +import org.apache.flink.runtime.rpc.messages.RunAsync; +import org.apache.flink.runtime.rpc.messages.UnfencedMessage; +import org.apache.flink.util.Preconditions; + +import akka.actor.ActorRef; +import akka.pattern.Patterns; + +import javax.annotation.Nullable; + +import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Fenced extension of the {@link AkkaInvocationHandler}. This invocation handler will be used in combination + * with the {@link FencedRpcEndpoint}. The fencing is done by wrapping all messages in a {@link FencedMessage}. + * + * @param type of the fencing token + */ +public class FencedAkkaInvocationHandler extends AkkaInvocationHandler implements FencedMainThreadExecutable, FencedRpcGateway { + + private final Supplier fencingTokenSupplier; + + public FencedAkkaInvocationHandler( + String address, + String hostname, + ActorRef rpcEndpoint, + Time timeout, + long maximumFramesize, + @Nullable CompletableFuture terminationFuture, + Supplier fencingTokenSupplier) { + super(address, hostname, rpcEndpoint, timeout, maximumFramesize, terminationFuture); + + this.fencingTokenSupplier = Preconditions.checkNotNull(fencingTokenSupplier); + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + Class declaringClass = method.getDeclaringClass(); + + if (declaringClass.equals(FencedMainThreadExecutable.class) || + declaringClass.equals(FencedRpcGateway.class)) { + return method.invoke(this, args); + } else { + return super.invoke(proxy, method, args); + } + } + + @Override + public void runAsyncWithoutFencing(Runnable runnable) { + checkNotNull(runnable, "runnable"); + + if (isLocal) { + getRpcEndpoint().tell( + new UnfencedMessage<>(new RunAsync(runnable, 0L)), ActorRef.noSender()); + } else { + throw new RuntimeException("Trying to send a Runnable to a remote actor at " + + getRpcEndpoint().path() + ". This is not supported."); + } + } + + @Override + public CompletableFuture callAsyncWithoutFencing(Callable callable, Time timeout) { + checkNotNull(callable, "callable"); + checkNotNull(timeout, "timeout"); + + if (isLocal) { + @SuppressWarnings("unchecked") + CompletableFuture resultFuture = (CompletableFuture) FutureUtils.toJava( + Patterns.ask( + getRpcEndpoint(), + new UnfencedMessage<>(new CallAsync(callable)), + timeout.toMilliseconds())); + + return resultFuture; + } else { + throw new RuntimeException("Trying to send a Runnable to a remote actor at " + + getRpcEndpoint().path() + ". This is not supported."); + } + } + + @Override + public void tell(Object message) { + super.tell(fenceMessage(message)); + } + + @Override + public CompletableFuture ask(Object message, Time timeout) { + return super.ask(fenceMessage(message), timeout); + } + + @Override + public F getFencingToken() { + return fencingTokenSupplier.get(); + } + + private

FencedMessage fenceMessage(P message) { + if (isLocal) { + return new LocalFencedMessage<>(fencingTokenSupplier.get(), message); + } else { + if (message instanceof Serializable) { + @SuppressWarnings("unchecked") + FencedMessage result = (FencedMessage) new RemoteFencedMessage<>(fencingTokenSupplier.get(), (Serializable) message); + + return result; + } else { + throw new RuntimeException("Trying to send a non-serializable message " + message + " to a remote " + + "RpcEndpoint. Please make sure that the message implements java.io.Serializable."); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaRpcActor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaRpcActor.java new file mode 100644 index 0000000000000..b10f7deccec3c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/FencedAkkaRpcActor.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka; + +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; +import org.apache.flink.runtime.rpc.akka.exceptions.AkkaUnknownMessageException; +import org.apache.flink.runtime.rpc.exceptions.FencingTokenMismatchException; +import org.apache.flink.runtime.rpc.RpcGateway; +import org.apache.flink.runtime.rpc.messages.FencedMessage; +import org.apache.flink.runtime.rpc.messages.UnfencedMessage; + +import java.io.Serializable; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; + +/** + * Fenced extension of the {@link AkkaRpcActor}. This actor will be started for {@link FencedRpcEndpoint} and is + * responsible for filtering out invalid messages with respect to the current fencing token. + * + * @param type of the fencing token + * @param type of the RpcEndpoint + */ +public class FencedAkkaRpcActor & RpcGateway> extends AkkaRpcActor { + + public FencedAkkaRpcActor(T rpcEndpoint, CompletableFuture terminationFuture) { + super(rpcEndpoint, terminationFuture); + } + + @Override + protected void handleMessage(Object message) { + if (message instanceof FencedMessage) { + @SuppressWarnings("unchecked") + FencedMessage fencedMessage = ((FencedMessage) message); + + F fencingToken = fencedMessage.getFencingToken(); + + if (Objects.equals(rpcEndpoint.getFencingToken(), fencingToken)) { + super.handleMessage(fencedMessage.getPayload()); + } else { + if (log.isDebugEnabled()) { + log.debug("Fencing token mismatch: Ignoring message {} because the fencing token {} did " + + "not match the expected fencing token {}.", message, fencingToken, rpcEndpoint.getFencingToken()); + } + + sendErrorIfSender(new FencingTokenMismatchException("Expected fencing token " + rpcEndpoint.getFencingToken() + ", actual fencing token " + fencingToken)); + } + } else if (message instanceof UnfencedMessage) { + super.handleMessage(((UnfencedMessage) message).getPayload()); + } else { + if (log.isDebugEnabled()) { + log.debug("Unknown message type: Ignoring message {} because it is neither of type {} nor {}.", + message, FencedMessage.class.getSimpleName(), UnfencedMessage.class.getSimpleName()); + } + + sendErrorIfSender(new AkkaUnknownMessageException("Unknown message type: Ignoring message " + message + + " of type " + message.getClass().getSimpleName() + " because it is neither of type " + + FencedMessage.class.getSimpleName() + " nor " + UnfencedMessage.class.getSimpleName() + '.')); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/exceptions/AkkaUnknownMessageException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/exceptions/AkkaUnknownMessageException.java new file mode 100644 index 0000000000000..7504761591c14 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/exceptions/AkkaUnknownMessageException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka.exceptions; + +/** + * Exception which indicates that the AkkaRpcActor has received an + * unknown message type. + */ +public class AkkaUnknownMessageException extends AkkaRpcException { + + private static final long serialVersionUID = 1691338049911020814L; + + public AkkaUnknownMessageException(String message) { + super(message); + } + + public AkkaUnknownMessageException(String message, Throwable cause) { + super(message, cause); + } + + public AkkaUnknownMessageException(Throwable cause) { + super(cause); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Processing.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Processing.java index 5c7df5dcd6021..030ff60bed6e2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Processing.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/Processing.java @@ -21,7 +21,7 @@ /** * Controls the processing behaviour of the {@link org.apache.flink.runtime.rpc.akka.AkkaRpcActor} */ -public enum Processing { +public enum Processing { START, // Unstashes all stashed messages and starts processing incoming messages STOP // Stop processing messages and stashes all incoming messages } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/FencingTokenMismatchException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/FencingTokenMismatchException.java new file mode 100644 index 0000000000000..9a59101a0e813 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/FencingTokenMismatchException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.exceptions; + +import org.apache.flink.runtime.rpc.FencedRpcEndpoint; +import org.apache.flink.runtime.rpc.exceptions.RpcException; + +/** + * Exception which is thrown if the fencing tokens of a {@link FencedRpcEndpoint} do + * not match. + */ +public class FencingTokenMismatchException extends RpcException { + private static final long serialVersionUID = -500634972988881467L; + + public FencingTokenMismatchException(String message) { + super(message); + } + + public FencingTokenMismatchException(String message, Throwable cause) { + super(message, cause); + } + + public FencingTokenMismatchException(Throwable cause) { + super(cause); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/LeaderSessionIDException.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/LeaderSessionIDException.java deleted file mode 100644 index d3ba9a97e8576..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/exceptions/LeaderSessionIDException.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.rpc.exceptions; - -import java.util.UUID; - -import static org.apache.flink.util.Preconditions.checkNotNull; - -/** - * An exception specifying that the received leader session ID is not the same as expected. - */ -public class LeaderSessionIDException extends Exception { - - private static final long serialVersionUID = -3276145308053264636L; - - /** expected leader session id */ - private final UUID expectedLeaderSessionID; - - /** actual leader session id */ - private final UUID actualLeaderSessionID; - - public LeaderSessionIDException(UUID expectedLeaderSessionID, UUID actualLeaderSessionID) { - super("Unmatched leader session ID : expected " + expectedLeaderSessionID + ", actual " + actualLeaderSessionID); - this.expectedLeaderSessionID = checkNotNull(expectedLeaderSessionID); - this.actualLeaderSessionID = checkNotNull(actualLeaderSessionID); - } - - /** - * Get expected leader session id - * - * @return expect leader session id - */ - public UUID getExpectedLeaderSessionID() { - return expectedLeaderSessionID; - } - - /** - * Get actual leader session id - * - * @return actual leader session id - */ - public UUID getActualLeaderSessionID() { - return actualLeaderSessionID; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/CallAsync.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/CallAsync.java similarity index 96% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/CallAsync.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/CallAsync.java index 79b7825e8a3ee..9aa7d70d3155b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/CallAsync.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/CallAsync.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; +package org.apache.flink.runtime.rpc.messages; import org.apache.flink.util.Preconditions; diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/FencedMessage.java similarity index 57% rename from flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StreamStateHandle.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/FencedMessage.java index bfc57bca97921..b67e564f593a4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/state/StreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/FencedMessage.java @@ -16,25 +16,19 @@ * limitations under the License. */ -package org.apache.flink.migration.runtime.state; +package org.apache.flink.runtime.rpc.messages; -import java.io.InputStream; import java.io.Serializable; /** - * A state handle that produces an input stream when resolved. + * Interface for fenced messages. * - * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. + * @param type of the fencing token + * @param

type of the payload */ -@Deprecated -@SuppressWarnings("deprecation") -public interface StreamStateHandle extends StateHandle { +public interface FencedMessage { - /** - * Converts this stream state handle into a state handle that de-serializes - * the stream into an object using Java's serialization mechanism. - * - * @return The state handle that automatically de-serializes. - */ - StateHandle toSerializableHandle(); + F getFencingToken(); + + P getPayload(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalFencedMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalFencedMessage.java new file mode 100644 index 0000000000000..248106558d3b8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalFencedMessage.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.messages; + +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; + +/** + * Local {@link FencedMessage} implementation. This message is used when the communication + * is local and thus does not require its payload to be serializable. + * + * @param type of the fencing token + * @param

type of the payload + */ +public class LocalFencedMessage implements FencedMessage { + + private final F fencingToken; + private final P payload; + + public LocalFencedMessage(F fencingToken, P payload) { + this.fencingToken = Preconditions.checkNotNull(fencingToken); + this.payload = Preconditions.checkNotNull(payload); + } + + @Override + public F getFencingToken() { + return fencingToken; + } + + @Override + public P getPayload() { + return payload; + } + + @Override + public String toString() { + return "LocalFencedMessage(" + fencingToken + ", " + payload + ')'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalRpcInvocation.java similarity index 72% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalRpcInvocation.java index 97c10d71bf141..0bd06c338bcf6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/LocalRpcInvocation.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; +package org.apache.flink.runtime.rpc.messages; import org.apache.flink.util.Preconditions; @@ -31,10 +31,14 @@ public final class LocalRpcInvocation implements RpcInvocation { private final Class[] parameterTypes; private final Object[] args; + private transient String toString; + public LocalRpcInvocation(String methodName, Class[] parameterTypes, Object[] args) { this.methodName = Preconditions.checkNotNull(methodName); this.parameterTypes = Preconditions.checkNotNull(parameterTypes); this.args = args; + + toString = null; } @Override @@ -51,4 +55,25 @@ public Class[] getParameterTypes() { public Object[] getArgs() { return args; } + + @Override + public String toString() { + if (toString == null) { + StringBuilder paramTypeStringBuilder = new StringBuilder(parameterTypes.length * 5); + + if (parameterTypes.length > 0) { + paramTypeStringBuilder.append(parameterTypes[0].getSimpleName()); + + for (int i = 1; i < parameterTypes.length; i++) { + paramTypeStringBuilder + .append(", ") + .append(parameterTypes[i].getSimpleName()); + } + } + + toString = "LocalRpcInvocation(" + methodName + '(' + paramTypeStringBuilder + "))"; + } + + return toString; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteFencedMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteFencedMessage.java new file mode 100644 index 0000000000000..5cf9b98d6202a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteFencedMessage.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.messages; + +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; + +/** + * Remote {@link FencedMessage} implementation. This message is used when the communication + * is remote and thus requires its payload to be serializable. + * + * @param type of the fencing token + * @param

type of the payload + */ +public class RemoteFencedMessage implements FencedMessage, Serializable { + private static final long serialVersionUID = 4043136067468477742L; + + private final F fencingToken; + private final P payload; + + public RemoteFencedMessage(F fencingToken, P payload) { + this.fencingToken = Preconditions.checkNotNull(fencingToken); + this.payload = Preconditions.checkNotNull(payload); + } + + @Override + public F getFencingToken() { + return fencingToken; + } + + @Override + public P getPayload() { + return payload; + } + + @Override + public String toString() { + return "RemoteFencedMessage(" + fencingToken + ", " + payload + ')'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteRpcInvocation.java similarity index 88% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteRpcInvocation.java index bc26a29715c91..779d5dd82c159 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RemoteRpcInvocation.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; +package org.apache.flink.runtime.rpc.messages; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; @@ -43,6 +43,8 @@ public class RemoteRpcInvocation implements RpcInvocation, Serializable { // Transient field which is lazily initialized upon first access to the invocation data private transient RemoteRpcInvocation.MethodInvocation methodInvocation; + private transient String toString; + public RemoteRpcInvocation( final String methodName, final Class[] parameterTypes, @@ -73,6 +75,35 @@ public Object[] getArgs() throws IOException, ClassNotFoundException { return methodInvocation.getArgs(); } + @Override + public String toString() { + if (toString == null) { + + try { + Class[] parameterTypes = getParameterTypes(); + String methodName = getMethodName(); + + StringBuilder paramTypeStringBuilder = new StringBuilder(parameterTypes.length * 5); + + if (parameterTypes.length > 0) { + paramTypeStringBuilder.append(parameterTypes[0].getSimpleName()); + + for (int i = 1; i < parameterTypes.length; i++) { + paramTypeStringBuilder + .append(", ") + .append(parameterTypes[i].getSimpleName()); + } + } + + toString = "RemoteRpcInvocation(" + methodName + '(' + paramTypeStringBuilder + "))"; + } catch (IOException | ClassNotFoundException e) { + toString = "Could not deserialize RemoteRpcInvocation: " + e.getMessage(); + } + } + + return toString; + } + /** * Size (#bytes of the serialized data) of the rpc invocation message. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RpcInvocation.java similarity index 97% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RpcInvocation.java index b174c99a4d37c..4e9f6299fec67 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RpcInvocation.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; +package org.apache.flink.runtime.rpc.messages; import java.io.IOException; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RunAsync.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RunAsync.java similarity index 95% rename from flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RunAsync.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RunAsync.java index 4b8a0b4c6542a..2f6d867ea833a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RunAsync.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/RunAsync.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.rpc.akka.messages; +package org.apache.flink.runtime.rpc.messages; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkArgument; @@ -26,7 +26,7 @@ */ public final class RunAsync { - /** The runnable to be executed. Transient, so it gets lost upon serialization */ + /** The runnable to be executed. Transient, so it gets lost upon serialization */ private final Runnable runnable; /** The delay after which the runnable should be called */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/UnfencedMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/UnfencedMessage.java new file mode 100644 index 0000000000000..27867c4d0735f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/messages/UnfencedMessage.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.messages; + +import org.apache.flink.runtime.rpc.FencedMainThreadExecutable; +import org.apache.flink.util.Preconditions; + +/** + * Wrapper class indicating a message which is not required to match the fencing token + * as it is used by the {@link FencedMainThreadExecutable} to run code in the main thread without + * a valid fencing token. This is required for operations which are not scoped by the current + * fencing token (e.g. leadership grants). + * + *

IMPORTANT: This message is only intended to be send locally. + * + * @param

type of the payload + */ +public class UnfencedMessage

{ + private final P payload; + + public UnfencedMessage(P payload) { + this.payload = Preconditions.checkNotNull(payload); + } + + public P getPayload() { + return payload; + } + + @Override + public String toString() { + return "UnfencedMessage(" + payload + ')'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/security/SecurityUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/security/SecurityUtils.java index 9e6f40258a09f..bdaaed6650d8e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/security/SecurityUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/security/SecurityUtils.java @@ -28,7 +28,8 @@ import org.apache.flink.runtime.security.modules.SecurityModule; import org.apache.flink.runtime.security.modules.ZooKeeperModule; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.security.UserGroupInformation; import org.slf4j.Logger; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java index 0085890df28f0..0268b102be885 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java @@ -65,27 +65,27 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle { private final UUID backendIdentifier; /** - * The key-group range covered by this state handle + * The key-group range covered by this state handle. */ private final KeyGroupRange keyGroupRange; /** - * The checkpoint Id + * The checkpoint Id. */ private final long checkpointId; /** - * Shared state in the incremental checkpoint. This i + * Shared state in the incremental checkpoint. */ private final Map sharedState; /** - * Private state in the incremental checkpoint + * Private state in the incremental checkpoint. */ private final Map privateState; /** - * Primary meta data state of the incremental checkpoint + * Primary meta data state of the incremental checkpoint. */ private final StreamStateHandle metaStateHandle; @@ -143,16 +143,21 @@ public UUID getBackendIdentifier() { @Override public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) { - if (this.keyGroupRange.getIntersection(keyGroupRange) != KeyGroupRange.EMPTY_KEY_GROUP_RANGE) { - return this; - } else { - return null; - } + return KeyGroupRange.EMPTY_KEY_GROUP_RANGE.equals(this.keyGroupRange.getIntersection(keyGroupRange)) ? + null : this; } @Override public void discardState() throws Exception { + SharedStateRegistry registry = this.sharedStateRegistry; + final boolean isRegistered = (registry != null); + + LOG.trace("Discarding IncrementalKeyedStateHandle (registered = {}) for checkpoint {} from backend with id {}.", + isRegistered, + checkpointId, + backendIdentifier); + try { metaStateHandle.discardState(); } catch (Exception e) { @@ -168,19 +173,20 @@ public void discardState() throws Exception { // If this was not registered, we can delete the shared state. We can simply apply this // to all handles, because all handles that have not been created for the first time for this // are only placeholders at this point (disposing them is a NOP). - if (sharedStateRegistry == null) { - try { - StateUtil.bestEffortDiscardAllStateObjects(sharedState.values()); - } catch (Exception e) { - LOG.warn("Could not properly discard new sst file states.", e); - } - } else { + if (isRegistered) { // If this was registered, we only unregister all our referenced shared states // from the registry. for (StateHandleID stateHandleID : sharedState.keySet()) { - sharedStateRegistry.unregisterReference( + registry.unregisterReference( createSharedStateRegistryKeyFromFileName(stateHandleID)); } + } else { + // Otherwise, we assume to own those handles and dispose them directly. + try { + StateUtil.bestEffortDiscardAllStateObjects(sharedState.values()); + } catch (Exception e) { + LOG.warn("Could not properly discard new sst file states.", e); + } } } @@ -202,10 +208,21 @@ public long getStateSize() { @Override public void registerSharedStates(SharedStateRegistry stateRegistry) { - Preconditions.checkState(sharedStateRegistry == null, "The state handle has already registered its shared states."); + // This is a quick check to avoid that we register twice with the same registry. However, the code allows to + // register again with a different registry. The implication is that ownership is transferred to this new + // registry. This should only happen in case of a restart, when the CheckpointCoordinator creates a new + // SharedStateRegistry for the current attempt and the old registry becomes meaningless. We also assume that + // an old registry object from a previous run is due to be GCed and will never be used for registration again. + Preconditions.checkState( + sharedStateRegistry != stateRegistry, + "The state handle has already registered its shared states to the given registry."); sharedStateRegistry = Preconditions.checkNotNull(stateRegistry); + LOG.trace("Registering IncrementalKeyedStateHandle for checkpoint {} from backend with id {}.", + checkpointId, + backendIdentifier); + for (Map.Entry sharedStateHandle : sharedState.entrySet()) { SharedStateRegistryKey registryKey = createSharedStateRegistryKeyFromFileName(sharedStateHandle.getKey()); @@ -284,5 +301,18 @@ public int hashCode() { result = 31 * result + getMetaStateHandle().hashCode(); return result; } + + @Override + public String toString() { + return "IncrementalKeyedStateHandle{" + + "backendIdentifier=" + backendIdentifier + + ", keyGroupRange=" + keyGroupRange + + ", checkpointId=" + checkpointId + + ", sharedState=" + sharedState + + ", privateState=" + privateState + + ", metaStateHandle=" + metaStateHandle + + ", registered=" + (sharedStateRegistry != null) + + '}'; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java index 8e38ad4750d86..8092f6c72de8c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java @@ -141,7 +141,7 @@ public int hashCode() { public String toString() { return "KeyGroupsStateHandle{" + "groupRangeOffsets=" + groupRangeOffsets + - ", data=" + stateHandle + + ", stateHandle=" + stateHandle + '}'; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java index b95daceece56a..1960c1c95f431 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java @@ -38,7 +38,7 @@ public class MultiStreamStateHandle implements StreamStateHandle { private final List stateHandles; private final long stateSize; - public MultiStreamStateHandle(List stateHandles) throws IOException { + public MultiStreamStateHandle(List stateHandles) { this.stateHandles = Preconditions.checkNotNull(stateHandles); long calculateSize = 0L; for(StreamStateHandle stateHandle : stateHandles) { @@ -62,6 +62,14 @@ public long getStateSize() { return stateSize; } + @Override + public String toString() { + return "MultiStreamStateHandle{" + + "stateHandles=" + stateHandles + + ", stateSize=" + stateSize + + '}'; + } + static final class MultiFSDataInputStream extends AbstractMultiFSDataInputStream { private final TreeMap stateHandleMap; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SerializedCheckpointData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SerializedCheckpointData.java index 16ad3fde602b1..394791b9014e2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SerializedCheckpointData.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SerializedCheckpointData.java @@ -25,8 +25,8 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; +import java.util.HashSet; +import java.util.Set; /** * This class represents serialized checkpoint data for a collection of elements. @@ -95,7 +95,7 @@ public int getNumIds() { * * @throws IOException Thrown, if the serialization fails. */ - public static SerializedCheckpointData[] fromDeque(ArrayDeque>> checkpoints, + public static SerializedCheckpointData[] fromDeque(ArrayDeque>> checkpoints, TypeSerializer serializer) throws IOException { return fromDeque(checkpoints, serializer, new DataOutputSerializer(128)); } @@ -111,15 +111,15 @@ public static SerializedCheckpointData[] fromDeque(ArrayDeque SerializedCheckpointData[] fromDeque(ArrayDeque>> checkpoints, + public static SerializedCheckpointData[] fromDeque(ArrayDeque>> checkpoints, TypeSerializer serializer, DataOutputSerializer outputBuffer) throws IOException { SerializedCheckpointData[] serializedCheckpoints = new SerializedCheckpointData[checkpoints.size()]; int pos = 0; - for (Tuple2> checkpoint : checkpoints) { + for (Tuple2> checkpoint : checkpoints) { outputBuffer.clear(); - List checkpointIds = checkpoint.f1; + Set checkpointIds = checkpoint.f1; for (T id : checkpointIds) { serializer.serialize(id, outputBuffer); @@ -146,10 +146,9 @@ public static SerializedCheckpointData[] fromDeque(ArrayDeque ArrayDeque>> toDeque( - SerializedCheckpointData[] data, TypeSerializer serializer) throws IOException - { - ArrayDeque>> deque = new ArrayDeque<>(data.length); + public static ArrayDeque>> toDeque( + SerializedCheckpointData[] data, TypeSerializer serializer) throws IOException { + ArrayDeque>> deque = new ArrayDeque<>(data.length); DataInputDeserializer deser = null; for (SerializedCheckpointData checkpoint : data) { @@ -161,14 +160,14 @@ public static ArrayDeque>> toDeque( deser.setBuffer(serializedData, 0, serializedData.length); } - final List ids = new ArrayList<>(checkpoint.getNumIds()); + final Set ids = new HashSet<>(checkpoint.getNumIds()); final int numIds = checkpoint.getNumIds(); for (int i = 0; i < numIds; i++) { ids.add(serializer.deserialize(deser)); } - deque.addLast(new Tuple2>(checkpoint.checkpointId, ids)); + deque.addLast(new Tuple2>(checkpoint.checkpointId, ids)); } return deque; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java index e0ca873668423..347f30c63ecbc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -38,13 +38,24 @@ * maintain the reference count of {@link StreamStateHandle}s by a key that (logically) identifies * them. */ -public class SharedStateRegistry { +public class SharedStateRegistry implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class); + /** A singleton object for the default implementation of a {@link SharedStateRegistryFactory} */ + public static final SharedStateRegistryFactory DEFAULT_FACTORY = new SharedStateRegistryFactory() { + @Override + public SharedStateRegistry create(Executor deleteExecutor) { + return new SharedStateRegistry(deleteExecutor); + } + }; + /** All registered state objects by an artificial key */ private final Map registeredStates; + /** This flag indicates whether or not the registry is open or if close() was called */ + private boolean open; + /** Executor for async state deletion */ private final Executor asyncDisposalExecutor; @@ -56,6 +67,7 @@ public SharedStateRegistry() { public SharedStateRegistry(Executor asyncDisposalExecutor) { this.registeredStates = new HashMap<>(); this.asyncDisposalExecutor = Preconditions.checkNotNull(asyncDisposalExecutor); + this.open = true; } /** @@ -82,6 +94,9 @@ public Result registerReference(SharedStateRegistryKey registrationKey, StreamSt SharedStateRegistry.SharedStateEntry entry; synchronized (registeredStates) { + + Preconditions.checkState(open, "Attempt to register state to closed SharedStateRegistry."); + entry = registeredStates.get(registrationKey); if (entry == null) { @@ -96,6 +111,11 @@ public Result registerReference(SharedStateRegistryKey registrationKey, StreamSt // delete if this is a real duplicate if (!Objects.equals(state, entry.stateHandle)) { scheduledStateDeletion = state; + LOG.trace("Identified duplicate state registration under key {}. New state {} was determined to " + + "be an unnecessary copy of existing state {} and will be dropped.", + registrationKey, + state, + entry.stateHandle); } entry.increaseReferenceCount(); } @@ -112,7 +132,8 @@ public Result registerReference(SharedStateRegistryKey registrationKey, StreamSt * * @param registrationKey the shared state for which we release a reference. * @return the result of the request, consisting of the reference count after this operation - * and the state handle, or null if the state handle was deleted through this request. + * and the state handle, or null if the state handle was deleted through this request. Returns null if the registry + * was previously closed. */ public Result unregisterReference(SharedStateRegistryKey registrationKey) { @@ -123,6 +144,7 @@ public Result unregisterReference(SharedStateRegistryKey registrationKey) { SharedStateRegistry.SharedStateEntry entry; synchronized (registeredStates) { + entry = registeredStates.get(registrationKey); Preconditions.checkState(entry != null, @@ -164,10 +186,18 @@ public void registerAll(Iterable stateHandles) { } } + @Override + public String toString() { + synchronized (registeredStates) { + return "SharedStateRegistry{" + + "registeredStates=" + registeredStates + + '}'; + } + } + private void scheduleAsyncDelete(StreamStateHandle streamStateHandle) { // We do the small optimization to not issue discards for placeholders, which are NOPs. if (streamStateHandle != null && !isPlaceholder(streamStateHandle)) { - LOG.trace("Scheduled delete of state handle {}.", streamStateHandle); asyncDisposalExecutor.execute( new SharedStateRegistry.AsyncDisposalRunnable(streamStateHandle)); @@ -178,6 +208,13 @@ private boolean isPlaceholder(StreamStateHandle stateHandle) { return stateHandle instanceof PlaceholderStreamStateHandle; } + @Override + public void close() { + synchronized (registeredStates) { + open = false; + } + } + /** * An entry in the registry, tracking the handle and the corresponding reference count. */ @@ -279,13 +316,4 @@ public void run() { } } } - - /** - * Clears the registry. - */ - public void clear() { - synchronized (registeredStates) { - registeredStates.clear(); - } - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java similarity index 59% rename from flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java index a6055a83ac64f..05c98258934ea 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java @@ -16,23 +16,20 @@ * limitations under the License. */ -package org.apache.flink.migration; +package org.apache.flink.runtime.state; -import org.apache.flink.migration.state.MigrationKeyGroupStateHandle; -import org.apache.flink.runtime.state.KeyedStateHandle; - -import java.util.Collection; +import java.util.concurrent.Executor; /** - * Utility functions for migration. + * Simple factory to produce {@link SharedStateRegistry} objects. */ -public class MigrationUtil { - - @SuppressWarnings("deprecation") - public static boolean isOldSavepointKeyedState(Collection keyedStateHandles) { - return (keyedStateHandles != null) - && (keyedStateHandles.size() == 1) - && (keyedStateHandles.iterator().next() instanceof MigrationKeyGroupStateHandle); - } +public interface SharedStateRegistryFactory { + /** + * Factory method for {@link SharedStateRegistry}. + * + * @param deleteExecutor executor used to run (async) deletes. + * @return a SharedStateRegistry object + */ + SharedStateRegistry create(Executor deleteExecutor); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java index d82af7217a7cd..031d7c717284c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.state; -import org.apache.commons.io.IOUtils; import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.java.tuple.Tuple2; @@ -26,6 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.util.Preconditions; +import org.apache.commons.io.IOUtils; + import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -139,6 +140,7 @@ public void close() { } private static Collection transform(Collection keyedStateHandles) { + if (keyedStateHandles == null) { return null; } @@ -146,13 +148,14 @@ private static Collection transform(Collection keyGroupsStateHandles = new ArrayList<>(); for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { - if (! (keyedStateHandle instanceof KeyGroupsStateHandle)) { + + if (keyedStateHandle instanceof KeyGroupsStateHandle) { + keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); + } else if (keyedStateHandle != null) { throw new IllegalStateException("Unexpected state handle type, " + "expected: " + KeyGroupsStateHandle.class + ", but found: " + keyedStateHandle.getClass() + "."); } - - keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); } return keyGroupsStateHandles; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java index 6f231e42c5151..09d195add1fde 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java @@ -18,8 +18,8 @@ package org.apache.flink.runtime.state; -import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FutureUtil; +import org.apache.flink.util.LambdaUtil; import java.util.concurrent.RunnableFuture; @@ -49,27 +49,8 @@ public static long getStateSize(StateObject handle) { * @throws Exception exception that is a collection of all suppressed exceptions that were caught during iteration */ public static void bestEffortDiscardAllStateObjects( - Iterable handlesToDiscard) throws Exception { - - if (handlesToDiscard != null) { - Exception exception = null; - - for (StateObject state : handlesToDiscard) { - - if (state != null) { - try { - state.discardState(); - } - catch (Exception ex) { - exception = ExceptionUtils.firstOrSuppressed(ex, exception); - } - } - } - - if (exception != null) { - throw exception; - } - } + Iterable handlesToDiscard) throws Exception { + LambdaUtil.applyToAllWhileSuppressingExceptions(handlesToDiscard, StateObject::discardState); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java deleted file mode 100644 index 2fde5485049f9..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.apache.flink.runtime.checkpoint.SubtaskState; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -/** - * This class encapsulates all state handles for a task. - */ -public class TaskStateHandles implements Serializable { - - public static final TaskStateHandles EMPTY = new TaskStateHandles(); - - private static final long serialVersionUID = 267686583583579359L; - - /** - * State handle with the (non-partitionable) legacy operator state - * - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - private final ChainedStateHandle legacyOperatorState; - - /** Collection of handles which represent the managed keyed state of the head operator */ - private final Collection managedKeyedState; - - /** Collection of handles which represent the raw/streamed keyed state of the head operator */ - private final Collection rawKeyedState; - - /** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */ - private final List> managedOperatorState; - - /** Outer list represents the operator chain, each collection holds handles for raw/streamed state of a single operator */ - private final List> rawOperatorState; - - public TaskStateHandles() { - this(null, null, null, null, null); - } - - public TaskStateHandles(SubtaskState checkpointStateHandles) { - this(checkpointStateHandles.getLegacyOperatorState(), - transform(checkpointStateHandles.getManagedOperatorState()), - transform(checkpointStateHandles.getRawOperatorState()), - transform(checkpointStateHandles.getManagedKeyedState()), - transform(checkpointStateHandles.getRawKeyedState())); - } - - public TaskStateHandles( - ChainedStateHandle legacyOperatorState, - List> managedOperatorState, - List> rawOperatorState, - Collection managedKeyedState, - Collection rawKeyedState) { - - this.legacyOperatorState = legacyOperatorState; - this.managedKeyedState = managedKeyedState; - this.rawKeyedState = rawKeyedState; - this.managedOperatorState = managedOperatorState; - this.rawOperatorState = rawOperatorState; - } - - /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - public ChainedStateHandle getLegacyOperatorState() { - return legacyOperatorState; - } - - public Collection getManagedKeyedState() { - return managedKeyedState; - } - - public Collection getRawKeyedState() { - return rawKeyedState; - } - - public List> getRawOperatorState() { - return rawOperatorState; - } - - public List> getManagedOperatorState() { - return managedOperatorState; - } - - private static List> transform(ChainedStateHandle in) { - if (null == in) { - return Collections.emptyList(); - } - List> out = new ArrayList<>(in.getLength()); - for (int i = 0; i < in.getLength(); ++i) { - OperatorStateHandle osh = in.get(i); - out.add(osh != null ? Collections.singletonList(osh) : null); - } - return out; - } - - private static List transform(T in) { - return in == null ? Collections.emptyList() : Collections.singletonList(in); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - TaskStateHandles that = (TaskStateHandles) o; - - if (legacyOperatorState != null ? - !legacyOperatorState.equals(that.legacyOperatorState) - : that.legacyOperatorState != null) { - return false; - } - if (managedKeyedState != null ? - !managedKeyedState.equals(that.managedKeyedState) - : that.managedKeyedState != null) { - return false; - } - if (rawKeyedState != null ? - !rawKeyedState.equals(that.rawKeyedState) - : that.rawKeyedState != null) { - return false; - } - - if (rawOperatorState != null ? - !rawOperatorState.equals(that.rawOperatorState) - : that.rawOperatorState != null) { - return false; - } - return managedOperatorState != null ? - managedOperatorState.equals(that.managedOperatorState) - : that.managedOperatorState == null; - } - - @Override - public int hashCode() { - int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; - result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); - result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); - result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); - result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); - return result; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java index 2800899280a21..8b58891e36ba6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.migration.MigrationNamespaceSerializerProxy; import java.io.IOException; @@ -90,11 +89,4 @@ public void copy(DataInputView source, DataOutputView target) throws IOException public boolean canEqual(Object obj) { return obj instanceof VoidNamespaceSerializer; } - - @Override - protected boolean isCompatibleSerializationFormatIdentifier(String identifier) { - // we might be replacing a migration namespace serializer, in which case we just assume compatibility - return super.isCompatibleSerializationFormatIdentifier(identifier) - || identifier.equals(MigrationNamespaceSerializerProxy.class.getCanonicalName()); - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index d1c0466e4c665..e235b96969355 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -35,11 +35,6 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.migration.MigrationNamespaceSerializerProxy; -import org.apache.flink.migration.MigrationUtil; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot; -import org.apache.flink.migration.state.MigrationKeyGroupStateHandle; import org.apache.flink.runtime.checkpoint.AbstractAsyncSnapshotIOCallable; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback; @@ -65,7 +60,6 @@ import org.apache.flink.runtime.state.internal.InternalMapState; import org.apache.flink.runtime.state.internal.InternalReducingState; import org.apache.flink.runtime.state.internal.InternalValueState; -import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.StateMigrationException; @@ -190,7 +184,7 @@ private StateTable tryRegisterStateTable( // check compatibility results to determine if state migration is required CompatibilityResult namespaceCompatibility = CompatibilityUtil.resolveCompatibilityResult( restoredMetaInfo.getNamespaceSerializer(), - MigrationNamespaceSerializerProxy.class, + null, restoredMetaInfo.getNamespaceSerializerConfigSnapshot(), newMetaInfo.getNamespaceSerializer()); @@ -405,11 +399,7 @@ public void restore(Collection restoredState) throws Exception LOG.debug("Restoring snapshot from state handles: {}.", restoredState); } - if (MigrationUtil.isOldSavepointKeyedState(restoredState)) { - restoreOldSavepointKeyedState(restoredState); - } else { - restorePartitionedState(restoredState); - } + restorePartitionedState(restoredState); } @SuppressWarnings({"unchecked"}) @@ -559,55 +549,6 @@ public String toString() { return "HeapKeyedStateBackend"; } - /** - * @deprecated Used for backwards compatibility with previous savepoint versions. - */ - @SuppressWarnings({"unchecked", "rawtypes", "DeprecatedIsStillUsed"}) - @Deprecated - private void restoreOldSavepointKeyedState( - Collection stateHandles) throws IOException, ClassNotFoundException { - - if (stateHandles.isEmpty()) { - return; - } - - Preconditions.checkState(1 == stateHandles.size(), "Only one element expected here."); - - KeyedStateHandle keyedStateHandle = stateHandles.iterator().next(); - if (!(keyedStateHandle instanceof MigrationKeyGroupStateHandle)) { - throw new IllegalStateException("Unexpected state handle type, " + - "expected: " + MigrationKeyGroupStateHandle.class + - ", but found " + keyedStateHandle.getClass()); - } - - MigrationKeyGroupStateHandle keyGroupStateHandle = (MigrationKeyGroupStateHandle) keyedStateHandle; - - HashMap> namedStates; - try (FSDataInputStream inputStream = keyGroupStateHandle.openInputStream()) { - namedStates = InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader); - } - - for (Map.Entry> nameToState : namedStates.entrySet()) { - - final String stateName = nameToState.getKey(); - final KvStateSnapshot genericSnapshot = nameToState.getValue(); - - if (genericSnapshot instanceof MigrationRestoreSnapshot) { - MigrationRestoreSnapshot stateSnapshot = (MigrationRestoreSnapshot) genericSnapshot; - final StateTable rawResultMap = - stateSnapshot.deserialize(stateName, this); - - // mimic a restored kv state meta info - restoredKvStateMetaInfos.put(stateName, rawResultMap.getMetaInfo().snapshot()); - - // add named state to the backend - stateTables.put(stateName, rawResultMap); - } else { - throw new IllegalStateException("Unknown state: " + genericSnapshot); - } - } - } - /** * Returns the total number of state entries across all keys/namespaces. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java index 9ba9d35ff8939..3a43d4ffb979b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java @@ -95,6 +95,7 @@ public int hashCode() { public String toString() { return "ByteStreamStateHandle{" + "handleName='" + handleName + '\'' + + ", dataBytes=" + data.length + '}'; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderListener.java index f02a8c23694e7..65012a090af1d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderListener.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderListener.java @@ -21,8 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; - -import java.util.UUID; +import org.apache.flink.runtime.jobmaster.JobMasterId; /** * Listener for the {@link JobLeaderService}. The listener is notified whenever a job manager @@ -38,18 +37,17 @@ public interface JobLeaderListener { * * @param jobId identifying the job for which the job manager has gained leadership * @param jobManagerGateway to the job leader - * @param jobLeaderId new leader id of the job leader * @param registrationMessage containing further registration information */ - void jobManagerGainedLeadership(JobID jobId, JobMasterGateway jobManagerGateway, UUID jobLeaderId, JMTMRegistrationSuccess registrationMessage); + void jobManagerGainedLeadership(JobID jobId, JobMasterGateway jobManagerGateway, JMTMRegistrationSuccess registrationMessage); /** * Callback if the job leader for the job with the given job id lost its leadership. * * @param jobId identifying the job whose leader has lost leadership - * @param jobLeaderId old leader id + * @param jobMasterId old JobMasterId */ - void jobManagerLostLeadership(JobID jobId, UUID jobLeaderId); + void jobManagerLostLeadership(JobID jobId, JobMasterId jobMasterId); /** * Callback for errors which might occur in the {@link JobLeaderService}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderService.java index 2ebf3c1f2dc24..20dcfa9321162 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobLeaderService.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.registration.RegisteredRpcConnection; @@ -32,11 +33,13 @@ import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -53,25 +56,25 @@ public class JobLeaderService { private static final Logger LOG = LoggerFactory.getLogger(JobLeaderService.class); - /** Self's location, used for the job manager connection */ + /** Self's location, used for the job manager connection. */ private final TaskManagerLocation ownLocation; - /** The leader retrieval service and listener for each registered job */ + /** The leader retrieval service and listener for each registered job. */ private final Map> jobLeaderServices; - /** Internal state of the service */ + /** Internal state of the service. */ private volatile JobLeaderService.State state; - /** Address of the owner of this service. This address is used for the job manager connection */ + /** Address of the owner of this service. This address is used for the job manager connection. */ private String ownerAddress; - /** Rpc service to use for establishing connections */ + /** Rpc service to use for establishing connections. */ private RpcService rpcService; - /** High availability services to create the leader retrieval services from */ + /** High availability services to create the leader retrieval services from. */ private HighAvailabilityServices highAvailabilityServices; - /** Job leader listener listening for job leader changes */ + /** Job leader listener listening for job leader changes. */ private JobLeaderListener jobLeaderListener; public JobLeaderService(TaskManagerLocation location) { @@ -206,24 +209,24 @@ public void addJob(final JobID jobId, final String defaultTargetAddress) throws */ private final class JobManagerLeaderListener implements LeaderRetrievalListener { - /** Job id identifying the job to look for a leader */ + /** Job id identifying the job to look for a leader. */ private final JobID jobId; - /** Rpc connection to the job leader */ - private RegisteredRpcConnection rpcConnection; + /** Rpc connection to the job leader. */ + private RegisteredRpcConnection rpcConnection; - /** State of the listener */ + /** State of the listener. */ private volatile boolean stopped; - /** Leader id of the current job leader */ - private volatile UUID currentLeaderId; + /** Leader id of the current job leader. */ + private volatile JobMasterId currentJobMasterId; private JobManagerLeaderListener(JobID jobId) { this.jobId = Preconditions.checkNotNull(jobId); stopped = false; rpcConnection = null; - currentLeaderId = null; + currentJobMasterId = null; } public void stop() { @@ -240,8 +243,10 @@ public void notifyLeaderAddress(final String leaderAddress, final UUID leaderId) LOG.debug("{}'s leader retrieval listener reported a new leader for job {}. " + "However, the service is no longer running.", JobLeaderService.class.getSimpleName(), jobId); } else { + final JobMasterId jobMasterId = leaderId != null ? new JobMasterId(leaderId) : null; + LOG.debug("New leader information for job {}. Address: {}, leader id: {}.", - jobId, leaderAddress, leaderId); + jobId, leaderAddress, jobMasterId); if (leaderAddress == null || leaderAddress.isEmpty()) { // the leader lost leadership but there is no other leader yet. @@ -249,28 +254,28 @@ public void notifyLeaderAddress(final String leaderAddress, final UUID leaderId) rpcConnection.close(); } - jobLeaderListener.jobManagerLostLeadership(jobId, currentLeaderId); + jobLeaderListener.jobManagerLostLeadership(jobId, currentJobMasterId); - currentLeaderId = leaderId; + currentJobMasterId = jobMasterId; } else { - currentLeaderId = leaderId; + currentJobMasterId = jobMasterId; if (rpcConnection != null) { // check if we are already trying to connect to this leader - if (!leaderId.equals(rpcConnection.getTargetLeaderId())) { + if (!Objects.equals(jobMasterId, rpcConnection.getTargetLeaderId())) { rpcConnection.close(); rpcConnection = new JobManagerRegisteredRpcConnection( LOG, leaderAddress, - leaderId, + jobMasterId, rpcService.getExecutor()); } } else { rpcConnection = new JobManagerRegisteredRpcConnection( LOG, leaderAddress, - leaderId, + jobMasterId, rpcService.getExecutor()); } @@ -299,18 +304,18 @@ public void handleError(Exception exception) { /** * Rpc connection for the job manager <--> task manager connection. */ - private final class JobManagerRegisteredRpcConnection extends RegisteredRpcConnection { + private final class JobManagerRegisteredRpcConnection extends RegisteredRpcConnection { JobManagerRegisteredRpcConnection( Logger log, String targetAddress, - UUID targetLeaderId, + JobMasterId jobMasterId, Executor executor) { - super(log, targetAddress, targetLeaderId, executor); + super(log, targetAddress, jobMasterId, executor); } @Override - protected RetryingRegistration generateRegistration() { + protected RetryingRegistration generateRegistration() { return new JobLeaderService.JobManagerRetryingRegistration( LOG, rpcService, @@ -325,10 +330,10 @@ protected RetryingRegistration genera @Override protected void onRegistrationSuccess(JMTMRegistrationSuccess success) { // filter out old registration attempts - if (getTargetLeaderId().equals(currentLeaderId)) { + if (Objects.equals(getTargetLeaderId(), currentJobMasterId)) { log.info("Successful registration at job manager {} for job {}.", getTargetAddress(), jobId); - jobLeaderListener.jobManagerGainedLeadership(jobId, getTargetGateway(), getTargetLeaderId(), success); + jobLeaderListener.jobManagerGainedLeadership(jobId, getTargetGateway(), success); } else { log.debug("Encountered obsolete JobManager registration success from {} with leader session ID {}.", getTargetAddress(), getTargetLeaderId()); } @@ -337,8 +342,8 @@ protected void onRegistrationSuccess(JMTMRegistrationSuccess success) { @Override protected void onRegistrationFailure(Throwable failure) { // filter out old registration attempts - if (getTargetLeaderId().equals(currentLeaderId)) { - log.info("Failed to register at job manager {} for job {}.", getTargetAddress(), jobId); + if (Objects.equals(getTargetLeaderId(), currentJobMasterId)) { + log.info("Failed to register at job manager {} for job {}.", getTargetAddress(), jobId); jobLeaderListener.handleError(failure); } else { log.debug("Obsolete JobManager registration failure from {} with leader session ID {}.", getTargetAddress(), getTargetLeaderId(), failure); @@ -351,7 +356,7 @@ protected void onRegistrationFailure(Throwable failure) { * Retrying registration for the job manager <--> task manager connection. */ private static final class JobManagerRetryingRegistration - extends RetryingRegistration + extends RetryingRegistration { private final String taskManagerRpcAddress; @@ -364,11 +369,10 @@ private static final class JobManagerRetryingRegistration String targetName, Class targetType, String targetAddress, - UUID leaderId, + JobMasterId jobMasterId, String taskManagerRpcAddress, - TaskManagerLocation taskManagerLocation) - { - super(log, rpcService, targetName, targetType, targetAddress, leaderId); + TaskManagerLocation taskManagerLocation) { + super(log, rpcService, targetName, targetType, targetAddress, jobMasterId); this.taskManagerRpcAddress = taskManagerRpcAddress; this.taskManagerLocation = Preconditions.checkNotNull(taskManagerLocation); @@ -376,15 +380,15 @@ private static final class JobManagerRetryingRegistration @Override protected CompletableFuture invokeRegistration( - JobMasterGateway gateway, UUID leaderId, long timeoutMillis) throws Exception - { - return gateway.registerTaskManager(taskManagerRpcAddress, taskManagerLocation, - leaderId, Time.milliseconds(timeoutMillis)); + JobMasterGateway gateway, + JobMasterId jobMasterId, + long timeoutMillis) throws Exception { + return gateway.registerTaskManager(taskManagerRpcAddress, taskManagerLocation, Time.milliseconds(timeoutMillis)); } } /** - * Internal state of the service + * Internal state of the service. */ private enum State { CREATED, STARTED, STOPPED diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java index 98c7bf11d8f22..2c05388a673be 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java @@ -19,17 +19,17 @@ package org.apache.flink.runtime.taskexecutor; import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.TaskManagerActions; import org.apache.flink.util.Preconditions; -import java.util.UUID; - /** * Container class for JobManager specific communication utils used by the {@link TaskExecutor}. */ @@ -41,9 +41,6 @@ public class JobManagerConnection { // The unique id used for identifying the job manager private final ResourceID resourceID; - // Job master leader session id - private final UUID leaderId; - // Gateway to the job master private final JobMasterGateway jobMasterGateway; @@ -53,6 +50,9 @@ public class JobManagerConnection { // Checkpoint responder for the specific job manager private final CheckpointResponder checkpointResponder; + // BLOB cache connected to the BLOB server at the specific job manager + private final BlobCache blobCache; + // Library cache manager connected to the specific job manager private final LibraryCacheManager libraryCacheManager; @@ -63,21 +63,20 @@ public class JobManagerConnection { private final PartitionProducerStateChecker partitionStateChecker; public JobManagerConnection( - JobID jobID, - ResourceID resourceID, - JobMasterGateway jobMasterGateway, - UUID leaderId, - TaskManagerActions taskManagerActions, - CheckpointResponder checkpointResponder, - LibraryCacheManager libraryCacheManager, - ResultPartitionConsumableNotifier resultPartitionConsumableNotifier, - PartitionProducerStateChecker partitionStateChecker) { + JobID jobID, + ResourceID resourceID, + JobMasterGateway jobMasterGateway, + TaskManagerActions taskManagerActions, + CheckpointResponder checkpointResponder, + BlobCache blobCache, LibraryCacheManager libraryCacheManager, + ResultPartitionConsumableNotifier resultPartitionConsumableNotifier, + PartitionProducerStateChecker partitionStateChecker) { this.jobID = Preconditions.checkNotNull(jobID); this.resourceID = Preconditions.checkNotNull(resourceID); - this.leaderId = Preconditions.checkNotNull(leaderId); this.jobMasterGateway = Preconditions.checkNotNull(jobMasterGateway); this.taskManagerActions = Preconditions.checkNotNull(taskManagerActions); this.checkpointResponder = Preconditions.checkNotNull(checkpointResponder); + this.blobCache = Preconditions.checkNotNull(blobCache); this.libraryCacheManager = Preconditions.checkNotNull(libraryCacheManager); this.resultPartitionConsumableNotifier = Preconditions.checkNotNull(resultPartitionConsumableNotifier); this.partitionStateChecker = Preconditions.checkNotNull(partitionStateChecker); @@ -91,8 +90,8 @@ public ResourceID getResourceID() { return resourceID; } - public UUID getLeaderId() { - return leaderId; + public JobMasterId getJobMasterId() { + return jobMasterGateway.getFencingToken(); } public JobMasterGateway getJobManagerGateway() { @@ -111,6 +110,15 @@ public LibraryCacheManager getLibraryCacheManager() { return libraryCacheManager; } + /** + * Gets the BLOB cache connected to the respective BLOB server instance at the job manager. + * + * @return BLOB cache + */ + public BlobCache getBlobCache() { + return blobCache; + } + public ResultPartitionConsumableNotifier getResultPartitionConsumableNotifier() { return resultPartitionConsumableNotifier; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java index 4abcdf4365d3b..b6a0637b0062c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java @@ -51,14 +51,16 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.registration.RegistrationConnectionListener; -import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; import org.apache.flink.runtime.metrics.groups.TaskManagerMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.registration.RegistrationConnectionListener; +import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; @@ -85,6 +87,7 @@ import org.apache.flink.runtime.taskmanager.TaskManagerActions; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkException; import org.apache.flink.util.Preconditions; import java.io.IOException; @@ -231,7 +234,7 @@ public void start() throws Exception { try { haServices.getResourceManagerLeaderRetriever().start(new ResourceManagerLeaderListener()); } catch (Exception e) { - onFatalErrorAsync(e); + onFatalError(e); } // tell the task slot table who's responsible for the task slot actions @@ -248,7 +251,7 @@ public void start() throws Exception { public void postStop() throws Exception { log.info("Stopping TaskManager {}.", getAddress()); - Exception exception = null; + Throwable throwable = null; taskSlotTable.stop(); @@ -256,6 +259,14 @@ public void postStop() throws Exception { resourceManagerConnection.close(); } + for (JobManagerConnection jobManagerConnection : jobManagerConnections.values()) { + try { + disassociateFromJobManager(jobManagerConnection, new FlinkException("The TaskExecutor is shutting down.")); + } catch (Throwable t) { + throwable = ExceptionUtils.firstOrSuppressed(t, throwable); + } + } + jobManagerHeartbeatManager.stop(); resourceManagerHeartbeatManager.stop(); @@ -270,12 +281,12 @@ public void postStop() throws Exception { try { super.postStop(); - } catch (Exception e) { - exception = ExceptionUtils.firstOrSuppressed(e, exception); + } catch (Throwable e) { + throwable = ExceptionUtils.firstOrSuppressed(e, throwable); } - if (exception != null) { - ExceptionUtils.rethrowException(exception, "Error while shutting the TaskExecutor down."); + if (throwable != null) { + ExceptionUtils.rethrowException(throwable, "Error while shutting the TaskExecutor down."); } log.info("Stopped TaskManager {}.", getAddress()); @@ -292,7 +303,7 @@ public void postStop() throws Exception { @Override public CompletableFuture submitTask( TaskDeploymentDescriptor tdd, - UUID jobManagerLeaderId, + JobMasterId jobMasterId, Time timeout) { try { @@ -317,10 +328,10 @@ public CompletableFuture submitTask( throw new TaskSubmissionException(message); } - if (!Objects.equals(jobManagerConnection.getLeaderId(), jobManagerLeaderId)) { + if (!Objects.equals(jobManagerConnection.getJobMasterId(), jobMasterId)) { final String message = "Rejecting the task submission because the job manager leader id " + - jobManagerLeaderId + " does not match the expected job manager leader id " + - jobManagerConnection.getLeaderId() + '.'; + jobMasterId + " does not match the expected job manager leader id " + + jobManagerConnection.getJobMasterId() + '.'; log.debug(message); throw new TaskSubmissionException(message); @@ -343,15 +354,14 @@ public CompletableFuture submitTask( tdd.getAttemptNumber()); InputSplitProvider inputSplitProvider = new RpcInputSplitProvider( - jobManagerConnection.getLeaderId(), jobManagerConnection.getJobManagerGateway(), - jobInformation.getJobId(), taskInformation.getJobVertexId(), tdd.getExecutionAttemptId(), taskManagerConfiguration.getTimeout()); TaskManagerActions taskManagerActions = jobManagerConnection.getTaskManagerActions(); CheckpointResponder checkpointResponder = jobManagerConnection.getCheckpointResponder(); + BlobCache blobCache = jobManagerConnection.getBlobCache(); LibraryCacheManager libraryCache = jobManagerConnection.getLibraryCacheManager(); ResultPartitionConsumableNotifier resultPartitionConsumableNotifier = jobManagerConnection.getResultPartitionConsumableNotifier(); PartitionProducerStateChecker partitionStateChecker = jobManagerConnection.getPartitionStateChecker(); @@ -374,6 +384,7 @@ public CompletableFuture submitTask( taskManagerActions, inputSplitProvider, checkpointResponder, + blobCache, libraryCache, fileCache, taskManagerConfiguration, @@ -574,29 +585,18 @@ public CompletableFuture confirmCheckpoint( // Slot allocation RPCs // ---------------------------------------------------------------------- - /** - * Requests a slot from the TaskManager - * - * @param slotId identifying the requested slot - * @param jobId identifying the job for which the request is issued - * @param allocationId id for the request - * @param targetAddress of the job manager requesting the slot - * @param rmLeaderId current leader id of the ResourceManager - * @throws SlotAllocationException if the slot allocation fails - * @return answer to the slot request - */ @Override public CompletableFuture requestSlot( final SlotID slotId, final JobID jobId, final AllocationID allocationId, final String targetAddress, - final UUID rmLeaderId, + final ResourceManagerId resourceManagerId, final Time timeout) { // TODO: Filter invalid requests from the resource manager by using the instance/registration Id log.info("Receive slot request {} for job {} from resource manager with leader id {}.", - allocationId, jobId, rmLeaderId); + allocationId, jobId, resourceManagerId); try { if (resourceManagerConnection == null) { @@ -605,8 +605,8 @@ public CompletableFuture requestSlot( throw new SlotAllocationException(message); } - if (!resourceManagerConnection.getTargetLeaderId().equals(rmLeaderId)) { - final String message = "The leader id " + rmLeaderId + + if (!Objects.equals(resourceManagerConnection.getTargetLeaderId(), resourceManagerId)) { + final String message = "The leader id " + resourceManagerId + " does not match with the leader id of the connected resource manager " + resourceManagerConnection.getTargetLeaderId() + '.'; @@ -681,7 +681,7 @@ public void disconnectResourceManager(Exception cause) { // Internal resource manager connection methods // ------------------------------------------------------------------------ - private void notifyOfNewResourceManagerLeader(String newLeaderAddress, UUID newLeaderId) { + private void notifyOfNewResourceManagerLeader(String newLeaderAddress, ResourceManagerId newResourceManagerId) { if (resourceManagerConnection != null) { if (newLeaderAddress != null) { // the resource manager switched to a new leader @@ -712,7 +712,7 @@ private void notifyOfNewResourceManagerLeader(String newLeaderAddress, UUID newL getResourceID(), taskSlotTable.createSlotReport(getResourceID()), newLeaderAddress, - newLeaderId, + newResourceManagerId, getMainThreadExecutor(), new ResourceManagerRegistrationListener()); resourceManagerConnection.start(); @@ -765,7 +765,7 @@ private void offerSlotsToJobManager(final JobID jobId) { final JobMasterGateway jobMasterGateway = jobManagerConnection.getJobManagerGateway(); final Iterator reservedSlotsIterator = taskSlotTable.getAllocatedSlots(jobId); - final UUID leaderId = jobManagerConnection.getLeaderId(); + final JobMasterId jobMasterId = jobManagerConnection.getJobMasterId(); final Collection reservedSlots = new HashSet<>(2); @@ -776,13 +776,11 @@ private void offerSlotsToJobManager(final JobID jobId) { // the slot is either free or releasing at the moment final String message = "Could not mark slot " + jobId + " active."; log.debug(message); - jobMasterGateway.failSlot(getResourceID(), offer.getAllocationId(), - leaderId, new Exception(message)); + jobMasterGateway.failSlot(getResourceID(), offer.getAllocationId(), new Exception(message)); } } catch (SlotNotFoundException e) { final String message = "Could not mark slot " + jobId + " active."; - jobMasterGateway.failSlot(getResourceID(), offer.getAllocationId(), - leaderId, new Exception(message)); + jobMasterGateway.failSlot(getResourceID(), offer.getAllocationId(), new Exception(message)); continue; } reservedSlots.add(offer); @@ -791,7 +789,6 @@ private void offerSlotsToJobManager(final JobID jobId) { CompletableFuture> acceptedSlotsFuture = jobMasterGateway.offerSlots( getResourceID(), reservedSlots, - leaderId, taskManagerConfiguration.getTimeout()); acceptedSlotsFuture.whenCompleteAsync( @@ -812,7 +809,7 @@ private void offerSlotsToJobManager(final JobID jobId) { } } else { // check if the response is still valid - if (isJobManagerConnectionValid(jobId, leaderId)) { + if (isJobManagerConnectionValid(jobId, jobMasterId)) { // mark accepted slots active for (SlotOffer acceptedSlot : acceptedSlots) { reservedSlots.remove(acceptedSlot); @@ -838,22 +835,27 @@ private void offerSlotsToJobManager(final JobID jobId) { } } - private void establishJobManagerConnection(JobID jobId, final JobMasterGateway jobMasterGateway, UUID jobManagerLeaderId, JMTMRegistrationSuccess registrationSuccess) { - log.info("Establish JobManager connection for job {}.", jobId); + private void establishJobManagerConnection(JobID jobId, final JobMasterGateway jobMasterGateway, JMTMRegistrationSuccess registrationSuccess) { if (jobManagerTable.contains(jobId)) { JobManagerConnection oldJobManagerConnection = jobManagerTable.get(jobId); - if (!oldJobManagerConnection.getLeaderId().equals(jobManagerLeaderId)) { + + if (Objects.equals(oldJobManagerConnection.getJobMasterId(), jobMasterGateway.getFencingToken())) { + // we already are connected to the given job manager + log.debug("Ignore JobManager gained leadership message for {} because we are already connected to it.", jobMasterGateway.getFencingToken()); + return; + } else { closeJobManagerConnection(jobId, new Exception("Found new job leader for job id " + jobId + '.')); } } + log.info("Establish JobManager connection for job {}.", jobId); + ResourceID jobManagerResourceID = registrationSuccess.getResourceID(); JobManagerConnection newJobManagerConnection = associateWithJobManager( jobId, jobManagerResourceID, jobMasterGateway, - jobManagerLeaderId, registrationSuccess.getBlobPort()); jobManagerConnections.put(jobManagerResourceID, newJobManagerConnection); jobManagerTable.put(jobId, newJobManagerConnection); @@ -920,29 +922,26 @@ private JobManagerConnection associateWithJobManager( JobID jobID, ResourceID resourceID, JobMasterGateway jobMasterGateway, - UUID jobManagerLeaderId, int blobPort) { Preconditions.checkNotNull(jobID); Preconditions.checkNotNull(resourceID); - Preconditions.checkNotNull(jobManagerLeaderId); Preconditions.checkNotNull(jobMasterGateway); Preconditions.checkArgument(blobPort > 0 || blobPort < MAX_BLOB_PORT, "Blob server port is out of range."); - TaskManagerActions taskManagerActions = new TaskManagerActionsImpl(jobManagerLeaderId, jobMasterGateway); + TaskManagerActions taskManagerActions = new TaskManagerActionsImpl(jobMasterGateway); CheckpointResponder checkpointResponder = new RpcCheckpointResponder(jobMasterGateway); InetSocketAddress blobServerAddress = new InetSocketAddress(jobMasterGateway.getHostname(), blobPort); final LibraryCacheManager libraryCacheManager; + final BlobCache blobCache; try { - final BlobCache blobCache = new BlobCache( + blobCache = new BlobCache( blobServerAddress, taskManagerConfiguration.getConfiguration(), haServices.createBlobStore()); - libraryCacheManager = new BlobLibraryCacheManager( - blobCache, - taskManagerConfiguration.getCleanupInterval()); + libraryCacheManager = new BlobLibraryCacheManager(blobCache); } catch (IOException e) { // Can't pass the IOException up - we need a RuntimeException anyway // two levels up where this is run asynchronously. Also, we don't @@ -953,20 +952,19 @@ private JobManagerConnection associateWithJobManager( } ResultPartitionConsumableNotifier resultPartitionConsumableNotifier = new RpcResultPartitionConsumableNotifier( - jobManagerLeaderId, jobMasterGateway, getRpcService().getExecutor(), taskManagerConfiguration.getTimeout()); - PartitionProducerStateChecker partitionStateChecker = new RpcPartitionStateChecker(jobManagerLeaderId, jobMasterGateway); + PartitionProducerStateChecker partitionStateChecker = new RpcPartitionStateChecker(jobMasterGateway); return new JobManagerConnection( jobID, resourceID, jobMasterGateway, - jobManagerLeaderId, taskManagerActions, checkpointResponder, + blobCache, libraryCacheManager, resultPartitionConsumableNotifier, partitionStateChecker); @@ -977,6 +975,7 @@ private void disassociateFromJobManager(JobManagerConnection jobManagerConnectio JobMasterGateway jobManagerGateway = jobManagerConnection.getJobManagerGateway(); jobManagerGateway.disconnectTaskManager(getResourceID(), cause); jobManagerConnection.getLibraryCacheManager().shutdown(); + jobManagerConnection.getBlobCache().close(); } // ------------------------------------------------------------------------ @@ -998,13 +997,12 @@ private void failTask(final ExecutionAttemptID executionAttemptID, final Throwab } private void updateTaskExecutionState( - final UUID jobMasterLeaderId, final JobMasterGateway jobMasterGateway, final TaskExecutionState taskExecutionState) { final ExecutionAttemptID executionAttemptID = taskExecutionState.getID(); - CompletableFuture futureAcknowledge = jobMasterGateway.updateTaskExecutionState(jobMasterLeaderId, taskExecutionState); + CompletableFuture futureAcknowledge = jobMasterGateway.updateTaskExecutionState(taskExecutionState); futureAcknowledge.whenCompleteAsync( (ack, throwable) -> { @@ -1016,7 +1014,6 @@ private void updateTaskExecutionState( } private void unregisterTaskAndNotifyFinalState( - final UUID jobMasterLeaderId, final JobMasterGateway jobMasterGateway, final ExecutionAttemptID executionAttemptID) { @@ -1036,7 +1033,6 @@ private void unregisterTaskAndNotifyFinalState( AccumulatorSnapshot accumulatorSnapshot = task.getAccumulatorRegistry().getSnapshot(); updateTaskExecutionState( - jobMasterLeaderId, jobMasterGateway, new TaskExecutionState( task.getJobID(), @@ -1061,7 +1057,6 @@ private void freeSlot(AllocationID allocationId, Throwable cause) { ResourceManagerGateway resourceManagerGateway = resourceManagerConnection.getTargetGateway(); resourceManagerGateway.notifySlotAvailable( - resourceManagerConnection.getTargetLeaderId(), resourceManagerConnection.getRegistrationId(), new SlotID(getResourceID(), freedSlotIndex), allocationId); @@ -1094,10 +1089,10 @@ private boolean isConnectedToResourceManager() { return (resourceManagerConnection != null && resourceManagerConnection.isConnected()); } - private boolean isJobManagerConnectionValid(JobID jobId, UUID leaderId) { + private boolean isJobManagerConnectionValid(JobID jobId, JobMasterId jobMasterId) { JobManagerConnection jmConnection = jobManagerTable.get(jobId); - return jmConnection != null && Objects.equals(jmConnection.getLeaderId(), leaderId); + return jmConnection != null && Objects.equals(jmConnection.getJobMasterId(), jobMasterId); } // ------------------------------------------------------------------------ @@ -1114,35 +1109,16 @@ public ResourceID getResourceID() { /** * Notifies the TaskExecutor that a fatal error has occurred and it cannot proceed. - * This method should be used when asynchronous threads want to notify the - * TaskExecutor of a fatal error. - * - * @param t The exception describing the fatal error - */ - void onFatalErrorAsync(final Throwable t) { - runAsync(new Runnable() { - @Override - public void run() { - onFatalError(t); - } - }); - } - - /** - * Notifies the TaskExecutor that a fatal error has occurred and it cannot proceed. - * This method must only be called from within the TaskExecutor's main thread. * * @param t The exception describing the fatal error */ void onFatalError(final Throwable t) { - log.error("Fatal error occurred.", t); - // this could potentially be a blocking call -> call asynchronously: - getRpcService().execute(new Runnable() { - @Override - public void run() { - fatalErrorHandler.onFatalError(t); - } - }); + try { + log.error("Fatal error occurred in TaskExecutor.", t); + } catch (Throwable ignored) {} + + // The fatal error handler implementation should make sure that this call is non-blocking + fatalErrorHandler.onFatalError(t); } // ------------------------------------------------------------------------ @@ -1164,23 +1140,21 @@ HeartbeatManager getResourceManagerHeartbeatManager() { // ------------------------------------------------------------------------ /** - * The listener for leader changes of the resource manager + * The listener for leader changes of the resource manager. */ private final class ResourceManagerLeaderListener implements LeaderRetrievalListener { @Override public void notifyLeaderAddress(final String leaderAddress, final UUID leaderSessionID) { - runAsync(new Runnable() { - @Override - public void run() { - notifyOfNewResourceManagerLeader(leaderAddress, leaderSessionID); - } - }); + runAsync( + () -> notifyOfNewResourceManagerLeader( + leaderAddress, + leaderSessionID != null ? new ResourceManagerId(leaderSessionID) : null)); } @Override public void handleError(Exception exception) { - onFatalErrorAsync(exception); + onFatalError(exception); } } @@ -1190,23 +1164,18 @@ private final class JobLeaderListenerImpl implements JobLeaderListener { public void jobManagerGainedLeadership( final JobID jobId, final JobMasterGateway jobManagerGateway, - final UUID jobLeaderId, final JMTMRegistrationSuccess registrationMessage) { - runAsync(new Runnable() { - @Override - public void run() { + runAsync( + () -> establishJobManagerConnection( jobId, jobManagerGateway, - jobLeaderId, - registrationMessage); - } - }); + registrationMessage)); } @Override - public void jobManagerLostLeadership(final JobID jobId, final UUID jobLeaderId) { - log.info("JobManager for job {} with leader id {} lost leadership.", jobId, jobLeaderId); + public void jobManagerLostLeadership(final JobID jobId, final JobMasterId jobMasterId) { + log.info("JobManager for job {} with leader id {} lost leadership.", jobId, jobMasterId); runAsync(new Runnable() { @Override @@ -1220,7 +1189,7 @@ public void run() { @Override public void handleError(Throwable throwable) { - onFatalErrorAsync(throwable); + onFatalError(throwable); } } @@ -1242,16 +1211,14 @@ public void run() { @Override public void onRegistrationFailure(Throwable failure) { - onFatalErrorAsync(failure); + onFatalError(failure); } } private final class TaskManagerActionsImpl implements TaskManagerActions { - private final UUID jobMasterLeaderId; private final JobMasterGateway jobMasterGateway; - private TaskManagerActionsImpl(UUID jobMasterLeaderId, JobMasterGateway jobMasterGateway) { - this.jobMasterLeaderId = Preconditions.checkNotNull(jobMasterLeaderId); + private TaskManagerActionsImpl(JobMasterGateway jobMasterGateway) { this.jobMasterGateway = Preconditions.checkNotNull(jobMasterGateway); } @@ -1260,14 +1227,18 @@ public void notifyFinalState(final ExecutionAttemptID executionAttemptID) { runAsync(new Runnable() { @Override public void run() { - unregisterTaskAndNotifyFinalState(jobMasterLeaderId, jobMasterGateway, executionAttemptID); + unregisterTaskAndNotifyFinalState(jobMasterGateway, executionAttemptID); } }); } @Override public void notifyFatalError(String message, Throwable cause) { - log.error(message, cause); + try { + log.error(message, cause); + } catch (Throwable ignored) {} + + // The fatal error handler implementation should make sure that this call is non-blocking fatalErrorHandler.onFatalError(cause); } @@ -1283,7 +1254,7 @@ public void run() { @Override public void updateTaskExecutionState(final TaskExecutionState taskExecutionState) { - TaskExecutor.this.updateTaskExecutionState(jobMasterLeaderId, jobMasterGateway, taskExecutionState); + TaskExecutor.this.updateTaskExecutionState(jobMasterGateway, taskExecutionState); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorGateway.java index 80841545f1059..ee0f69d234d4e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorGateway.java @@ -27,12 +27,13 @@ import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.PartitionInfo; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.RpcTimeout; import org.apache.flink.runtime.taskmanager.Task; -import java.util.UUID; import java.util.concurrent.CompletableFuture; /** @@ -44,8 +45,11 @@ public interface TaskExecutorGateway extends RpcGateway { * Requests a slot from the TaskManager * * @param slotId slot id for the request + * @param jobId for which to request a slot * @param allocationId id for the request - * @param resourceManagerLeaderId current leader id of the ResourceManager + * @param targetAddress to which to offer the requested slots + * @param resourceManagerId current leader id of the ResourceManager + * @param timeout for the operation * @return answer to the slot request */ CompletableFuture requestSlot( @@ -53,20 +57,20 @@ CompletableFuture requestSlot( JobID jobId, AllocationID allocationId, String targetAddress, - UUID resourceManagerLeaderId, + ResourceManagerId resourceManagerId, @RpcTimeout Time timeout); /** * Submit a {@link Task} to the {@link TaskExecutor}. * * @param tdd describing the task to submit - * @param leaderId of the job leader + * @param jobMasterId identifying the submitting JobMaster * @param timeout of the submit operation * @return Future acknowledge of the successful operation */ CompletableFuture submitTask( TaskDeploymentDescriptor tdd, - UUID leaderId, + JobMasterId jobMasterId, @RpcTimeout Time timeout); /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorToResourceManagerConnection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorToResourceManagerConnection.java index 4084d67a60055..c3d35326399a5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorToResourceManagerConnection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutorToResourceManagerConnection.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.registration.RegisteredRpcConnection; import org.apache.flink.runtime.registration.RegistrationConnectionListener; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.registration.RetryingRegistration; @@ -31,7 +32,6 @@ import org.apache.flink.util.Preconditions; import org.slf4j.Logger; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -41,7 +41,7 @@ * The connection between a TaskExecutor and the ResourceManager. */ public class TaskExecutorToResourceManagerConnection - extends RegisteredRpcConnection { + extends RegisteredRpcConnection { private final RpcService rpcService; @@ -64,11 +64,11 @@ public TaskExecutorToResourceManagerConnection( ResourceID taskManagerResourceId, SlotReport slotReport, String resourceManagerAddress, - UUID resourceManagerLeaderId, + ResourceManagerId resourceManagerId, Executor executor, RegistrationConnectionListener registrationListener) { - super(log, resourceManagerAddress, resourceManagerLeaderId, executor); + super(log, resourceManagerAddress, resourceManagerId, executor); this.rpcService = Preconditions.checkNotNull(rpcService); this.taskManagerAddress = Preconditions.checkNotNull(taskManagerAddress); @@ -79,7 +79,7 @@ public TaskExecutorToResourceManagerConnection( @Override - protected RetryingRegistration generateRegistration() { + protected RetryingRegistration generateRegistration() { return new TaskExecutorToResourceManagerConnection.ResourceManagerRegistration( log, rpcService, @@ -127,7 +127,7 @@ public ResourceID getResourceManagerId() { // ------------------------------------------------------------------------ private static class ResourceManagerRegistration - extends RetryingRegistration { + extends RetryingRegistration { private final String taskExecutorAddress; @@ -139,12 +139,12 @@ private static class ResourceManagerRegistration Logger log, RpcService rpcService, String targetAddress, - UUID leaderId, + ResourceManagerId resourceManagerId, String taskExecutorAddress, ResourceID resourceID, SlotReport slotReport) { - super(log, rpcService, "ResourceManager", ResourceManagerGateway.class, targetAddress, leaderId); + super(log, rpcService, "ResourceManager", ResourceManagerGateway.class, targetAddress, resourceManagerId); this.taskExecutorAddress = checkNotNull(taskExecutorAddress); this.resourceID = checkNotNull(resourceID); this.slotReport = checkNotNull(slotReport); @@ -152,10 +152,10 @@ private static class ResourceManagerRegistration @Override protected CompletableFuture invokeRegistration( - ResourceManagerGateway resourceManager, UUID leaderId, long timeoutMillis) throws Exception { + ResourceManagerGateway resourceManager, ResourceManagerId fencingToken, long timeoutMillis) throws Exception { Time timeout = Time.milliseconds(timeoutMillis); - return resourceManager.registerTaskExecutor(leaderId, taskExecutorAddress, resourceID, slotReport, timeout); + return resourceManager.registerTaskExecutor(taskExecutorAddress, resourceID, slotReport, timeout); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerConfiguration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerConfiguration.java index ea9f5767b0157..7c7693bb9a1b0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerConfiguration.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerConfiguration.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; @@ -53,8 +54,6 @@ public class TaskManagerConfiguration implements TaskManagerRuntimeInfo { private final Time maxRegistrationPause; private final Time refusedRegistrationPause; - private final long cleanupInterval; - private final UnmodifiableConfiguration configuration; private final boolean exitJvmOnOutOfMemory; @@ -78,7 +77,6 @@ public TaskManagerConfiguration( this.initialRegistrationPause = Preconditions.checkNotNull(initialRegistrationPause); this.maxRegistrationPause = Preconditions.checkNotNull(maxRegistrationPause); this.refusedRegistrationPause = Preconditions.checkNotNull(refusedRegistrationPause); - this.cleanupInterval = Preconditions.checkNotNull(cleanupInterval); this.configuration = new UnmodifiableConfiguration(Preconditions.checkNotNull(configuration)); this.exitJvmOnOutOfMemory = exitJvmOnOutOfMemory; } @@ -107,10 +105,6 @@ public Time getRefusedRegistrationPause() { return refusedRegistrationPause; } - public long getCleanupInterval() { - return cleanupInterval; - } - @Override public Configuration getConfiguration() { return configuration; @@ -153,9 +147,7 @@ public static TaskManagerConfiguration fromConfiguration(Configuration configura LOG.info("Messages have a max timeout of " + timeout); - final long cleanupInterval = configuration.getLong( - ConfigConstants.LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL, - ConfigConstants.DEFAULT_LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL) * 1000; + final long cleanupInterval = configuration.getLong(BlobServerOptions.CLEANUP_INTERVAL) * 1000; final Time finiteRegistrationDuration; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java index bf6016126af1d..aba8bda191825 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorGateway; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.util.Preconditions; @@ -40,7 +40,7 @@ public void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState) { + TaskStateSnapshot subtaskState) { checkpointCoordinatorGateway.acknowledgeCheckpoint( jobID, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcInputSplitProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcInputSplitProvider.java index a919c7878bbab..baa403bc3722b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcInputSplitProvider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcInputSplitProvider.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.taskexecutor.rpc; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.core.io.InputSplit; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -30,27 +29,20 @@ import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; -import java.util.UUID; import java.util.concurrent.CompletableFuture; public class RpcInputSplitProvider implements InputSplitProvider { - private final UUID jobMasterLeaderId; private final JobMasterGateway jobMasterGateway; - private final JobID jobID; private final JobVertexID jobVertexID; private final ExecutionAttemptID executionAttemptID; private final Time timeout; public RpcInputSplitProvider( - UUID jobMasterLeaderId, JobMasterGateway jobMasterGateway, - JobID jobID, JobVertexID jobVertexID, ExecutionAttemptID executionAttemptID, Time timeout) { - this.jobMasterLeaderId = Preconditions.checkNotNull(jobMasterLeaderId); this.jobMasterGateway = Preconditions.checkNotNull(jobMasterGateway); - this.jobID = Preconditions.checkNotNull(jobID); this.jobVertexID = Preconditions.checkNotNull(jobVertexID); this.executionAttemptID = Preconditions.checkNotNull(executionAttemptID); this.timeout = Preconditions.checkNotNull(timeout); @@ -62,7 +54,8 @@ public InputSplit getNextInputSplit(ClassLoader userCodeClassLoader) throws Inpu Preconditions.checkNotNull(userCodeClassLoader); CompletableFuture futureInputSplit = jobMasterGateway.requestNextInputSplit( - jobMasterLeaderId, jobVertexID, executionAttemptID); + jobVertexID, + executionAttemptID); try { SerializedInputSplit serializedInputSplit = futureInputSplit.get(timeout.getSize(), timeout.getUnit()); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java index 26e1b0efa781c..f3eb717166a92 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java @@ -26,16 +26,13 @@ import org.apache.flink.runtime.jobmaster.JobMasterGateway; import org.apache.flink.util.Preconditions; -import java.util.UUID; import java.util.concurrent.CompletableFuture; public class RpcPartitionStateChecker implements PartitionProducerStateChecker { - private final UUID jobMasterLeaderId; private final JobMasterGateway jobMasterGateway; - public RpcPartitionStateChecker(UUID jobMasterLeaderId, JobMasterGateway jobMasterGateway) { - this.jobMasterLeaderId = Preconditions.checkNotNull(jobMasterLeaderId); + public RpcPartitionStateChecker(JobMasterGateway jobMasterGateway) { this.jobMasterGateway = Preconditions.checkNotNull(jobMasterGateway); } @@ -45,6 +42,6 @@ public CompletableFuture requestPartitionProducerState( IntermediateDataSetID resultId, ResultPartitionID partitionId) { - return jobMasterGateway.requestPartitionState(jobMasterLeaderId, resultId, partitionId); + return jobMasterGateway.requestPartitionState(resultId, partitionId); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcResultPartitionConsumableNotifier.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcResultPartitionConsumableNotifier.java index d8985620bb828..82a6fbccbe863 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcResultPartitionConsumableNotifier.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcResultPartitionConsumableNotifier.java @@ -29,7 +29,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -37,25 +36,21 @@ public class RpcResultPartitionConsumableNotifier implements ResultPartitionCons private static final Logger LOG = LoggerFactory.getLogger(RpcResultPartitionConsumableNotifier.class); - private final UUID jobMasterLeaderId; private final JobMasterGateway jobMasterGateway; private final Executor executor; private final Time timeout; public RpcResultPartitionConsumableNotifier( - UUID jobMasterLeaderId, JobMasterGateway jobMasterGateway, Executor executor, Time timeout) { - this.jobMasterLeaderId = Preconditions.checkNotNull(jobMasterLeaderId); this.jobMasterGateway = Preconditions.checkNotNull(jobMasterGateway); this.executor = Preconditions.checkNotNull(executor); this.timeout = Preconditions.checkNotNull(timeout); } @Override public void notifyPartitionConsumable(JobID jobId, ResultPartitionID partitionId, final TaskActions taskActions) { - CompletableFuture acknowledgeFuture = jobMasterGateway.scheduleOrUpdateConsumers( - jobMasterLeaderId, partitionId, timeout); + CompletableFuture acknowledgeFuture = jobMasterGateway.scheduleOrUpdateConsumers(partitionId, timeout); acknowledgeFuture.whenCompleteAsync( (Acknowledge ack, Throwable throwable) -> { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlot.java index e12c15b5c411c..6f5230cec0206 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlot.java @@ -299,4 +299,10 @@ public SlotOffer generateSlotOffer() { return new SlotOffer(allocationId, index, resourceProfile); } + + @Override + public String toString() { + return "TaskSlot(index:" + index + ", state:" + state + ", resource profile: " + resourceProfile + + ", allocationId: " + (allocationId != null ? allocationId.toString() : "none") + ", jobId: " + (jobId != null ? jobId.toString() : "none") + ')'; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlotTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlotTable.java index 5c51c7c359b01..799f639cb5615 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlotTable.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/slot/TaskSlotTable.java @@ -282,16 +282,14 @@ public int freeSlot(AllocationID allocationId) throws SlotNotFoundException { public int freeSlot(AllocationID allocationId, Throwable cause) throws SlotNotFoundException { checkInit(); - if (LOG.isDebugEnabled()) { - LOG.debug("Free slot {}.", allocationId, cause); - } else { - LOG.info("Free slot {}.", allocationId); - } - TaskSlot taskSlot = getTaskSlot(allocationId); if (taskSlot != null) { - LOG.info("Free slot {}.", allocationId, cause); + if (LOG.isDebugEnabled()) { + LOG.debug("Free slot {}.", taskSlot, cause); + } else { + LOG.info("Free slot {}.", taskSlot); + } final JobID jobId = taskSlot.getJobId(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java index ad0df7151c20a..e9f600d672abc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java @@ -20,7 +20,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; @@ -44,7 +44,7 @@ public void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { AcknowledgeCheckpoint message = new AcknowledgeCheckpoint( jobID, executionAttemptID, checkpointId, checkpointMetrics, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java index cc66a3f283160..b3584a6dfc987 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java @@ -20,7 +20,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; /** @@ -47,7 +47,7 @@ void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState); + TaskStateSnapshot subtaskState); /** * Declines the given checkpoint. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 788a59090d309..92b58868d666f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -26,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -245,7 +245,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { checkpointResponder.acknowledgeCheckpoint( jobId, executionId, checkpointId, checkpointMetrics, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 596d36519f3ba..d62896054d40c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -30,10 +30,12 @@ import org.apache.flink.core.fs.SafetyNetCloseableRegistry; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotCheckpointingException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotReadyException; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -67,16 +69,17 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.WrappingRuntimeException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.io.IOException; import java.net.URL; import java.util.Collection; @@ -201,7 +204,10 @@ public class Task implements Runnable, TaskActions { /** All listener that want to be notified about changes in the task's execution state */ private final List taskExecutionStateListeners; - /** The library cache, from which the task can request its required JAR files */ + /** The BLOB cache, from which the task can request BLOB files */ + private final BlobCache blobCache; + + /** The library cache, from which the task can request its class loader */ private final LibraryCacheManager libraryCache; /** The cache for user-defined files that the invokable requires */ @@ -250,7 +256,7 @@ public class Task implements Runnable, TaskActions { * The handles to the states that the task was initialized with. Will be set * to null after the initialization, to be memory friendly. */ - private volatile TaskStateHandles taskStateHandles; + private volatile TaskStateSnapshot taskStateHandles; /** Initialized from the Flink configuration. May also be set at the ExecutionConfig */ private long taskCancellationInterval; @@ -272,7 +278,7 @@ public Task( Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors, int targetSlotNumber, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, MemoryManager memManager, IOManager ioManager, NetworkEnvironment networkEnvironment, @@ -280,6 +286,7 @@ public Task( TaskManagerActions taskManagerActions, InputSplitProvider inputSplitProvider, CheckpointResponder checkpointResponder, + BlobCache blobCache, LibraryCacheManager libraryCache, FileCache fileCache, TaskManagerRuntimeInfo taskManagerConfig, @@ -328,6 +335,7 @@ public Task( this.checkpointResponder = Preconditions.checkNotNull(checkpointResponder); this.taskManagerActions = checkNotNull(taskManagerActions); + this.blobCache = Preconditions.checkNotNull(blobCache); this.libraryCache = Preconditions.checkNotNull(libraryCache); this.fileCache = Preconditions.checkNotNull(fileCache); this.network = Preconditions.checkNotNull(networkEnvironment); @@ -566,6 +574,8 @@ else if (current == ExecutionState.CANCELING) { LOG.info("Creating FileSystem stream leak safety net for task {}", this); FileSystemSafetyNet.initializeSafetyNetForThread(); + blobCache.registerJob(jobId); + // first of all, get a user-code classloader // this may involve downloading the job's JAR files and/or classes LOG.info("Loading JAR files for task {}.", this); @@ -825,6 +835,7 @@ else if (transitionState(current, ExecutionState.FAILED, t)) { // remove all of the tasks library resources libraryCache.unregisterTask(jobId, executionId); + blobCache.releaseJob(jobId); // remove all files in the distributed cache removeCachedFiles(distributedCacheEntries, fileCache); @@ -860,7 +871,7 @@ private ClassLoader createUserCodeClassloader(LibraryCacheManager libraryCache) // triggers the download of all missing jar files from the job manager libraryCache.registerTask(jobId, executionId, requiredJarFiles, requiredClasspaths); - LOG.debug("Register task {} at library cache manager took {} milliseconds", + LOG.debug("Getting user code class loader for task {} at library cache manager took {} milliseconds", executionId, System.currentTimeMillis() - startDownloadTime); ClassLoader userCodeClassLoader = libraryCache.getClassLoader(jobId); diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/clusterframework/ContaineredJobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/clusterframework/ContaineredJobManager.scala index cd7b363924a2d..61c61b4c0ee4c 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/clusterframework/ContaineredJobManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/clusterframework/ContaineredJobManager.scala @@ -18,11 +18,12 @@ package org.apache.flink.runtime.clusterframework -import java.util.concurrent.{ScheduledExecutorService, Executor} +import java.util.concurrent.{Executor, ScheduledExecutorService} import akka.actor.ActorRef import org.apache.flink.api.common.JobID import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.clusterframework.messages._ import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager @@ -51,6 +52,7 @@ import scala.language.postfixOps * @param instanceManager Instance manager to manage the registered * [[org.apache.flink.runtime.taskmanager.TaskManager]] * @param scheduler Scheduler to schedule Flink jobs + * @param blobServer Server instance to store BLOBs for the individual tasks * @param libraryCacheManager Manager to manage uploaded jar files * @param archive Archive for finished Flink jobs * @param restartStrategyFactory Restart strategy to be used in case of a job recovery @@ -63,6 +65,7 @@ abstract class ContaineredJobManager( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: FlinkScheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -78,6 +81,7 @@ abstract class ContaineredJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala index 1616a7b7b6663..f0073db4dca53 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala @@ -126,6 +126,7 @@ class JobManager( protected val ioExecutor: Executor, protected val instanceManager: InstanceManager, protected val scheduler: FlinkScheduler, + protected val blobServer: BlobServer, protected val libraryCacheManager: BlobLibraryCacheManager, protected val archive: ActorRef, protected val restartStrategyFactory: RestartStrategyFactory, @@ -272,11 +273,12 @@ class JobManager( instanceManager.shutdown() scheduler.shutdown() + libraryCacheManager.shutdown() try { - libraryCacheManager.shutdown() + blobServer.close() } catch { - case e: IOException => log.error("Could not properly shutdown the library cache manager.", e) + case e: IOException => log.error("Could not properly shutdown the blob server.", e) } // failsafe shutdown of the metrics registry @@ -422,7 +424,7 @@ class JobManager( taskManager ! decorateMessage( AlreadyRegistered( instanceID, - libraryCacheManager.getBlobServerPort)) + blobServer.getPort)) } else { try { val actorGateway = new AkkaActorGateway(taskManager, leaderSessionID.orNull) @@ -437,7 +439,7 @@ class JobManager( taskManagerMap.put(taskManager, instanceID) taskManager ! decorateMessage( - AcknowledgeRegistration(instanceID, libraryCacheManager.getBlobServerPort)) + AcknowledgeRegistration(instanceID, blobServer.getPort)) // to be notified when the taskManager is no longer reachable context.watch(taskManager) @@ -839,6 +841,7 @@ class JobManager( try { log.info(s"Disposing savepoint at '$savepointPath'.") //TODO user code class loader ? + // (has not been used so far and new savepoints can simply be deleted by file) val savepoint = SavepointStore.loadSavepoint( savepointPath, Thread.currentThread().getContextClassLoader) @@ -1060,7 +1063,7 @@ class JobManager( case Some((graph, jobInfo)) => sender() ! decorateMessage( ClassloadingProps( - libraryCacheManager.getBlobServerPort, + blobServer.getPort, graph.getRequiredJarFiles, graph.getRequiredClasspaths)) case None => @@ -1068,7 +1071,7 @@ class JobManager( } case RequestBlobManagerPort => - sender ! decorateMessage(libraryCacheManager.getBlobServerPort) + sender ! decorateMessage(blobServer.getPort) case RequestArchive => sender ! decorateMessage(ResponseArchive(archive)) @@ -1254,8 +1257,8 @@ class JobManager( // because this makes sure that the uploaded jar files are removed in case of // unsuccessful try { - libraryCacheManager.registerJob(jobGraph.getJobID, jobGraph.getUserJarBlobKeys, - jobGraph.getClasspaths) + libraryCacheManager.registerJob( + jobGraph.getJobID, jobGraph.getUserJarBlobKeys, jobGraph.getClasspaths) } catch { case t: Throwable => @@ -1344,6 +1347,7 @@ class JobManager( log.error(s"Failed to submit job $jobId ($jobName)", t) libraryCacheManager.unregisterJob(jobId) + blobServer.cleanupJob(jobId) currentJobs.remove(jobId) if (executionGraph != null) { @@ -1785,12 +1789,10 @@ class JobManager( case None => None } - try { - libraryCacheManager.unregisterJob(jobID) - } catch { - case t: Throwable => - log.error(s"Could not properly unregister job $jobID from the library cache.", t) - } + // remove all job-related BLOBs from local and HA store + libraryCacheManager.unregisterJob(jobID) + blobServer.cleanupJob(jobID) + jobManagerMetricGroup.foreach(_.removeJob(jobID)) futureOption @@ -2230,7 +2232,7 @@ object JobManager { new AkkaJobManagerRetriever(jobManagerSystem, timeout), new AkkaQueryServiceRetriever(jobManagerSystem, timeout), timeout, - jobManagerSystem.dispatcher) + futureExecutor) Option(webServer) } @@ -2463,6 +2465,7 @@ object JobManager { blobStore: BlobStore) : (InstanceManager, FlinkScheduler, + BlobServer, BlobLibraryCacheManager, RestartStrategyFactory, FiniteDuration, // timeout @@ -2474,10 +2477,6 @@ object JobManager { val timeout: FiniteDuration = AkkaUtils.getTimeout(configuration) - val cleanupInterval = configuration.getLong( - ConfigConstants.LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL, - ConfigConstants.DEFAULT_LIBRARY_CACHE_MANAGER_CLEANUP_INTERVAL) * 1000 - val restartStrategy = RestartStrategyFactory.createRestartStrategyFactory(configuration) val archiveCount = configuration.getInteger(WebOptions.ARCHIVE_COUNT) @@ -2508,21 +2507,21 @@ object JobManager { blobServer = new BlobServer(configuration, blobStore) instanceManager = new InstanceManager() scheduler = new FlinkScheduler(ExecutionContext.fromExecutor(futureExecutor)) - libraryCacheManager = new BlobLibraryCacheManager(blobServer, cleanupInterval) + libraryCacheManager = new BlobLibraryCacheManager(blobServer) instanceManager.addInstanceListener(scheduler) } catch { case t: Throwable => - if (libraryCacheManager != null) { - libraryCacheManager.shutdown() - } if (scheduler != null) { scheduler.shutdown() } if (instanceManager != null) { instanceManager.shutdown() } + if (libraryCacheManager != null) { + libraryCacheManager.shutdown() + } if (blobServer != null) { blobServer.close() } @@ -2554,6 +2553,7 @@ object JobManager { (instanceManager, scheduler, + blobServer, libraryCacheManager, restartStrategy, timeout, @@ -2627,6 +2627,7 @@ object JobManager { val (instanceManager, scheduler, + blobServer, libraryCacheManager, restartStrategy, timeout, @@ -2654,6 +2655,7 @@ object JobManager { ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategy, @@ -2693,6 +2695,7 @@ object JobManager { ioExecutor: Executor, instanceManager: InstanceManager, scheduler: FlinkScheduler, + blobServer: BlobServer, libraryCacheManager: LibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -2710,6 +2713,7 @@ object JobManager { ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala index 0ae00a93852cf..dcf9dd0b80e62 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala @@ -26,6 +26,7 @@ import org.apache.flink.api.common.JobID import org.apache.flink.api.common.io.FileOutputFormat import org.apache.flink.configuration.{ConfigConstants, Configuration, JobManagerOptions, QueryableStateOptions, ResourceManagerOptions, TaskManagerOptions} import org.apache.flink.core.fs.Path +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.clusterframework.FlinkResourceManager import org.apache.flink.runtime.clusterframework.standalone.StandaloneResourceManager @@ -133,6 +134,7 @@ class LocalFlinkMiniCluster( val (instanceManager, scheduler, + blobServer, libraryCacheManager, restartStrategyFactory, timeout, @@ -164,6 +166,7 @@ class LocalFlinkMiniCluster( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, @@ -279,6 +282,7 @@ class LocalFlinkMiniCluster( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: Scheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -297,6 +301,7 @@ class LocalFlinkMiniCluster( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index 0c419eb460f13..431adb6f8b122 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -35,7 +35,7 @@ import org.apache.flink.configuration._ import org.apache.flink.core.fs.FileSystem import org.apache.flink.runtime.accumulators.AccumulatorSnapshot import org.apache.flink.runtime.akka.{AkkaUtils, DefaultQuarantineHandler, QuarantineMonitor} -import org.apache.flink.runtime.blob.{BlobCache, BlobClient, BlobService} +import org.apache.flink.runtime.blob.{BlobCache, BlobClient} import org.apache.flink.runtime.broadcast.BroadcastVariableManager import org.apache.flink.runtime.clusterframework.messages.StopCluster import org.apache.flink.runtime.clusterframework.types.ResourceID @@ -160,7 +160,7 @@ class TaskManager( * registered at the job manager */ private val waitForRegistration = scala.collection.mutable.Set[ActorRef]() - private var blobService: Option[BlobService] = None + private var blobCache: Option[BlobCache] = None private var libraryCacheManager: Option[LibraryCacheManager] = None /* The current leading JobManager Actor associated with */ @@ -333,11 +333,11 @@ class TaskManager( killTaskManagerFatal(message, cause) case RequestTaskManagerLog(requestType : LogTypeRequest) => - blobService match { + blobCache match { case Some(_) => handleRequestTaskManagerLog(sender(), requestType, currentJobManager.get) case None => - sender() ! akka.actor.Status.Failure(new IOException("BlobService not " + + sender() ! akka.actor.Status.Failure(new IOException("BlobCache not " + "available. Cannot upload TaskManager logs.")) } @@ -840,7 +840,7 @@ class TaskManager( if (file.exists()) { val fis = new FileInputStream(file); Future { - val client: BlobClient = blobService.get.createClient() + val client: BlobClient = blobCache.get.createClient() client.put(fis); }(context.dispatcher) .onComplete { @@ -915,7 +915,7 @@ class TaskManager( "starting network stack and library cache.") // sanity check that the JobManager dependent components are not set up currently - if (connectionUtils.isDefined || blobService.isDefined) { + if (connectionUtils.isDefined || blobCache.isDefined) { throw new IllegalStateException("JobManager-specific components are already initialized.") } @@ -968,9 +968,9 @@ class TaskManager( address, config.getConfiguration(), highAvailabilityServices.createBlobStore()) - blobService = Option(blobcache) + blobCache = Option(blobcache) libraryCacheManager = Some( - new BlobLibraryCacheManager(blobcache, config.getCleanupInterval())) + new BlobLibraryCacheManager(blobcache)) } catch { case e: Exception => @@ -1047,18 +1047,11 @@ class TaskManager( // shut down BLOB and library cache libraryCacheManager foreach { - manager => - try { - manager.shutdown() - } catch { - case ioe: IOException => log.error( - "Could not properly shutdown library cache manager.", - ioe) - } + manager => manager.shutdown() } libraryCacheManager = None - blobService foreach { + blobCache foreach { service => try { service.close() @@ -1066,7 +1059,7 @@ class TaskManager( case ioe: IOException => log.error("Could not properly shutdown blob service.", ioe) } } - blobService = None + blobCache = None // disassociate the slot environment connectionUtils = None @@ -1130,6 +1123,10 @@ class TaskManager( case Some(manager) => manager case None => throw new IllegalStateException("There is no valid library cache manager.") } + val blobCache = this.blobCache match { + case Some(manager) => manager + case None => throw new IllegalStateException("There is no valid BLOB cache.") + } val slot = tdd.getTargetSlotNumber if (slot < 0 || slot >= numberOfSlots) { @@ -1200,6 +1197,7 @@ class TaskManager( taskManagerConnection, inputSplitProvider, checkpointResponder, + blobCache, libCache, fileCache, config, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheCleanupTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheCleanupTest.java new file mode 100644 index 0000000000000..afd365b6fc892 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheCleanupTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.blob; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.BlobServerOptions; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.TestLogger; + +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * A few tests for the deferred ref-counting based cleanup inside the {@link BlobCache}. + */ +public class BlobCacheCleanupTest extends TestLogger { + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + /** + * Tests that {@link BlobCache} cleans up after calling {@link BlobCache#releaseJob(JobID)}. + */ + @Test + public void testJobCleanup() throws IOException, InterruptedException { + + JobID jobId = new JobID(); + List keys = new ArrayList(); + BlobServer server = null; + BlobCache cache = null; + + final byte[] buf = new byte[128]; + + try { + Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + server = new BlobServer(config, new VoidBlobStore()); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + + // upload blobs + try (BlobClient bc = new BlobClient(serverAddress, config)) { + keys.add(bc.put(jobId, buf)); + buf[0] += 1; + keys.add(bc.put(jobId, buf)); + } + + cache = new BlobCache(serverAddress, config, new VoidBlobStore()); + + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); + + // register once + cache.registerJob(jobId); + + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); + + for (BlobKey key : keys) { + cache.getFile(jobId, key); + } + + // register again (let's say, from another thread or so) + cache.registerJob(jobId); + for (BlobKey key : keys) { + cache.getFile(jobId, key); + } + + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // after releasing once, nothing should change + cache.releaseJob(jobId); + + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // after releasing the second time, the job is up for deferred cleanup + cache.releaseJob(jobId); + + // because we cannot guarantee that there are not thread races in the build system, we + // loop for a certain while until the references disappear + { + long deadline = System.currentTimeMillis() + 30_000L; + do { + Thread.sleep(100); + } + while (checkFilesExist(jobId, keys, cache, false) != 0 && + System.currentTimeMillis() < deadline); + } + + // the blob cache should no longer contain the files + // this fails if we exited via a timeout + checkFileCountForJob(0, jobId, cache); + // server should be unaffected + checkFileCountForJob(2, jobId, server); + } + finally { + if (cache != null) { + cache.close(); + } + + if (server != null) { + server.close(); + } + // now everything should be cleaned up + checkFileCountForJob(0, jobId, server); + } + } + + /** + * Tests that {@link BlobCache} cleans up after calling {@link BlobCache#releaseJob(JobID)} + * but only after preserving the file for a bit longer. + */ + @Test + @Ignore("manual test due to stalling: ensures a BLOB is retained first and only deleted after the (long) timeout ") + public void testJobDeferredCleanup() throws IOException, InterruptedException { + // file should be deleted between 5 and 10s after last job release + long cleanupInterval = 5L; + + JobID jobId = new JobID(); + List keys = new ArrayList(); + BlobServer server = null; + BlobCache cache = null; + + final byte[] buf = new byte[128]; + + try { + Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, cleanupInterval); + + server = new BlobServer(config, new VoidBlobStore()); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + + // upload blobs + try (BlobClient bc = new BlobClient(serverAddress, config)) { + keys.add(bc.put(jobId, buf)); + buf[0] += 1; + keys.add(bc.put(jobId, buf)); + } + + cache = new BlobCache(serverAddress, config, new VoidBlobStore()); + + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); + + // register once + cache.registerJob(jobId); + + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); + + for (BlobKey key : keys) { + cache.getFile(jobId, key); + } + + // register again (let's say, from another thread or so) + cache.registerJob(jobId); + for (BlobKey key : keys) { + cache.getFile(jobId, key); + } + + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // after releasing once, nothing should change + cache.releaseJob(jobId); + + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // after releasing the second time, the job is up for deferred cleanup + cache.releaseJob(jobId); + + // files should still be accessible for now + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, cache); + + Thread.sleep(cleanupInterval / 5); + // still accessible... + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, cache); + + Thread.sleep((cleanupInterval * 4) / 5); + + // files are up for cleanup now...wait for it: + // because we cannot guarantee that there are not thread races in the build system, we + // loop for a certain while until the references disappear + { + long deadline = System.currentTimeMillis() + 30_000L; + do { + Thread.sleep(100); + } + while (checkFilesExist(jobId, keys, cache, false) != 0 && + System.currentTimeMillis() < deadline); + } + + // the blob cache should no longer contain the files + // this fails if we exited via a timeout + checkFileCountForJob(0, jobId, cache); + // server should be unaffected + checkFileCountForJob(2, jobId, server); + } + finally { + if (cache != null) { + cache.close(); + } + + if (server != null) { + server.close(); + } + // now everything should be cleaned up + checkFileCountForJob(0, jobId, server); + } + } + + /** + * Checks how many of the files given by blob keys are accessible. + * + * @param jobId + * ID of a job + * @param keys + * blob keys to check + * @param blobService + * BLOB store to use + * @param doThrow + * whether exceptions should be ignored (false), or thrown (true) + * + * @return number of files we were able to retrieve via {@link BlobService#getFile} + */ + public static int checkFilesExist( + JobID jobId, Collection keys, BlobService blobService, boolean doThrow) + throws IOException { + + int numFiles = 0; + + for (BlobKey key : keys) { + final File blobFile; + if (blobService instanceof BlobServer) { + BlobServer server = (BlobServer) blobService; + blobFile = server.getStorageLocation(jobId, key); + } else { + BlobCache cache = (BlobCache) blobService; + blobFile = cache.getStorageLocation(jobId, key); + } + if (blobFile.exists()) { + ++numFiles; + } else if (doThrow) { + throw new IOException("File " + blobFile + " does not exist."); + } + } + + return numFiles; + } + + /** + * Checks how many of the files given by blob keys are accessible. + * + * @param expectedCount + * number of expected files in the blob service for the given job + * @param jobId + * ID of a job + * @param blobService + * BLOB store to use + * + * @return number of files we were able to retrieve via {@link BlobService#getFile} + */ + public static void checkFileCountForJob( + int expectedCount, JobID jobId, BlobService blobService) + throws IOException { + + final File jobDir; + if (blobService instanceof BlobServer) { + BlobServer server = (BlobServer) blobService; + jobDir = server.getStorageLocation(jobId, new BlobKey()).getParentFile(); + } else { + BlobCache cache = (BlobCache) blobService; + jobDir = cache.getStorageLocation(jobId, new BlobKey()).getParentFile(); + } + File[] blobsForJob = jobDir.listFiles(); + if (blobsForJob == null) { + if (expectedCount != 0) { + throw new IOException("File " + jobDir + " does not exist."); + } + } else { + assertEquals("Too many/few files in job dir: " + + Arrays.asList(blobsForJob).toString(), expectedCount, + blobsForJob.length); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheRetriesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheRetriesTest.java index 366b592a4b322..0060ccbf08fda 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheRetriesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheRetriesTest.java @@ -18,13 +18,17 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.HighAvailabilityOptions; +import org.apache.flink.util.TestLogger; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; @@ -35,13 +39,14 @@ /** * Unit tests for the blob cache retrying the connection to the server. */ -public class BlobCacheRetriesTest { +public class BlobCacheRetriesTest extends TestLogger { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); /** - * A test where the connection fails twice and then the get operation succeeds. + * A test where the connection fails twice and then the get operation succeeds + * (job-unrelated blob). */ @Test public void testBlobFetchRetries() throws IOException { @@ -49,15 +54,41 @@ public void testBlobFetchRetries() throws IOException { config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); - testBlobFetchRetries(config, new VoidBlobStore()); + testBlobFetchRetries(config, new VoidBlobStore(), null); } /** * A test where the connection fails twice and then the get operation succeeds - * (with high availability set). + * (job-related blob). + */ + @Test + public void testBlobForJobFetchRetries() throws IOException { + final Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + + testBlobFetchRetries(config, new VoidBlobStore(), new JobID()); + } + + /** + * A test where the connection fails twice and then the get operation succeeds + * (with high availability set, job-unrelated blob). + */ + @Test + public void testBlobNoJobFetchRetriesHa() throws IOException { + testBlobFetchRetriesHa(null); + } + + /** + * A test where the connection fails twice and then the get operation succeeds + * (with high availability set, job-related job). */ @Test public void testBlobFetchRetriesHa() throws IOException { + testBlobFetchRetriesHa(new JobID()); + } + + private void testBlobFetchRetriesHa(final JobID jobId) throws IOException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); @@ -70,7 +101,7 @@ public void testBlobFetchRetriesHa() throws IOException { try { blobStoreService = BlobUtils.createBlobStoreFromConfig(config); - testBlobFetchRetries(config, blobStoreService); + testBlobFetchRetries(config, blobStoreService, jobId); } finally { if (blobStoreService != null) { blobStoreService.closeAndCleanupAllData(); @@ -86,7 +117,9 @@ public void testBlobFetchRetriesHa() throws IOException { * configuration to use (the BlobCache will get some additional settings * set compared to this one) */ - private void testBlobFetchRetries(final Configuration config, final BlobStore blobStore) throws IOException { + private static void testBlobFetchRetries( + final Configuration config, final BlobStore blobStore, final JobID jobId) + throws IOException { final byte[] data = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; BlobServer server = null; @@ -104,7 +137,7 @@ private void testBlobFetchRetries(final Configuration config, final BlobStore bl try { blobClient = new BlobClient(serverAddress, config); - key = blobClient.put(data); + key = blobClient.put(jobId, data); } finally { if (blobClient != null) { @@ -115,16 +148,13 @@ private void testBlobFetchRetries(final Configuration config, final BlobStore bl cache = new BlobCache(serverAddress, config, new VoidBlobStore()); // trigger a download - it should fail the first two times, but retry, and succeed eventually - URL url = cache.getURL(key); - InputStream is = url.openStream(); - try { + File file = jobId == null ? cache.getFile(key) : cache.getFile(jobId, key); + URL url = file.toURI().toURL(); + try (InputStream is = url.openStream()) { byte[] received = new byte[data.length]; assertEquals(data.length, is.read(received)); assertArrayEquals(data, received); } - finally { - is.close(); - } } finally { if (cache != null) { cache.close(); @@ -136,23 +166,50 @@ private void testBlobFetchRetries(final Configuration config, final BlobStore bl } /** - * A test where the connection fails too often and eventually fails the GET request. + * A test where the connection fails too often and eventually fails the GET request + * (job-unrelated blob). + */ + @Test + public void testBlobNoJobFetchWithTooManyFailures() throws IOException { + final Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + + testBlobFetchWithTooManyFailures(config, new VoidBlobStore(), null); + } + + /** + * A test where the connection fails too often and eventually fails the GET request (job-related + * blob). */ @Test - public void testBlobFetchWithTooManyFailures() throws IOException { + public void testBlobForJobFetchWithTooManyFailures() throws IOException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); - testBlobFetchWithTooManyFailures(config, new VoidBlobStore()); + testBlobFetchWithTooManyFailures(config, new VoidBlobStore(), new JobID()); + } + + /** + * A test where the connection fails twice and then the get operation succeeds + * (with high availability set, job-unrelated blob). + */ + @Test + public void testBlobNoJobFetchWithTooManyFailuresHa() throws IOException { + testBlobFetchWithTooManyFailuresHa(null); } /** * A test where the connection fails twice and then the get operation succeeds - * (with high availability set). + * (with high availability set, job-related blob). */ @Test - public void testBlobFetchWithTooManyFailuresHa() throws IOException { + public void testBlobForJobFetchWithTooManyFailuresHa() throws IOException { + testBlobFetchWithTooManyFailuresHa(new JobID()); + } + + private void testBlobFetchWithTooManyFailuresHa(final JobID jobId) throws IOException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); @@ -165,7 +222,7 @@ public void testBlobFetchWithTooManyFailuresHa() throws IOException { try { blobStoreService = BlobUtils.createBlobStoreFromConfig(config); - testBlobFetchWithTooManyFailures(config, blobStoreService); + testBlobFetchWithTooManyFailures(config, blobStoreService, jobId); } finally { if (blobStoreService != null) { blobStoreService.closeAndCleanupAllData(); @@ -181,7 +238,9 @@ public void testBlobFetchWithTooManyFailuresHa() throws IOException { * configuration to use (the BlobCache will get some additional settings * set compared to this one) */ - private void testBlobFetchWithTooManyFailures(final Configuration config, final BlobStore blobStore) throws IOException { + private static void testBlobFetchWithTooManyFailures( + final Configuration config, final BlobStore blobStore, final JobID jobId) + throws IOException { final byte[] data = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 }; BlobServer server = null; @@ -199,7 +258,7 @@ private void testBlobFetchWithTooManyFailures(final Configuration config, final try { blobClient = new BlobClient(serverAddress, config); - key = blobClient.put(data); + key = blobClient.put(jobId, data); } finally { if (blobClient != null) { @@ -211,7 +270,11 @@ private void testBlobFetchWithTooManyFailures(final Configuration config, final // trigger a download - it should fail eventually try { - cache.getURL(key); + if (jobId == null) { + cache.getFile(key); + } else { + cache.getFile(jobId, key); + } fail("This should fail"); } catch (IOException e) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheSuccessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheSuccessTest.java index 51be1b044e21e..1216be2daf093 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheSuccessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobCacheSuccessTest.java @@ -18,10 +18,12 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.HighAvailabilityOptions; import org.apache.flink.util.Preconditions; +import org.apache.flink.util.TestLogger; import org.junit.Rule; import org.junit.Test; @@ -30,82 +32,133 @@ import java.io.File; import java.io.IOException; import java.net.InetSocketAddress; -import java.net.URISyntaxException; -import java.net.URL; import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; /** * This class contains unit tests for the {@link BlobCache}. */ -public class BlobCacheSuccessTest { +public class BlobCacheSuccessTest extends TestLogger { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); /** - * BlobCache with no HA. BLOBs need to be downloaded form a working + * BlobCache with no HA, job-unrelated BLOBs. BLOBs need to be downloaded form a working * BlobServer. */ @Test - public void testBlobCache() throws IOException { + public void testBlobNoJobCache() throws IOException { Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); - uploadFileGetTest(config, false, false); + uploadFileGetTest(config, null, false, false); + } + + /** + * BlobCache with no HA, job-related BLOBS. BLOBs need to be downloaded form a working + * BlobServer. + */ + @Test + public void testBlobForJobCache() throws IOException { + Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + + uploadFileGetTest(config, new JobID(), false, false); } /** * BlobCache is configured in HA mode and the cache can download files from * the file system directly and does not need to download BLOBs from the - * BlobServer. + * BlobServer. Using job-unrelated BLOBs. */ @Test - public void testBlobCacheHa() throws IOException { + public void testBlobNoJobCacheHa() throws IOException { + testBlobCacheHa(null); + } + + /** + * BlobCache is configured in HA mode and the cache can download files from + * the file system directly and does not need to download BLOBs from the + * BlobServer. Using job-related BLOBs. + */ + @Test + public void testBlobForJobCacheHa() throws IOException { + testBlobCacheHa(new JobID()); + } + + private void testBlobCacheHa(final JobID jobId) throws IOException { Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.newFolder().getPath()); - uploadFileGetTest(config, true, true); + uploadFileGetTest(config, jobId, true, true); } /** * BlobCache is configured in HA mode and the cache can download files from * the file system directly and does not need to download BLOBs from the - * BlobServer. + * BlobServer. Using job-unrelated BLOBs. + */ + @Test + public void testBlobNoJobCacheHa2() throws IOException { + testBlobCacheHa2(null); + } + + /** + * BlobCache is configured in HA mode and the cache can download files from + * the file system directly and does not need to download BLOBs from the + * BlobServer. Using job-related BLOBs. */ @Test - public void testBlobCacheHa2() throws IOException { + public void testBlobForJobCacheHa2() throws IOException { + testBlobCacheHa2(new JobID()); + } + + private void testBlobCacheHa2(JobID jobId) throws IOException { Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.newFolder().getPath()); - uploadFileGetTest(config, false, true); + uploadFileGetTest(config, jobId, false, true); + } + + /** + * BlobCache is configured in HA mode but the cache itself cannot access the + * file system and thus needs to download BLOBs from the BlobServer. Using job-unrelated BLOBs. + */ + @Test + public void testBlobNoJobCacheHaFallback() throws IOException { + testBlobCacheHaFallback(null); } /** * BlobCache is configured in HA mode but the cache itself cannot access the - * file system and thus needs to download BLOBs from the BlobServer. + * file system and thus needs to download BLOBs from the BlobServer. Using job-related BLOBs. */ @Test - public void testBlobCacheHaFallback() throws IOException { + public void testBlobForJobCacheHaFallback() throws IOException { + testBlobCacheHaFallback(new JobID()); + } + + private void testBlobCacheHaFallback(final JobID jobId) throws IOException { Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.newFolder().getPath()); - uploadFileGetTest(config, false, false); + uploadFileGetTest(config, jobId, false, false); } /** @@ -122,7 +175,7 @@ public void testBlobCacheHaFallback() throws IOException { * whether the cache should have access to a shared HA_STORAGE_PATH (only useful with * HA mode) */ - private void uploadFileGetTest(final Configuration config, boolean shutdownServerAfterUpload, + private void uploadFileGetTest(final Configuration config, JobID jobId, boolean shutdownServerAfterUpload, boolean cacheHasAccessToFs) throws IOException { Preconditions.checkArgument(!shutdownServerAfterUpload || cacheHasAccessToFs); @@ -157,9 +210,9 @@ private void uploadFileGetTest(final Configuration config, boolean shutdownServe blobClient = new BlobClient(serverAddress, config); - blobKeys.add(blobClient.put(buf)); + blobKeys.add(blobClient.put(jobId, buf)); buf[0] = 1; // Make sure the BLOB key changes - blobKeys.add(blobClient.put(buf)); + blobKeys.add(blobClient.put(jobId, buf)); } finally { if (blobClient != null) { blobClient.close(); @@ -175,7 +228,11 @@ private void uploadFileGetTest(final Configuration config, boolean shutdownServe blobCache = new BlobCache(serverAddress, cacheConfig, blobStoreService); for (BlobKey blobKey : blobKeys) { - blobCache.getURL(blobKey); + if (jobId == null) { + blobCache.getFile(blobKey); + } else { + blobCache.getFile(jobId, blobKey); + } } if (blobServer != null) { @@ -184,28 +241,24 @@ private void uploadFileGetTest(final Configuration config, boolean shutdownServe blobServer = null; } - final URL[] urls = new URL[blobKeys.size()]; + final File[] files = new File[blobKeys.size()]; for(int i = 0; i < blobKeys.size(); i++){ - urls[i] = blobCache.getURL(blobKeys.get(i)); + if (jobId == null) { + files[i] = blobCache.getFile(blobKeys.get(i)); + } else { + files[i] = blobCache.getFile(jobId, blobKeys.get(i)); + } } // Verify the result - assertEquals(blobKeys.size(), urls.length); - - for (final URL url : urls) { - - assertNotNull(url); + assertEquals(blobKeys.size(), files.length); - try { - final File cachedFile = new File(url.toURI()); + for (final File file : files) { + assertNotNull(file); - assertTrue(cachedFile.exists()); - assertEquals(buf.length, cachedFile.length()); - - } catch (URISyntaxException e) { - fail(e.getMessage()); - } + assertTrue(file.exists()); + assertEquals(buf.length, file.length()); } } finally { if (blobServer != null) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java index 2932f41daea2c..6d6bfd51f5dc0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java @@ -38,6 +38,9 @@ import java.util.Collections; import java.util.List; +import org.apache.flink.api.common.JobID; +import org.apache.flink.util.TestLogger; + import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -45,7 +48,7 @@ /** * This class contains unit tests for the {@link BlobClient}. */ -public class BlobClientTest { +public class BlobClientTest extends TestLogger { /** The buffer size used during the tests in bytes. */ private static final int TEST_BUFFER_SIZE = 17 * 1000; @@ -138,30 +141,35 @@ private static BlobKey prepareTestFile(File file) throws IOException { * the specified buffer. * * @param inputStream - * the input stream returned from the GET operation + * the input stream returned from the GET operation (will be closed by this method) * @param buf * the buffer to compare the input stream's data to * @throws IOException * thrown if an I/O error occurs while reading the input stream */ - private static void validateGet(final InputStream inputStream, final byte[] buf) throws IOException { - byte[] receivedBuffer = new byte[buf.length]; + static void validateGetAndClose(final InputStream inputStream, final byte[] buf) throws IOException { + try { + byte[] receivedBuffer = new byte[buf.length]; - int bytesReceived = 0; + int bytesReceived = 0; - while (true) { + while (true) { - final int read = inputStream.read(receivedBuffer, bytesReceived, receivedBuffer.length - bytesReceived); - if (read < 0) { - throw new EOFException(); - } - bytesReceived += read; + final int read = inputStream + .read(receivedBuffer, bytesReceived, receivedBuffer.length - bytesReceived); + if (read < 0) { + throw new EOFException(); + } + bytesReceived += read; - if (bytesReceived == receivedBuffer.length) { - assertEquals(-1, inputStream.read()); - assertArrayEquals(buf, receivedBuffer); - return; + if (bytesReceived == receivedBuffer.length) { + assertEquals(-1, inputStream.read()); + assertArrayEquals(buf, receivedBuffer); + return; + } } + } finally { + inputStream.close(); } } @@ -170,13 +178,13 @@ private static void validateGet(final InputStream inputStream, final byte[] buf) * the specified file. * * @param inputStream - * the input stream returned from the GET operation + * the input stream returned from the GET operation (will be closed by this method) * @param file * the file to compare the input stream's data to * @throws IOException * thrown if an I/O error occurs while reading the input stream or the file */ - private static void validateGet(final InputStream inputStream, final File file) throws IOException { + private static void validateGetAndClose(final InputStream inputStream, final File file) throws IOException { InputStream inputStream2 = null; try { @@ -199,6 +207,7 @@ private static void validateGet(final InputStream inputStream, final File file) if (inputStream2 != null) { inputStream2.close(); } + inputStream.close(); } } @@ -207,7 +216,7 @@ private static void validateGet(final InputStream inputStream, final File file) * Tests the PUT/GET operations for content-addressable buffers. */ @Test - public void testContentAddressableBuffer() { + public void testContentAddressableBuffer() throws IOException { BlobClient client = null; @@ -220,26 +229,34 @@ public void testContentAddressableBuffer() { InetSocketAddress serverAddress = new InetSocketAddress("localhost", getBlobServer().getPort()); client = new BlobClient(serverAddress, getBlobClientConfig()); + JobID jobId = new JobID(); + // Store the data - BlobKey receivedKey = client.put(testBuffer); + BlobKey receivedKey = client.put(null, testBuffer); + assertEquals(origKey, receivedKey); + // try again with a job-related BLOB: + receivedKey = client.put(jobId, testBuffer); assertEquals(origKey, receivedKey); // Retrieve the data - InputStream is = client.get(receivedKey); - validateGet(is, testBuffer); + validateGetAndClose(client.get(receivedKey), testBuffer); + validateGetAndClose(client.get(jobId, receivedKey), testBuffer); // Check reaction to invalid keys - try { - client.get(new BlobKey()); + try (InputStream ignored = client.get(new BlobKey())) { + fail("Expected IOException did not occur"); + } + catch (IOException fnfe) { + // expected + } + // new client needed (closed from failure above) + client = new BlobClient(serverAddress, getBlobClientConfig()); + try (InputStream ignored = client.get(jobId, new BlobKey())) { fail("Expected IOException did not occur"); } catch (IOException fnfe) { // expected } - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); } finally { if (client != null) { @@ -262,7 +279,7 @@ protected BlobServer getBlobServer() { * Tests the PUT/GET operations for content-addressable streams. */ @Test - public void testContentAddressableStream() { + public void testContentAddressableStream() throws IOException { BlobClient client = null; InputStream is = null; @@ -276,21 +293,23 @@ public void testContentAddressableStream() { InetSocketAddress serverAddress = new InetSocketAddress("localhost", getBlobServer().getPort()); client = new BlobClient(serverAddress, getBlobClientConfig()); + JobID jobId = new JobID(); + // Store the data is = new FileInputStream(testFile); BlobKey receivedKey = client.put(is); assertEquals(origKey, receivedKey); + // try again with a job-related BLOB: + is = new FileInputStream(testFile); + receivedKey = client.put(jobId, is); + assertEquals(origKey, receivedKey); is.close(); is = null; // Retrieve the data - is = client.get(receivedKey); - validateGet(is, testFile); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + validateGetAndClose(client.get(receivedKey), testFile); + validateGetAndClose(client.get(jobId, receivedKey), testFile); } finally { if (is != null) { @@ -307,7 +326,7 @@ public void testContentAddressableStream() { } /** - * Tests the static {@link BlobClient#uploadJarFiles(InetSocketAddress, Configuration, List)} helper. + * Tests the static {@link BlobClient#uploadJarFiles(InetSocketAddress, Configuration, JobID, List)} helper. */ @Test public void testUploadJarFilesHelper() throws Exception { @@ -315,7 +334,7 @@ public void testUploadJarFilesHelper() throws Exception { } /** - * Tests the static {@link BlobClient#uploadJarFiles(InetSocketAddress, Configuration, List)} helper. + * Tests the static {@link BlobClient#uploadJarFiles(InetSocketAddress, Configuration, JobID, List)}} helper. */ static void uploadJarFile(BlobServer blobServer, Configuration blobClientConfig) throws Exception { final File testFile = File.createTempFile("testfile", ".dat"); @@ -324,14 +343,21 @@ static void uploadJarFile(BlobServer blobServer, Configuration blobClientConfig) InetSocketAddress serverAddress = new InetSocketAddress("localhost", blobServer.getPort()); + uploadJarFile(serverAddress, blobClientConfig, testFile); + uploadJarFile(serverAddress, blobClientConfig, testFile); + } + + private static void uploadJarFile( + final InetSocketAddress serverAddress, final Configuration blobClientConfig, + final File testFile) throws IOException { + JobID jobId = new JobID(); List blobKeys = BlobClient.uploadJarFiles(serverAddress, blobClientConfig, - Collections.singletonList(new Path(testFile.toURI()))); + jobId, Collections.singletonList(new Path(testFile.toURI()))); assertEquals(1, blobKeys.size()); try (BlobClient blobClient = new BlobClient(serverAddress, blobClientConfig)) { - InputStream is = blobClient.get(blobKeys.get(0)); - validateGet(is, testFile); + validateGetAndClose(blobClient.get(jobId, blobKeys.get(0)), testFile); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobKeyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobKeyTest.java index 4071a1caf5544..43bc6228248f2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobKeyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobKeyTest.java @@ -29,12 +29,14 @@ import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.util.StringUtils; +import org.apache.flink.util.TestLogger; + import org.junit.Test; /** * This class contains unit tests for the {@link BlobKey} class. */ -public final class BlobKeyTest { +public final class BlobKeyTest extends TestLogger { /** * The first key array to be used during the unit tests. */ @@ -106,4 +108,4 @@ public void testStreams() throws Exception { assertEquals(k1, k2); } -} \ No newline at end of file +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobRecoveryITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobRecoveryITCase.java index 3c7711d86bc68..81304f45ad4b9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobRecoveryITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobRecoveryITCase.java @@ -94,11 +94,17 @@ public static void testBlobServerRecovery(final Configuration config, final Blob BlobKey[] keys = new BlobKey[2]; - // Put data - keys[0] = client.put(expected); // Request 1 - keys[1] = client.put(expected, 32, 256); // Request 2 + // Put job-unrelated data + keys[0] = client.put(null, expected); // Request 1 + keys[1] = client.put(null, expected, 32, 256); // Request 2 + // Put job-related data, verify that the checksums match JobID[] jobId = new JobID[] { new JobID(), new JobID() }; + BlobKey key; + key = client.put(jobId[0], expected); // Request 3 + assertEquals(keys[0], key); + key = client.put(jobId[1], expected, 32, 256); // Request 4 + assertEquals(keys[1], key); // check that the storage directory exists final Path blobServerPath = new Path(storagePath, "blob"); @@ -130,9 +136,31 @@ public static void testBlobServerRecovery(final Configuration config, final Blob } } + // Verify request 3 + try (InputStream is = client.get(jobId[0], keys[0])) { + byte[] actual = new byte[expected.length]; + BlobUtils.readFully(is, actual, 0, expected.length, null); + + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], actual[i]); + } + } + + // Verify request 4 + try (InputStream is = client.get(jobId[1], keys[1])) { + byte[] actual = new byte[256]; + BlobUtils.readFully(is, actual, 0, 256, null); + + for (int i = 32, j = 0; i < 256; i++, j++) { + assertEquals(expected[i], actual[j]); + } + } + // Remove again client.delete(keys[0]); client.delete(keys[1]); + client.delete(jobId[0], keys[0]); + client.delete(jobId[1], keys[1]); // Verify everything is clean assertTrue("HA storage directory does not exist", fs.exists(new Path(storagePath))); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerDeleteTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerDeleteTest.java index 5db956830e37a..6bb5ab57416ef 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerDeleteTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerDeleteTest.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.concurrent.FlinkFutureException; @@ -30,6 +31,7 @@ import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; @@ -39,6 +41,10 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import static org.apache.flink.runtime.blob.BlobCacheCleanupTest.checkFileCountForJob; +import static org.apache.flink.runtime.blob.BlobCacheCleanupTest.checkFilesExist; +import static org.apache.flink.runtime.blob.BlobClientTest.validateGetAndClose; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; @@ -58,7 +64,7 @@ public class BlobServerDeleteTest extends TestLogger { public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Test - public void testDeleteSingleByBlobKey() { + public void testDeleteSingleByBlobKey() throws IOException { BlobServer server = null; BlobClient client = null; BlobStore blobStore = new VoidBlobStore(); @@ -75,58 +81,86 @@ public void testDeleteSingleByBlobKey() { byte[] data = new byte[2000000]; rnd.nextBytes(data); - // put content addressable (like libraries) - BlobKey key = client.put(data); - assertNotNull(key); + // put job-unrelated (like libraries) + BlobKey key1 = client.put(null, data); + assertNotNull(key1); - // second item + // second job-unrelated item data[0] ^= 1; - BlobKey key2 = client.put(data); + BlobKey key2 = client.put(null, data); assertNotNull(key2); - assertNotEquals(key, key2); + assertNotEquals(key1, key2); + + // put job-related with same key1 as non-job-related + data[0] ^= 1; // back to the original data + final JobID jobId = new JobID(); + BlobKey key1b = client.put(jobId, data); + assertNotNull(key1b); + assertEquals(key1, key1b); // issue a DELETE request via the client - client.delete(key); + client.delete(key1); client.close(); client = new BlobClient(serverAddress, config); - try { - client.get(key); + try (InputStream ignored = client.get(key1)) { fail("BLOB should have been deleted"); } catch (IOException e) { // expected } + ensureClientIsClosed(client); + + client = new BlobClient(serverAddress, config); try { - client.put(new byte[1]); - fail("client should be closed after erroneous operation"); + // NOTE: the server will stall in its send operation until either the data is fully + // read or the socket is closed, e.g. via a client.close() call + validateGetAndClose(client.get(jobId, key1), data); } - catch (IllegalStateException e) { - // expected + catch (IOException e) { + fail("Deleting a job-unrelated BLOB should not affect a job-related BLOB with the same key"); } + client.close(); // delete a file directly on the server server.delete(key2); try { - server.getURL(key2); + server.getFile(key2); fail("BLOB should have been deleted"); } catch (IOException e) { // expected } } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } finally { cleanup(server, client); } } + private static void ensureClientIsClosed(final BlobClient client) throws IOException { + try { + client.put(null, new byte[1]); + fail("client should be closed after erroneous operation"); + } + catch (IllegalStateException e) { + // expected + } finally { + client.close(); + } + } + + @Test + public void testDeleteAlreadyDeletedNoJob() throws IOException { + testDeleteAlreadyDeleted(null); + } + @Test - public void testDeleteAlreadyDeletedByBlobKey() { + public void testDeleteAlreadyDeletedForJob() throws IOException { + testDeleteAlreadyDeleted(new JobID()); + } + + private void testDeleteAlreadyDeleted(final JobID jobId) throws IOException { BlobServer server = null; BlobClient client = null; BlobStore blobStore = new VoidBlobStore(); @@ -143,35 +177,52 @@ public void testDeleteAlreadyDeletedByBlobKey() { byte[] data = new byte[2000000]; rnd.nextBytes(data); - // put content addressable (like libraries) - BlobKey key = client.put(data); + // put file + BlobKey key = client.put(jobId, data); assertNotNull(key); - File blobFile = server.getStorageLocation(key); + File blobFile = server.getStorageLocation(jobId, key); assertTrue(blobFile.delete()); // issue a DELETE request via the client try { - client.delete(key); + deleteHelper(client, jobId, key); } catch (IOException e) { fail("DELETE operation should not fail if file is already deleted"); } // issue a DELETE request on the server - server.delete(key); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + if (jobId == null) { + server.delete(key); + } else { + server.delete(jobId, key); + } } finally { cleanup(server, client); } } + private static void deleteHelper(BlobClient client, JobID jobId, BlobKey key) throws IOException { + if (jobId == null) { + client.delete(key); + } else { + client.delete(jobId, key); + } + } + @Test - public void testDeleteByBlobKeyFails() { + public void testDeleteFailsNoJob() throws IOException { + testDeleteFails(null); + } + + @Test + public void testDeleteFailsForJob() throws IOException { + testDeleteFails(new JobID()); + } + + private void testDeleteFails(final JobID jobId) throws IOException { assumeTrue(!OperatingSystem.isWindows()); //setWritable doesn't work on Windows. BlobServer server = null; @@ -193,35 +244,115 @@ public void testDeleteByBlobKeyFails() { rnd.nextBytes(data); // put content addressable (like libraries) - BlobKey key = client.put(data); + BlobKey key = client.put(jobId, data); assertNotNull(key); - blobFile = server.getStorageLocation(key); + blobFile = server.getStorageLocation(jobId, key); directory = blobFile.getParentFile(); assertTrue(blobFile.setWritable(false, false)); assertTrue(directory.setWritable(false, false)); // issue a DELETE request via the client - client.delete(key); + deleteHelper(client, jobId, key); // issue a DELETE request on the server - server.delete(key); + if (jobId == null) { + server.delete(key); + } else { + server.delete(jobId, key); + } // the file should still be there - server.getURL(key); - } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + if (jobId == null) { + server.getFile(key); + } else { + server.getFile(jobId, key); + } } finally { if (blobFile != null && directory != null) { + //noinspection ResultOfMethodCallIgnored blobFile.setWritable(true, false); + //noinspection ResultOfMethodCallIgnored directory.setWritable(true, false); } cleanup(server, client); } } + /** + * Tests that {@link BlobServer} cleans up after calling {@link BlobServer#cleanupJob(JobID)}. + */ + @Test + public void testJobCleanup() throws IOException, InterruptedException { + + JobID jobId1 = new JobID(); + List keys1 = new ArrayList(); + JobID jobId2 = new JobID(); + List keys2 = new ArrayList(); + BlobServer server = null; + + final byte[] buf = new byte[128]; + + try { + Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + + server = new BlobServer(config, new VoidBlobStore()); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + BlobClient bc = new BlobClient(serverAddress, config); + + keys1.add(bc.put(jobId1, buf)); + keys2.add(bc.put(jobId2, buf)); + assertEquals(keys2.get(0), keys1.get(0)); + + buf[0] += 1; + keys1.add(bc.put(jobId1, buf)); + + bc.close(); + + assertEquals(2, checkFilesExist(jobId1, keys1, server, true)); + checkFileCountForJob(2, jobId1, server); + assertEquals(1, checkFilesExist(jobId2, keys2, server, true)); + checkFileCountForJob(1, jobId2, server); + + server.cleanupJob(jobId1); + + checkFileCountForJob(0, jobId1, server); + assertEquals(1, checkFilesExist(jobId2, keys2, server, true)); + checkFileCountForJob(1, jobId2, server); + + server.cleanupJob(jobId2); + + checkFileCountForJob(0, jobId1, server); + checkFileCountForJob(0, jobId2, server); + + // calling a second time should not fail + server.cleanupJob(jobId2); + } + finally { + if (server != null) { + server.close(); + } + } + } + + /** + * FLINK-6020 + * + * Tests that concurrent delete operations don't interfere with each other. + * + * Note: The test checks that there cannot be two threads which have checked whether a given blob file exist + * and then one of them fails deleting it. Without the introduced lock, this situation should rarely happen + * and make this test fail. Thus, if this test should become "unstable", then the delete atomicity is most likely + * broken. + */ + @Test + public void testConcurrentDeleteOperationsNoJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentDeleteOperations(null); + } + /** * FLINK-6020 * @@ -233,10 +364,14 @@ public void testDeleteByBlobKeyFails() { * broken. */ @Test - public void testConcurrentDeleteOperations() throws IOException, ExecutionException, InterruptedException { + public void testConcurrentDeleteOperationsForJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentDeleteOperations(new JobID()); + } + + private void testConcurrentDeleteOperations(final JobID jobId) + throws IOException, InterruptedException, ExecutionException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); - final BlobStore blobStore = mock(BlobStore.class); final int concurrentDeleteOperations = 3; @@ -251,16 +386,16 @@ public void testConcurrentDeleteOperations() throws IOException, ExecutionExcept final BlobKey blobKey; try (BlobClient client = blobServer.createClient()) { - blobKey = client.put(data); + blobKey = client.put(jobId, data); } - assertTrue(blobServer.getStorageLocation(blobKey).exists()); + assertTrue(blobServer.getStorageLocation(jobId, blobKey).exists()); for (int i = 0; i < concurrentDeleteOperations; i++) { CompletableFuture deleteFuture = CompletableFuture.supplyAsync( () -> { try (BlobClient blobClient = blobServer.createClient()) { - blobClient.delete(blobKey); + deleteHelper(blobClient, jobId, blobKey); } catch (IOException e) { throw new FlinkFutureException("Could not delete the given blob key " + blobKey + '.', e); } @@ -278,13 +413,13 @@ public void testConcurrentDeleteOperations() throws IOException, ExecutionExcept // in case of no lock, one of the delete operations should eventually fail waitFuture.get(); - assertFalse(blobServer.getStorageLocation(blobKey).exists()); + assertFalse(blobServer.getStorageLocation(jobId, blobKey).exists()); } finally { executor.shutdownNow(); } } - private void cleanup(BlobServer server, BlobClient client) { + private static void cleanup(BlobServer server, BlobClient client) { if (client != null) { try { client.close(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerGetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerGetTest.java index bd27d702a9d18..7ccf075937085 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerGetTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerGetTest.java @@ -20,6 +20,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.concurrent.FlinkFutureException; @@ -48,6 +49,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import static org.apache.flink.runtime.blob.BlobClientTest.validateGetAndClose; import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; @@ -69,7 +71,28 @@ public class BlobServerGetTest extends TestLogger { public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Test - public void testGetFailsDuringLookup() throws IOException { + public void testGetFailsDuringLookup1() throws IOException { + testGetFailsDuringLookup(null, new JobID()); + } + + @Test + public void testGetFailsDuringLookup2() throws IOException { + testGetFailsDuringLookup(new JobID(), new JobID()); + } + + @Test + public void testGetFailsDuringLookup3() throws IOException { + testGetFailsDuringLookup(new JobID(), null); + } + + /** + * Checks the correct result if a GET operation fails during the lookup of the file. + * + * @param jobId1 first job ID or null if job-unrelated + * @param jobId2 second job ID different to jobId1 + */ + private void testGetFailsDuringLookup(final JobID jobId1, final JobID jobId2) + throws IOException { BlobServer server = null; BlobClient client = null; @@ -86,20 +109,29 @@ public void testGetFailsDuringLookup() throws IOException { rnd.nextBytes(data); // put content addressable (like libraries) - BlobKey key = client.put(data); + BlobKey key = client.put(jobId1, data); assertNotNull(key); - // delete all files to make sure that GET requests fail - File blobFile = server.getStorageLocation(key); + // delete file to make sure that GET requests fail + File blobFile = server.getStorageLocation(jobId1, key); assertTrue(blobFile.delete()); // issue a GET request that fails - try { - client.get(key); - fail("This should not succeed."); - } catch (IOException e) { - // expected - } + client = verifyDeleted(client, jobId1, key, serverAddress, config); + + BlobKey key2 = client.put(jobId2, data); + assertNotNull(key); + assertEquals(key, key2); + // request for jobId2 should succeed + validateGetAndClose(getFileHelper(client, jobId2, key), data); + // request for jobId1 should still fail + client = verifyDeleted(client, jobId1, key, serverAddress, config); + + // same checks as for jobId1 but for jobId2 should also work: + blobFile = server.getStorageLocation(jobId2, key); + assertTrue(blobFile.delete()); + client = verifyDeleted(client, jobId2, key, serverAddress, config); + } finally { if (client != null) { client.close(); @@ -110,8 +142,50 @@ public void testGetFailsDuringLookup() throws IOException { } } + /** + * Checks that the given blob does not exist anymore. + * + * @param client + * BLOB client to use for connecting to the BLOB server + * @param jobId + * job ID or null if job-unrelated + * @param key + * key identifying the BLOB to request + * @param serverAddress + * BLOB server address + * @param config + * client config + * + * @return a new client (since the old one is being closed on failure) + */ + private static BlobClient verifyDeleted( + BlobClient client, JobID jobId, BlobKey key, + InetSocketAddress serverAddress, Configuration config) throws IOException { + try (InputStream ignored = getFileHelper(client, jobId, key)) { + fail("This should not succeed."); + } catch (IOException e) { + // expected + } + // need a new client (old ony closed due to failure + return new BlobClient(serverAddress, config); + } + + @Test + public void testGetFailsDuringStreamingNoJob() throws IOException { + testGetFailsDuringStreaming(null); + } + @Test - public void testGetFailsDuringStreaming() throws IOException { + public void testGetFailsDuringStreamingForJob() throws IOException { + testGetFailsDuringStreaming(new JobID()); + } + + /** + * Checks the correct result if a GET operation fails during the file download. + * + * @param jobId job ID or null if job-unrelated + */ + private void testGetFailsDuringStreaming(final JobID jobId) throws IOException { BlobServer server = null; BlobClient client = null; @@ -128,11 +202,11 @@ public void testGetFailsDuringStreaming() throws IOException { rnd.nextBytes(data); // put content addressable (like libraries) - BlobKey key = client.put(data); + BlobKey key = client.put(jobId, data); assertNotNull(key); // issue a GET request that succeeds - InputStream is = client.get(key); + InputStream is = getFileHelper(client, jobId, key); byte[] receiveBuffer = new byte[data.length]; int firstChunkLen = 50000; @@ -153,6 +227,7 @@ public void testGetFailsDuringStreaming() throws IOException { catch (IOException e) { // expected } + is.close(); } finally { if (client != null) { client.close(); @@ -169,8 +244,22 @@ public void testGetFailsDuringStreaming() throws IOException { * Tests that concurrent get operations don't concurrently access the BlobStore to download a blob. */ @Test - public void testConcurrentGetOperations() throws IOException, ExecutionException, InterruptedException { + public void testConcurrentGetOperationsNoJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentGetOperations(null); + } + + /** + * FLINK-6020 + * + * Tests that concurrent get operations don't concurrently access the BlobStore to download a blob. + */ + @Test + public void testConcurrentGetOperationsForJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentGetOperations(new JobID()); + } + private void testConcurrentGetOperations(final JobID jobId) + throws IOException, InterruptedException, ExecutionException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); @@ -191,14 +280,14 @@ public void testConcurrentGetOperations() throws IOException, ExecutionException new Answer() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { - File targetFile = (File) invocation.getArguments()[1]; + File targetFile = (File) invocation.getArguments()[2]; FileUtils.copyInputStreamToFile(bais, targetFile); return null; } } - ).when(blobStore).get(any(BlobKey.class), any(File.class)); + ).when(blobStore).get(any(JobID.class), any(BlobKey.class), any(File.class)); final ExecutorService executor = Executors.newFixedThreadPool(numberConcurrentGetOperations); @@ -207,7 +296,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { CompletableFuture getOperation = CompletableFuture.supplyAsync( () -> { try (BlobClient blobClient = blobServer.createClient(); - InputStream inputStream = blobClient.get(blobKey)) { + InputStream inputStream = getFileHelper(blobClient, jobId, blobKey)) { byte[] buffer = new byte[data.length]; IOUtils.readFully(inputStream, buffer); @@ -241,9 +330,18 @@ public Object answer(InvocationOnMock invocation) throws Throwable { } // verify that we downloaded the requested blob exactly once from the BlobStore - verify(blobStore, times(1)).get(eq(blobKey), any(File.class)); + verify(blobStore, times(1)).get(eq(jobId), eq(blobKey), any(File.class)); } finally { executor.shutdownNow(); } } + + static InputStream getFileHelper(BlobClient blobClient, JobID jobId, BlobKey blobKey) + throws IOException { + if (jobId == null) { + return blobClient.get(blobKey); + } else { + return blobClient.get(jobId, blobKey); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerPutTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerPutTest.java index c4791672a73e7..2b8e2d27c0443 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerPutTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobServerPutTest.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.blob; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.CheckedThread; @@ -45,7 +46,8 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import static org.junit.Assert.assertArrayEquals; +import static org.apache.flink.runtime.blob.BlobClientTest.validateGetAndClose; +import static org.apache.flink.runtime.blob.BlobServerGetTest.getFileHelper; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -71,28 +73,43 @@ public class BlobServerPutTest extends TestLogger { // --- concurrency tests for utility methods which could fail during the put operation --- /** - * Checked thread that calls {@link BlobServer#getStorageLocation(BlobKey)} + * Checked thread that calls {@link BlobServer#getStorageLocation(JobID, BlobKey)}. */ public static class ContentAddressableGetStorageLocation extends CheckedThread { private final BlobServer server; + private final JobID jobId; private final BlobKey key; - public ContentAddressableGetStorageLocation(BlobServer server, BlobKey key) { + public ContentAddressableGetStorageLocation(BlobServer server, JobID jobId, BlobKey key) { this.server = server; + this.jobId = jobId; this.key = key; } @Override public void go() throws Exception { - server.getStorageLocation(key); + server.getStorageLocation(jobId, key); } } /** - * Tests concurrent calls to {@link BlobServer#getStorageLocation(BlobKey)}. + * Tests concurrent calls to {@link BlobServer#getStorageLocation(JobID, BlobKey)}. */ @Test - public void testServerContentAddressableGetStorageLocationConcurrent() throws Exception { + public void testServerContentAddressableGetStorageLocationConcurrentNoJob() throws Exception { + testServerContentAddressableGetStorageLocationConcurrent(null); + } + + /** + * Tests concurrent calls to {@link BlobServer#getStorageLocation(JobID, BlobKey)}. + */ + @Test + public void testServerContentAddressableGetStorageLocationConcurrentForJob() throws Exception { + testServerContentAddressableGetStorageLocationConcurrent(new JobID()); + } + + private void testServerContentAddressableGetStorageLocationConcurrent(final JobID jobId) + throws Exception { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); @@ -101,9 +118,9 @@ public void testServerContentAddressableGetStorageLocationConcurrent() throws Ex try { BlobKey key = new BlobKey(); CheckedThread[] threads = new CheckedThread[] { - new ContentAddressableGetStorageLocation(server, key), - new ContentAddressableGetStorageLocation(server, key), - new ContentAddressableGetStorageLocation(server, key) + new ContentAddressableGetStorageLocation(server, jobId, key), + new ContentAddressableGetStorageLocation(server, jobId, key), + new ContentAddressableGetStorageLocation(server, jobId, key) }; checkedThreadSimpleTest(threads); } finally { @@ -134,7 +151,27 @@ protected void checkedThreadSimpleTest(CheckedThread[] threads) // -------------------------------------------------------------------------------------------- @Test - public void testPutBufferSuccessful() throws IOException { + public void testPutBufferSuccessfulGet1() throws IOException { + testPutBufferSuccessfulGet(null, null); + } + + @Test + public void testPutBufferSuccessfulGet2() throws IOException { + testPutBufferSuccessfulGet(null, new JobID()); + } + + @Test + public void testPutBufferSuccessfulGet3() throws IOException { + testPutBufferSuccessfulGet(new JobID(), new JobID()); + } + + @Test + public void testPutBufferSuccessfulGet4() throws IOException { + testPutBufferSuccessfulGet(new JobID(), null); + } + + private void testPutBufferSuccessfulGet(final JobID jobId1, final JobID jobId2) + throws IOException { BlobServer server = null; BlobClient client = null; @@ -150,17 +187,63 @@ public void testPutBufferSuccessful() throws IOException { byte[] data = new byte[2000000]; rnd.nextBytes(data); - // put content addressable (like libraries) - BlobKey key1 = client.put(data); - assertNotNull(key1); + // put data for jobId1 and verify + BlobKey key1a = client.put(jobId1, data); + assertNotNull(key1a); + + BlobKey key1b = client.put(jobId1, data, 10, 44); + assertNotNull(key1b); + + testPutBufferSuccessfulGet(jobId1, key1a, key1b, data, serverAddress, config); + + // now put data for jobId2 and verify that both are ok + BlobKey key2a = client.put(jobId2, data); + assertNotNull(key2a); + assertEquals(key1a, key2a); + + BlobKey key2b = client.put(jobId2, data, 10, 44); + assertNotNull(key2b); + assertEquals(key1b, key2b); - BlobKey key2 = client.put(data, 10, 44); - assertNotNull(key2); - // --- GET the data and check that it is equal --- + testPutBufferSuccessfulGet(jobId1, key1a, key1b, data, serverAddress, config); + testPutBufferSuccessfulGet(jobId2, key2a, key2b, data, serverAddress, config); - // one get request on the same client - InputStream is1 = client.get(key2); + + } finally { + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + } + + /** + * GET the data stored at the two keys and check that it is equal to data. + * + * @param jobId + * job ID or null if job-unrelated + * @param key1 + * first key for 44 bytes starting at byte 10 of data in the BLOB + * @param key2 + * second key for the complete data in the BLOB + * @param data + * expected data + * @param serverAddress + * BlobServer address to connect to + * @param config + * client configuration + */ + private static void testPutBufferSuccessfulGet( + JobID jobId, BlobKey key1, BlobKey key2, byte[] data, + InetSocketAddress serverAddress, Configuration config) throws IOException { + + BlobClient client = new BlobClient(serverAddress, config); + + // one get request on the same client + try (InputStream is1 = getFileHelper(client, jobId, key2)) { byte[] result1 = new byte[44]; BlobUtils.readFully(is1, result1, 0, result1.length, null); is1.close(); @@ -169,28 +252,27 @@ public void testPutBufferSuccessful() throws IOException { assertEquals(data[j], result1[i]); } - // close the client and create a new one for the remaining requests + // close the client and create a new one for the remaining request client.close(); client = new BlobClient(serverAddress, config); - InputStream is2 = client.get(key1); - byte[] result2 = new byte[data.length]; - BlobUtils.readFully(is2, result2, 0, result2.length, null); - is2.close(); - assertArrayEquals(data, result2); + validateGetAndClose(getFileHelper(client, jobId, key1), data); } finally { - if (client != null) { - client.close(); - } - if (server != null) { - server.close(); - } + client.close(); } } + @Test + public void testPutStreamSuccessfulNoJob() throws IOException { + testPutStreamSuccessful(null); + } @Test - public void testPutStreamSuccessful() throws IOException { + public void testPutStreamSuccessfulForJob() throws IOException { + testPutStreamSuccessful(new JobID()); + } + + private void testPutStreamSuccessful(final JobID jobId) throws IOException { BlobServer server = null; BlobClient client = null; @@ -208,7 +290,12 @@ public void testPutStreamSuccessful() throws IOException { // put content addressable (like libraries) { - BlobKey key1 = client.put(new ByteArrayInputStream(data)); + BlobKey key1; + if (jobId == null) { + key1 = client.put(new ByteArrayInputStream(data)); + } else { + key1 = client.put(jobId, new ByteArrayInputStream(data)); + } assertNotNull(key1); } } finally { @@ -226,7 +313,16 @@ public void testPutStreamSuccessful() throws IOException { } @Test - public void testPutChunkedStreamSuccessful() throws IOException { + public void testPutChunkedStreamSuccessfulNoJob() throws IOException { + testPutChunkedStreamSuccessful(null); + } + + @Test + public void testPutChunkedStreamSuccessfulForJob() throws IOException { + testPutChunkedStreamSuccessful(new JobID()); + } + + private void testPutChunkedStreamSuccessful(final JobID jobId) throws IOException { BlobServer server = null; BlobClient client = null; @@ -244,7 +340,12 @@ public void testPutChunkedStreamSuccessful() throws IOException { // put content addressable (like libraries) { - BlobKey key1 = client.put(new ChunkedInputStream(data, 19)); + BlobKey key1; + if (jobId == null) { + key1 = client.put(new ChunkedInputStream(data, 19)); + } else { + key1 = client.put(jobId, new ChunkedInputStream(data, 19)); + } assertNotNull(key1); } } finally { @@ -258,7 +359,16 @@ public void testPutChunkedStreamSuccessful() throws IOException { } @Test - public void testPutBufferFails() throws IOException { + public void testPutBufferFailsNoJob() throws IOException { + testPutBufferFails(null); + } + + @Test + public void testPutBufferFailsForJob() throws IOException { + testPutBufferFails(new JobID()); + } + + private void testPutBufferFails(final JobID jobId) throws IOException { assumeTrue(!OperatingSystem.isWindows()); //setWritable doesn't work on Windows. BlobServer server = null; @@ -285,7 +395,7 @@ public void testPutBufferFails() throws IOException { // put content addressable (like libraries) try { - client.put(data); + client.put(jobId, data); fail("This should fail."); } catch (IOException e) { @@ -293,7 +403,7 @@ public void testPutBufferFails() throws IOException { } try { - client.put(data); + client.put(jobId, data); fail("Client should be closed"); } catch (IllegalStateException e) { @@ -320,7 +430,22 @@ public void testPutBufferFails() throws IOException { * Tests that concurrent put operations will only upload the file once to the {@link BlobStore}. */ @Test - public void testConcurrentPutOperations() throws IOException, ExecutionException, InterruptedException { + public void testConcurrentPutOperationsNoJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentPutOperations(null); + } + + /** + * FLINK-6020 + * + * Tests that concurrent put operations will only upload the file once to the {@link BlobStore}. + */ + @Test + public void testConcurrentPutOperationsForJob() throws IOException, ExecutionException, InterruptedException { + testConcurrentPutOperations(new JobID()); + } + + private void testConcurrentPutOperations(final JobID jobId) + throws IOException, InterruptedException, ExecutionException { final Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); @@ -331,7 +456,7 @@ public void testConcurrentPutOperations() throws IOException, ExecutionException final CountDownLatch countDownLatch = new CountDownLatch(concurrentPutOperations); final byte[] data = new byte[dataSize]; - ArrayList> allFutures = new ArrayList(concurrentPutOperations); + ArrayList> allFutures = new ArrayList<>(concurrentPutOperations); ExecutorService executor = Executors.newFixedThreadPool(concurrentPutOperations); @@ -342,7 +467,13 @@ public void testConcurrentPutOperations() throws IOException, ExecutionException CompletableFuture putFuture = CompletableFuture.supplyAsync( () -> { try (BlobClient blobClient = blobServer.createClient()) { - return blobClient.put(new BlockingInputStream(countDownLatch, data)); + if (jobId == null) { + return blobClient + .put(new BlockingInputStream(countDownLatch, data)); + } else { + return blobClient + .put(jobId, new BlockingInputStream(countDownLatch, data)); + } } catch (IOException e) { throw new FlinkFutureException("Could not upload blob.", e); } @@ -369,7 +500,7 @@ public void testConcurrentPutOperations() throws IOException, ExecutionException } // check that we only uploaded the file once to the blob store - verify(blobStore, times(1)).put(any(File.class), eq(blobKey)); + verify(blobStore, times(1)).put(any(File.class), eq(jobId), eq(blobKey)); } finally { executor.shutdownNow(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobUtilsTest.java index 2987c3976bea5..a6ac44790eeaa 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobUtilsTest.java @@ -18,11 +18,10 @@ package org.apache.flink.runtime.blob; -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; -import static org.mockito.Mockito.mock; - +import org.apache.flink.api.common.JobID; import org.apache.flink.util.OperatingSystem; +import org.apache.flink.util.TestLogger; + import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -32,7 +31,11 @@ import java.io.File; import java.io.IOException; -public class BlobUtilsTest { +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; +import static org.mockito.Mockito.mock; + +public class BlobUtilsTest extends TestLogger { private final static String CANNOT_CREATE_THIS = "cannot-create-this"; @@ -62,12 +65,18 @@ public void after() { public void testExceptionOnCreateStorageDirectoryFailure() throws IOException { // Should throw an Exception - BlobUtils.initStorageDirectory(new File(blobUtilsTestDirectory, CANNOT_CREATE_THIS).getAbsolutePath()); + BlobUtils.initLocalStorageDirectory(new File(blobUtilsTestDirectory, CANNOT_CREATE_THIS).getAbsolutePath()); + } + + @Test(expected = Exception.class) + public void testExceptionOnCreateCacheDirectoryFailureNoJob() { + // Should throw an Exception + BlobUtils.getStorageLocation(new File(blobUtilsTestDirectory, CANNOT_CREATE_THIS), null, mock(BlobKey.class)); } @Test(expected = Exception.class) - public void testExceptionOnCreateCacheDirectoryFailure() { + public void testExceptionOnCreateCacheDirectoryFailureForJob() { // Should throw an Exception - BlobUtils.getStorageLocation(new File(blobUtilsTestDirectory, CANNOT_CREATE_THIS), mock(BlobKey.class)); + BlobUtils.getStorageLocation(new File(blobUtilsTestDirectory, CANNOT_CREATE_THIS), new JobID(), mock(BlobKey.class)); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java index d293eea1c757e..edc29feb7c397 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java @@ -18,14 +18,6 @@ package org.apache.flink.runtime.checkpoint; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -import java.io.File; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader; import org.apache.flink.runtime.concurrent.Executors; @@ -37,11 +29,22 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.filesystem.FileStateHandle; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + /** * CheckpointCoordinator tests for externalized checkpoints. * @@ -91,7 +94,8 @@ public void testTriggerAndConfirmSimpleExternalizedCheckpoint() new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), checkpointDir.getAbsolutePath(), - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 344b34093d9b7..7c95a34720e5b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -23,14 +23,14 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; + import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -42,8 +42,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -78,7 +78,8 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { new StandaloneCheckpointIDCounter(), new FailingCompletedCheckpointStore(), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.triggerCheckpoint(triggerTimestamp, false); @@ -89,31 +90,25 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertFalse(pendingCheckpoint.isDiscarded()); final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); - - SubtaskState subtaskState = mock(SubtaskState.class); - StreamStateHandle legacyHandle = mock(StreamStateHandle.class); - ChainedStateHandle chainedLegacyHandle = mock(ChainedStateHandle.class); - when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle); - when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle); + KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); + KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class); + OperatorStateHandle managedOpHandle = mock(OperatorStateHandle.class); + OperatorStateHandle rawOpHandle = mock(OperatorStateHandle.class); - OperatorStateHandle managedHandle = mock(OperatorStateHandle.class); - ChainedStateHandle chainedManagedHandle = mock(ChainedStateHandle.class); - when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle); - when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle); + final OperatorSubtaskState operatorSubtaskState = spy(new OperatorSubtaskState( + managedOpHandle, + rawOpHandle, + managedKeyedHandle, + rawKeyedHandle)); - OperatorStateHandle rawHandle = mock(OperatorStateHandle.class); - ChainedStateHandle chainedRawHandle = mock(ChainedStateHandle.class); - when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle); - when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle); + TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot()); + subtaskState.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState); + + when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState); - KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); - when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle); - KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class); - when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle); - AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState); - + try { coord.receiveAcknowledgeMessage(acknowledgeMessage); fail("Expected a checkpoint exception because the completed checkpoint store could not " + @@ -126,17 +121,17 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. - verify(subtaskState.getLegacyOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedOperatorState().get(0)).discardState(); - verify(subtaskState.getRawOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedKeyedState()).discardState(); - verify(subtaskState.getRawKeyedState()).discardState(); + verify(operatorSubtaskState).discardState(); + verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { @Override - public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { + public void recover() throws Exception { throw new UnsupportedOperationException("Not implemented."); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java index e23f6a2f7d8f7..2f860e0da2e6c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java @@ -28,9 +28,9 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Test; - import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -46,14 +46,12 @@ import java.util.concurrent.Executor; import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex; - import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; - import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isNull; import static org.mockito.Mockito.any; @@ -404,7 +402,8 @@ private static CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(10), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); } private static T mockGeneric(Class clazz) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index e78152aef9fff..4193c2c66e6a5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.checkpoint; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.java.tuple.Tuple2; @@ -38,29 +36,36 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.IncrementalKeyedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.PlaceholderStreamStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.SharedStateRegistryFactory; +import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; import java.io.IOException; import java.io.Serializable; @@ -85,13 +90,11 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -100,7 +103,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; /** * Tests for the checkpoint coordinator. @@ -140,7 +142,8 @@ public void testCheckpointAbortsIfTriggerTasksAreNotExecuted() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); @@ -200,7 +203,8 @@ public void testCheckpointAbortsIfTriggerTasksAreFinished() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); @@ -251,7 +255,8 @@ public void testCheckpointAbortsIfAckTasksAreNotExecuted() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // nothing should be happening assertEquals(0, coord.getNumberOfPendingCheckpoints()); @@ -303,7 +308,8 @@ public void testTriggerAndDeclineCheckpointSimple() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -407,7 +413,8 @@ public void testTriggerAndDeclineCheckpointComplex() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -526,7 +533,8 @@ public void testTriggerAndConfirmSimpleCheckpoint() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -553,31 +561,29 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map operatorStates = checkpoint.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism())); - // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); } + OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); + // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); + verify(taskOperatorSubtaskStates2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); @@ -586,8 +592,7 @@ public void testTriggerAndConfirmSimpleCheckpoint() { verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -626,9 +631,7 @@ public void testTriggerAndConfirmSimpleCheckpoint() { long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -698,7 +701,8 @@ public void testMultipleConcurrentCheckpoints() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -828,7 +832,8 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(10), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -850,18 +855,20 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId()); - Map operatorStates1 = pending1.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot()); - operatorStates1.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates1.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates1.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1); + taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2); + taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3); // acknowledge one of the three tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2)); + // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); @@ -878,14 +885,17 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { } long checkpointId2 = pending2.getCheckpointId(); - Map operatorStates2 = pending2.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot()); + + OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class); - operatorStates2.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates2.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates2.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1); + taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2); + taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); @@ -894,17 +904,13 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2)); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -936,8 +942,7 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // send the last remaining ack for the first checkpoint. This should not do anything - SubtaskState subtaskState1_3 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3)); verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); @@ -992,7 +997,8 @@ public void testCheckpointTimeoutIsolated() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // trigger a checkpoint, partially acknowledged assertTrue(coord.triggerCheckpoint(timestamp, false)); @@ -1003,13 +1009,11 @@ public void testCheckpointTimeoutIsolated() { OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - Map operatorStates = checkpoint.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1)); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -1027,7 +1031,7 @@ public void testCheckpointTimeoutIsolated() { assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that the received states have been discarded - verify(subtaskState, times(1)).discardState(); + verify(subtaskState1, times(1)).discardState(); // no confirm message must have been sent verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong()); @@ -1071,7 +1075,8 @@ public void testHandleMessagesForNonExistingCheckpoints() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertTrue(coord.triggerCheckpoint(timestamp, false)); @@ -1134,7 +1139,8 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertTrue(coord.triggerCheckpoint(timestamp, false)); @@ -1145,26 +1151,18 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { long checkpointId = pendingCheckpoint.getCheckpointId(); OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId()); - OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); - Map operatorStates = pendingCheckpoint.getOperatorStates(); - - operatorStates.put(opIDtrigger, new SpyInjectingOperatorState( - opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism())); - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, subtaskStateTrigger); // acknowledge the first trigger vertex - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStatesTrigger)); // verify that the subtask state has not been discarded - verify(storedTriggerSubtaskState, never()).discardState(); + verify(subtaskStateTrigger, never()).discardState(); - SubtaskState unknownSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); @@ -1172,7 +1170,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { // we should discard acknowledge messages from an unknown vertex belonging to our job verify(unknownSubtaskState, times(1)).discardState(); - SubtaskState differentJobSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot differentJobSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message from an unknown job coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); @@ -1181,22 +1179,22 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex - SubtaskState triggerSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex - reset(storedTriggerSubtaskState); + reset(subtaskStateTrigger); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state - verify(storedTriggerSubtaskState, times(1)).discardState(); + verify(subtaskStateTrigger, times(1)).discardState(); - SubtaskState ackSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot ackSubtaskState = mock(TaskStateSnapshot.class); // late acknowledge message from the second ack vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); @@ -1211,7 +1209,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { // we should not interfere with different jobs verify(differentJobSubtaskState, never()).discardState(); - SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState2 = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); @@ -1274,7 +1272,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.startCheckpointScheduler(); @@ -1366,7 +1365,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), "dummy-path", - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); try { coord.startCheckpointScheduler(); @@ -1439,7 +1439,8 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -1468,18 +1469,16 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); @@ -1493,8 +1492,7 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertFalse(savepointFuture.isDone()); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1534,9 +1532,6 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); - assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -1596,7 +1591,8 @@ public void testSavepointsAreNotSubsumed() throws Exception { counter, new StandaloneCompletedCheckpointStore(10), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); String savepointDir = tmpFolder.newFolder().getAbsolutePath(); @@ -1702,7 +1698,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.startCheckpointScheduler(); @@ -1775,7 +1772,8 @@ public void testMaxConcurrentAttempsWithSubsumption() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.startCheckpointScheduler(); @@ -1857,7 +1855,8 @@ public ExecutionState answer(InvocationOnMock invocation){ new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.startCheckpointScheduler(); @@ -1909,7 +1908,8 @@ public void testConcurrentSavepoints() throws Exception { checkpointIDCounter, new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); List> savepointFutures = new ArrayList<>(); @@ -1962,7 +1962,8 @@ public void testMinDelayBetweenSavepoints() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); String savepointDir = tmpFolder.newFolder().getAbsolutePath(); @@ -2024,7 +2025,8 @@ public void testRestoreLatestCheckpointedState() throws Exception { new StandaloneCheckpointIDCounter(), store, null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); @@ -2035,20 +2037,8 @@ public void testRestoreLatestCheckpointedState() throws Exception { List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); - PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); - - OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1); - OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2); - - Map operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism())); - for (int index = 0; index < jobVertex1.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2061,7 +2051,7 @@ public void testRestoreLatestCheckpointedState() throws Exception { } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2150,43 +2140,45 @@ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2218,131 +2210,6 @@ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws fail("The restoration should have failed because the max parallelism changed."); } - /** - * Tests that the checkpoint restoration fails if the parallelism of a job vertices with - * non-partitioned state has changed. - * - * @throws Exception - */ - @Test(expected=IllegalStateException.class) - public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Exception { - final JobID jid = new JobID(); - final long timestamp = System.currentTimeMillis(); - - final JobVertexID jobVertexID1 = new JobVertexID(); - final JobVertexID jobVertexID2 = new JobVertexID(); - int parallelism1 = 3; - int parallelism2 = 2; - int maxParallelism1 = 42; - int maxParallelism2 = 13; - - final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( - jobVertexID1, - parallelism1, - maxParallelism1); - final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( - jobVertexID2, - parallelism2, - maxParallelism2); - - List allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); - - allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); - allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); - - ExecutionVertex[] arrayExecutionVertices = - allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); - - // set up the coordinator and validate the initial state - CheckpointCoordinator coord = new CheckpointCoordinator( - jid, - 600000, - 600000, - 0, - Integer.MAX_VALUE, - ExternalizedCheckpointSettings.none(), - arrayExecutionVertices, - arrayExecutionVertices, - arrayExecutionVertices, - new StandaloneCheckpointIDCounter(), - new StandaloneCompletedCheckpointStore(1), - null, - Executors.directExecutor()); - - // trigger the checkpoint - coord.triggerCheckpoint(timestamp, false); - - assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); - long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - - List keyGroupPartitions1 = - StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); - List keyGroupPartitions2 = - StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); - - for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - KeyGroupsStateHandle keyGroupState = generateKeyGroupState( - jobVertexID1, keyGroupPartitions1.get(index), false); - - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); - AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - new CheckpointMetrics(), - checkpointStateHandles); - - coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); - } - - - for (int index = 0; index < jobVertex2.getParallelism(); index++) { - - ChainedStateHandle state = generateStateForVertex(jobVertexID2, index); - KeyGroupsStateHandle keyGroupState = generateKeyGroupState( - jobVertexID2, keyGroupPartitions2.get(index), false); - - SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null); - AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - new CheckpointMetrics(), - checkpointStateHandles); - - coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); - } - - List completedCheckpoints = coord.getSuccessfulCheckpoints(); - - assertEquals(1, completedCheckpoints.size()); - - Map tasks = new HashMap<>(); - - int newParallelism1 = 4; - int newParallelism2 = 3; - - final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( - jobVertexID1, - newParallelism1, - maxParallelism1); - - final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( - jobVertexID2, - newParallelism2, - maxParallelism2); - - tasks.put(jobVertexID1, newJobVertex1); - tasks.put(jobVertexID2, newJobVertex2); - - coord.restoreLatestCheckpointedState(tasks, true, false); - - fail("The restoration should have failed because the parallelism of an vertex with " + - "non-partitioned state changed."); - } - @Test public void testRestoreLatestCheckpointedStateScaleIn() throws Exception { testRestoreLatestCheckpointedStateWithChangingParallelism(false); @@ -2420,7 +2287,8 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // trigger the checkpoint coord.triggerCheckpoint(timestamp, false); @@ -2436,18 +2304,19 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s //vertex 1 for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false); KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, null, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2458,19 +2327,21 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s for (int index = 0; index < jobVertex2.getParallelism(); index++) { KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true); - ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); - ChainedStateHandle opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true); - expectedOpStatesBackend.add(opStateBackend); - expectedOpStatesRaw.add(opStateRaw); - SubtaskState checkpointStateHandles = - new SubtaskState(new ChainedStateHandle<>( - Collections.singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false); + OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true); + expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend))); + expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw))); + + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2504,27 +2375,36 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s List>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); List>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); - KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + List operatorIDs = newJobVertex2.getOperatorIDs(); - ChainedStateHandle operatorState = taskStateHandles.getLegacyOperatorState(); - List> opStateBackend = taskStateHandles.getManagedOperatorState(); - List> opStateRaw = taskStateHandles.getRawOperatorState(); - Collection keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); + KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - actualOpStatesBackend.add(opStateBackend); - actualOpStatesRaw.add(opStateRaw); - // the 'non partition state' is not null because it is recombined. - assertNotNull(operatorState); - for (int index = 0; index < operatorState.getLength(); index++) { - assertNull(operatorState.get(index)); + TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + final int headOpIndex = operatorIDs.size() - 1; + List> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size()); + List> allParallelRawOpStates = new ArrayList<>(operatorIDs.size()); + + for (int idx = 0; idx < operatorIDs.size(); ++idx) { + OperatorID operatorID = operatorIDs.get(idx); + OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID); + Collection opStateBackend = opState.getManagedOperatorState(); + Collection opStateRaw = opState.getRawOperatorState(); + allParallelManagedOpStates.add(opStateBackend); + allParallelRawOpStates.add(opStateRaw); + if (idx == headOpIndex) { + Collection keyedStateBackend = opState.getManagedKeyedState(); + Collection keyGroupStateRaw = opState.getRawKeyedState(); + compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); + compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + } } - compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); - compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + actualOpStatesBackend.add(allParallelManagedOpStates); + actualOpStatesRaw.add(allParallelRawOpStates); } + comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw); } @@ -2575,17 +2455,11 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception OperatorState taskState = new OperatorState(id.f1, parallelism1, maxParallelism1); operatorStates.put(id.f1, taskState); for (int index = 0; index < taskState.getParallelism(); index++) { - StreamStateHandle subNonPartitionedState = - generateStateForVertex(id.f0, index) - .get(0); OperatorStateHandle subManagedOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false) - .get(0); + generatePartitionableStateHandle(id.f0, index, 2, 8, false); OperatorStateHandle subRawOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true) - .get(0); - - OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState, + generatePartitionableStateHandle(id.f0, index, 2, 8, true); + OperatorSubtaskState subtaskState = new OperatorSubtaskState( subManagedOperatorState, subRawOperatorState, null, @@ -2623,7 +2497,6 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception expectedRawOperatorState.add(ChainedStateHandle.wrapSingleHandle(subRawOperatorState)); OperatorSubtaskState subtaskState = new OperatorSubtaskState( - null, subManagedOperatorState, subRawOperatorState, subManagedKeyedState, @@ -2699,63 +2572,71 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception new StandaloneCheckpointIDCounter(), standaloneCompletedCheckpointStore, null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); coord.restoreLatestCheckpointedState(tasks, false, true); for (int i = 0; i < newJobVertex1.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); - ChainedStateHandle actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState(); - List> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState(); - List> actualSubRawOperatorState = taskStateHandles.getRawOperatorState(); + final List operatorIds = newJobVertex1.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); - assertNull(taskStateHandles.getManagedKeyedState()); - assertNull(taskStateHandles.getRawKeyedState()); + OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); + assertTrue(headOpState.getManagedKeyedState().isEmpty()); + assertTrue(headOpState.getRawKeyedState().isEmpty()); // operator5 { int operatorIndexInChain = 2; - assertNull(actualSubNonPartitionedState.get(operatorIndexInChain)); - assertNull(actualSubManagedOperatorState.get(operatorIndexInChain)); - assertNull(actualSubRawOperatorState.get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } // operator1 { int operatorIndexInChain = 1; - ChainedStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); - ChainedStateHandle expectedManagedOpState = generateChainedPartitionableStateHandle( + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, false); - ChainedStateHandle expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, true); - assertTrue(CommonTestUtils.isSteamContentEqual( - expectSubNonPartitionedState.get(0).openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); + Collection managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), + managedOperatorState.iterator().next().openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + Collection rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), + rawOperatorState.iterator().next().openInputStream())); } // operator2 { int operatorIndexInChain = 0; - ChainedStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); - ChainedStateHandle expectedManagedOpState = generateChainedPartitionableStateHandle( + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, false); - ChainedStateHandle expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, true); - assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); - - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + Collection managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), + managedOperatorState.iterator().next().openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + Collection rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), + rawOperatorState.iterator().next().openInputStream())); } } @@ -2763,38 +2644,45 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception List>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + + final List operatorIds = newJobVertex2.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); // operator 3 { int operatorIndexInChain = 1; + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + List> actualSubManagedOperatorState = new ArrayList<>(1); - actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); + actualSubManagedOperatorState.add(opState.getManagedOperatorState()); List> actualSubRawOperatorState = new ArrayList<>(1); - actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); + actualSubRawOperatorState.add(opState.getRawOperatorState()); actualManagedOperatorStates.add(actualSubManagedOperatorState); actualRawOperatorStates.add(actualSubRawOperatorState); - - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); } // operator 6 { int operatorIndexInChain = 0; - assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true); + OperatorSubtaskState headOpState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); - Collection keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + Collection keyedStateBackend = headOpState.getManagedKeyedState(); + Collection keyGroupStateRaw = headOpState.getRawKeyedState(); compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); @@ -2832,7 +2720,8 @@ public void testExternalizedCheckpoints() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), "fake-directory", - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); assertTrue(coord.triggerCheckpoint(timestamp, false)); @@ -2972,19 +2861,50 @@ public static Tuple2> serializeTogetherAndTrackOffsets( return new Tuple2<>(allSerializedValuesConcatenated, offsets); } - public static ChainedStateHandle generateStateForVertex( + public static StreamStateHandle generateStateForVertex( JobVertexID jobVertexID, int index) throws IOException { Random random = new Random(jobVertexID.hashCode() + index); int value = random.nextInt(); - return generateChainedStateHandle(value); + return generateStreamStateHandle(value); + } + + public static StreamStateHandle generateStreamStateHandle(Serializable value) throws IOException { + return TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value); } public static ChainedStateHandle generateChainedStateHandle( Serializable value) throws IOException { return ChainedStateHandle.wrapSingleHandle( - TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value)); + generateStreamStateHandle(value)); + } + + public static OperatorStateHandle generatePartitionableStateHandle( + JobVertexID jobVertexID, + int index, + int namedStates, + int partitionsPerState, + boolean rawState) throws IOException { + + Map> statesListsMap = new HashMap<>(namedStates); + + for (int i = 0; i < namedStates; ++i) { + List testStatesLists = new ArrayList<>(partitionsPerState); + // generate state + int seed = jobVertexID.hashCode() * index + i * namedStates; + if (rawState) { + seed = (seed + 1) * 31; + } + Random random = new Random(seed); + for (int j = 0; j < partitionsPerState; ++j) { + int simulatedStateValue = random.nextInt(); + testStatesLists.add(simulatedStateValue); + } + statesListsMap.put("state-" + i, testStatesLists); + } + + return generatePartitionableStateHandle(statesListsMap); } public static ChainedStateHandle generateChainedPartitionableStateHandle( @@ -3011,11 +2931,11 @@ public static ChainedStateHandle generateChainedPartitionab statesListsMap.put("state-" + i, testStatesLists); } - return generateChainedPartitionableStateHandle(statesListsMap); + return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap)); } - private static ChainedStateHandle generateChainedPartitionableStateHandle( - Map> states) throws IOException { + private static OperatorStateHandle generatePartitionableStateHandle( + Map> states) throws IOException { List> namedStateSerializables = new ArrayList<>(states.size()); @@ -3030,20 +2950,18 @@ private static ChainedStateHandle generateChainedPartitiona int idx = 0; for (Map.Entry> entry : states.entrySet()) { offsetsMap.put( - entry.getKey(), - new OperatorStateHandle.StateMetaInfo( - serializationWithOffsets.f1.get(idx), - OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + entry.getKey(), + new OperatorStateHandle.StateMetaInfo( + serializationWithOffsets.f1.get(idx), + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); ++idx; } ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare( - String.valueOf(UUID.randomUUID()), - serializationWithOffsets.f0); + String.valueOf(UUID.randomUUID()), + serializationWithOffsets.f0); - OperatorStateHandle operatorStateHandle = - new OperatorStateHandle(offsetsMap, streamStateHandle); - return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); + return new OperatorStateHandle(offsetsMap, streamStateHandle); } static ExecutionJobVertex mockExecutionJobVertex( @@ -3137,24 +3055,22 @@ private static ExecutionVertex mockExecutionVertex( return vertex; } - static SubtaskState mockSubtaskState( + static TaskStateSnapshot mockSubtaskState( JobVertexID jobVertexID, int index, KeyGroupRange keyGroupRange) throws IOException { - ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false); + OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false); KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false); - SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable()); + TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState( + partitionableState, null, partitionedKeyGroupState, null) + ); - doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState(); - doReturn(partitionableState).when(subtaskState).getManagedOperatorState(); - doReturn(null).when(subtaskState).getRawOperatorState(); - doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState(); - doReturn(null).when(subtaskState).getRawKeyedState(); + subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState); - return subtaskState; + return subtaskStates; } public static void verifyStateRestore( @@ -3163,27 +3079,20 @@ public static void verifyStateRestore( for (int i = 0; i < executionJobVertex.getParallelism(); i++) { - TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); - ChainedStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); - ChainedStateHandle actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); - assertTrue(CommonTestUtils.isSteamContentEqual( - expectNonPartitionedState.get(0).openInputStream(), - actualNonPartitionedState.get(0).openInputStream())); + OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID)); ChainedStateHandle expectedOpStateBackend = generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false); - List> actualPartitionableState = taskStateHandles.getManagedOperatorState(); - assertTrue(CommonTestUtils.isSteamContentEqual( expectedOpStateBackend.get(0).openInputStream(), - actualPartitionableState.get(0).iterator().next().openInputStream())); + operatorState.getManagedOperatorState().iterator().next().openInputStream())); KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState( jobVertexID, keyGroupPartitions.get(i), false); - Collection actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState(); - compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState); + compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState()); } } @@ -3308,7 +3217,8 @@ public void testStopPeriodicScheduler() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // Periodic CheckpointTriggerResult triggerResult = coord.triggerCheckpoint( @@ -3486,7 +3396,8 @@ public void testCheckpointStatsTrackerPendingCheckpointCallback() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); coord.setCheckpointStatsTracker(tracker); @@ -3524,7 +3435,8 @@ public void testCheckpointStatsTrackerRestoreCallback() throws Exception { new StandaloneCheckpointIDCounter(), store, null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); store.addCheckpoint(new CompletedCheckpoint( new JobID(), @@ -3580,7 +3492,8 @@ public void testSavepointsAreNotAddedToCompletedCheckpointStore() throws Excepti checkpointIDCounter, completedCheckpointStore, null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // trigger a first checkpoint assertTrue( @@ -3631,16 +3544,244 @@ public void testSavepointsAreNotAddedToCompletedCheckpointStore() throws Excepti completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast()); } - private static final class SpyInjectingOperatorState extends OperatorState { + @Test + public void testSharedStateRegistrationOnRestore() throws Exception { + + final JobID jid = new JobID(); + final long timestamp = System.currentTimeMillis(); + + final JobVertexID jobVertexID1 = new JobVertexID(); + + int parallelism1 = 2; + int maxParallelism1 = 4; + + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + + List allExecutionVertices = new ArrayList<>(parallelism1); + + allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); + + ExecutionVertex[] arrayExecutionVertices = + allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); + + RecoverableCompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(10); + + final List createdSharedStateRegistries = new ArrayList<>(2); + + // set up the coordinator and validate the initial state + CheckpointCoordinator coord = new CheckpointCoordinator( + jid, + 600000, + 600000, + 0, + Integer.MAX_VALUE, + ExternalizedCheckpointSettings.none(), + arrayExecutionVertices, + arrayExecutionVertices, + arrayExecutionVertices, + new StandaloneCheckpointIDCounter(), + store, + null, + Executors.directExecutor(), + new SharedStateRegistryFactory() { + @Override + public SharedStateRegistry create(Executor deleteExecutor) { + SharedStateRegistry instance = new SharedStateRegistry(deleteExecutor); + createdSharedStateRegistries.add(instance); + return instance; + } + }); + + final int numCheckpoints = 3; + + List keyGroupPartitions1 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); + + for (int i = 0; i < numCheckpoints; ++i) { + performIncrementalCheckpoint(jid, coord, jobVertex1, keyGroupPartitions1, timestamp + i, i); + } + + List completedCheckpoints = coord.getSuccessfulCheckpoints(); + assertEquals(numCheckpoints, completedCheckpoints.size()); + + int sharedHandleCount = 0; + + List> sharedHandlesByCheckpoint = new ArrayList<>(numCheckpoints); + + for (int i = 0; i < numCheckpoints; ++i) { + sharedHandlesByCheckpoint.add(new HashMap(2)); + } + + int cp = 0; + for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { + for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { + for (OperatorSubtaskState subtaskState : taskState.getStates()) { + for (KeyedStateHandle keyedStateHandle : subtaskState.getManagedKeyedState()) { + // test we are once registered with the current registry + verify(keyedStateHandle, times(1)).registerSharedStates(createdSharedStateRegistries.get(0)); + IncrementalKeyedStateHandle incrementalKeyedStateHandle = (IncrementalKeyedStateHandle) keyedStateHandle; + + sharedHandlesByCheckpoint.get(cp).putAll(incrementalKeyedStateHandle.getSharedState()); + + for (StreamStateHandle streamStateHandle : incrementalKeyedStateHandle.getSharedState().values()) { + assertTrue(!(streamStateHandle instanceof PlaceholderStreamStateHandle)); + verify(streamStateHandle, never()).discardState(); + ++sharedHandleCount; + } + + for (StreamStateHandle streamStateHandle : incrementalKeyedStateHandle.getPrivateState().values()) { + verify(streamStateHandle, never()).discardState(); + } + + verify(incrementalKeyedStateHandle.getMetaStateHandle(), never()).discardState(); + } + + verify(subtaskState, never()).discardState(); + } + } + ++cp; + } - private static final long serialVersionUID = -4004437428483663815L; + // 2 (parallelism) x (1 (CP0) + 2 (CP1) + 2 (CP2)) = 10 + assertEquals(10, sharedHandleCount); - public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { - super(taskID, parallelism, maxParallelism); + // discard CP0 + store.removeOldestCheckpoint(); + + // we expect no shared state was discarded because the state of CP0 is still referenced by CP1 + for (Map cpList : sharedHandlesByCheckpoint) { + for (StreamStateHandle streamStateHandle : cpList.values()) { + verify(streamStateHandle, never()).discardState(); + } } - public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { - super.putState(subtaskIndex, spy(subtaskState)); + // shutdown the store + store.shutdown(JobStatus.SUSPENDED); + + // restore the store + Map tasks = new HashMap<>(); + tasks.put(jobVertexID1, jobVertex1); + coord.restoreLatestCheckpointedState(tasks, true, false); + + // validate that all shared states are registered again after the recovery. + cp = 0; + for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { + for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { + for (OperatorSubtaskState subtaskState : taskState.getStates()) { + for (KeyedStateHandle keyedStateHandle : subtaskState.getManagedKeyedState()) { + VerificationMode verificationMode; + // test we are once registered with the new registry + if (cp > 0) { + verificationMode = times(1); + } else { + verificationMode = never(); + } + + //check that all are registered with the new registry + verify(keyedStateHandle, verificationMode).registerSharedStates(createdSharedStateRegistries.get(1)); + } + } + } + ++cp; + } + + // discard CP1 + store.removeOldestCheckpoint(); + + // we expect that all shared state from CP0 is no longer referenced and discarded. CP2 is still live and also + // references the state from CP1, so we expect they are not discarded. + for (Map cpList : sharedHandlesByCheckpoint) { + for (Map.Entry entry : cpList.entrySet()) { + String key = entry.getKey().getKeyString(); + int belongToCP = Integer.parseInt(String.valueOf(key.charAt(key.length() - 1))); + if (belongToCP == 0) { + verify(entry.getValue(), times(1)).discardState(); + } else { + verify(entry.getValue(), never()).discardState(); + } + } + } + + // discard CP2 + store.removeOldestCheckpoint(); + + // we expect all shared state was discarded now, because all CPs are + for (Map cpList : sharedHandlesByCheckpoint) { + for (StreamStateHandle streamStateHandle : cpList.values()) { + verify(streamStateHandle, times(1)).discardState(); + } + } + } + + private void performIncrementalCheckpoint( + JobID jid, + CheckpointCoordinator coord, + ExecutionJobVertex jobVertex1, + List keyGroupPartitions1, + long timestamp, + int cpSequenceNumber) throws Exception { + + // trigger the checkpoint + coord.triggerCheckpoint(timestamp, false); + + assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); + long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); + + for (int index = 0; index < jobVertex1.getParallelism(); index++) { + + KeyGroupRange keyGroupRange = keyGroupPartitions1.get(index); + + Map privateState = new HashMap<>(); + privateState.put( + new StateHandleID("private-1"), + spy(new ByteStreamStateHandle("private-1", new byte[]{'p'}))); + + Map sharedState = new HashMap<>(); + + // let all but the first CP overlap by one shared state. + if (cpSequenceNumber > 0) { + sharedState.put( + new StateHandleID("shared-" + (cpSequenceNumber - 1)), + spy(new PlaceholderStreamStateHandle())); + } + + sharedState.put( + new StateHandleID("shared-" + cpSequenceNumber), + spy(new ByteStreamStateHandle("shared-" + cpSequenceNumber + "-" + keyGroupRange, new byte[]{'s'}))); + + IncrementalKeyedStateHandle managedState = + spy(new IncrementalKeyedStateHandle( + new UUID(42L, 42L), + keyGroupRange, + checkpointId, + sharedState, + privateState, + spy(new ByteStreamStateHandle("meta", new byte[]{'m'})))); + + OperatorSubtaskState operatorSubtaskState = + spy(new OperatorSubtaskState( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(managedState), + Collections.emptyList())); + + Map opStates = new HashMap<>(); + + opStates.put(jobVertex1.getOperatorIDs().get(0), operatorSubtaskState); + + TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(opStates); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + new CheckpointMetrics(), + taskStateSnapshot); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointPropertiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointPropertiesTest.java index 52ac54ca37159..a0509c43ccb90 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointPropertiesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointPropertiesTest.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.checkpoint; +import org.apache.flink.util.InstantiationUtil; + import org.junit.Test; import static org.junit.Assert.assertFalse; @@ -109,6 +111,12 @@ public void testIsSavepoint() throws Exception { { CheckpointProperties props = CheckpointProperties.forStandardSavepoint(); assertTrue(props.isSavepoint()); + + CheckpointProperties deserializedCheckpointProperties = + InstantiationUtil.deserializeObject( + InstantiationUtil.serializeObject(props), + getClass().getClassLoader()); + assertTrue(deserializedCheckpointProperties.isSavepoint()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 7d2456881fc33..1788434255bdb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -29,23 +29,22 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.util.SerializableObject; + import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.junit.Test; import org.mockito.Mockito; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -66,7 +65,6 @@ public class CheckpointStateRestoreTest { public void testSetState() { try { - final ChainedStateHandle serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject()); KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0); List testStates = Collections.singletonList(new SerializableObject()); final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates); @@ -109,7 +107,8 @@ public void testSetState() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // create ourselves a checkpoint with state final long timestamp = 34623786L; @@ -118,10 +117,19 @@ public void testSetState() { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); + final TaskStateSnapshot subtaskStates = new TaskStateSnapshot(); + + subtaskStates.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(statefulId), + new OperatorSubtaskState( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(serializedKeyGroupStates), + Collections.emptyList())); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); @@ -133,33 +141,26 @@ public void testSetState() { // verify that each stateful vertex got the state - final TaskStateHandles taskStateHandles = new TaskStateHandles( - serializedState, - Collections.>singletonList(null), - Collections.>singletonList(null), - Collections.singletonList(serializedKeyGroupStates), - null); - - BaseMatcher matcher = new BaseMatcher() { + BaseMatcher matcher = new BaseMatcher() { @Override public boolean matches(Object o) { - if (o instanceof TaskStateHandles) { - return o.equals(taskStateHandles); + if (o instanceof TaskStateSnapshot) { + return Objects.equals(o, subtaskStates); } return false; } @Override public void describeTo(Description description) { - description.appendValue(taskStateHandles); + description.appendValue(subtaskStates); } }; verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher)); - verify(statelessExec1, times(0)).setInitialState(Mockito.any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.any()); + verify(statelessExec1, times(0)).setInitialState(Mockito.any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.any()); } catch (Exception e) { e.printStackTrace(); @@ -183,7 +184,8 @@ public void testNoCheckpointAvailable() { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); try { coord.restoreLatestCheckpointedState(new HashMap(), true, false); @@ -240,19 +242,16 @@ public void testNonRestoredState() throws Exception { new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, - Executors.directExecutor()); - - StreamStateHandle serializedState = CheckpointCoordinatorTest - .generateChainedStateHandle(new SerializableObject()) - .get(0); + Executors.directExecutor(), + SharedStateRegistry.DEFAULT_FACTORY); // --- (2) Checkpoint misses state for a jobVertex (should work) --- Map checkpointTaskStates = new HashMap<>(); { OperatorState taskState = new OperatorState(operatorId1, 3, 3); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState()); + taskState.putState(1, new OperatorSubtaskState()); + taskState.putState(2, new OperatorSubtaskState()); checkpointTaskStates.put(operatorId1, taskState); } @@ -279,7 +278,7 @@ public void testNonRestoredState() throws Exception { // There is no task for this { OperatorState taskState = new OperatorState(newOperatorID, 1, 1); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState()); checkpointTaskStates.put(newOperatorID, taskState); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 1fe4e65979cc8..320dc2df52bfb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -331,7 +331,7 @@ static class TestOperatorSubtaskState extends OperatorSubtaskState { boolean discarded; public TestOperatorSubtaskState() { - super(null, null, null, null, null); + super(); this.registered = false; this.discarded = false; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index 4846879244d14..293675c14e932 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -82,7 +82,7 @@ public void testCleanUpOnSubsume() throws Exception { operatorStates.put(new OperatorID(), state); boolean discardSubsumed = true; - CheckpointProperties props = new CheckpointProperties(false, false, discardSubsumed, true, true, true, true); + CheckpointProperties props = new CheckpointProperties(false, false, false, discardSubsumed, true, true, true, true); CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, @@ -122,7 +122,7 @@ public void testCleanUpOnShutdown() throws Exception { Mockito.reset(state); // Keep - CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); + CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false, false); CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(operatorStates), @@ -139,7 +139,7 @@ public void testCleanUpOnShutdown() throws Exception { assertEquals(true, file.exists()); // Discard - props = new CheckpointProperties(false, false, true, true, true, true, true); + props = new CheckpointProperties(false, false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(operatorStates), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CoordinatorShutdownTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CoordinatorShutdownTest.java index ec1bbd8e2157a..c58e3a0d0eee9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CoordinatorShutdownTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CoordinatorShutdownTest.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster; import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.runtime.testtasks.FailingBlockingInvokable; import org.apache.flink.util.TestLogger; import org.junit.Test; @@ -190,26 +191,4 @@ public static void unblock() { } } - public static class FailingBlockingInvokable extends AbstractInvokable { - private static boolean blocking = true; - private static final Object lock = new Object(); - - @Override - public void invoke() throws Exception { - while (blocking) { - synchronized (lock) { - lock.wait(); - } - } - throw new RuntimeException("This exception is expected."); - } - - public static void unblock() { - blocking = false; - - synchronized (lock) { - lock.notifyAll(); - } - } - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index 7d103d0b297f0..ef31f0a3aa63c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -80,7 +80,7 @@ public class PendingCheckpointTest { @Test public void testCanBeSubsumed() throws Exception { // Forced checkpoints cannot be subsumed - CheckpointProperties forced = new CheckpointProperties(true, true, false, false, false, false, false); + CheckpointProperties forced = new CheckpointProperties(true, true, false, false, false, false, false, false); PendingCheckpoint pending = createPendingCheckpoint(forced, "ignored"); assertFalse(pending.canBeSubsumed()); @@ -92,7 +92,7 @@ public void testCanBeSubsumed() throws Exception { } // Non-forced checkpoints can be subsumed - CheckpointProperties subsumed = new CheckpointProperties(false, true, false, false, false, false, false); + CheckpointProperties subsumed = new CheckpointProperties(false, true, false, false, false, false, false, false); pending = createPendingCheckpoint(subsumed, "ignored"); assertTrue(pending.canBeSubsumed()); } @@ -106,7 +106,7 @@ public void testPersistExternally() throws Exception { File tmp = tmpFolder.newFolder(); // Persisted checkpoint - CheckpointProperties persisted = new CheckpointProperties(false, true, false, false, false, false, false); + CheckpointProperties persisted = new CheckpointProperties(false, true, false, false, false, false, false, false); PendingCheckpoint pending = createPendingCheckpoint(persisted, tmp.getAbsolutePath()); pending.acknowledgeTask(ATTEMPT_ID, null, new CheckpointMetrics()); @@ -115,7 +115,7 @@ public void testPersistExternally() throws Exception { assertEquals(1, tmp.listFiles().length); // Ephemeral checkpoint - CheckpointProperties ephemeral = new CheckpointProperties(false, false, true, true, true, true, true); + CheckpointProperties ephemeral = new CheckpointProperties(false, false, false, true, true, true, true, true); pending = createPendingCheckpoint(ephemeral, null); pending.acknowledgeTask(ATTEMPT_ID, null, new CheckpointMetrics()); @@ -130,7 +130,7 @@ public void testPersistExternally() throws Exception { */ @Test public void testCompletionFuture() throws Exception { - CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); + CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false, false); // Abort declined PendingCheckpoint pending = createPendingCheckpoint(props, "ignored"); @@ -192,7 +192,7 @@ public void testCompletionFuture() throws Exception { @Test @SuppressWarnings("unchecked") public void testAbortDiscardsState() throws Exception { - CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); + CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false, false); QueueExecutor executor = new QueueExecutor(); OperatorState state = mock(OperatorState.class); @@ -324,13 +324,13 @@ public void testNullSubtaskStateLeadsToStatelessTask() throws Exception { @Test public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception { PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null); - pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class)); + pending.acknowledgeTask(ATTEMPT_ID, mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class)); Assert.assertFalse(pending.getOperatorStates().isEmpty()); } @Test public void testSetCanceller() { - final CheckpointProperties props = new CheckpointProperties(false, false, true, true, true, true, true); + final CheckpointProperties props = new CheckpointProperties(false, false, false, true, true, true, true, true); PendingCheckpoint aborted = createPendingCheckpoint(props, null); aborted.abortDeclined(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/RestoredCheckpointStatsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/RestoredCheckpointStatsTest.java index 85b151635efe9..d43283d992b55 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/RestoredCheckpointStatsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/RestoredCheckpointStatsTest.java @@ -31,7 +31,7 @@ public class RestoredCheckpointStatsTest { public void testSimpleAccess() throws Exception { long checkpointId = Integer.MAX_VALUE + 1L; long triggerTimestamp = Integer.MAX_VALUE + 1L; - CheckpointProperties props = new CheckpointProperties(true, true, false, false, true, false, true); + CheckpointProperties props = new CheckpointProperties(true, true, false, false, false, true, false, true); long restoreTimestamp = Integer.MAX_VALUE + 1L; String externalPath = "external-path"; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 77423c213fd81..dc2b11ebd08bc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -18,13 +18,14 @@ package org.apache.flink.runtime.checkpoint; -import org.apache.curator.framework.CuratorFramework; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.state.RetrievableStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment; + +import org.apache.curator.framework.CuratorFramework; import org.apache.zookeeper.data.Stat; import org.junit.AfterClass; import org.junit.Before; @@ -106,8 +107,9 @@ public void testRecover() throws Exception { assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); // Recover - sharedStateRegistry.clear(); - checkpoints.recover(sharedStateRegistry); + sharedStateRegistry.close(); + sharedStateRegistry = new SharedStateRegistry(); + checkpoints.recover(); assertEquals(3, ZOOKEEPER.getClient().getChildren().forPath(CHECKPOINT_PATH).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); @@ -148,8 +150,8 @@ public void testShutdownDiscardsCheckpoints() throws Exception { assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNull(client.checkExists().forPath(CHECKPOINT_PATH + ZooKeeperCompletedCheckpointStore.checkpointIdToPath(checkpoint.getCheckpointID()))); - sharedStateRegistry.clear(); - store.recover(sharedStateRegistry); + sharedStateRegistry.close(); + store.recover(); assertEquals(0, store.getNumberOfRetainedCheckpoints()); } @@ -182,8 +184,8 @@ public void testSuspendKeepsCheckpoints() throws Exception { assertEquals("The checkpoint node should not be locked.", 0, stat.getNumChildren()); // Recover again - sharedStateRegistry.clear(); - store.recover(sharedStateRegistry); + sharedStateRegistry.close(); + store.recover(); CompletedCheckpoint recovered = store.getLatestCheckpoint(); assertEquals(checkpoint, recovered); @@ -209,8 +211,8 @@ public void testLatestCheckpointRecovery() throws Exception { checkpointStore.addCheckpoint(checkpoint); } - sharedStateRegistry.clear(); - checkpointStore.recover(sharedStateRegistry); + sharedStateRegistry.close(); + checkpointStore.recover(); CompletedCheckpoint latestCheckpoint = checkpointStore.getLatestCheckpoint(); @@ -239,8 +241,9 @@ public void testConcurrentCheckpointOperations() throws Exception { zkCheckpointStore1.addCheckpoint(completedCheckpoint); // recover the checkpoint by a different checkpoint store - sharedStateRegistry.clear(); - zkCheckpointStore2.recover(sharedStateRegistry); + sharedStateRegistry.close(); + sharedStateRegistry = new SharedStateRegistry(); + zkCheckpointStore2.recover(); CompletedCheckpoint recoveredCheckpoint = zkCheckpointStore2.getLatestCheckpoint(); assertTrue(recoveredCheckpoint instanceof TestCompletedCheckpoint); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java index 91bab85fef84a..3171f1f84e388 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java @@ -52,7 +52,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyCollection; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -162,11 +161,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { stateStorage, Executors.directExecutor()); - SharedStateRegistry sharedStateRegistry = spy(new SharedStateRegistry()); - zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry); - - verify(retrievableStateHandle1.retrieveState(), times(1)).registerSharedStatesAfterRestored(sharedStateRegistry); - verify(retrievableStateHandle2.retrieveState(), times(1)).registerSharedStatesAfterRestored(sharedStateRegistry); + zooKeeperCompletedCheckpointStore.recover(); CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index de1f599e137bc..acedb5071b252 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -77,7 +77,6 @@ public static Collection createOperatorStates( OperatorState taskState = new OperatorState(new OperatorID(), numSubtasksPerTask, 128); - boolean hasNonPartitionableState = random.nextBoolean(); boolean hasOperatorStateBackend = random.nextBoolean(); boolean hasOperatorStateStream = random.nextBoolean(); @@ -87,7 +86,6 @@ public static Collection createOperatorStates( for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { - StreamStateHandle nonPartitionableState = null; StreamStateHandle operatorStateBackend = new TestByteStreamStateHandleDeepCompare("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateStream = @@ -101,11 +99,6 @@ public static Collection createOperatorStates( offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); - if (hasNonPartitionableState) { - nonPartitionableState = - new TestByteStreamStateHandleDeepCompare("a", ("Hi").getBytes(ConfigConstants.DEFAULT_CHARSET)); - } - if (hasOperatorStateBackend) { operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend); } @@ -130,7 +123,6 @@ public static Collection createOperatorStates( } taskState.putState(subtaskIdx, new OperatorSubtaskState( - nonPartitionableState, operatorStateHandleBackend, operatorStateHandleStream, keyedStateStream, @@ -175,15 +167,11 @@ public static Collection createTaskStates( for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { - List nonPartitionableStates = new ArrayList<>(chainLength); List operatorStatesBackend = new ArrayList<>(chainLength); List operatorStatesStream = new ArrayList<>(chainLength); for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) { - StreamStateHandle nonPartitionableState = - new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes( - ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateBackend = new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateStream = @@ -193,10 +181,6 @@ public static Collection createTaskStates( offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); - if (chainIdx != noNonPartitionableStateAtIndex) { - nonPartitionableStates.add(nonPartitionableState); - } - if (chainIdx != noOperatorStateBackendAtIndex) { OperatorStateHandle operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend); @@ -222,7 +206,6 @@ public static Collection createTaskStates( } taskState.putState(subtaskIdx, new SubtaskState( - new ChainedStateHandle<>(nonPartitionableStates), new ChainedStateHandle<>(operatorStatesBackend), new ChainedStateHandle<>(operatorStatesStream), keyedStateStream, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java deleted file mode 100644 index 16f3769160b79..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.checkpoint.savepoint; - -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.java.tuple.Tuple4; -import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.core.fs.Path; -import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0; -import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0Serializer; -import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.memory.MemValueState; -import org.apache.flink.migration.runtime.state.memory.SerializedStateHandle; -import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState; -import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskStateList; -import org.apache.flink.migration.util.MigrationInstantiationUtil; -import org.apache.flink.runtime.checkpoint.SubtaskState; -import org.apache.flink.runtime.checkpoint.TaskState; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; -import org.apache.flink.util.FileUtils; -import org.apache.flink.util.InstantiationUtil; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -@SuppressWarnings("deprecation") -public class MigrationV0ToV1Test { - - @Rule - public TemporaryFolder tmp = new TemporaryFolder(); - - /** - * Simple test of savepoint methods. - */ - @Test - public void testSavepointMigrationV0ToV1() throws Exception { - - String target = tmp.getRoot().getAbsolutePath(); - - assertEquals(0, tmp.getRoot().listFiles().length); - - long checkpointId = ThreadLocalRandom.current().nextLong(Integer.MAX_VALUE); - int numTaskStates = 4; - int numSubtaskStates = 16; - - Collection expected = - createTaskStatesOld(numTaskStates, numSubtaskStates); - - SavepointV0 savepoint = new SavepointV0(checkpointId, expected); - - assertEquals(SavepointV0.VERSION, savepoint.getVersion()); - assertEquals(checkpointId, savepoint.getCheckpointId()); - assertEquals(expected, savepoint.getOldTaskStates()); - - assertFalse(savepoint.getOldTaskStates().isEmpty()); - - Exception latestException = null; - Path path = null; - FSDataOutputStream fdos = null; - - FileSystem fs = null; - - try { - - // Try to create a FS output stream - for (int attempt = 0; attempt < 10; attempt++) { - path = new Path(target, FileUtils.getRandomFilename("savepoint-")); - - if (fs == null) { - fs = FileSystem.get(path.toUri()); - } - - try { - fdos = fs.create(path, FileSystem.WriteMode.NO_OVERWRITE); - break; - } catch (Exception e) { - latestException = e; - } - } - - if (fdos == null) { - throw new IOException("Failed to create file output stream at " + path, latestException); - } - - try (DataOutputStream dos = new DataOutputStream(fdos)) { - dos.writeInt(SavepointStore.MAGIC_NUMBER); - dos.writeInt(savepoint.getVersion()); - SavepointV0Serializer.INSTANCE.serializeOld(savepoint, dos); - } - - ClassLoader cl = Thread.currentThread().getContextClassLoader(); - - Savepoint sp = SavepointStore.loadSavepoint(path.toString(), cl); - int t = 0; - for (TaskState taskState : sp.getTaskStates()) { - for (int p = 0; p < taskState.getParallelism(); ++p) { - SubtaskState subtaskState = taskState.getState(p); - ChainedStateHandle legacyOperatorState = subtaskState.getLegacyOperatorState(); - for (int c = 0; c < legacyOperatorState.getLength(); ++c) { - StreamStateHandle stateHandle = legacyOperatorState.get(c); - try (InputStream is = stateHandle.openInputStream()) { - Tuple4 expTestState = new Tuple4<>(0, t, p, c); - Tuple4 actTestState; - //check function state - if (p % 4 != 0) { - assertEquals(1, is.read()); - actTestState = InstantiationUtil.deserializeObject(is, cl); - assertEquals(expTestState, actTestState); - } else { - assertEquals(0, is.read()); - } - - //check operator state - expTestState.f0 = 1; - actTestState = InstantiationUtil.deserializeObject(is, cl); - assertEquals(expTestState, actTestState); - } - } - - //check keyed state - KeyedStateHandle keyedStateHandle = subtaskState.getManagedKeyedState(); - - if (t % 3 != 0) { - - assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle); - - KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle; - - assertEquals(1, keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups()); - assertEquals(p, keyGroupsStateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup()); - - ByteStreamStateHandle stateHandle = - (ByteStreamStateHandle) keyGroupsStateHandle.getDelegateStateHandle(); - HashMap> testKeyedState = - MigrationInstantiationUtil.deserializeObject(stateHandle.getData(), cl); - - assertEquals(2, testKeyedState.size()); - for (KvStateSnapshot snapshot : testKeyedState.values()) { - MemValueState.Snapshot castedSnapshot = (MemValueState.Snapshot) snapshot; - byte[] data = castedSnapshot.getData(); - assertEquals(t, data[0]); - assertEquals(p, data[1]); - } - } else { - assertEquals(null, keyedStateHandle); - } - } - - ++t; - } - - savepoint.dispose(); - - } finally { - // Dispose - SavepointStore.removeSavepointFile(path.toString()); - } - } - - private static Collection createTaskStatesOld( - int numTaskStates, int numSubtaskStates) throws Exception { - - List taskStates = new ArrayList<>(numTaskStates); - - for (int i = 0; i < numTaskStates; i++) { - org.apache.flink.migration.runtime.checkpoint.TaskState taskState = - new org.apache.flink.migration.runtime.checkpoint.TaskState(new JobVertexID(), numSubtaskStates); - for (int j = 0; j < numSubtaskStates; j++) { - - StreamTaskState[] streamTaskStates = new StreamTaskState[2]; - - for (int k = 0; k < streamTaskStates.length; k++) { - StreamTaskState state = new StreamTaskState(); - Tuple4 testState = new Tuple4<>(0, i, j, k); - if (j % 4 != 0) { - state.setFunctionState(new SerializedStateHandle(testState)); - } - testState = new Tuple4<>(1, i, j, k); - state.setOperatorState(new SerializedStateHandle<>(testState)); - - if ((0 == k) && (i % 3 != 0)) { - HashMap> testKeyedState = new HashMap<>(2); - for (int l = 0; l < 2; ++l) { - String name = "keyed-" + l; - KvStateSnapshot testKeyedSnapshot = - new MemValueState.Snapshot<>( - IntSerializer.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - IntSerializer.INSTANCE, - new ValueStateDescriptor<>(name, Integer.class, 0), - new byte[]{(byte) i, (byte) j}); - testKeyedState.put(name, testKeyedSnapshot); - } - state.setKvStates(testKeyedState); - } - streamTaskStates[k] = state; - } - - StreamTaskStateList streamTaskStateList = new StreamTaskStateList(streamTaskStates); - org.apache.flink.migration.util.SerializedValue< - org.apache.flink.migration.runtime.state.StateHandle> handle = - new org.apache.flink.migration.util.SerializedValue< - org.apache.flink.migration.runtime.state.StateHandle>(streamTaskStateList); - - taskState.putState(j, new org.apache.flink.migration.runtime.checkpoint.SubtaskState(handle, 0, 0)); - } - - taskStates.add(taskState); - } - - return taskStates; - } -} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/clusterframework/ResourceManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/clusterframework/ResourceManagerTest.java index 3ca0327c02199..6013e91e2cec5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/clusterframework/ResourceManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/clusterframework/ResourceManagerTest.java @@ -39,6 +39,7 @@ import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.JobMasterRegistrationSuccess; import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; @@ -48,9 +49,11 @@ import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.resourcemanager.JobLeaderIdService; import org.apache.flink.runtime.resourcemanager.ResourceManagerConfiguration; +import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.StandaloneResourceManager; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.taskexecutor.SlotReport; import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskexecutor.TaskExecutorRegistrationSuccess; @@ -64,6 +67,8 @@ import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + import scala.Option; import java.util.ArrayList; @@ -93,6 +98,8 @@ public class ResourceManagerTest extends TestLogger { private static Configuration config = new Configuration(); + private final Time timeout = Time.seconds(10L); + private TestingHighAvailabilityServices highAvailabilityServices; private TestingLeaderRetrievalService jobManagerLeaderRetrievalService; @@ -479,7 +486,7 @@ public void testHeartbeatTimeoutWithTaskExecutor() throws Exception { final ResourceID resourceManagerResourceID = ResourceID.generate(); final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); - final TestingSerialRpcService rpcService = new TestingSerialRpcService(); + final TestingRpcService rpcService = new TestingRpcService(); rpcService.registerGateway(taskManagerAddress, taskExecutorGateway); final ResourceManagerConfiguration resourceManagerConfiguration = new ResourceManagerConfiguration( @@ -519,18 +526,19 @@ public void testHeartbeatTimeoutWithTaskExecutor() throws Exception { resourceManager.start(); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); + final UUID rmLeaderSessionId = UUID.randomUUID(); rmLeaderElectionService.isLeader(rmLeaderSessionId); final SlotReport slotReport = new SlotReport(); // test registration response successful and it will trigger monitor heartbeat target, schedule heartbeat request at interval time - CompletableFuture successfulFuture = resourceManager.registerTaskExecutor( - rmLeaderSessionId, + CompletableFuture successfulFuture = rmGateway.registerTaskExecutor( taskManagerAddress, taskManagerResourceID, slotReport, - Time.milliseconds(0L)); - RegistrationResponse response = successfulFuture.get(5, TimeUnit.SECONDS); + timeout); + RegistrationResponse response = successfulFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); assertTrue(response instanceof TaskExecutorRegistrationSuccess); ArgumentCaptor heartbeatRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); @@ -557,7 +565,7 @@ public void testHeartbeatTimeoutWithTaskExecutor() throws Exception { // run the timeout runnable to simulate a heartbeat timeout timeoutRunnable.run(); - verify(taskExecutorGateway).disconnectResourceManager(any(TimeoutException.class)); + verify(taskExecutorGateway, Mockito.timeout(timeout.toMilliseconds())).disconnectResourceManager(any(TimeoutException.class)); } finally { rpcService.stopService(); @@ -569,13 +577,13 @@ public void testHeartbeatTimeoutWithJobManager() throws Exception { final String jobMasterAddress = "jm"; final ResourceID jmResourceId = new ResourceID(jobMasterAddress); final ResourceID rmResourceId = ResourceID.generate(); - final UUID rmLeaderId = UUID.randomUUID(); - final UUID jmLeaderId = UUID.randomUUID(); + final ResourceManagerId rmLeaderId = ResourceManagerId.generate(); + final JobMasterId jobMasterId = JobMasterId.generate(); final JobID jobId = new JobID(); final JobMasterGateway jobMasterGateway = mock(JobMasterGateway.class); - final TestingSerialRpcService rpcService = new TestingSerialRpcService(); + final TestingRpcService rpcService = new TestingRpcService(); rpcService.registerGateway(jobMasterAddress, jobMasterGateway); final ResourceManagerConfiguration resourceManagerConfiguration = new ResourceManagerConfiguration( @@ -583,7 +591,7 @@ public void testHeartbeatTimeoutWithJobManager() throws Exception { Time.seconds(5L)); final TestingLeaderElectionService rmLeaderElectionService = new TestingLeaderElectionService(); - final TestingLeaderRetrievalService jmLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jmLeaderId); + final TestingLeaderRetrievalService jmLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jobMasterId.toUUID()); final TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); highAvailabilityServices.setResourceManagerLeaderElectionService(rmLeaderElectionService); highAvailabilityServices.setJobMasterLeaderRetriever(jobId, jmLeaderRetrievalService); @@ -620,17 +628,18 @@ public void testHeartbeatTimeoutWithJobManager() throws Exception { resourceManager.start(); - rmLeaderElectionService.isLeader(rmLeaderId); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); + + rmLeaderElectionService.isLeader(rmLeaderId.toUUID()).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); // test registration response successful and it will trigger monitor heartbeat target, schedule heartbeat request at interval time - CompletableFuture successfulFuture = resourceManager.registerJobManager( - rmLeaderId, - jmLeaderId, + CompletableFuture successfulFuture = rmGateway.registerJobManager( + jobMasterId, jmResourceId, jobMasterAddress, jobId, - Time.milliseconds(0L)); - RegistrationResponse response = successfulFuture.get(5, TimeUnit.SECONDS); + timeout); + RegistrationResponse response = successfulFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); assertTrue(response instanceof JobMasterRegistrationSuccess); ArgumentCaptor heartbeatRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); @@ -657,7 +666,7 @@ public void testHeartbeatTimeoutWithJobManager() throws Exception { // run the timeout runnable to simulate a heartbeat timeout timeoutRunnable.run(); - verify(jobMasterGateway).disconnectResourceManager(eq(jmLeaderId), eq(rmLeaderId), any(TimeoutException.class)); + verify(jobMasterGateway, Mockito.timeout(timeout.toMilliseconds())).disconnectResourceManager(eq(rmLeaderId), any(TimeoutException.class)); } finally { rpcService.stopService(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ConjunctFutureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ConjunctFutureTest.java new file mode 100644 index 0000000000000..f92504ef84977 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ConjunctFutureTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.concurrent; + +import org.apache.flink.runtime.concurrent.FutureUtils.ConjunctFuture; +import org.apache.flink.util.TestLogger; + +import org.hamcrest.collection.IsIterableContainingInAnyOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for the {@link ConjunctFuture} and {@link FutureUtils.WaitingConjunctFuture}. + */ +@RunWith(Parameterized.class) +public class ConjunctFutureTest extends TestLogger { + + @Parameterized.Parameters + public static Collection parameters (){ + return Arrays.asList(new ConjunctFutureFactory(), new WaitingFutureFactory()); + } + + @Parameterized.Parameter + public FutureFactory futureFactory; + + @Test + public void testConjunctFutureFailsOnEmptyAndNull() throws Exception { + try { + futureFactory.createFuture(null); + fail(); + } catch (NullPointerException ignored) {} + + try { + futureFactory.createFuture(Arrays.asList( + new CompletableFuture<>(), + null, + new CompletableFuture<>())); + fail(); + } catch (NullPointerException ignored) {} + } + + @Test + public void testConjunctFutureCompletion() throws Exception { + // some futures that we combine + java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); + + // some future is initially completed + future2.complete(new Object()); + + // build the conjunct future + ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); + + CompletableFuture resultMapped = result.thenAccept(value -> {}); + + assertEquals(4, result.getNumFuturesTotal()); + assertEquals(1, result.getNumFuturesCompleted()); + assertFalse(result.isDone()); + assertFalse(resultMapped.isDone()); + + // complete two more futures + future4.complete(new Object()); + assertEquals(2, result.getNumFuturesCompleted()); + assertFalse(result.isDone()); + assertFalse(resultMapped.isDone()); + + future1.complete(new Object()); + assertEquals(3, result.getNumFuturesCompleted()); + assertFalse(result.isDone()); + assertFalse(resultMapped.isDone()); + + // complete one future again + future1.complete(new Object()); + assertEquals(3, result.getNumFuturesCompleted()); + assertFalse(result.isDone()); + assertFalse(resultMapped.isDone()); + + // complete the final future + future3.complete(new Object()); + assertEquals(4, result.getNumFuturesCompleted()); + assertTrue(result.isDone()); + assertTrue(resultMapped.isDone()); + } + + @Test + public void testConjunctFutureFailureOnFirst() throws Exception { + + java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); + + // build the conjunct future + ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); + + CompletableFuture resultMapped = result.thenAccept(value -> {}); + + assertEquals(4, result.getNumFuturesTotal()); + assertEquals(0, result.getNumFuturesCompleted()); + assertFalse(result.isDone()); + assertFalse(resultMapped.isDone()); + + future2.completeExceptionally(new IOException()); + + assertEquals(0, result.getNumFuturesCompleted()); + assertTrue(result.isDone()); + assertTrue(resultMapped.isDone()); + + try { + result.get(); + fail(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IOException); + } + + try { + resultMapped.get(); + fail(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IOException); + } + } + + @Test + public void testConjunctFutureFailureOnSuccessive() throws Exception { + + java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); + java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); + + // build the conjunct future + ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); + assertEquals(4, result.getNumFuturesTotal()); + + java.util.concurrent.CompletableFuture resultMapped = result.thenAccept(value -> {}); + + future1.complete(new Object()); + future3.complete(new Object()); + future4.complete(new Object()); + + future2.completeExceptionally(new IOException()); + + assertEquals(3, result.getNumFuturesCompleted()); + assertTrue(result.isDone()); + assertTrue(resultMapped.isDone()); + + try { + result.get(); + fail(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IOException); + } + + try { + resultMapped.get(); + fail(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof IOException); + } + } + + /** + * Tests that the conjunct future returns upon completion the collection of all future values + */ + @Test + public void testConjunctFutureValue() throws ExecutionException, InterruptedException { + java.util.concurrent.CompletableFuture future1 = java.util.concurrent.CompletableFuture.completedFuture(1); + java.util.concurrent.CompletableFuture future2 = java.util.concurrent.CompletableFuture.completedFuture(2L); + java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); + + ConjunctFuture> result = FutureUtils.combineAll(Arrays.asList(future1, future2, future3)); + + assertFalse(result.isDone()); + + future3.complete(.1); + + assertTrue(result.isDone()); + + assertThat(result.get(), IsIterableContainingInAnyOrder.containsInAnyOrder(1, 2L, .1)); + } + + @Test + public void testConjunctOfNone() throws Exception { + final ConjunctFuture result = futureFactory.createFuture(Collections.>emptyList()); + + assertEquals(0, result.getNumFuturesTotal()); + assertEquals(0, result.getNumFuturesCompleted()); + assertTrue(result.isDone()); + } + + /** + * Factory to create {@link ConjunctFuture} for testing. + */ + private interface FutureFactory { + ConjunctFuture createFuture(Collection> futures); + } + + private static class ConjunctFutureFactory implements FutureFactory { + + @Override + public ConjunctFuture createFuture(Collection> futures) { + return FutureUtils.combineAll(futures); + } + } + + private static class WaitingFutureFactory implements FutureFactory { + + @Override + public ConjunctFuture createFuture(Collection> futures) { + return FutureUtils.waitForAll(futures); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/FutureUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/FutureUtilsTest.java index cc95e7ad38ca9..c624ef2ec8fb8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/FutureUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/FutureUtilsTest.java @@ -18,225 +18,208 @@ package org.apache.flink.runtime.concurrent; -import org.apache.flink.runtime.concurrent.FutureUtils.ConjunctFuture; - +import org.apache.flink.api.common.time.Time; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkException; import org.apache.flink.util.TestLogger; -import org.hamcrest.collection.IsIterableContainingInAnyOrder; + import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.mockito.invocation.InvocationOnMock; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; - -import static org.junit.Assert.*; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** - * Tests for the utility methods in {@link FutureUtils} + * Tests for the utility methods in {@link FutureUtils}. */ -@RunWith(Parameterized.class) -public class FutureUtilsTest extends TestLogger{ - - @Parameterized.Parameters - public static Collection parameters (){ - return Arrays.asList(new ConjunctFutureFactory(), new WaitingFutureFactory()); - } - - @Parameterized.Parameter - public FutureFactory futureFactory; - - @Test - public void testConjunctFutureFailsOnEmptyAndNull() throws Exception { - try { - futureFactory.createFuture(null); - fail(); - } catch (NullPointerException ignored) {} - - try { - futureFactory.createFuture(Arrays.asList( - new CompletableFuture<>(), - null, - new CompletableFuture<>())); - fail(); - } catch (NullPointerException ignored) {} - } +public class FutureUtilsTest extends TestLogger { + /** + * Tests that we can retry an operation. + */ @Test - public void testConjunctFutureCompletion() throws Exception { - // some futures that we combine - java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); - - // some future is initially completed - future2.complete(new Object()); - - // build the conjunct future - ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); - - CompletableFuture resultMapped = result.thenAccept(value -> {}); - - assertEquals(4, result.getNumFuturesTotal()); - assertEquals(1, result.getNumFuturesCompleted()); - assertFalse(result.isDone()); - assertFalse(resultMapped.isDone()); - - // complete two more futures - future4.complete(new Object()); - assertEquals(2, result.getNumFuturesCompleted()); - assertFalse(result.isDone()); - assertFalse(resultMapped.isDone()); - - future1.complete(new Object()); - assertEquals(3, result.getNumFuturesCompleted()); - assertFalse(result.isDone()); - assertFalse(resultMapped.isDone()); - - // complete one future again - future1.complete(new Object()); - assertEquals(3, result.getNumFuturesCompleted()); - assertFalse(result.isDone()); - assertFalse(resultMapped.isDone()); - - // complete the final future - future3.complete(new Object()); - assertEquals(4, result.getNumFuturesCompleted()); - assertTrue(result.isDone()); - assertTrue(resultMapped.isDone()); + public void testRetrySuccess() throws Exception { + final int retries = 10; + final AtomicInteger atomicInteger = new AtomicInteger(0); + CompletableFuture retryFuture = FutureUtils.retry( + () -> + CompletableFuture.supplyAsync( + () -> { + if (atomicInteger.incrementAndGet() == retries) { + return true; + } else { + throw new FlinkFutureException("Test exception"); + } + }, + TestingUtils.defaultExecutor()), + retries, + TestingUtils.defaultExecutor()); + + assertTrue(retryFuture.get()); + assertTrue(retries == atomicInteger.get()); } - @Test - public void testConjunctFutureFailureOnFirst() throws Exception { - - java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); - - // build the conjunct future - ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); - - CompletableFuture resultMapped = result.thenAccept(value -> {}); - - assertEquals(4, result.getNumFuturesTotal()); - assertEquals(0, result.getNumFuturesCompleted()); - assertFalse(result.isDone()); - assertFalse(resultMapped.isDone()); - - future2.completeExceptionally(new IOException()); - - assertEquals(0, result.getNumFuturesCompleted()); - assertTrue(result.isDone()); - assertTrue(resultMapped.isDone()); + /** + * Tests that a retry future is failed after all retries have been consumed. + */ + @Test(expected = FutureUtils.RetryException.class) + public void testRetryFailure() throws Throwable { + final int retries = 3; - try { - result.get(); - fail(); - } catch (ExecutionException e) { - assertTrue(e.getCause() instanceof IOException); - } + CompletableFuture retryFuture = FutureUtils.retry( + () -> FutureUtils.completedExceptionally(new FlinkException("Test exception")), + retries, + TestingUtils.defaultExecutor()); try { - resultMapped.get(); - fail(); - } catch (ExecutionException e) { - assertTrue(e.getCause() instanceof IOException); + retryFuture.get(); + } catch (ExecutionException ee) { + throw ExceptionUtils.stripExecutionException(ee); } } + /** + * Tests that we can cancel a retry future. + */ @Test - public void testConjunctFutureFailureOnSuccessive() throws Exception { - - java.util.concurrent.CompletableFuture future1 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future2 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); - java.util.concurrent.CompletableFuture future4 = new java.util.concurrent.CompletableFuture<>(); - - // build the conjunct future - ConjunctFuture result = futureFactory.createFuture(Arrays.asList(future1, future2, future3, future4)); - assertEquals(4, result.getNumFuturesTotal()); - - java.util.concurrent.CompletableFuture resultMapped = result.thenAccept(value -> {}); - - future1.complete(new Object()); - future3.complete(new Object()); - future4.complete(new Object()); - - future2.completeExceptionally(new IOException()); - - assertEquals(3, result.getNumFuturesCompleted()); - assertTrue(result.isDone()); - assertTrue(resultMapped.isDone()); - - try { - result.get(); - fail(); - } catch (ExecutionException e) { - assertTrue(e.getCause() instanceof IOException); - } - - try { - resultMapped.get(); - fail(); - } catch (ExecutionException e) { - assertTrue(e.getCause() instanceof IOException); + public void testRetryCancellation() throws Exception { + final int retries = 10; + final AtomicInteger atomicInteger = new AtomicInteger(0); + final OneShotLatch notificationLatch = new OneShotLatch(); + final OneShotLatch waitLatch = new OneShotLatch(); + final AtomicReference atomicThrowable = new AtomicReference<>(null); + + CompletableFuture retryFuture = FutureUtils.retry( + () -> + CompletableFuture.supplyAsync( + () -> { + if (atomicInteger.incrementAndGet() == 2) { + notificationLatch.trigger(); + try { + waitLatch.await(); + } catch (InterruptedException e) { + atomicThrowable.compareAndSet(null, e); + } + } + + throw new FlinkFutureException("Test exception"); + }, + TestingUtils.defaultExecutor()), + retries, + TestingUtils.defaultExecutor()); + + // await that we have failed once + notificationLatch.await(); + + assertFalse(retryFuture.isDone()); + + // cancel the retry future + retryFuture.cancel(false); + + // let the retry operation continue + waitLatch.trigger(); + + assertTrue(retryFuture.isCancelled()); + assertEquals(2, atomicInteger.get()); + + if (atomicThrowable.get() != null) { + throw new FlinkException("Exception occurred in the retry operation.", atomicThrowable.get()); } } /** - * Tests that the conjunct future returns upon completion the collection of all future values + * Tests that retry with delay fails after having exceeded all retries. */ - @Test - public void testConjunctFutureValue() throws ExecutionException, InterruptedException { - java.util.concurrent.CompletableFuture future1 = java.util.concurrent.CompletableFuture.completedFuture(1); - java.util.concurrent.CompletableFuture future2 = java.util.concurrent.CompletableFuture.completedFuture(2L); - java.util.concurrent.CompletableFuture future3 = new java.util.concurrent.CompletableFuture<>(); - - ConjunctFuture> result = FutureUtils.combineAll(Arrays.asList(future1, future2, future3)); - - assertFalse(result.isDone()); - - future3.complete(.1); - - assertTrue(result.isDone()); + @Test(expected = FutureUtils.RetryException.class) + public void testRetryWithDelayFailure() throws Throwable { + CompletableFuture retryFuture = FutureUtils.retryWithDelay( + () -> FutureUtils.completedExceptionally(new FlinkException("Test exception")), + 3, + Time.milliseconds(1L), + TestingUtils.defaultScheduledExecutor()); - assertThat(result.get(), IsIterableContainingInAnyOrder.containsInAnyOrder(1, 2L, .1)); + try { + retryFuture.get(TestingUtils.TIMEOUT().toMilliseconds(), TimeUnit.MILLISECONDS); + } catch (ExecutionException ee) { + throw ExceptionUtils.stripExecutionException(ee); + } } + /** + * Tests that the delay is respected between subsequent retries of a retry future with retry delay. + */ @Test - public void testConjunctOfNone() throws Exception { - final ConjunctFuture result = futureFactory.createFuture(Collections.>emptyList()); - - assertEquals(0, result.getNumFuturesTotal()); - assertEquals(0, result.getNumFuturesCompleted()); - assertTrue(result.isDone()); + public void testRetryWithDelay() throws Exception { + final int retries = 4; + final Time delay = Time.milliseconds(50L); + final AtomicInteger countDown = new AtomicInteger(retries); + + CompletableFuture retryFuture = FutureUtils.retryWithDelay( + () -> { + if (countDown.getAndDecrement() == 0) { + return CompletableFuture.completedFuture(true); + } else { + return FutureUtils.completedExceptionally(new FlinkException("Test exception.")); + } + }, + retries, + delay, + TestingUtils.defaultScheduledExecutor()); + + long start = System.currentTimeMillis(); + + Boolean result = retryFuture.get(); + + long completionTime = System.currentTimeMillis() - start; + + assertTrue(result); + assertTrue("The completion time should be at least rertries times delay between retries.", completionTime >= retries * delay.toMilliseconds()); } /** - * Factory to create {@link ConjunctFuture} for testing. + * Tests that all scheduled tasks are canceled if the retry future is being cancelled. */ - private interface FutureFactory { - ConjunctFuture createFuture(Collection> futures); - } - - private static class ConjunctFutureFactory implements FutureFactory { - - @Override - public ConjunctFuture createFuture(Collection> futures) { - return FutureUtils.combineAll(futures); - } - } - - private static class WaitingFutureFactory implements FutureFactory { - - @Override - public ConjunctFuture createFuture(Collection> futures) { - return FutureUtils.waitForAll(futures); - } + @Test + public void testRetryWithDelayCancellation() { + ScheduledFuture scheduledFutureMock = mock(ScheduledFuture.class); + ScheduledExecutor scheduledExecutorMock = mock(ScheduledExecutor.class); + doReturn(scheduledFutureMock).when(scheduledExecutorMock).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + doAnswer( + (InvocationOnMock invocation) -> { + invocation.getArgumentAt(0, Runnable.class).run(); + return null; + }).when(scheduledExecutorMock).execute(any(Runnable.class)); + + CompletableFuture retryFuture = FutureUtils.retryWithDelay( + () -> FutureUtils.completedExceptionally(new FlinkException("Test exception")), + 1, + TestingUtils.infiniteTime(), + scheduledExecutorMock); + + assertFalse(retryFuture.isDone()); + + verify(scheduledExecutorMock).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + + retryFuture.cancel(false); + + assertTrue(retryFuture.isCancelled()); + verify(scheduledFutureMock).cancel(anyBoolean()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java index 36c9cadeaec3e..9ed4851cad516 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; @@ -30,7 +31,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.operators.BatchTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.SerializedValue; import org.junit.Test; @@ -73,7 +73,7 @@ public void testSerialization() { final SerializedValue serializedJobVertexInformation = new SerializedValue<>(new TaskInformation( vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration)); final int targetSlotNumber = 47; - final TaskStateHandles taskStateHandles = new TaskStateHandles(); + final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot(); final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor( serializedJobInformation, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/dispatcher/DispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/dispatcher/DispatcherTest.java index 38146848f2369..da76115006027 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/dispatcher/DispatcherTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/dispatcher/DispatcherTest.java @@ -22,28 +22,37 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; import org.apache.flink.runtime.highavailability.nonha.standalone.StandaloneHaServices; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobmanager.OnCompletionActions; +import org.apache.flink.runtime.jobmanager.SubmittedJobGraphStore; import org.apache.flink.runtime.jobmaster.JobManagerRunner; +import org.apache.flink.runtime.jobmaster.JobManagerServices; +import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.rpc.FatalErrorHandler; import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.runtime.rpc.RpcUtils; import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.util.TestingFatalErrorHandler; import org.apache.flink.util.TestLogger; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.mockito.Mockito; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -53,6 +62,23 @@ */ public class DispatcherTest extends TestLogger { + private static RpcService rpcService; + private static final Time timeout = Time.seconds(10L); + + @BeforeClass + public static void setup() { + rpcService = new TestingRpcService(); + } + + @AfterClass + public static void teardown() { + if (rpcService != null) { + rpcService.stopService(); + + rpcService = null; + } + } + /** * Tests that we can submit a job to the Dispatcher which then spawns a * new JobManagerRunner. @@ -60,29 +86,30 @@ public class DispatcherTest extends TestLogger { @Test public void testJobSubmission() throws Exception { TestingFatalErrorHandler fatalErrorHandler = new TestingFatalErrorHandler(); - RpcService rpcService = new TestingRpcService(); - HighAvailabilityServices haServices = new StandaloneHaServices("localhost", "localhost"); + HighAvailabilityServices haServices = new StandaloneHaServices( + "localhost", + "localhost", + "localhost"); HeartbeatServices heartbeatServices = new HeartbeatServices(1000L, 10000L); JobManagerRunner jobManagerRunner = mock(JobManagerRunner.class); - final Time timeout = Time.seconds(5L); final JobGraph jobGraph = mock(JobGraph.class); final JobID jobId = new JobID(); when(jobGraph.getJobID()).thenReturn(jobId); - try { - final TestingDispatcher dispatcher = new TestingDispatcher( - rpcService, - Dispatcher.DISPATCHER_NAME, - new Configuration(), - haServices, - mock(BlobServer.class), - heartbeatServices, - mock(MetricRegistry.class), - fatalErrorHandler, - jobManagerRunner, - jobId); + final TestingDispatcher dispatcher = new TestingDispatcher( + rpcService, + Dispatcher.DISPATCHER_NAME, + new Configuration(), + haServices, + mock(BlobServer.class), + heartbeatServices, + mock(MetricRegistry.class), + fatalErrorHandler, + jobManagerRunner, + jobId); + try { dispatcher.start(); DispatcherGateway dispatcherGateway = dispatcher.getSelfGateway(DispatcherGateway.class); @@ -96,7 +123,60 @@ public void testJobSubmission() throws Exception { // check that no error has occurred fatalErrorHandler.rethrowError(); } finally { - rpcService.stopService(); + RpcUtils.terminateRpcEndpoint(dispatcher, timeout); + } + } + + /** + * Tests that the dispatcher takes part in the leader election. + */ + @Test + public void testLeaderElection() throws Exception { + TestingFatalErrorHandler fatalErrorHandler = new TestingFatalErrorHandler(); + TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); + + UUID expectedLeaderSessionId = UUID.randomUUID(); + CompletableFuture leaderSessionIdFuture = new CompletableFuture<>(); + SubmittedJobGraphStore mockSubmittedJobGraphStore = mock(SubmittedJobGraphStore.class); + TestingLeaderElectionService testingLeaderElectionService = new TestingLeaderElectionService() { + @Override + public void confirmLeaderSessionID(UUID leaderSessionId) { + super.confirmLeaderSessionID(leaderSessionId); + leaderSessionIdFuture.complete(leaderSessionId); + } + }; + + haServices.setSubmittedJobGraphStore(mockSubmittedJobGraphStore); + haServices.setDispatcherLeaderElectionService(testingLeaderElectionService); + HeartbeatServices heartbeatServices = new HeartbeatServices(1000L, 1000L); + final JobID jobId = new JobID(); + + final TestingDispatcher dispatcher = new TestingDispatcher( + rpcService, + Dispatcher.DISPATCHER_NAME, + new Configuration(), + haServices, + mock(BlobServer.class), + heartbeatServices, + mock(MetricRegistry.class), + fatalErrorHandler, + mock(JobManagerRunner.class), + jobId); + + try { + dispatcher.start(); + + assertFalse(leaderSessionIdFuture.isDone()); + + testingLeaderElectionService.isLeader(expectedLeaderSessionId); + + UUID actualLeaderSessionId = leaderSessionIdFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + assertEquals(expectedLeaderSessionId, actualLeaderSessionId); + + verify(mockSubmittedJobGraphStore, Mockito.timeout(timeout.toMilliseconds()).atLeast(1)).getJobIds(); + } finally { + RpcUtils.terminateRpcEndpoint(dispatcher, timeout); } } @@ -137,8 +217,8 @@ protected JobManagerRunner createJobManagerRunner( Configuration configuration, RpcService rpcService, HighAvailabilityServices highAvailabilityServices, - BlobService blobService, HeartbeatServices heartbeatServices, + JobManagerServices jobManagerServices, MetricRegistry metricRegistry, OnCompletionActions onCompleteActions, FatalErrorHandler fatalErrorHandler) throws Exception { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManagerTest.java index 606d8c9a0491e..a4b48e80cdd1b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheManagerTest.java @@ -18,24 +18,22 @@ package org.apache.flink.runtime.execution.librarycache; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobClient; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.blob.VoidBlobStore; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.api.common.JobID; import org.apache.flink.util.OperatingSystem; +import org.apache.flink.util.TestLogger; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import static org.junit.Assert.*; -import static org.junit.Assume.assumeTrue; - import java.io.File; import java.io.IOException; import java.net.InetSocketAddress; @@ -45,7 +43,19 @@ import java.util.Collections; import java.util.List; -public class BlobLibraryCacheManagerTest { +import static org.apache.flink.runtime.blob.BlobCacheCleanupTest.checkFileCountForJob; +import static org.apache.flink.runtime.blob.BlobCacheCleanupTest.checkFilesExist; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; + +/** + * Tests for {@link BlobLibraryCacheManager}. + */ +public class BlobLibraryCacheManagerTest extends TestLogger { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @@ -57,10 +67,13 @@ public class BlobLibraryCacheManagerTest { @Test public void testLibraryCacheManagerJobCleanup() throws IOException, InterruptedException { - JobID jid = new JobID(); - List keys = new ArrayList(); + JobID jobId1 = new JobID(); + JobID jobId2 = new JobID(); + List keys1 = new ArrayList<>(); + List keys2 = new ArrayList<>(); BlobServer server = null; - BlobLibraryCacheManager libraryCacheManager = null; + BlobCache cache = null; + BlobLibraryCacheManager libCache = null; final byte[] buf = new byte[128]; @@ -68,107 +81,231 @@ public void testLibraryCacheManagerJobCleanup() throws IOException, InterruptedE Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); server = new BlobServer(config, new VoidBlobStore()); - InetSocketAddress blobSocketAddress = new InetSocketAddress(server.getPort()); - BlobClient bc = new BlobClient(blobSocketAddress, config); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + BlobClient bc = new BlobClient(serverAddress, config); + cache = new BlobCache(serverAddress, config, new VoidBlobStore()); - keys.add(bc.put(buf)); + keys1.add(bc.put(jobId1, buf)); buf[0] += 1; - keys.add(bc.put(buf)); + keys1.add(bc.put(jobId1, buf)); + keys2.add(bc.put(jobId2, buf)); bc.close(); - long cleanupInterval = 1000l; - libraryCacheManager = new BlobLibraryCacheManager(server, cleanupInterval); - libraryCacheManager.registerJob(jid, keys, Collections.emptyList()); - - assertEquals(2, checkFilesExist(keys, server, true)); - assertEquals(2, libraryCacheManager.getNumberOfCachedLibraries()); - assertEquals(1, libraryCacheManager.getNumberOfReferenceHolders(jid)); - - libraryCacheManager.unregisterJob(jid); + libCache = new BlobLibraryCacheManager(cache); + cache.registerJob(jobId1); + cache.registerJob(jobId2); + + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId1)); + checkFileCountForJob(2, jobId1, server); + checkFileCountForJob(0, jobId1, cache); + checkFileCountForJob(1, jobId2, server); + checkFileCountForJob(0, jobId2, cache); + + libCache.registerJob(jobId1, keys1, Collections.emptyList()); + ClassLoader classLoader1 = libCache.getClassLoader(jobId1); + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId1)); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId2)); + assertEquals(2, checkFilesExist(jobId1, keys1, cache, true)); + checkFileCountForJob(2, jobId1, server); + checkFileCountForJob(2, jobId1, cache); + assertEquals(0, checkFilesExist(jobId2, keys2, cache, false)); + checkFileCountForJob(1, jobId2, server); + checkFileCountForJob(0, jobId2, cache); + + libCache.registerJob(jobId2, keys2, Collections.emptyList()); + ClassLoader classLoader2 = libCache.getClassLoader(jobId2); + assertNotEquals(classLoader1, classLoader2); - // because we cannot guarantee that there are not thread races in the build system, we - // loop for a certain while until the references disappear - { - long deadline = System.currentTimeMillis() + 30000; - do { - Thread.sleep(500); - } - while (libraryCacheManager.getNumberOfCachedLibraries() > 0 && - System.currentTimeMillis() < deadline); + try { + libCache.registerJob(jobId2, keys1, Collections.emptyList()); + fail("Should fail with an IllegalStateException"); + } + catch (IllegalStateException e) { + // that's what we want } - - // this fails if we exited via a timeout - assertEquals(0, libraryCacheManager.getNumberOfCachedLibraries()); - assertEquals(0, libraryCacheManager.getNumberOfReferenceHolders(jid)); - - // the blob cache should no longer contain the files - assertEquals(0, checkFilesExist(keys, server, false)); try { - server.getURL(keys.get(0)); - fail("name-addressable BLOB should have been deleted"); - } catch (IOException e) { - // expected + libCache.registerJob( + jobId2, keys2, + Collections.singletonList(new URL("file:///tmp/does-not-exist"))); + fail("Should fail with an IllegalStateException"); } - try { - server.getURL(keys.get(1)); - fail("name-addressable BLOB should have been deleted"); - } catch (IOException e) { - // expected + catch (IllegalStateException e) { + // that's what we want } - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + + assertEquals(2, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId1)); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId2)); + assertEquals(2, checkFilesExist(jobId1, keys1, cache, true)); + checkFileCountForJob(2, jobId1, server); + checkFileCountForJob(2, jobId1, cache); + assertEquals(1, checkFilesExist(jobId2, keys2, cache, true)); + checkFileCountForJob(1, jobId2, server); + checkFileCountForJob(1, jobId2, cache); + + libCache.unregisterJob(jobId1); + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId1)); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId2)); + assertEquals(2, checkFilesExist(jobId1, keys1, cache, true)); + checkFileCountForJob(2, jobId1, server); + checkFileCountForJob(2, jobId1, cache); + assertEquals(1, checkFilesExist(jobId2, keys2, cache, true)); + checkFileCountForJob(1, jobId2, server); + checkFileCountForJob(1, jobId2, cache); + + libCache.unregisterJob(jobId2); + + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId1)); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId2)); + assertEquals(2, checkFilesExist(jobId1, keys1, cache, true)); + checkFileCountForJob(2, jobId1, server); + checkFileCountForJob(2, jobId1, cache); + assertEquals(1, checkFilesExist(jobId2, keys2, cache, true)); + checkFileCountForJob(1, jobId2, server); + checkFileCountForJob(1, jobId2, cache); + + // only BlobCache#releaseJob() calls clean up files (tested in BlobCacheCleanupTest etc. } finally { - if (server != null) { - server.close(); + if (libCache != null) { + libCache.shutdown(); } - if (libraryCacheManager != null) { - try { - libraryCacheManager.shutdown(); - } - catch (IOException e) { - e.printStackTrace(); - } + // should have been closed by the libraryCacheManager, but just in case + if (cache != null) { + cache.close(); + } + + if (server != null) { + server.close(); } } } /** - * Checks how many of the files given by blob keys are accessible. - * - * @param keys - * blob keys to check - * @param blobService - * BLOB store to use - * @param doThrow - * whether exceptions should be ignored (false), or throws (true) - * - * @return number of files we were able to retrieve via {@link BlobService#getURL(BlobKey)} + * Tests that the {@link BlobLibraryCacheManager} cleans up after calling {@link + * BlobLibraryCacheManager#unregisterTask(JobID, ExecutionAttemptID)}. */ - private static int checkFilesExist( - List keys, BlobService blobService, boolean doThrow) - throws IOException { - int numFiles = 0; + @Test + public void testLibraryCacheManagerTaskCleanup() throws IOException, InterruptedException { + + JobID jobId = new JobID(); + ExecutionAttemptID attempt1 = new ExecutionAttemptID(); + ExecutionAttemptID attempt2 = new ExecutionAttemptID(); + List keys = new ArrayList<>(); + BlobServer server = null; + BlobCache cache = null; + BlobLibraryCacheManager libCache = null; + + final byte[] buf = new byte[128]; + + try { + Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + server = new BlobServer(config, new VoidBlobStore()); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + BlobClient bc = new BlobClient(serverAddress, config); + cache = new BlobCache(serverAddress, config, new VoidBlobStore()); + + keys.add(bc.put(jobId, buf)); + buf[0] += 1; + keys.add(bc.put(jobId, buf)); + + bc.close(); + + libCache = new BlobLibraryCacheManager(cache); + cache.registerJob(jobId); + + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); + + libCache.registerTask(jobId, attempt1, keys, Collections.emptyList()); + ClassLoader classLoader1 = libCache.getClassLoader(jobId); + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + libCache.registerTask(jobId, attempt2, keys, Collections.emptyList()); + ClassLoader classLoader2 = libCache.getClassLoader(jobId); + assertEquals(classLoader1, classLoader2); - for (BlobKey key : keys) { try { - blobService.getURL(key); - ++numFiles; - } catch (IOException e) { - if (doThrow) { - throw e; - } + libCache.registerTask( + jobId, new ExecutionAttemptID(), Collections.emptyList(), + Collections.emptyList()); + fail("Should fail with an IllegalStateException"); + } + catch (IllegalStateException e) { + // that's what we want } + + try { + libCache.registerTask( + jobId, new ExecutionAttemptID(), keys, + Collections.singletonList(new URL("file:///tmp/does-not-exist"))); + fail("Should fail with an IllegalStateException"); + } + catch (IllegalStateException e) { + // that's what we want + } + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(2, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + libCache.unregisterTask(jobId, attempt1); + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + libCache.unregisterTask(jobId, attempt2); + + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // only BlobCache#releaseJob() calls clean up files (tested in BlobCacheCleanupTest etc. } + finally { + if (libCache != null) { + libCache.shutdown(); + } - return numFiles; + // should have been closed by the libraryCacheManager, but just in case + if (cache != null) { + cache.close(); + } + + if (server != null) { + server.close(); + } + } } /** @@ -176,14 +313,14 @@ private static int checkFilesExist( * BlobLibraryCacheManager#unregisterTask(JobID, ExecutionAttemptID)}. */ @Test - public void testLibraryCacheManagerTaskCleanup() throws IOException, InterruptedException { + public void testLibraryCacheManagerMixedJobTaskCleanup() throws IOException, InterruptedException { - JobID jid = new JobID(); - ExecutionAttemptID executionId1 = new ExecutionAttemptID(); - ExecutionAttemptID executionId2 = new ExecutionAttemptID(); - List keys = new ArrayList(); + JobID jobId = new JobID(); + ExecutionAttemptID attempt1 = new ExecutionAttemptID(); + List keys = new ArrayList<>(); BlobServer server = null; - BlobLibraryCacheManager libraryCacheManager = null; + BlobCache cache = null; + BlobLibraryCacheManager libCache = null; final byte[] buf = new byte[128]; @@ -191,63 +328,96 @@ public void testLibraryCacheManagerTaskCleanup() throws IOException, Interrupted Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); server = new BlobServer(config, new VoidBlobStore()); - InetSocketAddress blobSocketAddress = new InetSocketAddress(server.getPort()); - BlobClient bc = new BlobClient(blobSocketAddress, config); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); + BlobClient bc = new BlobClient(serverAddress, config); + cache = new BlobCache(serverAddress, config, new VoidBlobStore()); - keys.add(bc.put(buf)); + keys.add(bc.put(jobId, buf)); buf[0] += 1; - keys.add(bc.put(buf)); + keys.add(bc.put(jobId, buf)); - long cleanupInterval = 1000l; - libraryCacheManager = new BlobLibraryCacheManager(server, cleanupInterval); - libraryCacheManager.registerTask(jid, executionId1, keys, Collections.emptyList()); - libraryCacheManager.registerTask(jid, executionId2, keys, Collections.emptyList()); + bc.close(); - assertEquals(2, checkFilesExist(keys, server, true)); - assertEquals(2, libraryCacheManager.getNumberOfCachedLibraries()); - assertEquals(2, libraryCacheManager.getNumberOfReferenceHolders(jid)); + libCache = new BlobLibraryCacheManager(cache); + cache.registerJob(jobId); - libraryCacheManager.unregisterTask(jid, executionId1); + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); - assertEquals(2, checkFilesExist(keys, server, true)); - assertEquals(2, libraryCacheManager.getNumberOfCachedLibraries()); - assertEquals(1, libraryCacheManager.getNumberOfReferenceHolders(jid)); + libCache.registerJob(jobId, keys, Collections.emptyList()); + ClassLoader classLoader1 = libCache.getClassLoader(jobId); - libraryCacheManager.unregisterTask(jid, executionId2); + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); - // because we cannot guarantee that there are not thread races in the build system, we - // loop for a certain while until the references disappear - { - long deadline = System.currentTimeMillis() + 30000; - do { - Thread.sleep(100); - } - while (libraryCacheManager.getNumberOfCachedLibraries() > 0 && - System.currentTimeMillis() < deadline); + libCache.registerTask(jobId, attempt1, keys, Collections.emptyList()); + ClassLoader classLoader2 = libCache.getClassLoader(jobId); + assertEquals(classLoader1, classLoader2); + + try { + libCache.registerTask( + jobId, new ExecutionAttemptID(), Collections.emptyList(), + Collections.emptyList()); + fail("Should fail with an IllegalStateException"); + } + catch (IllegalStateException e) { + // that's what we want } - // this fails if we exited via a timeout - assertEquals(0, libraryCacheManager.getNumberOfCachedLibraries()); - assertEquals(0, libraryCacheManager.getNumberOfReferenceHolders(jid)); + try { + libCache.registerTask( + jobId, new ExecutionAttemptID(), keys, + Collections.singletonList(new URL("file:///tmp/does-not-exist"))); + fail("Should fail with an IllegalStateException"); + } + catch (IllegalStateException e) { + // that's what we want + } - // the blob cache should no longer contain the files - assertEquals(0, checkFilesExist(keys, server, false)); + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(2, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); - bc.close(); - } finally { - if (server != null) { - server.close(); + libCache.unregisterJob(jobId); + + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + libCache.unregisterTask(jobId, attempt1); + + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(2, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(2, jobId, cache); + + // only BlobCache#releaseJob() calls clean up files (tested in BlobCacheCleanupTest etc. + } + finally { + if (libCache != null) { + libCache.shutdown(); } - if (libraryCacheManager != null) { - try { - libraryCacheManager.shutdown(); - } - catch (IOException e) { - e.printStackTrace(); - } + // should have been closed by the libraryCacheManager, but just in case + if (cache != null) { + cache.close(); + } + + if (server != null) { + server.close(); } } } @@ -256,14 +426,17 @@ public void testLibraryCacheManagerTaskCleanup() throws IOException, Interrupted public void testRegisterAndDownload() throws IOException { assumeTrue(!OperatingSystem.isWindows()); //setWritable doesn't work on Windows. + JobID jobId = new JobID(); BlobServer server = null; BlobCache cache = null; + BlobLibraryCacheManager libCache = null; File cacheDir = null; try { // create the blob transfer services Configuration config = new Configuration(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1_000_000L); server = new BlobServer(config, new VoidBlobStore()); InetSocketAddress serverAddress = new InetSocketAddress("localhost", server.getPort()); @@ -271,56 +444,85 @@ public void testRegisterAndDownload() throws IOException { // upload some meaningless data to the server BlobClient uploader = new BlobClient(serverAddress, config); - BlobKey dataKey1 = uploader.put(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}); - BlobKey dataKey2 = uploader.put(new byte[]{11, 12, 13, 14, 15, 16, 17, 18}); + BlobKey dataKey1 = uploader.put(jobId, new byte[]{1, 2, 3, 4, 5, 6, 7, 8}); + BlobKey dataKey2 = uploader.put(jobId, new byte[]{11, 12, 13, 14, 15, 16, 17, 18}); uploader.close(); - BlobLibraryCacheManager libCache = new BlobLibraryCacheManager(cache, 1000000000L); - - assertEquals(0, libCache.getNumberOfCachedLibraries()); + libCache = new BlobLibraryCacheManager(cache); + assertEquals(0, libCache.getNumberOfManagedJobs()); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(0, jobId, cache); // first try to access a non-existing entry + assertEquals(0, libCache.getNumberOfReferenceHolders(new JobID())); try { libCache.getClassLoader(new JobID()); fail("Should fail with an IllegalStateException"); } catch (IllegalStateException e) { - // that#s what we want + // that's what we want } - // now register some BLOBs as libraries + // register some BLOBs as libraries { - JobID jid = new JobID(); - ExecutionAttemptID executionId = new ExecutionAttemptID(); Collection keys = Collections.singleton(dataKey1); - libCache.registerTask(jid, executionId, keys, Collections.emptyList()); - assertEquals(1, libCache.getNumberOfReferenceHolders(jid)); - assertEquals(1, libCache.getNumberOfCachedLibraries()); - assertNotNull(libCache.getClassLoader(jid)); - - // un-register them again - libCache.unregisterTask(jid, executionId); + cache.registerJob(jobId); + ExecutionAttemptID executionId = new ExecutionAttemptID(); + libCache.registerTask(jobId, executionId, keys, Collections.emptyList()); + ClassLoader classLoader1 = libCache.getClassLoader(jobId); + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(1, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(1, jobId, cache); + assertNotNull(libCache.getClassLoader(jobId)); + + libCache.registerJob(jobId, keys, Collections.emptyList()); + ClassLoader classLoader2 = libCache.getClassLoader(jobId); + assertEquals(classLoader1, classLoader2); + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(2, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(1, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(1, jobId, cache); + assertNotNull(libCache.getClassLoader(jobId)); + + // un-register the job + libCache.unregisterJob(jobId); + // still one task + assertEquals(1, libCache.getNumberOfManagedJobs()); + assertEquals(1, libCache.getNumberOfReferenceHolders(jobId)); + assertEquals(1, checkFilesExist(jobId, keys, cache, true)); + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(1, jobId, cache); + + // unregister the task registration + libCache.unregisterTask(jobId, executionId); + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); + // changing the libCache registration does not influence the BLOB stores... + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(1, jobId, cache); // Don't fail if called again - libCache.unregisterTask(jid, executionId); + libCache.unregisterJob(jobId); + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); - assertEquals(0, libCache.getNumberOfReferenceHolders(jid)); + libCache.unregisterTask(jobId, executionId); + assertEquals(0, libCache.getNumberOfManagedJobs()); + assertEquals(0, libCache.getNumberOfReferenceHolders(jobId)); - // library is still cached (but not associated with job any more) - assertEquals(1, libCache.getNumberOfCachedLibraries()); + cache.releaseJob(jobId); - // should not be able to access the classloader any more - try { - libCache.getClassLoader(jid); - fail("Should fail with an IllegalStateException"); - } - catch (IllegalStateException e) { - // that#s what we want - } + // library is still cached (but not associated with job any more) + checkFileCountForJob(2, jobId, server); + checkFileCountForJob(1, jobId, cache); } - cacheDir = new File(cache.getStorageDir(), "cache"); + // see BlobUtils for the directory layout + cacheDir = cache.getStorageLocation(jobId, new BlobKey()).getParentFile(); assertTrue(cacheDir.exists()); // make sure no further blobs can be downloaded by removing the write @@ -329,12 +531,14 @@ public void testRegisterAndDownload() throws IOException { // since we cannot download this library any more, this call should fail try { - libCache.registerTask(new JobID(), new ExecutionAttemptID(), Collections.singleton(dataKey2), - Collections.emptyList()); + cache.registerJob(jobId); + libCache.registerTask(jobId, new ExecutionAttemptID(), Collections.singleton(dataKey2), + Collections.emptyList()); fail("This should fail with an IOException"); } catch (IOException e) { // splendid! + cache.releaseJob(jobId); } } finally { if (cacheDir != null) { @@ -345,6 +549,9 @@ public void testRegisterAndDownload() throws IOException { if (cache != null) { cache.close(); } + if (libCache != null) { + libCache.shutdown(); + } if (server != null) { server.close(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheRecoveryITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheRecoveryITCase.java index e5efd19ba59f7..e52310e6b361d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheRecoveryITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/execution/librarycache/BlobLibraryCacheRecoveryITCase.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.util.TestLogger; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -49,6 +50,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +/** + * Integration test for {@link BlobLibraryCacheManager}. + */ public class BlobLibraryCacheRecoveryITCase extends TestLogger { @Rule @@ -65,7 +69,6 @@ public void testRecoveryRegisterAndDownload() throws Exception { InetSocketAddress[] serverAddress = new InetSocketAddress[2]; BlobLibraryCacheManager[] libServer = new BlobLibraryCacheManager[2]; BlobCache cache = null; - BlobLibraryCacheManager libCache = null; BlobStoreService blobStoreService = null; Configuration config = new Configuration(); @@ -75,6 +78,7 @@ public void testRecoveryRegisterAndDownload() throws Exception { temporaryFolder.newFolder().getAbsolutePath()); config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 3_600L); try { blobStoreService = BlobUtils.createBlobStoreFromConfig(config); @@ -82,7 +86,7 @@ public void testRecoveryRegisterAndDownload() throws Exception { for (int i = 0; i < server.length; i++) { server[i] = new BlobServer(config, blobStoreService); serverAddress[i] = new InetSocketAddress("localhost", server[i].getPort()); - libServer[i] = new BlobLibraryCacheManager(server[i], 3600 * 1000); + libServer[i] = new BlobLibraryCacheManager(server[i]); } // Random data @@ -91,23 +95,23 @@ public void testRecoveryRegisterAndDownload() throws Exception { List keys = new ArrayList<>(2); + JobID jobId = new JobID(); + // Upload some data (libraries) try (BlobClient client = new BlobClient(serverAddress[0], config)) { - keys.add(client.put(expected)); // Request 1 - keys.add(client.put(expected, 32, 256)); // Request 2 + keys.add(client.put(jobId, expected)); // Request 1 + keys.add(client.put(jobId, expected, 32, 256)); // Request 2 } // The cache cache = new BlobCache(serverAddress[0], config, blobStoreService); - libCache = new BlobLibraryCacheManager(cache, 3600 * 1000); // Register uploaded libraries - JobID jobId = new JobID(); ExecutionAttemptID executionId = new ExecutionAttemptID(); libServer[0].registerTask(jobId, executionId, keys, Collections.emptyList()); // Verify key 1 - File f = new File(cache.getURL(keys.get(0)).toURI()); + File f = cache.getFile(jobId, keys.get(0)); assertEquals(expected.length, f.length()); try (FileInputStream fis = new FileInputStream(f)) { @@ -120,13 +124,11 @@ public void testRecoveryRegisterAndDownload() throws Exception { // Shutdown cache and start with other server cache.close(); - libCache.shutdown(); cache = new BlobCache(serverAddress[1], config, blobStoreService); - libCache = new BlobLibraryCacheManager(cache, 3600 * 1000); // Verify key 1 - f = new File(cache.getURL(keys.get(0)).toURI()); + f = cache.getFile(jobId, keys.get(0)); assertEquals(expected.length, f.length()); try (FileInputStream fis = new FileInputStream(f)) { @@ -138,7 +140,7 @@ public void testRecoveryRegisterAndDownload() throws Exception { } // Verify key 2 - f = new File(cache.getURL(keys.get(1)).toURI()); + f = cache.getFile(jobId, keys.get(1)); assertEquals(256, f.length()); try (FileInputStream fis = new FileInputStream(f)) { @@ -151,8 +153,8 @@ public void testRecoveryRegisterAndDownload() throws Exception { // Remove blobs again try (BlobClient client = new BlobClient(serverAddress[1], config)) { - client.delete(keys.get(0)); - client.delete(keys.get(1)); + client.delete(jobId, keys.get(0)); + client.delete(jobId, keys.get(1)); } // Verify everything is clean below recoveryDir/ @@ -164,6 +166,11 @@ public void testRecoveryRegisterAndDownload() throws Exception { assertEquals("Unclean state backend: " + Arrays.toString(recoveryFiles), 0, recoveryFiles.length); } finally { + for (BlobLibraryCacheManager s : libServer) { + if (s != null) { + s.shutdown(); + } + } for (BlobServer s : server) { if (s != null) { s.close(); @@ -174,10 +181,6 @@ public void testRecoveryRegisterAndDownload() throws Exception { cache.close(); } - if (libCache != null) { - libCache.shutdown(); - } - if (blobStoreService != null) { blobStoreService.closeAndCleanupAllData(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java index 0eed90d271bcc..c9b7a40a78b0a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; @@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; import org.apache.flink.runtime.jobmanager.slots.SlotOwner; import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testtasks.NoOpInvokable; @@ -51,8 +51,10 @@ import java.util.Iterator; import java.util.concurrent.TimeUnit; -import static org.mockito.Mockito.*; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; /** * Tests that the execution vertex handles locality preferences well. @@ -169,7 +171,7 @@ public void testLocalityBasedOnState() throws Exception { // target state ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i]; - target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class)); + target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class)); } // validate that the target vertices have the state's location as the location preference diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/filecache/FileCacheDeleteValidationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/filecache/FileCacheDeleteValidationTest.java index 89ab975dc8903..0782e207627e6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/filecache/FileCacheDeleteValidationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/filecache/FileCacheDeleteValidationTest.java @@ -22,8 +22,9 @@ import org.apache.flink.api.common.cache.DistributedCache.DistributedCacheEntry; import org.apache.flink.core.fs.Path; -import com.google.common.base.Charsets; -import com.google.common.io.Files; +import org.apache.flink.shaded.guava18.com.google.common.base.Charsets; +import org.apache.flink.shaded.guava18.com.google.common.io.Files; + import org.junit.After; import org.junit.Before; import org.junit.Rule; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingHighAvailabilityServices.java b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingHighAvailabilityServices.java index 0a7e9c848b19b..dba7bef5fa264 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingHighAvailabilityServices.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingHighAvailabilityServices.java @@ -38,12 +38,16 @@ public class TestingHighAvailabilityServices implements HighAvailabilityServices private volatile LeaderRetrievalService resourceManagerLeaderRetriever; + private volatile LeaderRetrievalService dispatcherLeaderRetriever; + private ConcurrentHashMap jobMasterLeaderRetrievers = new ConcurrentHashMap<>(); private ConcurrentHashMap jobManagerLeaderElectionServices = new ConcurrentHashMap<>(); private volatile LeaderElectionService resourceManagerLeaderElectionService; + private volatile LeaderElectionService dispatcherLeaderElectionService; + private volatile CheckpointRecoveryFactory checkpointRecoveryFactory; private volatile SubmittedJobGraphStore submittedJobGraphStore; @@ -56,6 +60,10 @@ public void setResourceManagerLeaderRetriever(LeaderRetrievalService resourceMan this.resourceManagerLeaderRetriever = resourceManagerLeaderRetriever; } + public void setDispatcherLeaderRetriever(LeaderRetrievalService dispatcherLeaderRetriever) { + this.dispatcherLeaderRetriever = dispatcherLeaderRetriever; + } + public void setJobMasterLeaderRetriever(JobID jobID, LeaderRetrievalService jobMasterLeaderRetriever) { this.jobMasterLeaderRetrievers.put(jobID, jobMasterLeaderRetriever); } @@ -68,6 +76,10 @@ public void setResourceManagerLeaderElectionService(LeaderElectionService leader this.resourceManagerLeaderElectionService = leaderElectionService; } + public void setDispatcherLeaderElectionService(LeaderElectionService leaderElectionService) { + this.dispatcherLeaderElectionService = leaderElectionService; + } + public void setCheckpointRecoveryFactory(CheckpointRecoveryFactory checkpointRecoveryFactory) { this.checkpointRecoveryFactory = checkpointRecoveryFactory; } @@ -90,6 +102,16 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { } } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + LeaderRetrievalService service = this.dispatcherLeaderRetriever; + if (service != null) { + return service; + } else { + throw new IllegalStateException("ResourceManagerLeaderRetriever has not been set"); + } + } + @Override public LeaderRetrievalService getJobManagerLeaderRetriever(JobID jobID) { LeaderRetrievalService service = this.jobMasterLeaderRetrievers.get(jobID); @@ -116,6 +138,17 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { } } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + LeaderElectionService service = dispatcherLeaderElectionService; + + if (service != null) { + return service; + } else { + throw new IllegalStateException("DispatcherLeaderElectionService has not been set"); + } + } + @Override public LeaderElectionService getJobManagerLeaderElectionService(JobID jobID) { LeaderElectionService service = this.jobManagerLeaderElectionServices.get(jobID); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingManualHighAvailabilityServices.java b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingManualHighAvailabilityServices.java index 0735d17a31bf0..1f319ebfa6c61 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingManualHighAvailabilityServices.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/TestingManualHighAvailabilityServices.java @@ -45,9 +45,12 @@ public class TestingManualHighAvailabilityServices implements HighAvailabilitySe private final ManualLeaderService resourceManagerLeaderService; + private final ManualLeaderService dispatcherLeaderService; + public TestingManualHighAvailabilityServices() { jobManagerLeaderServices = new HashMap<>(4); resourceManagerLeaderService = new ManualLeaderService(); + dispatcherLeaderService = new ManualLeaderService(); } @Override @@ -55,6 +58,11 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { return resourceManagerLeaderService.createLeaderRetrievalService(); } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + return dispatcherLeaderService.createLeaderRetrievalService(); + } + @Override public LeaderRetrievalService getJobManagerLeaderRetriever(JobID jobID) { ManualLeaderService leaderService = getOrCreateJobManagerLeaderService(jobID); @@ -72,6 +80,11 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { return resourceManagerLeaderService.createLeaderElectionService(); } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + return dispatcherLeaderService.createLeaderElectionService(); + } + @Override public LeaderElectionService getJobManagerLeaderElectionService(JobID jobID) { ManualLeaderService leaderService = getOrCreateJobManagerLeaderService(jobID); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServicesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServicesTest.java index 2d51360777660..1cf2e5bb4ebdb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServicesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/highavailability/nonha/standalone/StandaloneHaServicesTest.java @@ -39,6 +39,7 @@ public class StandaloneHaServicesTest extends TestLogger { private final String jobManagerAddress = "jobManager"; + private final String dispatcherAddress = "dispatcher"; private final String resourceManagerAddress = "resourceManager"; private StandaloneHaServices standaloneHaServices; @@ -46,7 +47,10 @@ public class StandaloneHaServicesTest extends TestLogger { @Before public void setupTest() { - standaloneHaServices = new StandaloneHaServices(resourceManagerAddress, jobManagerAddress); + standaloneHaServices = new StandaloneHaServices( + resourceManagerAddress, + dispatcherAddress, + jobManagerAddress); } @After diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolRpcTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolRpcTest.java index 8d613ac480f76..9d742e278fa44 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolRpcTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolRpcTest.java @@ -25,15 +25,16 @@ import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.jobmanager.scheduler.NoResourceAvailableException; import org.apache.flink.runtime.jobmanager.scheduler.ScheduledUnit; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.akka.AkkaRpcService; import org.apache.flink.runtime.util.clock.SystemClock; +import org.apache.flink.util.TestLogger; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -47,7 +48,7 @@ /** * Tests for the SlotPool using a proper RPC setup. */ -public class SlotPoolRpcTest { +public class SlotPoolRpcTest extends TestLogger { private static RpcService rpcService; @@ -80,7 +81,7 @@ public void testSlotAllocationNoResourceManager() throws Exception { Time.days(1), Time.days(1), Time.milliseconds(100) // this is the timeout for the request tested here ); - pool.start(UUID.randomUUID(), "foobar"); + pool.start(JobMasterId.generate(), "foobar"); CompletableFuture future = pool.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, Time.days(1)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolTest.java index aeceb59e6ef9d..5993dcbce53a2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SlotPoolTest.java @@ -26,20 +26,20 @@ import org.apache.flink.runtime.jobmanager.scheduler.ScheduledUnit; import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; import org.apache.flink.runtime.resourcemanager.SlotRequest; -import org.apache.flink.runtime.rpc.MainThreadValidatorUtil; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.TestLogger; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import java.util.List; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -51,218 +51,268 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.RETURNS_MOCKS; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class SlotPoolTest extends TestLogger { + private final Time timeout = Time.seconds(10L); + private RpcService rpcService; private JobID jobId; - private MainThreadValidatorUtil mainThreadValidatorUtil; - - private SlotPool slotPool; - - private ResourceManagerGateway resourceManagerGateway; - @Before public void setUp() throws Exception { - - this.rpcService = new TestingSerialRpcService(); + this.rpcService = new TestingRpcService(); this.jobId = new JobID(); - this.slotPool = new SlotPool(rpcService, jobId); - - this.mainThreadValidatorUtil = new MainThreadValidatorUtil(slotPool); - - mainThreadValidatorUtil.enterMainThread(); - - final String jobManagerAddress = "foobar"; - - slotPool.start(UUID.randomUUID(), jobManagerAddress); - - this.resourceManagerGateway = mock(ResourceManagerGateway.class); - when(resourceManagerGateway - .requestSlot(any(UUID.class), any(UUID.class), any(SlotRequest.class), any(Time.class))) - .thenReturn(mock(CompletableFuture.class, RETURNS_MOCKS)); - - slotPool.connectToResourceManager(UUID.randomUUID(), resourceManagerGateway); } @After public void tearDown() throws Exception { - mainThreadValidatorUtil.exitMainThread(); + rpcService.stopService(); } @Test public void testAllocateSimpleSlot() throws Exception { - ResourceID resourceID = new ResourceID("resource"); - slotPool.registerTaskManager(resourceID); - - ScheduledUnit task = mock(ScheduledUnit.class); - CompletableFuture future = slotPool.allocateSlot(task, DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); - assertFalse(future.isDone()); - - ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); - verify(resourceManagerGateway).requestSlot(any(UUID.class), any(UUID.class), slotRequestArgumentCaptor.capture(), any(Time.class)); - - final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); - - AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - assertTrue(slotPool.offerSlot(allocatedSlot).get()); - - SimpleSlot slot = future.get(1, TimeUnit.SECONDS); - assertTrue(future.isDone()); - assertTrue(slot.isAlive()); - assertEquals(resourceID, slot.getTaskManagerID()); - assertEquals(jobId, slot.getJobID()); - assertEquals(slotPool.getSlotOwner(), slot.getOwner()); - assertEquals(slotPool.getAllocatedSlots().get(slot.getAllocatedSlot().getSlotAllocationId()), slot); + ResourceManagerGateway resourceManagerGateway = createResourceManagerGatewayMock(); + final SlotPool slotPool = new SlotPool(rpcService, jobId); + + try { + SlotPoolGateway slotPoolGateway = setupSlotPool(slotPool, resourceManagerGateway); + ResourceID resourceID = new ResourceID("resource"); + slotPoolGateway.registerTaskManager(resourceID); + + ScheduledUnit task = mock(ScheduledUnit.class); + CompletableFuture future = slotPoolGateway.allocateSlot(task, DEFAULT_TESTING_PROFILE, null, timeout); + assertFalse(future.isDone()); + + ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).requestSlot(any(JobMasterId.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + + final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); + + AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); + + SimpleSlot slot = future.get(1, TimeUnit.SECONDS); + assertTrue(future.isDone()); + assertTrue(slot.isAlive()); + assertEquals(resourceID, slot.getTaskManagerID()); + assertEquals(jobId, slot.getJobID()); + assertEquals(slotPool.getSlotOwner(), slot.getOwner()); + assertEquals(slotPool.getAllocatedSlots().get(slot.getAllocatedSlot().getSlotAllocationId()), slot); + } finally { + slotPool.shutDown(); + } } @Test public void testAllocationFulfilledByReturnedSlot() throws Exception { - ResourceID resourceID = new ResourceID("resource"); - slotPool.registerTaskManager(resourceID); + ResourceManagerGateway resourceManagerGateway = createResourceManagerGatewayMock(); + final SlotPool slotPool = new SlotPool(rpcService, jobId); + + try { + SlotPoolGateway slotPoolGateway = setupSlotPool(slotPool, resourceManagerGateway); + ResourceID resourceID = new ResourceID("resource"); + slotPool.registerTaskManager(resourceID); - CompletableFuture future1 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); - CompletableFuture future2 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); + CompletableFuture future1 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); + CompletableFuture future2 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); - assertFalse(future1.isDone()); - assertFalse(future2.isDone()); + assertFalse(future1.isDone()); + assertFalse(future2.isDone()); - ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); - verify(resourceManagerGateway, times(2)) - .requestSlot(any(UUID.class), any(UUID.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds()).times(2)) + .requestSlot(any(JobMasterId.class), slotRequestArgumentCaptor.capture(), any(Time.class)); - final List slotRequests = slotRequestArgumentCaptor.getAllValues(); + final List slotRequests = slotRequestArgumentCaptor.getAllValues(); - AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequests.get(0).getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - assertTrue(slotPool.offerSlot(allocatedSlot).get()); + AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequests.get(0).getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); - SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); - assertTrue(future1.isDone()); - assertFalse(future2.isDone()); + SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); + assertTrue(future1.isDone()); + assertFalse(future2.isDone()); - // return this slot to pool - slot1.releaseSlot(); + // return this slot to pool + slot1.releaseSlot(); - // second allocation fulfilled by previous slot returning - SimpleSlot slot2 = future2.get(1, TimeUnit.SECONDS); - assertTrue(future2.isDone()); + // second allocation fulfilled by previous slot returning + SimpleSlot slot2 = future2.get(1, TimeUnit.SECONDS); + assertTrue(future2.isDone()); - assertNotEquals(slot1, slot2); - assertTrue(slot1.isReleased()); - assertTrue(slot2.isAlive()); - assertEquals(slot1.getTaskManagerID(), slot2.getTaskManagerID()); - assertEquals(slot1.getSlotNumber(), slot2.getSlotNumber()); - assertEquals(slotPool.getAllocatedSlots().get(slot1.getAllocatedSlot().getSlotAllocationId()), slot2); + assertNotEquals(slot1, slot2); + assertTrue(slot1.isReleased()); + assertTrue(slot2.isAlive()); + assertEquals(slot1.getTaskManagerID(), slot2.getTaskManagerID()); + assertEquals(slot1.getSlotNumber(), slot2.getSlotNumber()); + assertEquals(slotPool.getAllocatedSlots().get(slot1.getAllocatedSlot().getSlotAllocationId()), slot2); + } finally { + slotPool.shutDown(); + } } @Test public void testAllocateWithFreeSlot() throws Exception { - ResourceID resourceID = new ResourceID("resource"); - slotPool.registerTaskManager(resourceID); + ResourceManagerGateway resourceManagerGateway = createResourceManagerGatewayMock(); + final SlotPool slotPool = new SlotPool(rpcService, jobId); - CompletableFuture future1 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); - assertFalse(future1.isDone()); + try { + SlotPoolGateway slotPoolGateway = setupSlotPool(slotPool, resourceManagerGateway); + ResourceID resourceID = new ResourceID("resource"); + slotPoolGateway.registerTaskManager(resourceID); - ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); - verify(resourceManagerGateway).requestSlot(any(UUID.class), any(UUID.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + CompletableFuture future1 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); + assertFalse(future1.isDone()); - final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); + ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).requestSlot(any(JobMasterId.class), slotRequestArgumentCaptor.capture(), any(Time.class)); - AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - assertTrue(slotPool.offerSlot(allocatedSlot).get()); + final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); - SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); - assertTrue(future1.isDone()); + AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); - // return this slot to pool - slot1.releaseSlot(); + SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); + assertTrue(future1.isDone()); - CompletableFuture future2 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); + // return this slot to pool + slot1.releaseSlot(); - // second allocation fulfilled by previous slot returning - SimpleSlot slot2 = future2.get(1, TimeUnit.SECONDS); - assertTrue(future2.isDone()); + CompletableFuture future2 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); - assertNotEquals(slot1, slot2); - assertTrue(slot1.isReleased()); - assertTrue(slot2.isAlive()); - assertEquals(slot1.getTaskManagerID(), slot2.getTaskManagerID()); - assertEquals(slot1.getSlotNumber(), slot2.getSlotNumber()); + // second allocation fulfilled by previous slot returning + SimpleSlot slot2 = future2.get(1, TimeUnit.SECONDS); + assertTrue(future2.isDone()); + + assertNotEquals(slot1, slot2); + assertTrue(slot1.isReleased()); + assertTrue(slot2.isAlive()); + assertEquals(slot1.getTaskManagerID(), slot2.getTaskManagerID()); + assertEquals(slot1.getSlotNumber(), slot2.getSlotNumber()); + } finally { + slotPool.shutDown(); + } } @Test public void testOfferSlot() throws Exception { - ResourceID resourceID = new ResourceID("resource"); - slotPool.registerTaskManager(resourceID); + ResourceManagerGateway resourceManagerGateway = createResourceManagerGatewayMock(); + final SlotPool slotPool = new SlotPool(rpcService, jobId); + + try { + SlotPoolGateway slotPoolGateway = setupSlotPool(slotPool, resourceManagerGateway); + ResourceID resourceID = new ResourceID("resource"); + slotPoolGateway.registerTaskManager(resourceID); - CompletableFuture future = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); - assertFalse(future.isDone()); + CompletableFuture future = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); + assertFalse(future.isDone()); - ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); - verify(resourceManagerGateway).requestSlot(any(UUID.class), any(UUID.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).requestSlot(any(JobMasterId.class), slotRequestArgumentCaptor.capture(), any(Time.class)); - final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); + final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); - // slot from unregistered resource - AllocatedSlot invalid = createAllocatedSlot(new ResourceID("unregistered"), slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - assertFalse(slotPool.offerSlot(invalid).get()); + // slot from unregistered resource + AllocatedSlot invalid = createAllocatedSlot(new ResourceID("unregistered"), slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + assertFalse(slotPoolGateway.offerSlot(invalid).get()); - AllocatedSlot notRequested = createAllocatedSlot(resourceID, new AllocationID(), jobId, DEFAULT_TESTING_PROFILE); + AllocatedSlot notRequested = createAllocatedSlot(resourceID, new AllocationID(), jobId, DEFAULT_TESTING_PROFILE); - // we'll also accept non requested slots - assertTrue(slotPool.offerSlot(notRequested).get()); + // we'll also accept non requested slots + assertTrue(slotPoolGateway.offerSlot(notRequested).get()); - AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - // accepted slot - assertTrue(slotPool.offerSlot(allocatedSlot).get()); - SimpleSlot slot = future.get(1, TimeUnit.SECONDS); - assertTrue(future.isDone()); - assertTrue(slot.isAlive()); + // accepted slot + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); + SimpleSlot slot = future.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + assertTrue(slot.isAlive()); - // duplicated offer with using slot - assertTrue(slotPool.offerSlot(allocatedSlot).get()); - assertTrue(future.isDone()); - assertTrue(slot.isAlive()); + // duplicated offer with using slot + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); + assertTrue(slot.isAlive()); - // duplicated offer with free slot - slot.releaseSlot(); - assertTrue(slot.isReleased()); - assertTrue(slotPool.offerSlot(allocatedSlot).get()); + // duplicated offer with free slot + slot.releaseSlot(); + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); + } finally { + slotPool.shutDown(); + } } @Test public void testReleaseResource() throws Exception { - ResourceID resourceID = new ResourceID("resource"); - slotPool.registerTaskManager(resourceID); + ResourceManagerGateway resourceManagerGateway = createResourceManagerGatewayMock(); + + final CompletableFuture slotReturnFuture = new CompletableFuture<>(); + + final SlotPool slotPool = new SlotPool(rpcService, jobId) { + @Override + public void returnAllocatedSlot(Slot slot) { + super.returnAllocatedSlot(slot); + + slotReturnFuture.complete(true); + } + }; + + try { + SlotPoolGateway slotPoolGateway = setupSlotPool(slotPool, resourceManagerGateway); + ResourceID resourceID = new ResourceID("resource"); + slotPoolGateway.registerTaskManager(resourceID); + + CompletableFuture future1 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); + + ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).requestSlot(any(JobMasterId.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + + final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); + + CompletableFuture future2 = slotPoolGateway.allocateSlot(mock(ScheduledUnit.class), DEFAULT_TESTING_PROFILE, null, timeout); - CompletableFuture future1 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); + AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); + assertTrue(slotPoolGateway.offerSlot(allocatedSlot).get()); - ArgumentCaptor slotRequestArgumentCaptor = ArgumentCaptor.forClass(SlotRequest.class); - verify(resourceManagerGateway).requestSlot(any(UUID.class), any(UUID.class), slotRequestArgumentCaptor.capture(), any(Time.class)); + SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); + assertTrue(future1.isDone()); + assertFalse(future2.isDone()); - final SlotRequest slotRequest = slotRequestArgumentCaptor.getValue(); + slotPoolGateway.releaseTaskManager(resourceID); - CompletableFuture future2 = slotPool.allocateSlot(mock(ScheduledUnit.class),DEFAULT_TESTING_PROFILE, null, Time.milliseconds(0L)); + // wait until the slot has been returned + slotReturnFuture.get(); - AllocatedSlot allocatedSlot = createAllocatedSlot(resourceID, slotRequest.getAllocationId(), jobId, DEFAULT_TESTING_PROFILE); - assertTrue(slotPool.offerSlot(allocatedSlot).get()); + assertTrue(slot1.isReleased()); + + // slot released and not usable, second allocation still not fulfilled + Thread.sleep(10); + assertFalse(future2.isDone()); + } finally { + slotPool.shutDown(); + } + } + + private static ResourceManagerGateway createResourceManagerGatewayMock() { + ResourceManagerGateway resourceManagerGateway = mock(ResourceManagerGateway.class); + when(resourceManagerGateway + .requestSlot(any(JobMasterId.class), any(SlotRequest.class), any(Time.class))) + .thenReturn(mock(CompletableFuture.class, RETURNS_MOCKS)); + + return resourceManagerGateway; + } + + private static SlotPoolGateway setupSlotPool( + SlotPool slotPool, + ResourceManagerGateway resourceManagerGateway) throws Exception { + final String jobManagerAddress = "foobar"; - SimpleSlot slot1 = future1.get(1, TimeUnit.SECONDS); - assertTrue(future1.isDone()); - assertFalse(future2.isDone()); + slotPool.start(JobMasterId.generate(), jobManagerAddress); - slotPool.releaseTaskManager(resourceID); - assertTrue(slot1.isReleased()); + slotPool.connectToResourceManager(resourceManagerGateway); - // slot released and not usable, second allocation still not fulfilled - Thread.sleep(10); - assertFalse(future2.isDone()); + return slotPool.getSelfGateway(SlotPoolGateway.class); } static AllocatedSlot createAllocatedSlot( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java index a186d56e9c9c7..03f82d81e0c24 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java @@ -18,9 +18,11 @@ package org.apache.flink.runtime.io.network.buffer; -import com.google.common.collect.Lists; import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.util.event.EventListener; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.After; import org.junit.AfterClass; import org.junit.Before; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java index 8f8da9340cf18..fa62593387122 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition; -import com.google.common.collect.Lists; import org.apache.flink.runtime.io.disk.iomanager.BufferFileWriter; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; @@ -29,6 +28,9 @@ import org.apache.flink.runtime.io.network.util.TestInfiniteBufferProvider; import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; import org.apache.flink.runtime.io.network.util.TestSubpartitionConsumer; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.AfterClass; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index fe819a419b749..e685f17bb1230 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.execution.CancelTaskException; @@ -42,6 +41,9 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.taskmanager.TaskActions; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index 1d30a9a999582..4a32d7380b342 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import com.google.common.collect.Lists; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; @@ -27,6 +26,9 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import scala.Tuple2; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java index 339b6f44ca74e..d7e96439eac74 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java @@ -18,13 +18,14 @@ package org.apache.flink.runtime.io.network.util; -import com.google.common.collect.Queues; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.shaded.guava18.com.google.common.collect.Queues; + import java.io.IOException; import java.util.Queue; import java.util.concurrent.ArrayBlockingQueue; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BlockingBackChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BlockingBackChannelTest.java index c9015e9dacb49..f7b5c1dee5e22 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BlockingBackChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BlockingBackChannelTest.java @@ -22,10 +22,10 @@ import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.runtime.iterative.io.SerializedUpdateBuffer; -import com.google.common.collect.Lists; import org.junit.Test; import org.mockito.Mockito; +import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.concurrent.ArrayBlockingQueue; @@ -46,7 +46,7 @@ public class BlockingBackChannelTest { public void multiThreaded() throws InterruptedException { BlockingQueue dataChannel = new ArrayBlockingQueue(1); - List actionLog = Lists.newArrayList(); + List actionLog = new ArrayList<>(); SerializedUpdateBuffer buffer = Mockito.mock(SerializedUpdateBuffer.class); BlockingBackChannel channel = new BlockingBackChannel(buffer); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BrokerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BrokerTest.java index e12cb32c52ff9..e462e082551f6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BrokerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/iterative/concurrent/BrokerTest.java @@ -20,9 +20,9 @@ import org.apache.flink.util.Preconditions; -import com.google.common.collect.Lists; import org.junit.Test; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; @@ -52,7 +52,7 @@ void mediate(int subtasks) throws InterruptedException, ExecutionException { final ExecutorService executorService = Executors.newFixedThreadPool(subtasks * 2); try { - List> tasks = Lists.newArrayList(); + List> tasks = new ArrayList<>(); Broker broker = new Broker(); for (int subtask = 0; subtask < subtasks; subtask++) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerCleanupITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerCleanupITCase.java new file mode 100644 index 0000000000000..b2b455b8d7659 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerCleanupITCase.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobmanager; + +import akka.actor.ActorSystem; +import akka.testkit.JavaTestKit; +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.AkkaOptions; +import org.apache.flink.configuration.BlobServerOptions; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.akka.ListeningBehaviour; +import org.apache.flink.runtime.blob.BlobClient; +import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.instance.ActorGateway; +import org.apache.flink.runtime.instance.AkkaActorGateway; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.messages.JobManagerMessages; +import org.apache.flink.runtime.testingUtils.TestingCluster; +import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.runtime.testtasks.FailingBlockingInvokable; +import org.apache.flink.runtime.testtasks.NoOpInvokable; +import org.apache.flink.util.TestLogger; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Arrays; + +import static org.apache.flink.runtime.testingUtils.TestingUtils.DEFAULT_AKKA_ASK_TIMEOUT; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +/** + * Small test to check that the {@link org.apache.flink.runtime.blob.BlobServer} cleanup is executed + * after job termination. + */ +public class JobManagerCleanupITCase extends TestLogger { + + @Rule + public TemporaryFolder tmpFolder = new TemporaryFolder(); + + private static ActorSystem system; + + @BeforeClass + public static void setup() { + system = AkkaUtils.createLocalActorSystem(new Configuration()); + } + + @AfterClass + public static void teardown() { + JavaTestKit.shutdownActorSystem(system); + } + + /** + * Specifies which test case to run in {@link #testBlobServerCleanup(TestCase)}. + */ + private enum TestCase { + JOB_FINISHES_SUCESSFULLY, + JOB_IS_CANCELLED, + JOB_FAILS, + JOB_SUBMISSION_FAILS + } + + /** + * Test cleanup for a job that finishes ordinarily. + */ + @Test + public void testBlobServerCleanupFinishedJob() throws IOException { + testBlobServerCleanup(TestCase.JOB_FINISHES_SUCESSFULLY); + } + + /** + * Test cleanup for a job which is cancelled after submission. + */ + @Test + public void testBlobServerCleanupCancelledJob() throws IOException { + testBlobServerCleanup(TestCase.JOB_IS_CANCELLED); + } + + /** + * Test cleanup for a job that fails (first a task fails, then the job recovers, then the whole + * job fails due to a limited restart policy). + */ + @Test + public void testBlobServerCleanupFailedJob() throws IOException { + testBlobServerCleanup(TestCase.JOB_FAILS); + } + + /** + * Test cleanup for a job that fails job submission (emulated by an additional BLOB not being + * present). + */ + @Test + public void testBlobServerCleanupFailedSubmission() throws IOException { + testBlobServerCleanup(TestCase.JOB_SUBMISSION_FAILS); + } + + private void testBlobServerCleanup(final TestCase testCase) throws IOException { + final int num_tasks = 2; + final File blobBaseDir = tmpFolder.newFolder(); + + new JavaTestKit(system) {{ + new Within(duration("30 seconds")) { + @Override + protected void run() { + // Setup + + TestingCluster cluster = null; + BlobClient bc = null; + + try { + Configuration config = new Configuration(); + config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 2); + config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 1); + config.setString(AkkaOptions.ASK_TIMEOUT, DEFAULT_AKKA_ASK_TIMEOUT()); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, blobBaseDir.getAbsolutePath()); + + config.setString(ConfigConstants.RESTART_STRATEGY, "fixeddelay"); + config.setInteger(ConfigConstants.RESTART_STRATEGY_FIXED_DELAY_ATTEMPTS, 1); + config.setString(ConfigConstants.RESTART_STRATEGY_FIXED_DELAY_DELAY, "1 s"); + // BLOBs are deleted from BlobCache between 1s and 2s after last reference + // -> the BlobCache may still have the BLOB or not (let's test both cases randomly) + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + cluster = new TestingCluster(config); + cluster.start(); + + final ActorGateway jobManagerGateway = cluster.getLeaderGateway( + TestingUtils.TESTING_DURATION()); + + // we can set the leader session ID to None because we don't use this gateway to send messages + final ActorGateway testActorGateway = new AkkaActorGateway(getTestActor(), + HighAvailabilityServices.DEFAULT_LEADER_ID); + + // Create a task + + JobVertex source = new JobVertex("Source"); + if (testCase == TestCase.JOB_FAILS || testCase == TestCase.JOB_IS_CANCELLED) { + source.setInvokableClass(FailingBlockingInvokable.class); + } else { + source.setInvokableClass(NoOpInvokable.class); + } + source.setParallelism(num_tasks); + + JobGraph jobGraph = new JobGraph("BlobCleanupTest", source); + final JobID jid = jobGraph.getJobID(); + + // request the blob port from the job manager + Future future = jobManagerGateway + .ask(JobManagerMessages.getRequestBlobManagerPort(), remaining()); + int blobPort = (Integer) Await.result(future, remaining()); + + // upload a blob + BlobKey key1; + bc = new BlobClient(new InetSocketAddress("localhost", blobPort), + config); + try { + key1 = bc.put(jid, new byte[10]); + } finally { + bc.close(); + } + jobGraph.addBlob(key1); + + if (testCase == TestCase.JOB_SUBMISSION_FAILS) { + // add an invalid key so that the submission fails + jobGraph.addBlob(new BlobKey()); + } + + // Submit the job and wait for all vertices to be running + jobManagerGateway.tell( + new JobManagerMessages.SubmitJob( + jobGraph, + ListeningBehaviour.EXECUTION_RESULT), + testActorGateway); + if (testCase == TestCase.JOB_SUBMISSION_FAILS) { + expectMsgClass(JobManagerMessages.JobResultFailure.class); + } else { + expectMsgClass(JobManagerMessages.JobSubmitSuccess.class); + + if (testCase == TestCase.JOB_FAILS) { + // fail a task so that the job is going to be recovered (we actually do not + // need the blocking part of the invokable and can start throwing right away) + FailingBlockingInvokable.unblock(); + + // job will get restarted, BlobCache may re-download the BLOB if already deleted + // then the tasks will fail again and the restart strategy will finalise the job + + expectMsgClass(JobManagerMessages.JobResultFailure.class); + } else if (testCase == TestCase.JOB_IS_CANCELLED) { + jobManagerGateway.tell( + new JobManagerMessages.CancelJob(jid), + testActorGateway); + expectMsgClass(JobManagerMessages.CancellationResponse.class); + + // job will be cancelled and everything should be cleaned up + + expectMsgClass(JobManagerMessages.JobResultFailure.class); + } else { + expectMsgClass(JobManagerMessages.JobResultSuccess.class); + } + } + + // both BlobServer and BlobCache should eventually delete all files + + File[] blobDirs = blobBaseDir.listFiles(new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.startsWith("blobStore-"); + } + }); + assertNotNull(blobDirs); + for (File blobDir : blobDirs) { + waitForEmptyBlobDir(blobDir, remaining()); + } + + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + if (bc != null) { + try { + bc.close(); + } catch (IOException ignored) { + } + } + if (cluster != null) { + cluster.shutdown(); + } + } + } + }; + }}; + + // after everything has been shut down, the storage directory itself should be empty + assertArrayEquals(new File[] {}, blobBaseDir.listFiles()); + } + + /** + * Waits until the given {@link org.apache.flink.runtime.blob.BlobService} storage directory + * does not contain any job-related folders any more. + * + * @param blobDir + * directory of a {@link org.apache.flink.runtime.blob.BlobServer} or {@link + * org.apache.flink.runtime.blob.BlobCache} + * @param remaining + * remaining time for this test + * + * @see org.apache.flink.runtime.blob.BlobUtils + */ + private static void waitForEmptyBlobDir(File blobDir, FiniteDuration remaining) + throws InterruptedException { + long deadline = System.currentTimeMillis() + remaining.toMillis(); + String[] blobDirContents; + do { + blobDirContents = blobDir.list(new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.startsWith("job_"); + } + }); + if (blobDirContents == null || blobDirContents.length == 0) { + return; + } + Thread.sleep(100); + } while (System.currentTimeMillis() < deadline); + + fail("Timeout while waiting for " + blobDir.getAbsolutePath() + " to become empty. Current contents: " + Arrays.toString(blobDirContents)); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index a63b02d785f19..173730a80e884 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -18,17 +18,8 @@ package org.apache.flink.runtime.jobmanager; -import akka.actor.ActorRef; -import akka.actor.ActorSystem; -import akka.actor.Identify; -import akka.actor.PoisonPill; -import akka.actor.Props; -import akka.japi.pf.FI; -import akka.japi.pf.ReceiveBuilder; -import akka.pattern.Patterns; -import akka.testkit.CallingThreadDispatcher; -import akka.testkit.JavaTestKit; import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.HighAvailabilityOptions; @@ -37,15 +28,15 @@ import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; import org.apache.flink.runtime.blob.BlobServer; -import org.apache.flink.runtime.blob.BlobService; import org.apache.flink.runtime.checkpoint.CheckpointIDCounter; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy; @@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; @@ -69,9 +61,7 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -83,22 +73,25 @@ import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; - +import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Identify; +import akka.actor.PoisonPill; +import akka.actor.Props; +import akka.japi.pf.FI; +import akka.japi.pf.ReceiveBuilder; +import akka.pattern.Patterns; +import akka.testkit.CallingThreadDispatcher; +import akka.testkit.JavaTestKit; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import scala.Int; -import scala.Option; -import scala.PartialFunction; -import scala.concurrent.Await; -import scala.concurrent.Future; -import scala.concurrent.duration.Deadline; -import scala.concurrent.duration.FiniteDuration; -import scala.runtime.BoxedUnit; import java.util.ArrayList; import java.util.Arrays; @@ -107,12 +100,22 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import scala.Int; +import scala.Option; +import scala.PartialFunction; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.runtime.BoxedUnit; + import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -159,6 +162,7 @@ public void testJobRecoveryWhenLosingLeadership() throws Exception { flinkConfiguration.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); flinkConfiguration.setString(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.newFolder().toString()); flinkConfiguration.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, slots); + flinkConfiguration.setLong(BlobServerOptions.CLEANUP_INTERVAL, 3_600L); try { Scheduler scheduler = new Scheduler(TestingUtils.defaultExecutionContext()); @@ -180,6 +184,9 @@ public void testJobRecoveryWhenLosingLeadership() throws Exception { archive = system.actorOf(JobManager.getArchiveProps(MemoryArchivist.class, 10, Option.empty())); + BlobServer blobServer = new BlobServer( + flinkConfiguration, + testingHighAvailabilityServices.createBlobStore()); Props jobManagerProps = Props.create( TestingJobManager.class, flinkConfiguration, @@ -187,11 +194,8 @@ public void testJobRecoveryWhenLosingLeadership() throws Exception { TestingUtils.defaultExecutor(), instanceManager, scheduler, - new BlobLibraryCacheManager( - new BlobServer( - flinkConfiguration, - testingHighAvailabilityServices.createBlobStore()), - 3600000L), + blobServer, + new BlobLibraryCacheManager(blobServer), archive, new FixedDelayRestartStrategy.FixedDelayRestartStrategyFactory(Int.MaxValue(), 100), timeout, @@ -354,6 +358,7 @@ public void testFailingJobRecovery() throws Exception { final Collection recoveredJobs = new ArrayList<>(2); + BlobServer blobServer = mock(BlobServer.class); Props jobManagerProps = Props.create( TestingFailingHAJobManager.class, flinkConfiguration, @@ -361,7 +366,8 @@ public void testFailingJobRecovery() throws Exception { TestingUtils.defaultExecutor(), mock(InstanceManager.class), mock(Scheduler.class), - new BlobLibraryCacheManager(mock(BlobService.class), 1 << 20), + blobServer, + new BlobLibraryCacheManager(blobServer), ActorRef.noSender(), new FixedDelayRestartStrategy.FixedDelayRestartStrategyFactory(Int.MaxValue(), 100), timeout, @@ -398,6 +404,7 @@ public TestingFailingHAJobManager( Executor ioExecutor, InstanceManager instanceManager, Scheduler scheduler, + BlobServer blobServer, BlobLibraryCacheManager libraryCacheManager, ActorRef archive, RestartStrategyFactory restartStrategyFactory, @@ -414,6 +421,7 @@ public TestingFailingHAJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, @@ -552,10 +560,11 @@ public static class BlockingStatefulInvokable extends BlockingInvokable implemen @Override public void setInitialState( - TaskStateHandles taskStateHandles) throws Exception { + TaskStateSnapshot taskStateHandles) throws Exception { int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { - try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) { + OperatorStateHandle operatorStateHandle = extractSingletonOperatorState(taskStateHandles); + try (FSDataInputStream in = operatorStateHandle.openInputStream()) { recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); } } @@ -567,10 +576,21 @@ public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, Checkpoi String.valueOf(UUID.randomUUID()), InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId())); - ChainedStateHandle chainedStateHandle = - new ChainedStateHandle(Collections.singletonList(byteStreamStateHandle)); - SubtaskState checkpointStateHandles = - new SubtaskState(chainedStateHandle, null, null, null, null); + Map stateNameToPartitionOffsets = new HashMap<>(1); + stateNameToPartitionOffsets.put( + "test-state", + new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + + OperatorStateHandle operatorStateHandle = new OperatorStateHandle(stateNameToPartitionOffsets, byteStreamStateHandle); + + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()), + new OperatorSubtaskState( + Collections.singletonList(operatorStateHandle), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList())); getEnvironment().acknowledgeCheckpoint( checkpointMetaData.getCheckpointId(), @@ -608,5 +628,17 @@ public static void awaitCompletedCheckpoints() throws InterruptedException { public static long[] getRecoveredStates() { return recoveredStates; } + + private static OperatorStateHandle extractSingletonOperatorState(TaskStateSnapshot taskStateHandles) { + Set> subtaskStateMappings = taskStateHandles.getSubtaskStateMappings(); + Preconditions.checkNotNull(subtaskStateMappings); + Preconditions.checkState(subtaskStateMappings.size() == 1); + OperatorSubtaskState subtaskState = subtaskStateMappings.iterator().next().getValue(); + Collection managedOperatorState = + Preconditions.checkNotNull(subtaskState).getManagedOperatorState(); + Preconditions.checkNotNull(managedOperatorState); + Preconditions.checkState(managedOperatorState.size() == 1); + return managedOperatorState.iterator().next(); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerStartupTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerStartupTest.java index 82b510ca6d38f..838b12458de74 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerStartupTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerStartupTest.java @@ -33,9 +33,9 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.util.StartupUtils; import org.apache.flink.util.NetUtils; - import org.apache.flink.util.OperatingSystem; import org.apache.flink.util.TestLogger; + import org.junit.After; import org.junit.Before; import org.junit.Rule; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobSubmitTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobSubmitTest.java index 79b9c1c56765d..6a39293342d92 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobSubmitTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobSubmitTest.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.jobmanager; import akka.actor.ActorSystem; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.runtime.akka.AkkaUtils; @@ -136,12 +137,13 @@ public void testFailureWhenJarBlobsMissing() { // upload two dummy bytes and add their keys to the job graph as dependencies BlobKey key1, key2; BlobClient bc = new BlobClient(new InetSocketAddress("localhost", blobPort), jmConfig); + JobID jobId = jg.getJobID(); try { - key1 = bc.put(new byte[10]); - key2 = bc.put(new byte[10]); + key1 = bc.put(jobId, new byte[10]); + key2 = bc.put(jobId, new byte[10]); // delete one of the blobs to make sure that the startup failed - bc.delete(key2); + bc.delete(jobId, key2); } finally { bc.close(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java index eb4d96fb29759..9c781ec101c18 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.jobmanager.scheduler; -import com.google.common.collect.Lists; - import org.apache.flink.runtime.io.network.api.writer.RecordWriter; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -31,6 +29,8 @@ import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.types.IntValue; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobManagerRunnerMockTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobManagerRunnerMockTest.java index 435c23dcd0c61..b4f50fbb01668 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobManagerRunnerMockTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobManagerRunnerMockTest.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.blob.BlobStore; @@ -132,7 +133,7 @@ public void testStartAndShutdown() throws Exception { assertTrue(!jobCompletion.isJobFinished()); assertTrue(!jobCompletion.isJobFailed()); - verify(jobManager).start(any(UUID.class)); + verify(jobManager).start(any(JobMasterId.class), any(Time.class)); runner.shutdown(); verify(leaderElectionService).stop(); @@ -164,9 +165,9 @@ public void testShutdownBeforeGrantLeadership() throws Exception { public void testJobFinished() throws Exception { runner.start(); - UUID leaderSessionID = UUID.randomUUID(); - runner.grantLeadership(leaderSessionID); - verify(jobManager).start(leaderSessionID); + JobMasterId jobMasterId = JobMasterId.generate(); + runner.grantLeadership(jobMasterId.toUUID()); + verify(jobManager).start(eq(jobMasterId), any(Time.class)); assertTrue(!jobCompletion.isJobFinished()); // runner been told by JobManager that job is finished @@ -184,9 +185,9 @@ public void testJobFinished() throws Exception { public void testJobFailed() throws Exception { runner.start(); - UUID leaderSessionID = UUID.randomUUID(); - runner.grantLeadership(leaderSessionID); - verify(jobManager).start(leaderSessionID); + JobMasterId jobMasterId = JobMasterId.generate(); + runner.grantLeadership(jobMasterId.toUUID()); + verify(jobManager).start(eq(jobMasterId), any(Time.class)); assertTrue(!jobCompletion.isJobFinished()); // runner been told by JobManager that job is failed @@ -203,13 +204,13 @@ public void testJobFailed() throws Exception { public void testLeadershipRevoked() throws Exception { runner.start(); - UUID leaderSessionID = UUID.randomUUID(); - runner.grantLeadership(leaderSessionID); - verify(jobManager).start(leaderSessionID); + JobMasterId jobMasterId = JobMasterId.generate(); + runner.grantLeadership(jobMasterId.toUUID()); + verify(jobManager).start(eq(jobMasterId), any(Time.class)); assertTrue(!jobCompletion.isJobFinished()); runner.revokeLeadership(); - verify(jobManager).suspendExecution(any(Throwable.class)); + verify(jobManager).suspend(any(Throwable.class), any(Time.class)); assertFalse(runner.isShutdown()); } @@ -218,18 +219,18 @@ public void testLeadershipRevoked() throws Exception { public void testRegainLeadership() throws Exception { runner.start(); - UUID leaderSessionID = UUID.randomUUID(); - runner.grantLeadership(leaderSessionID); - verify(jobManager).start(leaderSessionID); + JobMasterId jobMasterId = JobMasterId.generate(); + runner.grantLeadership(jobMasterId.toUUID()); + verify(jobManager).start(eq(jobMasterId), any(Time.class)); assertTrue(!jobCompletion.isJobFinished()); runner.revokeLeadership(); - verify(jobManager).suspendExecution(any(Throwable.class)); + verify(jobManager).suspend(any(Throwable.class), any(Time.class)); assertFalse(runner.isShutdown()); - UUID leaderSessionID2 = UUID.randomUUID(); - runner.grantLeadership(leaderSessionID2); - verify(jobManager).start(leaderSessionID2); + JobMasterId jobMasterId2 = JobMasterId.generate(); + runner.grantLeadership(jobMasterId2.toUUID()); + verify(jobManager).start(eq(jobMasterId2), any(Time.class)); } private static class TestingOnCompletionActions implements OnCompletionActions, FatalErrorHandler { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterTest.java index 0c4d3762884e7..64cc13b624094 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterTest.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.concurrent.ScheduledExecutor; @@ -33,8 +34,11 @@ import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobmanager.OnCompletionActions; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; +import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.registration.RegistrationResponse; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.TestingFatalErrorHandler; @@ -47,21 +51,19 @@ import java.net.InetAddress; import java.net.URL; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; -import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(BlobLibraryCacheManager.class) public class JobMasterTest extends TestLogger { + private final Time testingTimeout = Time.seconds(10L); + @Test public void testHeartbeatTimeoutWithTaskManager() throws Exception { final TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); @@ -73,7 +75,7 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); final String jobManagerAddress = "jm"; - final UUID jmLeaderId = UUID.randomUUID(); + final JobMasterId jobMasterId = JobMasterId.generate(); final ResourceID jmResourceId = new ResourceID(jobManagerAddress); final String taskManagerAddress = "tm"; @@ -81,7 +83,7 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { final TaskManagerLocation taskManagerLocation = new TaskManagerLocation(tmResourceId, InetAddress.getLoopbackAddress(), 1234); final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); + final TestingRpcService rpc = new TestingRpcService(); rpc.registerGateway(taskManagerAddress, taskExecutorGateway); final long heartbeatInterval = 1L; @@ -89,6 +91,8 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { final ScheduledExecutor scheduledExecutor = mock(ScheduledExecutor.class); final HeartbeatServices heartbeatServices = new TestingHeartbeatServices(heartbeatInterval, heartbeatTimeout, scheduledExecutor); + BlobServer blobServer = mock(BlobServer.class); + when(blobServer.getPort()).thenReturn(1337); final JobGraph jobGraph = new JobGraph(); @@ -101,18 +105,28 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { haServices, heartbeatServices, Executors.newScheduledThreadPool(1), + blobServer, mock(BlobLibraryCacheManager.class), mock(RestartStrategyFactory.class), - Time.of(10, TimeUnit.SECONDS), + testingTimeout, null, mock(OnCompletionActions.class), testingFatalErrorHandler, new FlinkUserCodeClassLoader(new URL[0])); - jobMaster.start(jmLeaderId); + CompletableFuture startFuture = jobMaster.start(jobMasterId, testingTimeout); + + // wait for the start to complete + startFuture.get(testingTimeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + final JobMasterGateway jobMasterGateway = jobMaster.getSelfGateway(JobMasterGateway.class); // register task manager will trigger monitor heartbeat target, schedule heartbeat request at interval time - jobMaster.registerTaskManager(taskManagerAddress, taskManagerLocation, jmLeaderId, Time.milliseconds(0L)); + CompletableFuture registrationResponse = jobMasterGateway + .registerTaskManager(taskManagerAddress, taskManagerLocation, testingTimeout); + + // wait for the completion of the registration + registrationResponse.get(); ArgumentCaptor heartbeatRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); verify(scheduledExecutor, times(1)).scheduleAtFixedRate( @@ -124,7 +138,7 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { Runnable heartbeatRunnable = heartbeatRunnableCaptor.getValue(); ArgumentCaptor timeoutRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); - verify(scheduledExecutor).schedule(timeoutRunnableCaptor.capture(), eq(heartbeatTimeout), eq(TimeUnit.MILLISECONDS)); + verify(scheduledExecutor, timeout(testingTimeout.toMilliseconds())).schedule(timeoutRunnableCaptor.capture(), eq(heartbeatTimeout), eq(TimeUnit.MILLISECONDS)); Runnable timeoutRunnable = timeoutRunnableCaptor.getValue(); @@ -136,7 +150,7 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { // run the timeout runnable to simulate a heartbeat timeout timeoutRunnable.run(); - verify(taskExecutorGateway).disconnectJobManager(eq(jobGraph.getJobID()), any(TimeoutException.class)); + verify(taskExecutorGateway, timeout(testingTimeout.toMilliseconds())).disconnectJobManager(eq(jobGraph.getJobID()), any(TimeoutException.class)); // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); @@ -150,8 +164,8 @@ public void testHeartbeatTimeoutWithTaskManager() throws Exception { public void testHeartbeatTimeoutWithResourceManager() throws Exception { final String resourceManagerAddress = "rm"; final String jobManagerAddress = "jm"; - final UUID rmLeaderId = UUID.randomUUID(); - final UUID jmLeaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); + final JobMasterId jobMasterId = JobMasterId.generate(); final ResourceID rmResourceId = new ResourceID(resourceManagerAddress); final ResourceID jmResourceId = new ResourceID(jobManagerAddress); final JobGraph jobGraph = new JobGraph(); @@ -170,16 +184,15 @@ public void testHeartbeatTimeoutWithResourceManager() throws Exception { final ResourceManagerGateway resourceManagerGateway = mock(ResourceManagerGateway.class); when(resourceManagerGateway.registerJobManager( - any(UUID.class), - any(UUID.class), + any(JobMasterId.class), any(ResourceID.class), anyString(), any(JobID.class), any(Time.class) )).thenReturn(CompletableFuture.completedFuture(new JobMasterRegistrationSuccess( - heartbeatInterval, rmLeaderId, rmResourceId))); + heartbeatInterval, resourceManagerId, rmResourceId))); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); + final TestingRpcService rpc = new TestingRpcService(); rpc.registerGateway(resourceManagerAddress, resourceManagerGateway); final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); @@ -193,23 +206,26 @@ public void testHeartbeatTimeoutWithResourceManager() throws Exception { haServices, heartbeatServices, Executors.newScheduledThreadPool(1), + mock(BlobServer.class), mock(BlobLibraryCacheManager.class), mock(RestartStrategyFactory.class), - Time.of(10, TimeUnit.SECONDS), + testingTimeout, null, mock(OnCompletionActions.class), testingFatalErrorHandler, new FlinkUserCodeClassLoader(new URL[0])); - jobMaster.start(jmLeaderId); + CompletableFuture startFuture = jobMaster.start(jobMasterId, testingTimeout); + + // wait for the start operation to complete + startFuture.get(testingTimeout.toMilliseconds(), TimeUnit.MILLISECONDS); // define a leader and see that a registration happens - rmLeaderRetrievalService.notifyListener(resourceManagerAddress, rmLeaderId); + rmLeaderRetrievalService.notifyListener(resourceManagerAddress, resourceManagerId.toUUID()); // register job manager success will trigger monitor heartbeat target between jm and rm - verify(resourceManagerGateway).registerJobManager( - eq(rmLeaderId), - eq(jmLeaderId), + verify(resourceManagerGateway, timeout(testingTimeout.toMilliseconds())).registerJobManager( + eq(jobMasterId), eq(jmResourceId), anyString(), eq(jobGraph.getJobID()), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/JobManagerLeaderElectionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/JobManagerLeaderElectionTest.java index 70800e50ad9d7..230ca91417792 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/JobManagerLeaderElectionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/JobManagerLeaderElectionTest.java @@ -25,9 +25,10 @@ import akka.pattern.Patterns; import akka.testkit.JavaTestKit; import akka.util.Timeout; - import org.apache.curator.framework.CuratorFramework; import org.apache.curator.test.TestingServer; +import org.apache.flink.configuration.BlobServerOptions; +import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.blob.BlobServer; @@ -47,13 +48,11 @@ import org.apache.flink.runtime.testutils.ZooKeeperTestUtils; import org.apache.flink.runtime.util.ZooKeeperUtils; import org.apache.flink.util.TestLogger; - import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; - import scala.Option; import scala.concurrent.Await; import scala.concurrent.Future; @@ -178,6 +177,9 @@ private Props createJobManagerProps(Configuration configuration) throws Exceptio SubmittedJobGraphStore submittedJobGraphStore = new StandaloneSubmittedJobGraphStore(); CheckpointRecoveryFactory checkpointRecoveryFactory = new StandaloneCheckpointRecoveryFactory(); + configuration.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + BlobServer blobServer = new BlobServer(configuration, new VoidBlobStore()); return Props.create( TestingJobManager.class, configuration, @@ -185,7 +187,8 @@ private Props createJobManagerProps(Configuration configuration) throws Exceptio TestingUtils.defaultExecutor(), new InstanceManager(), new Scheduler(TestingUtils.defaultExecutionContext()), - new BlobLibraryCacheManager(new BlobServer(configuration, new VoidBlobStore()), 10L), + blobServer, + new BlobLibraryCacheManager(blobServer), ActorRef.noSender(), new NoRestartStrategy.NoRestartStrategyFactory(), AkkaUtils.getDefaultTimeoutAsFiniteDuration(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/TestingLeaderElectionService.java b/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/TestingLeaderElectionService.java index d4560833e8ab5..d951db5de39a8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/TestingLeaderElectionService.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/leaderelection/TestingLeaderElectionService.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.leaderelection; import java.util.UUID; +import java.util.concurrent.CompletableFuture; /** * Test {@link LeaderElectionService} implementation which directly forwards isLeader and notLeader @@ -28,43 +29,61 @@ public class TestingLeaderElectionService implements LeaderElectionService { private LeaderContender contender; private boolean hasLeadership = false; + private CompletableFuture confirmationFuture = null; + + /** + * Gets a future that completes when leadership is confirmed. + * + *

Note: the future is created upon calling {@link #isLeader(UUID)}. + */ + public synchronized CompletableFuture getConfirmationFuture() { + return confirmationFuture; + } @Override - public void start(LeaderContender contender) throws Exception { + public synchronized void start(LeaderContender contender) throws Exception { this.contender = contender; } @Override - public void stop() throws Exception { + public synchronized void stop() throws Exception { } @Override - public void confirmLeaderSessionID(UUID leaderSessionID) { - + public synchronized void confirmLeaderSessionID(UUID leaderSessionID) { + if (confirmationFuture != null) { + confirmationFuture.complete(leaderSessionID); + } } @Override - public boolean hasLeadership() { + public synchronized boolean hasLeadership() { return hasLeadership; } - public void isLeader(UUID leaderSessionID) { + public synchronized CompletableFuture isLeader(UUID leaderSessionID) { + if (confirmationFuture != null) { + confirmationFuture.cancel(false); + } + confirmationFuture = new CompletableFuture<>(); hasLeadership = true; contender.grantLeadership(leaderSessionID); + + return confirmationFuture; } - public void notLeader() { + public synchronized void notLeader() { hasLeadership = false; contender.revokeLeadership(); } - public void reset() { + public synchronized void reset() { contender = null; hasLeadership = false; } - public String getAddress() { + public synchronized String getAddress() { return contender.getAddress(); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java index bc420cc27799b..b36ac86e4b85f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java @@ -24,14 +24,17 @@ import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.StreamStateHandle; + import org.junit.Test; import java.io.IOException; @@ -68,13 +71,16 @@ public void testConfirmTaskCheckpointed() { KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42); - SubtaskState checkpointStateHandles = - new SubtaskState( - CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), - CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), - null, - CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), - null); + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + new OperatorID(), + new OperatorSubtaskState( + CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), + null, + CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), + null + ) + ); AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint( new JobID(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/MetricRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/MetricRegistryTest.java index 5568467452771..ccbb4f4d9a443 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/MetricRegistryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/MetricRegistryTest.java @@ -205,6 +205,10 @@ public void testReporterScheduling() throws InterruptedException { MetricRegistry registry = new MetricRegistry(MetricRegistryConfiguration.fromConfiguration(config)); long start = System.currentTimeMillis(); + + // only start counting from now on + TestReporter3.reportCount = 0; + for (int x = 0; x < 10; x++) { Thread.sleep(100); int reportCount = TestReporter3.reportCount; @@ -218,7 +222,7 @@ public void testReporterScheduling() throws InterruptedException { * or after T=50. */ long maxAllowedReports = (curT - start) / 50 + 2; - Assert.assertTrue("Too many report were triggered.", maxAllowedReports >= reportCount); + Assert.assertTrue("Too many reports were triggered.", maxAllowedReports >= reportCount); } Assert.assertTrue("No report was triggered.", TestReporter3.reportCount > 0); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/AbstractOuterJoinTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/AbstractOuterJoinTaskTest.java index b265eaef13eb3..ac57d8d77077b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/AbstractOuterJoinTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/AbstractOuterJoinTaskTest.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.operators; -import com.google.common.base.Throwables; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -37,6 +36,9 @@ import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator; import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator; import org.apache.flink.util.Collector; + +import org.apache.flink.shaded.guava18.com.google.common.base.Throwables; + import org.junit.Assert; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/LeftOuterJoinTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/LeftOuterJoinTaskTest.java index 266723a778fe8..3b4c705fc8054 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/LeftOuterJoinTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/LeftOuterJoinTaskTest.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.operators; -import com.google.common.base.Throwables; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; @@ -28,6 +27,9 @@ import org.apache.flink.runtime.operators.testutils.ExpectedTestException; import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator; import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator; + +import org.apache.flink.shaded.guava18.com.google.common.base.Throwables; + import org.junit.Assert; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/RightOuterJoinTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/RightOuterJoinTaskTest.java index 4d410316d89a6..7f1f2b465160b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/RightOuterJoinTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/RightOuterJoinTaskTest.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.operators; -import com.google.common.base.Throwables; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; @@ -28,6 +27,9 @@ import org.apache.flink.runtime.operators.testutils.ExpectedTestException; import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator; import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator; + +import org.apache.flink.shaded.guava18.com.google.common.base.Throwables; + import org.junit.Assert; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/InPlaceMutableHashTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/InPlaceMutableHashTableTest.java index 4db5ef80d087f..beeccecd21c82 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/InPlaceMutableHashTableTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/InPlaceMutableHashTableTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.operators.hash; -import com.google.common.collect.Ordering; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.util.CopyingListCollector; import org.apache.flink.api.common.typeutils.SameTypePairComparator; @@ -37,6 +36,9 @@ import org.apache.flink.runtime.operators.testutils.types.*; import org.apache.flink.util.Collector; import org.apache.flink.util.MutableObjectIterator; + +import org.apache.flink.shaded.guava18.com.google.common.collect.Ordering; + import org.junit.Test; import java.io.EOFException; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 851fa967be729..8ed06b2ef3682 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -26,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -156,7 +156,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 4f0242e131ab5..7514cc4200d74 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -27,7 +27,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -50,8 +50,8 @@ import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.types.Record; import org.apache.flink.util.MutableObjectIterator; - import org.apache.flink.util.Preconditions; + import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -354,7 +354,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { throw new UnsupportedOperationException(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RegisteredRpcConnectionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RegisteredRpcConnectionTest.java index a4548673531fc..19a57563c3815 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RegisteredRpcConnectionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RegisteredRpcConnectionTest.java @@ -141,7 +141,7 @@ public void testRpcConnectionClose() throws Exception { // test RegisteredRpcConnection // ------------------------------------------------------------------------ - private static class TestRpcConnection extends RegisteredRpcConnection { + private static class TestRpcConnection extends RegisteredRpcConnection { private final RpcService rpcService; @@ -155,7 +155,7 @@ public TestRpcConnection(String targetAddress, UUID targetLeaderId, Executor exe } @Override - protected RetryingRegistration generateRegistration() { + protected RetryingRegistration generateRegistration() { return new RetryingRegistrationTest.TestRetryingRegistration(rpcService, getTargetAddress(), getTargetLeaderId()); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RetryingRegistrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RetryingRegistrationTest.java index da992bb103b57..ac0dbc57c848a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RetryingRegistrationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/registration/RetryingRegistrationTest.java @@ -320,7 +320,7 @@ public String getCorrelationId() { } } - static class TestRetryingRegistration extends RetryingRegistration { + static class TestRetryingRegistration extends RetryingRegistration { // we use shorter timeouts here to speed up the tests static final long INITIAL_TIMEOUT = 20; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdServiceTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdServiceTest.java index 7b8703e768be6..fb5ee8bad7ea7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdServiceTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/JobLeaderIdServiceTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.util.TestLogger; import org.junit.Test; @@ -62,7 +63,7 @@ public class JobLeaderIdServiceTest extends TestLogger { public void testAddingJob() throws Exception { final JobID jobId = new JobID(); final String address = "foobar"; - final UUID leaderId = UUID.randomUUID(); + final JobMasterId leaderId = JobMasterId.generate(); TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService( null, @@ -83,10 +84,10 @@ public void testAddingJob() throws Exception { jobLeaderIdService.addJob(jobId); - CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); + CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); // notify the leader id service about the new leader - leaderRetrievalService.notifyListener(address, leaderId); + leaderRetrievalService.notifyListener(address, leaderId.toUUID()); assertEquals(leaderId, leaderIdFuture.get()); @@ -117,7 +118,7 @@ public void testRemovingJob() throws Exception { jobLeaderIdService.addJob(jobId); - CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); + CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); // remove the job before we could find a leader jobLeaderIdService.removeJob(jobId); @@ -183,7 +184,7 @@ public void testInitialJobTimeout() throws Exception { public void jobTimeoutAfterLostLeadership() throws Exception { final JobID jobId = new JobID(); final String address = "foobar"; - final UUID leaderId = UUID.randomUUID(); + final JobMasterId leaderId = JobMasterId.generate(); TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService( null, @@ -228,10 +229,10 @@ public Object answer(InvocationOnMock invocation) throws Throwable { jobLeaderIdService.addJob(jobId); - CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); + CompletableFuture leaderIdFuture = jobLeaderIdService.getLeaderId(jobId); // notify the leader id service about the new leader - leaderRetrievalService.notifyListener(address, leaderId); + leaderRetrievalService.notifyListener(address, leaderId.toUUID()); assertEquals(leaderId, leaderIdFuture.get()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerHATest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerHATest.java index 986f8480aadda..2b8792b2eba6f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerHATest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerHATest.java @@ -27,7 +27,7 @@ import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManagerConfiguration; import org.apache.flink.runtime.rpc.RpcService; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingFatalErrorHandler; import org.apache.flink.util.TestLogger; @@ -35,6 +35,7 @@ import org.junit.Test; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import static org.mockito.Mockito.mock; @@ -46,9 +47,17 @@ public class ResourceManagerHATest extends TestLogger { @Test public void testGrantAndRevokeLeadership() throws Exception { ResourceID rmResourceId = ResourceID.generate(); - RpcService rpcService = new TestingSerialRpcService(); + RpcService rpcService = new TestingRpcService(); + + CompletableFuture leaderSessionIdFuture = new CompletableFuture<>(); + + TestingLeaderElectionService leaderElectionService = new TestingLeaderElectionService() { + @Override + public void confirmLeaderSessionID(UUID leaderId) { + leaderSessionIdFuture.complete(leaderId); + } + }; - TestingLeaderElectionService leaderElectionService = new TestingLeaderElectionService(); TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); highAvailabilityServices.setResourceManagerLeaderElectionService(leaderElectionService); @@ -73,6 +82,8 @@ public void testGrantAndRevokeLeadership() throws Exception { TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + CompletableFuture revokedLeaderIdFuture = new CompletableFuture<>(); + final ResourceManager resourceManager = new StandaloneResourceManager( rpcService, @@ -84,20 +95,33 @@ public void testGrantAndRevokeLeadership() throws Exception { resourceManagerRuntimeServices.getSlotManager(), metricRegistry, resourceManagerRuntimeServices.getJobLeaderIdService(), - testingFatalErrorHandler); - resourceManager.start(); - // before grant leadership, resourceManager's leaderId is null - Assert.assertEquals(null, resourceManager.getLeaderSessionId()); - final UUID leaderId = UUID.randomUUID(); - leaderElectionService.isLeader(leaderId); - // after grant leadership, resourceManager's leaderId has value - Assert.assertEquals(leaderId, resourceManager.getLeaderSessionId()); - // then revoke leadership, resourceManager's leaderId is null again - leaderElectionService.notLeader(); - Assert.assertEquals(null, resourceManager.getLeaderSessionId()); - - if (testingFatalErrorHandler.hasExceptionOccurred()) { - testingFatalErrorHandler.rethrowError(); + testingFatalErrorHandler) { + + @Override + public void revokeLeadership() { + super.revokeLeadership(); + runAsyncWithoutFencing( + () -> revokedLeaderIdFuture.complete(getFencingToken())); + } + }; + + try { + resourceManager.start(); + + Assert.assertNotNull(resourceManager.getFencingToken()); + final UUID leaderId = UUID.randomUUID(); + leaderElectionService.isLeader(leaderId); + // after grant leadership, resourceManager's leaderId has value + Assert.assertEquals(leaderId, leaderSessionIdFuture.get()); + // then revoke leadership, resourceManager's leaderId should be different + leaderElectionService.notLeader(); + Assert.assertNotEquals(leaderId, revokedLeaderIdFuture.get()); + + if (testingFatalErrorHandler.hasExceptionOccurred()) { + testingFatalErrorHandler.rethrowError(); + } + } finally { + rpcService.stopService(); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerJobMasterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerJobMasterTest.java index 10d6a72721a86..156bc73ca3972 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerJobMasterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerJobMasterTest.java @@ -26,37 +26,44 @@ import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.JobMasterRegistrationSuccess; +import org.apache.flink.runtime.leaderelection.LeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.metrics.MetricRegistry; +import org.apache.flink.runtime.resourcemanager.exceptions.ResourceManagerException; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.rpc.FatalErrorHandler; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.registration.RegistrationResponse; +import org.apache.flink.runtime.rpc.exceptions.FencingTokenMismatchException; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingFatalErrorHandler; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.TestLogger; import org.junit.After; import org.junit.Before; import org.junit.Test; -import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.*; public class ResourceManagerJobMasterTest extends TestLogger { - private TestingSerialRpcService rpcService; + private TestingRpcService rpcService; - private final Time timeout = Time.milliseconds(0L); + private final Time timeout = Time.seconds(10L); @Before public void setup() throws Exception { - rpcService = new TestingSerialRpcService(); + rpcService = new TestingRpcService(); } @After @@ -71,23 +78,21 @@ public void teardown() throws Exception { public void testRegisterJobMaster() throws Exception { String jobMasterAddress = "/jobMasterAddress1"; JobID jobID = mockJobMaster(jobMasterAddress); - TestingLeaderElectionService resourceManagerLeaderElectionService = new TestingLeaderElectionService(); - UUID jmLeaderID = UUID.randomUUID(); + JobMasterId jobMasterId = JobMasterId.generate(); final ResourceID jmResourceId = new ResourceID(jobMasterAddress); - TestingLeaderRetrievalService jobMasterLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jmLeaderID); + TestingLeaderRetrievalService jobMasterLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jobMasterId.toUUID()); TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - final ResourceManager resourceManager = createAndStartResourceManager(resourceManagerLeaderElectionService, jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); - final UUID rmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); + final ResourceManager resourceManager = createAndStartResourceManager(mock(LeaderElectionService.class), jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); // test response successful - CompletableFuture successfulFuture = resourceManager.registerJobManager( - rmLeaderSessionId, - jmLeaderID, + CompletableFuture successfulFuture = rmGateway.registerJobManager( + jobMasterId, jmResourceId, jobMasterAddress, jobID, timeout); - RegistrationResponse response = successfulFuture.get(5L, TimeUnit.SECONDS); + RegistrationResponse response = successfulFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); assertTrue(response instanceof JobMasterRegistrationSuccess); if (testingFatalErrorHandler.hasExceptionOccurred()) { @@ -102,24 +107,28 @@ public void testRegisterJobMaster() throws Exception { public void testRegisterJobMasterWithUnmatchedLeaderSessionId1() throws Exception { String jobMasterAddress = "/jobMasterAddress1"; JobID jobID = mockJobMaster(jobMasterAddress); - TestingLeaderElectionService resourceManagerLeaderElectionService = new TestingLeaderElectionService(); - UUID jmLeaderID = UUID.randomUUID(); + JobMasterId jobMasterId = JobMasterId.generate(); final ResourceID jmResourceId = new ResourceID(jobMasterAddress); - TestingLeaderRetrievalService jobMasterLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jmLeaderID); + TestingLeaderRetrievalService jobMasterLeaderRetrievalService = new TestingLeaderRetrievalService(jobMasterAddress, jobMasterId.toUUID()); TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - final ResourceManager resourceManager = createAndStartResourceManager(resourceManagerLeaderElectionService, jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); - final UUID rmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); + final ResourceManager resourceManager = createAndStartResourceManager(mock(LeaderElectionService.class), jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); + final ResourceManagerGateway wronglyFencedGateway = rpcService.connect(resourceManager.getAddress(), ResourceManagerId.generate(), ResourceManagerGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); // test throw exception when receive a registration from job master which takes unmatched leaderSessionId - UUID differentLeaderSessionID = UUID.randomUUID(); - CompletableFuture unMatchedLeaderFuture = resourceManager.registerJobManager( - differentLeaderSessionID, - jmLeaderID, + CompletableFuture unMatchedLeaderFuture = wronglyFencedGateway.registerJobManager( + jobMasterId, jmResourceId, jobMasterAddress, jobID, timeout); - assertTrue(unMatchedLeaderFuture.get(5, TimeUnit.SECONDS) instanceof RegistrationResponse.Decline); + + try { + unMatchedLeaderFuture.get(5L, TimeUnit.SECONDS); + fail("Should fail because we are using the wrong fencing token."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } if (testingFatalErrorHandler.hasExceptionOccurred()) { testingFatalErrorHandler.rethrowError(); @@ -139,15 +148,13 @@ public void testRegisterJobMasterWithUnmatchedLeaderSessionId2() throws Exceptio HighAvailabilityServices.DEFAULT_LEADER_ID); TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); final ResourceManager resourceManager = createAndStartResourceManager(resourceManagerLeaderElectionService, jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); - final UUID rmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); - final UUID jmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); final ResourceID jmResourceId = new ResourceID(jobMasterAddress); // test throw exception when receive a registration from job master which takes unmatched leaderSessionId - UUID differentLeaderSessionID = UUID.randomUUID(); - CompletableFuture unMatchedLeaderFuture = resourceManager.registerJobManager( - rmLeaderSessionId, - differentLeaderSessionID, + JobMasterId differentJobMasterId = JobMasterId.generate(); + CompletableFuture unMatchedLeaderFuture = rmGateway.registerJobManager( + differentJobMasterId, jmResourceId, jobMasterAddress, jobID, @@ -172,15 +179,13 @@ public void testRegisterJobMasterFromInvalidAddress() throws Exception { HighAvailabilityServices.DEFAULT_LEADER_ID); TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); final ResourceManager resourceManager = createAndStartResourceManager(resourceManagerLeaderElectionService, jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); - final UUID rmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); - final UUID jmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); final ResourceID jmResourceId = new ResourceID(jobMasterAddress); // test throw exception when receive a registration from job master which takes invalid address String invalidAddress = "/jobMasterAddress2"; - CompletableFuture invalidAddressFuture = resourceManager.registerJobManager( - rmLeaderSessionId, - jmLeaderSessionId, + CompletableFuture invalidAddressFuture = rmGateway.registerJobManager( + new JobMasterId(HighAvailabilityServices.DEFAULT_LEADER_ID), jmResourceId, invalidAddress, jobID, @@ -204,25 +209,28 @@ public void testRegisterJobMasterWithFailureLeaderListener() throws Exception { "localhost", HighAvailabilityServices.DEFAULT_LEADER_ID); TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - final ResourceManager resourceManager = createAndStartResourceManager(resourceManagerLeaderElectionService, jobID, jobMasterLeaderRetrievalService, testingFatalErrorHandler); - final UUID rmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); - final UUID jmLeaderSessionId = grantResourceManagerLeadership(resourceManagerLeaderElectionService); + final ResourceManager resourceManager = createAndStartResourceManager( + resourceManagerLeaderElectionService, + jobID, + jobMasterLeaderRetrievalService, + testingFatalErrorHandler); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); final ResourceID jmResourceId = new ResourceID(jobMasterAddress); JobID unknownJobIDToHAServices = new JobID(); - // verify return RegistrationResponse.Decline when failed to start a job master Leader retrieval listener - CompletableFuture declineFuture = resourceManager.registerJobManager( - rmLeaderSessionId, - jmLeaderSessionId, + + // this should fail because we try to register a job leader listener for an unknown job id + CompletableFuture registrationFuture = rmGateway.registerJobManager( + new JobMasterId(HighAvailabilityServices.DEFAULT_LEADER_ID), jmResourceId, jobMasterAddress, unknownJobIDToHAServices, timeout); - RegistrationResponse response = declineFuture.get(5, TimeUnit.SECONDS); - assertTrue(response instanceof RegistrationResponse.Decline); - if (testingFatalErrorHandler.hasExceptionOccurred()) { - testingFatalErrorHandler.rethrowError(); + try { + registrationFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof ResourceManagerException); } } @@ -234,9 +242,9 @@ private JobID mockJobMaster(String jobMasterAddress) { } private ResourceManager createAndStartResourceManager( - TestingLeaderElectionService resourceManagerLeaderElectionService, + LeaderElectionService resourceManagerLeaderElectionService, JobID jobID, - TestingLeaderRetrievalService jobMasterLeaderRetrievalService, + LeaderRetrievalService jobMasterLeaderRetrievalService, FatalErrorHandler fatalErrorHandler) throws Exception { ResourceID rmResourceId = ResourceID.generate(); TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); @@ -274,11 +282,4 @@ private ResourceManager createAndStartResourceManager( resourceManager.start(); return resourceManager; } - - private UUID grantResourceManagerLeadership(TestingLeaderElectionService resourceManagerLeaderElectionService) { - UUID leaderSessionId = UUID.randomUUID(); - resourceManagerLeaderElectionService.isLeader(leaderSessionId); - return leaderSessionId; - } - } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerTaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerTaskExecutorTest.java index fc96f0d27a651..8add1685dbea7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerTaskExecutorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/ResourceManagerTaskExecutorTest.java @@ -23,17 +23,20 @@ import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; +import org.apache.flink.runtime.leaderelection.LeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.rpc.FatalErrorHandler; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.registration.RegistrationResponse; +import org.apache.flink.runtime.rpc.exceptions.FencingTokenMismatchException; import org.apache.flink.runtime.taskexecutor.SlotReport; import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway; import org.apache.flink.runtime.taskexecutor.TaskExecutorRegistrationSuccess; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingFatalErrorHandler; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.TestLogger; import org.junit.After; import org.junit.Before; @@ -41,15 +44,19 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class ResourceManagerTaskExecutorTest extends TestLogger { - private TestingSerialRpcService rpcService; + private final Time timeout = Time.seconds(10L); + + private TestingRpcService rpcService; private SlotReport slotReport = new SlotReport(); @@ -61,20 +68,26 @@ public class ResourceManagerTaskExecutorTest extends TestLogger { private StandaloneResourceManager resourceManager; - private UUID leaderSessionId; + private ResourceManagerGateway rmGateway; + + private ResourceManagerGateway wronglyFencedGateway; private TestingFatalErrorHandler testingFatalErrorHandler; @Before public void setup() throws Exception { - rpcService = new TestingSerialRpcService(); + rpcService = new TestingRpcService(); taskExecutorResourceID = mockTaskExecutor(taskExecutorAddress); resourceManagerResourceID = ResourceID.generate(); - TestingLeaderElectionService rmLeaderElectionService = new TestingLeaderElectionService(); testingFatalErrorHandler = new TestingFatalErrorHandler(); + TestingLeaderElectionService rmLeaderElectionService = new TestingLeaderElectionService(); resourceManager = createAndStartResourceManager(rmLeaderElectionService, testingFatalErrorHandler); - leaderSessionId = grantLeadership(rmLeaderElectionService); + rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); + wronglyFencedGateway = rpcService.connect(resourceManager.getAddress(), ResourceManagerId.generate(), ResourceManagerGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + grantLeadership(rmLeaderElectionService).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } @After @@ -90,13 +103,13 @@ public void testRegisterTaskExecutor() throws Exception { try { // test response successful CompletableFuture successfulFuture = - resourceManager.registerTaskExecutor(leaderSessionId, taskExecutorAddress, taskExecutorResourceID, slotReport, Time.milliseconds(0L)); - RegistrationResponse response = successfulFuture.get(5, TimeUnit.SECONDS); + rmGateway.registerTaskExecutor(taskExecutorAddress, taskExecutorResourceID, slotReport, timeout); + RegistrationResponse response = successfulFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); assertTrue(response instanceof TaskExecutorRegistrationSuccess); // test response successful with instanceID not equal to previous when receive duplicate registration from taskExecutor CompletableFuture duplicateFuture = - resourceManager.registerTaskExecutor(leaderSessionId, taskExecutorAddress, taskExecutorResourceID, slotReport, Time.milliseconds(0L)); + rmGateway.registerTaskExecutor(taskExecutorAddress, taskExecutorResourceID, slotReport, timeout); RegistrationResponse duplicateResponse = duplicateFuture.get(); assertTrue(duplicateResponse instanceof TaskExecutorRegistrationSuccess); assertNotEquals(((TaskExecutorRegistrationSuccess) response).getRegistrationId(), ((TaskExecutorRegistrationSuccess) duplicateResponse).getRegistrationId()); @@ -114,10 +127,15 @@ public void testRegisterTaskExecutor() throws Exception { public void testRegisterTaskExecutorWithUnmatchedLeaderSessionId() throws Exception { try { // test throw exception when receive a registration from taskExecutor which takes unmatched leaderSessionId - UUID differentLeaderSessionID = UUID.randomUUID(); CompletableFuture unMatchedLeaderFuture = - resourceManager.registerTaskExecutor(differentLeaderSessionID, taskExecutorAddress, taskExecutorResourceID, slotReport, Time.milliseconds(0L)); - assertTrue(unMatchedLeaderFuture.get(5, TimeUnit.SECONDS) instanceof RegistrationResponse.Decline); + wronglyFencedGateway.registerTaskExecutor(taskExecutorAddress, taskExecutorResourceID, slotReport, timeout); + + try { + unMatchedLeaderFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + fail("Should have failed because we are using a wrongly fenced ResourceManagerGateway."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } } finally { if (testingFatalErrorHandler.hasExceptionOccurred()) { testingFatalErrorHandler.rethrowError(); @@ -134,8 +152,8 @@ public void testRegisterTaskExecutorFromInvalidAddress() throws Exception { // test throw exception when receive a registration from taskExecutor which takes invalid address String invalidAddress = "/taskExecutor2"; CompletableFuture invalidAddressFuture = - resourceManager.registerTaskExecutor(leaderSessionId, invalidAddress, taskExecutorResourceID, slotReport, Time.milliseconds(0L)); - assertTrue(invalidAddressFuture.get(5, TimeUnit.SECONDS) instanceof RegistrationResponse.Decline); + rmGateway.registerTaskExecutor(invalidAddress, taskExecutorResourceID, slotReport, timeout); + assertTrue(invalidAddressFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS) instanceof RegistrationResponse.Decline); } finally { if (testingFatalErrorHandler.hasExceptionOccurred()) { testingFatalErrorHandler.rethrowError(); @@ -150,7 +168,7 @@ private ResourceID mockTaskExecutor(String taskExecutorAddress) { return taskExecutorResourceID; } - private StandaloneResourceManager createAndStartResourceManager(TestingLeaderElectionService rmLeaderElectionService, FatalErrorHandler fatalErrorHandler) throws Exception { + private StandaloneResourceManager createAndStartResourceManager(LeaderElectionService rmLeaderElectionService, FatalErrorHandler fatalErrorHandler) throws Exception { TestingHighAvailabilityServices highAvailabilityServices = new TestingHighAvailabilityServices(); HeartbeatServices heartbeatServices = new HeartbeatServices(5L, 5L); highAvailabilityServices.setResourceManagerLeaderElectionService(rmLeaderElectionService); @@ -182,14 +200,15 @@ private StandaloneResourceManager createAndStartResourceManager(TestingLeaderEle metricRegistry, jobLeaderIdService, fatalErrorHandler); + resourceManager.start(); + return resourceManager; } - private UUID grantLeadership(TestingLeaderElectionService leaderElectionService) { + private CompletableFuture grantLeadership(TestingLeaderElectionService leaderElectionService) { UUID leaderSessionId = UUID.randomUUID(); - leaderElectionService.isLeader(leaderSessionId); - return leaderSessionId; + return leaderElectionService.isLeader(leaderSessionId); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManagerTest.java index 93e96a7b0f026..80b445fbb2b9e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotManagerTest.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.SlotRequest; import org.apache.flink.runtime.resourcemanager.exceptions.ResourceManagerException; import org.apache.flink.runtime.resourcemanager.registration.TaskExecutorConnection; @@ -43,7 +44,6 @@ import org.mockito.ArgumentCaptor; import java.util.Arrays; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -74,7 +74,7 @@ public class SlotManagerTest extends TestLogger { */ @Test public void testTaskManagerRegistration() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); @@ -88,7 +88,7 @@ public void testTaskManagerRegistration() throws Exception { final SlotStatus slotStatus2 = new SlotStatus(slotId2, resourceProfile); final SlotReport slotReport = new SlotReport(Arrays.asList(slotStatus1, slotStatus2)); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); assertTrue("The number registered slots does not equal the expected number.",2 == slotManager.getNumberRegisteredSlots()); @@ -103,7 +103,7 @@ public void testTaskManagerRegistration() throws Exception { */ @Test public void testTaskManagerUnregistration() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final JobID jobId = new JobID(); @@ -113,7 +113,7 @@ public void testTaskManagerUnregistration() throws Exception { any(JobID.class), any(AllocationID.class), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(new CompletableFuture<>()); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -134,7 +134,7 @@ public void testTaskManagerUnregistration() throws Exception { resourceProfile, "foobar"); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); assertTrue("The number registered slots does not equal the expected number.",2 == slotManager.getNumberRegisteredSlots()); @@ -166,7 +166,7 @@ public void testTaskManagerUnregistration() throws Exception { */ @Test public void testSlotRequestWithoutFreeSlots() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceProfile resourceProfile = new ResourceProfile(42.0, 1337); final SlotRequest slotRequest = new SlotRequest( new JobID(), @@ -176,7 +176,7 @@ public void testSlotRequestWithoutFreeSlots() throws Exception { ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerSlotRequest(slotRequest); @@ -189,7 +189,7 @@ public void testSlotRequestWithoutFreeSlots() throws Exception { */ @Test public void testSlotRequestWithResourceAllocationFailure() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceProfile resourceProfile = new ResourceProfile(42.0, 1337); final SlotRequest slotRequest = new SlotRequest( new JobID(), @@ -200,7 +200,7 @@ public void testSlotRequestWithResourceAllocationFailure() throws Exception { ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); doThrow(new ResourceManagerException("Test exception")).when(resourceManagerActions).allocateResource(any(ResourceProfile.class)); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerSlotRequest(slotRequest); @@ -216,7 +216,7 @@ public void testSlotRequestWithResourceAllocationFailure() throws Exception { */ @Test public void testSlotRequestWithFreeSlot() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceID resourceID = ResourceID.generate(); final JobID jobId = new JobID(); final SlotID slotId = new SlotID(resourceID, 0); @@ -231,7 +231,7 @@ public void testSlotRequestWithFreeSlot() throws Exception { ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { // accept an incoming slot request final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); @@ -240,7 +240,7 @@ public void testSlotRequestWithFreeSlot() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskExecutorConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -254,7 +254,7 @@ public void testSlotRequestWithFreeSlot() throws Exception { assertTrue("The slot request should be accepted", slotManager.registerSlotRequest(slotRequest)); - verify(taskExecutorGateway).requestSlot(eq(slotId), eq(jobId), eq(allocationId), eq(targetAddress), eq(leaderId), any(Time.class)); + verify(taskExecutorGateway).requestSlot(eq(slotId), eq(jobId), eq(allocationId), eq(targetAddress), eq(resourceManagerId), any(Time.class)); TaskManagerSlot slot = slotManager.getSlot(slotId); @@ -268,7 +268,7 @@ public void testSlotRequestWithFreeSlot() throws Exception { */ @Test public void testUnregisterPendingSlotRequest() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final SlotID slotId = new SlotID(ResourceID.generate(), 0); final AllocationID allocationId = new AllocationID(); @@ -279,7 +279,7 @@ public void testUnregisterPendingSlotRequest() throws Exception { any(JobID.class), any(AllocationID.class), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(new CompletableFuture<>()); final ResourceProfile resourceProfile = new ResourceProfile(1.0, 1); @@ -290,7 +290,7 @@ public void testUnregisterPendingSlotRequest() throws Exception { final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); TaskManagerSlot slot = slotManager.getSlot(slotId); @@ -315,7 +315,7 @@ public void testUnregisterPendingSlotRequest() throws Exception { */ @Test public void testFulfillingPendingSlotRequest() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceID resourceID = ResourceID.generate(); final JobID jobId = new JobID(); final SlotID slotId = new SlotID(resourceID, 0); @@ -337,7 +337,7 @@ public void testFulfillingPendingSlotRequest() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskExecutorConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -345,7 +345,7 @@ public void testFulfillingPendingSlotRequest() throws Exception { final SlotStatus slotStatus = new SlotStatus(slotId, resourceProfile); final SlotReport slotReport = new SlotReport(slotStatus); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { assertTrue("The slot request should be accepted", slotManager.registerSlotRequest(slotRequest)); @@ -355,7 +355,7 @@ public void testFulfillingPendingSlotRequest() throws Exception { taskExecutorConnection, slotReport); - verify(taskExecutorGateway).requestSlot(eq(slotId), eq(jobId), eq(allocationId), eq(targetAddress), eq(leaderId), any(Time.class)); + verify(taskExecutorGateway).requestSlot(eq(slotId), eq(jobId), eq(allocationId), eq(targetAddress), eq(resourceManagerId), any(Time.class)); TaskManagerSlot slot = slotManager.getSlot(slotId); @@ -368,7 +368,7 @@ public void testFulfillingPendingSlotRequest() throws Exception { */ @Test public void testFreeSlot() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceID resourceID = ResourceID.generate(); final JobID jobId = new JobID(); final SlotID slotId = new SlotID(resourceID, 0); @@ -385,7 +385,7 @@ public void testFreeSlot() throws Exception { final SlotStatus slotStatus = new SlotStatus(slotId, resourceProfile, jobId, allocationId); final SlotReport slotReport = new SlotReport(slotStatus); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager( taskExecutorConnection, @@ -414,8 +414,7 @@ public void testFreeSlot() throws Exception { */ @Test public void testDuplicatePendingSlotRequest() throws Exception { - - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final AllocationID allocationId = new AllocationID(); final ResourceProfile resourceProfile1 = new ResourceProfile(1.0, 2); @@ -423,7 +422,7 @@ public void testDuplicatePendingSlotRequest() throws Exception { final SlotRequest slotRequest1 = new SlotRequest(new JobID(), allocationId, resourceProfile1, "foobar"); final SlotRequest slotRequest2 = new SlotRequest(new JobID(), allocationId, resourceProfile2, "barfoo"); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { assertTrue(slotManager.registerSlotRequest(slotRequest1)); assertFalse(slotManager.registerSlotRequest(slotRequest2)); } @@ -439,7 +438,7 @@ public void testDuplicatePendingSlotRequest() throws Exception { */ @Test public void testDuplicatePendingSlotRequestAfterSlotReport() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final JobID jobId = new JobID(); final AllocationID allocationId = new AllocationID(); @@ -454,7 +453,7 @@ public void testDuplicatePendingSlotRequestAfterSlotReport() throws Exception { final SlotRequest slotRequest = new SlotRequest(jobId, allocationId, resourceProfile, "foobar"); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); assertFalse(slotManager.registerSlotRequest(slotRequest)); @@ -467,7 +466,7 @@ public void testDuplicatePendingSlotRequestAfterSlotReport() throws Exception { */ @Test public void testDuplicatePendingSlotRequestAfterSuccessfulAllocation() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final AllocationID allocationId = new AllocationID(); final ResourceProfile resourceProfile1 = new ResourceProfile(1.0, 2); @@ -481,7 +480,7 @@ public void testDuplicatePendingSlotRequestAfterSuccessfulAllocation() throws Ex any(JobID.class), any(AllocationID.class), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -490,7 +489,7 @@ public void testDuplicatePendingSlotRequestAfterSuccessfulAllocation() throws Ex final SlotStatus slotStatus = new SlotStatus(slotId, resourceProfile1); final SlotReport slotReport = new SlotReport(slotStatus); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); assertTrue(slotManager.registerSlotRequest(slotRequest1)); @@ -512,7 +511,7 @@ public void testDuplicatePendingSlotRequestAfterSuccessfulAllocation() throws Ex */ @Test public void testAcceptingDuplicateSlotRequestAfterAllocationRelease() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final AllocationID allocationId = new AllocationID(); final ResourceProfile resourceProfile1 = new ResourceProfile(1.0, 2); @@ -526,7 +525,7 @@ public void testAcceptingDuplicateSlotRequestAfterAllocationRelease() throws Exc any(JobID.class), any(AllocationID.class), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -535,7 +534,7 @@ public void testAcceptingDuplicateSlotRequestAfterAllocationRelease() throws Exc final SlotStatus slotStatus = new SlotStatus(slotId, new ResourceProfile(2.0, 2)); final SlotReport slotReport = new SlotReport(slotStatus); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); assertTrue(slotManager.registerSlotRequest(slotRequest1)); @@ -565,7 +564,7 @@ public void testAcceptingDuplicateSlotRequestAfterAllocationRelease() throws Exc */ @Test public void testReceivingUnknownSlotReport() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final InstanceID unknownInstanceID = new InstanceID(); @@ -574,7 +573,7 @@ public void testReceivingUnknownSlotReport() throws Exception { final SlotStatus unknownSlotStatus = new SlotStatus(unknownSlotId, unknownResourceProfile); final SlotReport unknownSlotReport = new SlotReport(unknownSlotStatus); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { // check that we don't have any slots registered assertTrue(0 == slotManager.getNumberRegisteredSlots()); @@ -591,7 +590,7 @@ public void testReceivingUnknownSlotReport() throws Exception { */ @Test public void testUpdateSlotReport() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final JobID jobId = new JobID(); @@ -614,7 +613,7 @@ public void testUpdateSlotReport() throws Exception { final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { // check that we don't have any slots registered assertTrue(0 == slotManager.getNumberRegisteredSlots()); @@ -651,7 +650,7 @@ public void testTaskManagerTimeout() throws Exception { final long tmTimeout = 500L; final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -669,7 +668,7 @@ public void testTaskManagerTimeout() throws Exception { TestingUtils.infiniteTime(), Time.milliseconds(tmTimeout))) { - slotManager.start(leaderId, mainThreadExecutor, resourceManagerActions); + slotManager.start(resourceManagerId, mainThreadExecutor, resourceManagerActions); mainThreadExecutor.execute(new Runnable() { @Override @@ -693,7 +692,7 @@ public void testSlotRequestTimeout() throws Exception { final long allocationTimeout = 50L; final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final JobID jobId = new JobID(); final AllocationID allocationId = new AllocationID(); @@ -708,7 +707,7 @@ public void testSlotRequestTimeout() throws Exception { Time.milliseconds(allocationTimeout), TestingUtils.infiniteTime())) { - slotManager.start(leaderId, mainThreadExecutor, resourceManagerActions); + slotManager.start(resourceManagerId, mainThreadExecutor, resourceManagerActions); final AtomicReference atomicException = new AtomicReference<>(null); @@ -740,7 +739,7 @@ public void run() { @Test @SuppressWarnings("unchecked") public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final JobID jobId = new JobID(); @@ -756,7 +755,7 @@ public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { any(JobID.class), eq(allocationId), anyString(), - any(UUID.class), + any(ResourceManagerId.class), any(Time.class))).thenReturn(slotRequestFuture1, slotRequestFuture2); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -768,7 +767,7 @@ public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { final SlotStatus slotStatus2 = new SlotStatus(slotId2, resourceProfile); final SlotReport slotReport = new SlotReport(Arrays.asList(slotStatus1, slotStatus2)); - try (SlotManager slotManager = createSlotManager(leaderId, resourceManagerActions)) { + try (SlotManager slotManager = createSlotManager(resourceManagerId, resourceManagerActions)) { slotManager.registerTaskManager(taskManagerConnection, slotReport); @@ -781,7 +780,7 @@ public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class)); TaskManagerSlot failedSlot = slotManager.getSlot(slotIdCaptor.getValue()); @@ -794,7 +793,7 @@ public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class)); // the second attempt succeeds @@ -819,7 +818,7 @@ public void testTaskManagerSlotRequestTimeoutHandling() throws Exception { @SuppressWarnings("unchecked") public void testSlotReportWhileActiveSlotRequest() throws Exception { final long verifyTimeout = 1000L; - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final JobID jobId = new JobID(); @@ -834,7 +833,7 @@ public void testSlotReportWhileActiveSlotRequest() throws Exception { any(JobID.class), eq(allocationId), anyString(), - any(UUID.class), + any(ResourceManagerId.class), any(Time.class))).thenReturn(slotRequestFuture1, CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -854,7 +853,7 @@ public void testSlotReportWhileActiveSlotRequest() throws Exception { TestingUtils.infiniteTime(), TestingUtils.infiniteTime())) { - slotManager.start(leaderId, mainThreadExecutor, resourceManagerActions); + slotManager.start(resourceManagerId, mainThreadExecutor, resourceManagerActions); CompletableFuture registrationFuture = CompletableFuture.supplyAsync( () -> { @@ -882,7 +881,7 @@ public void testSlotReportWhileActiveSlotRequest() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class)); final SlotID requestedSlotId = slotIdCaptor.getValue(); @@ -908,7 +907,7 @@ public void testSlotReportWhileActiveSlotRequest() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class)); final SlotID requestedSlotId2 = slotIdCaptor.getValue(); @@ -935,7 +934,7 @@ public void testTimeoutForUnusedTaskManager() throws Exception { final long taskManagerTimeout = 50L; final long verifyTimeout = taskManagerTimeout * 10L; - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final ResourceManagerActions resourceManagerActions = mock(ResourceManagerActions.class); final ScheduledExecutor scheduledExecutor = TestingUtils.defaultScheduledExecutor(); @@ -952,7 +951,7 @@ public void testTimeoutForUnusedTaskManager() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); final TaskExecutorConnection taskManagerConnection = new TaskExecutorConnection(taskExecutorGateway); @@ -971,7 +970,7 @@ public void testTimeoutForUnusedTaskManager() throws Exception { TestingUtils.infiniteTime(), Time.of(taskManagerTimeout, TimeUnit.MILLISECONDS))) { - slotManager.start(leaderId, mainThreadExecutor, resourceManagerActions); + slotManager.start(resourceManagerId, mainThreadExecutor, resourceManagerActions); CompletableFuture.supplyAsync( () -> { @@ -991,7 +990,7 @@ public void testTimeoutForUnusedTaskManager() throws Exception { eq(jobId), eq(allocationId), anyString(), - eq(leaderId), + eq(resourceManagerId), any(Time.class)); CompletableFuture idleFuture = CompletableFuture.supplyAsync( @@ -1023,14 +1022,14 @@ public void testTimeoutForUnusedTaskManager() throws Exception { } } - private SlotManager createSlotManager(UUID leaderId, ResourceManagerActions resourceManagerActions) { + private SlotManager createSlotManager(ResourceManagerId resourceManagerId, ResourceManagerActions resourceManagerActions) { SlotManager slotManager = new SlotManager( TestingUtils.defaultScheduledExecutor(), TestingUtils.infiniteTime(), TestingUtils.infiniteTime(), TestingUtils.infiniteTime()); - slotManager.start(leaderId, Executors.directExecutor(), resourceManagerActions); + slotManager.start(resourceManagerId, Executors.directExecutor(), resourceManagerActions); return slotManager; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotProtocolTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotProtocolTest.java index 844e1597f7c94..6de4d52af391a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotProtocolTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/resourcemanager/slotmanager/SlotProtocolTest.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.concurrent.ScheduledExecutor; import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.resourcemanager.SlotRequest; import org.apache.flink.runtime.resourcemanager.registration.TaskExecutorConnection; import org.apache.flink.runtime.taskexecutor.SlotReport; @@ -39,7 +40,6 @@ import org.mockito.Mockito; import java.util.Collections; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -77,7 +77,7 @@ public static void afterClass() { public void testSlotsUnavailableRequest() throws Exception { final JobID jobID = new JobID(); - final UUID rmLeaderID = UUID.randomUUID(); + final ResourceManagerId rmLeaderID = ResourceManagerId.generate(); try (SlotManager slotManager = new SlotManager( scheduledExecutor, @@ -103,7 +103,7 @@ public void testSlotsUnavailableRequest() throws Exception { TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); Mockito.when( taskExecutorGateway - .requestSlot(any(SlotID.class), any(JobID.class), any(AllocationID.class), any(String.class), any(UUID.class), any(Time.class))) + .requestSlot(any(SlotID.class), any(JobID.class), any(AllocationID.class), any(String.class), any(ResourceManagerId.class), any(Time.class))) .thenReturn(mock(CompletableFuture.class)); final ResourceID resourceID = ResourceID.generate(); @@ -118,7 +118,7 @@ public void testSlotsUnavailableRequest() throws Exception { // 4) Slot becomes available and TaskExecutor gets a SlotRequest verify(taskExecutorGateway, timeout(5000L)) - .requestSlot(eq(slotID), eq(jobID), eq(allocationID), any(String.class), any(UUID.class), any(Time.class)); + .requestSlot(eq(slotID), eq(jobID), eq(allocationID), any(String.class), any(ResourceManagerId.class), any(Time.class)); } } @@ -133,12 +133,12 @@ public void testSlotsUnavailableRequest() throws Exception { public void testSlotAvailableRequest() throws Exception { final JobID jobID = new JobID(); - final UUID rmLeaderID = UUID.randomUUID(); + final ResourceManagerId rmLeaderID = ResourceManagerId.generate(); TaskExecutorGateway taskExecutorGateway = mock(TaskExecutorGateway.class); Mockito.when( taskExecutorGateway - .requestSlot(any(SlotID.class), any(JobID.class), any(AllocationID.class), any(String.class), any(UUID.class), any(Time.class))) + .requestSlot(any(SlotID.class), any(JobID.class), any(AllocationID.class), any(String.class), any(ResourceManagerId.class), any(Time.class))) .thenReturn(mock(CompletableFuture.class)); try (SlotManager slotManager = new SlotManager( @@ -171,7 +171,7 @@ public void testSlotAvailableRequest() throws Exception { // a SlotRequest is routed to the TaskExecutor verify(taskExecutorGateway, timeout(5000)) - .requestSlot(eq(slotID), eq(jobID), eq(allocationID), any(String.class), any(UUID.class), any(Time.class)); + .requestSlot(eq(slotID), eq(jobID), eq(allocationID), any(String.class), any(ResourceManagerId.class), any(Time.class)); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java new file mode 100644 index 0000000000000..ab43f770ecd2c --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestEndpointITCase.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.rest.handler.AbstractRestHandler; +import org.apache.flink.runtime.rest.handler.HandlerRequest; +import org.apache.flink.runtime.rest.handler.RestHandlerException; +import org.apache.flink.runtime.rest.messages.MessageHeaders; +import org.apache.flink.runtime.rest.messages.MessageParameters; +import org.apache.flink.runtime.rest.messages.MessagePathParameter; +import org.apache.flink.runtime.rest.messages.MessageQueryParameter; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.util.ConfigurationException; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +/** + * IT cases for {@link RestClient} and {@link RestServerEndpoint}. + */ +public class RestEndpointITCase extends TestLogger { + + private static final JobID PATH_JOB_ID = new JobID(); + private static final JobID QUERY_JOB_ID = new JobID(); + private static final String JOB_ID_KEY = "jobid"; + private static final Time timeout = Time.seconds(10L); + + @Test + public void testEndpoints() throws ConfigurationException, IOException, InterruptedException, ExecutionException { + Configuration config = new Configuration(); + + RestServerEndpointConfiguration serverConfig = RestServerEndpointConfiguration.fromConfiguration(config); + RestClientConfiguration clientConfig = RestClientConfiguration.fromConfiguration(config); + + RestServerEndpoint serverEndpoint = new TestRestServerEndpoint(serverConfig); + RestClient clientEndpoint = new TestRestClient(clientConfig); + + try { + serverEndpoint.start(); + + TestParameters parameters = new TestParameters(); + parameters.jobIDPathParameter.resolve(PATH_JOB_ID); + parameters.jobIDQueryParameter.resolve(Collections.singletonList(QUERY_JOB_ID)); + + // send first request and wait until the handler blocks + CompletableFuture response1; + synchronized (TestHandler.LOCK) { + response1 = clientEndpoint.sendRequest( + serverConfig.getEndpointBindAddress(), + serverConfig.getEndpointBindPort(), + new TestHeaders(), + parameters, + new TestRequest(1)); + TestHandler.LOCK.wait(); + } + + // send second request and verify response + CompletableFuture response2 = clientEndpoint.sendRequest( + serverConfig.getEndpointBindAddress(), + serverConfig.getEndpointBindPort(), + new TestHeaders(), + parameters, + new TestRequest(2)); + Assert.assertEquals(2, response2.get().id); + + // wake up blocked handler + synchronized (TestHandler.LOCK) { + TestHandler.LOCK.notifyAll(); + } + // verify response to first request + Assert.assertEquals(1, response1.get().id); + } finally { + clientEndpoint.shutdown(timeout); + serverEndpoint.shutdown(timeout); + } + } + + private static class TestRestServerEndpoint extends RestServerEndpoint { + + TestRestServerEndpoint(RestServerEndpointConfiguration configuration) { + super(configuration); + } + + @Override + protected Collection> initializeHandlers() { + return Collections.singleton(new TestHandler()); + } + } + + private static class TestHandler extends AbstractRestHandler { + + public static final Object LOCK = new Object(); + + TestHandler() { + super(new TestHeaders()); + } + + @Override + protected CompletableFuture handleRequest(@Nonnull HandlerRequest request) throws RestHandlerException { + Assert.assertEquals(request.getPathParameter(JobIDPathParameter.class), PATH_JOB_ID); + Assert.assertEquals(request.getQueryParameter(JobIDQueryParameter.class).get(0), QUERY_JOB_ID); + + if (request.getRequestBody().id == 1) { + synchronized (LOCK) { + try { + LOCK.notifyAll(); + LOCK.wait(); + } catch (InterruptedException ignored) { + } + } + } + return CompletableFuture.completedFuture(new TestResponse(request.getRequestBody().id)); + } + } + + private static class TestRestClient extends RestClient { + + TestRestClient(RestClientConfiguration configuration) { + super(configuration, TestingUtils.defaultExecutor()); + } + } + + private static class TestRequest implements RequestBody { + public final int id; + + @JsonCreator + public TestRequest(@JsonProperty("id") int id) { + this.id = id; + } + } + + private static class TestResponse implements ResponseBody { + public final int id; + + @JsonCreator + public TestResponse(@JsonProperty("id") int id) { + this.id = id; + } + } + + private static class TestHeaders implements MessageHeaders { + + @Override + public HttpMethodWrapper getHttpMethod() { + return HttpMethodWrapper.POST; + } + + @Override + public String getTargetRestEndpointURL() { + return "/test/:jobid"; + } + + @Override + public Class getRequestClass() { + return TestRequest.class; + } + + @Override + public Class getResponseClass() { + return TestResponse.class; + } + + @Override + public HttpResponseStatus getResponseStatusCode() { + return HttpResponseStatus.OK; + } + + @Override + public TestParameters getUnresolvedMessageParameters() { + return new TestParameters(); + } + } + + private static class TestParameters extends MessageParameters { + private final JobIDPathParameter jobIDPathParameter = new JobIDPathParameter(); + private final JobIDQueryParameter jobIDQueryParameter = new JobIDQueryParameter(); + + @Override + public Collection> getPathParameters() { + return Collections.singleton(jobIDPathParameter); + } + + @Override + public Collection> getQueryParameters() { + return Collections.singleton(jobIDQueryParameter); + } + } + + static class JobIDPathParameter extends MessagePathParameter { + JobIDPathParameter() { + super(JOB_ID_KEY); + } + + @Override + public JobID convertFromString(String value) { + return JobID.fromHexString(value); + } + + @Override + protected String convertToString(JobID value) { + return value.toString(); + } + } + + static class JobIDQueryParameter extends MessageQueryParameter { + JobIDQueryParameter() { + super(JOB_ID_KEY, MessageParameterRequisiteness.MANDATORY); + } + + @Override + public JobID convertValueFromString(String value) { + return JobID.fromHexString(value); + } + + @Override + public String convertStringToValue(JobID value) { + return value.toString(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/messages/MessageParametersTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/messages/MessageParametersTest.java new file mode 100644 index 0000000000000..7458821870c7f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/messages/MessageParametersTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rest.messages; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.util.TestLogger; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collection; +import java.util.Collections; + +/** + * Tests for {@link MessageParameters}. + */ +public class MessageParametersTest extends TestLogger { + @Test + public void testResolveUrl() { + String genericUrl = "/jobs/:jobid/state"; + TestMessageParameters parameters = new TestMessageParameters(); + JobID pathJobID = new JobID(); + JobID queryJobID = new JobID(); + parameters.pathParameter.resolve(pathJobID); + parameters.queryParameter.resolve(Collections.singletonList(queryJobID)); + + String resolvedUrl = MessageParameters.resolveUrl(genericUrl, parameters); + + Assert.assertEquals("/jobs/" + pathJobID + "/state?jobid=" + queryJobID, resolvedUrl); + } + + private static class TestMessageParameters extends MessageParameters { + private final TestPathParameter pathParameter = new TestPathParameter(); + private final TestQueryParameter queryParameter = new TestQueryParameter(); + + @Override + public Collection> getPathParameters() { + return Collections.singleton(pathParameter); + } + + @Override + public Collection> getQueryParameters() { + return Collections.singleton(queryParameter); + } + } + + private static class TestPathParameter extends MessagePathParameter { + + TestPathParameter() { + super("jobid"); + } + + @Override + public JobID convertFromString(String value) { + return JobID.fromHexString(value); + } + + @Override + protected String convertToString(JobID value) { + return value.toString(); + } + } + + private static class TestQueryParameter extends MessageQueryParameter { + + TestQueryParameter() { + super("jobid", MessageParameterRequisiteness.MANDATORY); + } + + @Override + public JobID convertValueFromString(String value) { + return JobID.fromHexString(value); + } + + @Override + public String convertStringToValue(JobID value) { + return value.toString(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/AsyncCallsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/AsyncCallsTest.java index 00762b9f0719a..f8eca1692c97f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/AsyncCallsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/AsyncCallsTest.java @@ -23,16 +23,23 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.rpc.akka.AkkaRpcService; +import org.apache.flink.runtime.rpc.exceptions.FencingTokenMismatchException; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.TestLogger; import org.junit.AfterClass; import org.junit.Test; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; import static org.junit.Assert.*; @@ -44,6 +51,8 @@ public class AsyncCallsTest extends TestLogger { private static final ActorSystem actorSystem = AkkaUtils.createDefaultActorSystem(); + private static final Time timeout = Time.seconds(10L); + private static final AkkaRpcService akkaRpcService = new AkkaRpcService(actorSystem, Time.milliseconds(10000L)); @@ -162,6 +171,119 @@ public void run() { assertTrue("call was not properly delayed", ((stop - start) / 1_000_000) >= delay); } + /** + * Tests that async code is not executed if the fencing token changes. + */ + @Test + public void testRunAsyncWithFencing() throws Exception { + final Time shortTimeout = Time.milliseconds(100L); + final UUID newFencingToken = UUID.randomUUID(); + final CompletableFuture resultFuture = new CompletableFuture<>(); + + testRunAsync( + endpoint -> { + endpoint.runAsync( + () -> resultFuture.complete(endpoint.getFencingToken())); + + return resultFuture; + }, + newFencingToken); + + try { + resultFuture.get(shortTimeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + fail("The async run operation should not complete since it is filtered out due to the changed fencing token."); + } catch (TimeoutException ignored) {} + } + + /** + * Tests that code can be executed in the main thread without respecting the fencing token. + */ + @Test + public void testRunAsyncWithoutFencing() throws Exception { + final CompletableFuture resultFuture = new CompletableFuture<>(); + final UUID newFencingToken = UUID.randomUUID(); + + testRunAsync( + endpoint -> { + endpoint.runAsyncWithoutFencing( + () -> resultFuture.complete(endpoint.getFencingToken())); + return resultFuture; + }, + newFencingToken); + + assertEquals(newFencingToken, resultFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + } + + /** + * Tests that async callables are not executed if the fencing token changes. + */ + @Test + public void testCallAsyncWithFencing() throws Exception { + final UUID newFencingToken = UUID.randomUUID(); + + CompletableFuture resultFuture = testRunAsync( + endpoint -> endpoint.callAsync(() -> true, timeout), + newFencingToken); + + try { + resultFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + fail("The async call operation should fail due to the changed fencing token."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } + } + + /** + * Tests that async callables can be executed in the main thread without checking the fencing token. + */ + @Test + public void testCallAsyncWithoutFencing() throws Exception { + final UUID newFencingToken = UUID.randomUUID(); + + CompletableFuture resultFuture = testRunAsync( + endpoint -> endpoint.callAsyncWithoutFencing(() -> true, timeout), + newFencingToken); + + assertTrue(resultFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + } + + private static CompletableFuture testRunAsync(Function> runAsyncCall, UUID newFencingToken) throws Exception { + final UUID initialFencingToken = UUID.randomUUID(); + final OneShotLatch enterSetNewFencingToken = new OneShotLatch(); + final OneShotLatch triggerSetNewFencingToken = new OneShotLatch(); + final FencedTestEndpoint fencedTestEndpoint = new FencedTestEndpoint( + akkaRpcService, + initialFencingToken, + enterSetNewFencingToken, + triggerSetNewFencingToken); + final FencedTestGateway fencedTestGateway = fencedTestEndpoint.getSelfGateway(FencedTestGateway.class); + + try { + fencedTestEndpoint.start(); + + CompletableFuture newFencingTokenFuture = fencedTestGateway.setNewFencingToken(newFencingToken, timeout); + + assertFalse(newFencingTokenFuture.isDone()); + + assertEquals(initialFencingToken, fencedTestEndpoint.getFencingToken()); + + CompletableFuture result = runAsyncCall.apply(fencedTestEndpoint); + + enterSetNewFencingToken.await(); + + triggerSetNewFencingToken.trigger(); + + newFencingTokenFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + return result; + } finally { + fencedTestEndpoint.shutDown(); + fencedTestEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + // ------------------------------------------------------------------------ // test RPC endpoint // ------------------------------------------------------------------------ @@ -209,4 +331,39 @@ public boolean hasConcurrentAccess() { return concurrentAccess; } } + + public interface FencedTestGateway extends FencedRpcGateway { + CompletableFuture setNewFencingToken(UUID fencingToken, @RpcTimeout Time timeout); + } + + public static class FencedTestEndpoint extends FencedRpcEndpoint implements FencedTestGateway { + + private final OneShotLatch enteringSetNewFencingToken; + private final OneShotLatch triggerSetNewFencingToken; + + protected FencedTestEndpoint( + RpcService rpcService, + UUID initialFencingToken, + OneShotLatch enteringSetNewFencingToken, + OneShotLatch triggerSetNewFencingToken) { + super(rpcService, initialFencingToken); + + this.enteringSetNewFencingToken = enteringSetNewFencingToken; + this.triggerSetNewFencingToken = triggerSetNewFencingToken; + } + + @Override + public CompletableFuture setNewFencingToken(UUID fencingToken, Time timeout) { + enteringSetNewFencingToken.trigger(); + try { + triggerSetNewFencingToken.await(); + } catch (InterruptedException e) { + throw new RuntimeException("TriggerSetNewFencingToken OneShotLatch was interrupted."); + } + + setFencingToken(fencingToken); + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/FencedRpcEndpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/FencedRpcEndpointTest.java new file mode 100644 index 0000000000000..62d5354fd51be --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/FencedRpcEndpointTest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.concurrent.FlinkFutureException; +import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.rpc.exceptions.FencingTokenMismatchException; +import org.apache.flink.runtime.rpc.exceptions.RpcException; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.TestLogger; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class FencedRpcEndpointTest extends TestLogger { + + private static final Time timeout = Time.seconds(10L); + private static RpcService rpcService; + + @BeforeClass + public static void setup() { + rpcService = new TestingRpcService(); + } + + @AfterClass + public static void teardown() throws ExecutionException, InterruptedException, TimeoutException { + if (rpcService != null) { + rpcService.stopService(); + rpcService.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + /** + * Tests that the fencing token can be retrieved from the FencedRpcEndpoint and self + * FencedRpcGateway. Moreover it tests that you can only set the fencing token from + * the main thread. + */ + @Test + public void testFencingTokenSetting() throws Exception { + final UUID initialFencingToken = UUID.randomUUID(); + final String value = "foobar"; + FencedTestingEndpoint fencedTestingEndpoint = new FencedTestingEndpoint(rpcService, initialFencingToken, value); + FencedTestingGateway fencedTestingGateway = fencedTestingEndpoint.getSelfGateway(FencedTestingGateway.class); + FencedTestingGateway fencedGateway = fencedTestingEndpoint.getSelfGateway(FencedTestingGateway.class); + + try { + fencedTestingEndpoint.start(); + + assertEquals(initialFencingToken, fencedGateway.getFencingToken()); + assertEquals(initialFencingToken, fencedTestingEndpoint.getFencingToken()); + + final UUID newFencingToken = UUID.randomUUID(); + + try { + fencedTestingEndpoint.setFencingToken(newFencingToken); + fail("Fencing token can only be set from within the main thread."); + } catch (AssertionError ignored) { + // expected to fail + } + + assertEquals(initialFencingToken, fencedTestingEndpoint.getFencingToken()); + + CompletableFuture setFencingFuture = fencedTestingGateway.rpcSetFencingToken(newFencingToken, timeout); + + // wait for the completion of the set fencing token operation + setFencingFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + // self gateway should adapt its fencing token + assertEquals(newFencingToken, fencedGateway.getFencingToken()); + assertEquals(newFencingToken, fencedTestingEndpoint.getFencingToken()); + } finally { + fencedTestingEndpoint.shutDown(); + fencedTestingEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + /** + * Tests that messages with the wrong fencing token are filtered out. + */ + @Test + public void testFencing() throws Exception { + final UUID initialFencingToken = UUID.randomUUID(); + final UUID wrongFencingToken = UUID.randomUUID(); + final String value = "barfoo"; + FencedTestingEndpoint fencedTestingEndpoint = new FencedTestingEndpoint(rpcService, initialFencingToken, value); + + try { + fencedTestingEndpoint.start(); + + final FencedTestingGateway properFencedGateway = rpcService.connect(fencedTestingEndpoint.getAddress(), initialFencingToken, FencedTestingGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + final FencedTestingGateway wronglyFencedGateway = rpcService.connect(fencedTestingEndpoint.getAddress(), wrongFencingToken, FencedTestingGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + assertEquals(value, properFencedGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + + try { + wronglyFencedGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + fail("This should fail since we have the wrong fencing token."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } + + final UUID newFencingToken = UUID.randomUUID(); + + CompletableFuture newFencingTokenFuture = properFencedGateway.rpcSetFencingToken(newFencingToken, timeout); + + // wait for the new fencing token to be set + newFencingTokenFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + // this should no longer work because of the new fencing token + try { + properFencedGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + fail("This should fail since we have the wrong fencing token by now."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } + + } finally { + fencedTestingEndpoint.shutDown(); + fencedTestingEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + /** + * Tests that the self gateway always uses the current fencing token whereas the remote + * gateway has a fixed fencing token. + */ + @Test + public void testRemoteAndSelfGateways() throws Exception { + final UUID initialFencingToken = UUID.randomUUID(); + final UUID newFencingToken = UUID.randomUUID(); + final String value = "foobar"; + + final FencedTestingEndpoint fencedTestingEndpoint = new FencedTestingEndpoint(rpcService, initialFencingToken, value); + + try { + fencedTestingEndpoint.start(); + + FencedTestingGateway selfGateway = fencedTestingEndpoint.getSelfGateway(FencedTestingGateway.class); + FencedTestingGateway remoteGateway = rpcService.connect(fencedTestingEndpoint.getAddress(), initialFencingToken, FencedTestingGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + assertEquals(initialFencingToken, selfGateway.getFencingToken()); + assertEquals(initialFencingToken, remoteGateway.getFencingToken()); + + assertEquals(value, selfGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + assertEquals(value, remoteGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + + CompletableFuture newFencingTokenFuture = selfGateway.rpcSetFencingToken(newFencingToken, timeout); + + // wait for the new fencing token to be set + newFencingTokenFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + assertEquals(newFencingToken, selfGateway.getFencingToken()); + assertNotEquals(newFencingToken, remoteGateway.getFencingToken()); + + assertEquals(value, selfGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS)); + + try { + remoteGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + fail("This should have failed because we don't have the right fencing token."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof FencingTokenMismatchException); + } + } finally { + fencedTestingEndpoint.shutDown(); + fencedTestingEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + /** + * Tests that call via the MainThreadExecutor fail after the fencing token changes. + */ + @Test + public void testMainThreadExecutorUnderChangingFencingToken() throws Exception { + final Time shortTimeout = Time.milliseconds(100L); + final UUID initialFencingToken = UUID.randomUUID(); + final String value = "foobar"; + final FencedTestingEndpoint fencedTestingEndpoint = new FencedTestingEndpoint(rpcService, initialFencingToken, value); + + try { + fencedTestingEndpoint.start(); + + FencedTestingGateway selfGateway = fencedTestingEndpoint.getSelfGateway(FencedTestingGateway.class); + + CompletableFuture mainThreadExecutorComputation = selfGateway.triggerMainThreadExecutorComputation(timeout); + + // we know that subsequent calls on the same gateway are executed sequentially + // therefore, we know that the change fencing token call is executed after the trigger MainThreadExecutor + // computation + final UUID newFencingToken = UUID.randomUUID(); + CompletableFuture newFencingTokenFuture = selfGateway.rpcSetFencingToken(newFencingToken, timeout); + + newFencingTokenFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + // trigger the computation + CompletableFuture triggerFuture = selfGateway.triggerComputationLatch(timeout); + + triggerFuture.get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + // wait for the main thread executor computation to fail + try { + mainThreadExecutorComputation.get(shortTimeout.toMilliseconds(), TimeUnit.MILLISECONDS); + fail("The MainThreadExecutor computation should be able to complete because it was filtered out leading to a timeout exception."); + } catch (TimeoutException ignored) { + // as predicted + } + + } finally { + fencedTestingEndpoint.shutDown(); + fencedTestingEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + /** + * Tests that all calls from an unfenced remote gateway are ignored and that one cannot obtain + * the fencing token from such a gateway. + */ + @Test + public void testUnfencedRemoteGateway() throws Exception { + final UUID initialFencingToken = UUID.randomUUID(); + final String value = "foobar"; + + final FencedTestingEndpoint fencedTestingEndpoint = new FencedTestingEndpoint(rpcService, initialFencingToken, value); + + try { + fencedTestingEndpoint.start(); + + FencedTestingGateway unfencedGateway = rpcService.connect(fencedTestingEndpoint.getAddress(), FencedTestingGateway.class) + .get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + + try { + unfencedGateway.foobar(timeout).get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + fail("This should have failed because we have an unfenced gateway."); + } catch (ExecutionException e) { + assertTrue(ExceptionUtils.stripExecutionException(e) instanceof RpcException); + } + + try { + unfencedGateway.getFencingToken(); + fail("We should not be able to call getFencingToken on an unfenced gateway."); + } catch (UnsupportedOperationException ignored) { + // we should not be able to call getFencingToken on an unfenced gateway + } + } finally { + fencedTestingEndpoint.shutDown(); + fencedTestingEndpoint.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } + } + + public interface FencedTestingGateway extends FencedRpcGateway { + CompletableFuture foobar(@RpcTimeout Time timeout); + + CompletableFuture rpcSetFencingToken(UUID fencingToken, @RpcTimeout Time timeout); + + CompletableFuture triggerMainThreadExecutorComputation(@RpcTimeout Time timeout); + + CompletableFuture triggerComputationLatch(@RpcTimeout Time timeout); + } + + private static class FencedTestingEndpoint extends FencedRpcEndpoint implements FencedTestingGateway { + + private final OneShotLatch computationLatch; + + private final String value; + + protected FencedTestingEndpoint(RpcService rpcService, UUID initialFencingToken, String value) { + super(rpcService, initialFencingToken); + + computationLatch = new OneShotLatch(); + + this.value = value; + } + + @Override + public CompletableFuture foobar(Time timeout) { + return CompletableFuture.completedFuture(value); + } + + @Override + public CompletableFuture rpcSetFencingToken(UUID fencingToken, Time timeout) { + setFencingToken(fencingToken); + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + + @Override + public CompletableFuture triggerMainThreadExecutorComputation(Time timeout) { + return CompletableFuture.supplyAsync( + () -> { + try { + computationLatch.await(); + } catch (InterruptedException e) { + throw new FlinkFutureException("Waiting on latch failed.", e); + } + + return value; + }, + getRpcService().getExecutor()) + .thenApplyAsync( + (String v) -> Acknowledge.get(), + getMainThreadExecutor()); + } + + @Override + public CompletableFuture triggerComputationLatch(Time timeout) { + computationLatch.trigger(); + + return CompletableFuture.completedFuture(Acknowledge.get()); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingRpcService.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingRpcService.java index 14cf35a4d2a4e..4b9f3977fa6c1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingRpcService.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingRpcService.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.rpc.akka.AkkaRpcService; +import java.io.Serializable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -104,7 +105,27 @@ public CompletableFuture connect(String address, Class return FutureUtils.completedExceptionally(new Exception("Gateway registered under " + address + " is not of type " + clazz)); } } else { - return FutureUtils.completedExceptionally(new Exception("No gateway registered under " + address + '.')); + return super.connect(address, clazz); + } + } + + @Override + public > CompletableFuture connect( + String address, + F fencingToken, + Class clazz) { + RpcGateway gateway = registeredConnections.get(address); + + if (gateway != null) { + if (clazz.isAssignableFrom(gateway.getClass())) { + @SuppressWarnings("unchecked") + C typedGateway = (C) gateway; + return CompletableFuture.completedFuture(typedGateway); + } else { + return FutureUtils.completedExceptionally(new Exception("Gateway registered under " + address + " is not of type " + clazz)); + } + } else { + return super.connect(address, fencingToken, clazz); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingSerialRpcService.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingSerialRpcService.java deleted file mode 100644 index cb38f6fee5612..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/TestingSerialRpcService.java +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.rpc; - -import org.apache.flink.api.common.time.Time; -import org.apache.flink.runtime.concurrent.FutureUtils; -import org.apache.flink.runtime.concurrent.ScheduledExecutor; -import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter; -import org.apache.flink.runtime.util.DirectExecutorService; -import org.apache.flink.util.Preconditions; - -import java.lang.annotation.Annotation; -import java.lang.reflect.InvocationHandler; -import java.lang.reflect.Method; -import java.lang.reflect.Proxy; -import java.util.List; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.Callable; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Delayed; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -import static org.apache.flink.util.Preconditions.checkNotNull; - -/** - * An RPC Service implementation for testing. This RPC service directly executes all asynchronous - * calls one by one in the calling thread. - */ -public class TestingSerialRpcService implements RpcService { - - private final DirectExecutorService executorService; - private final ScheduledExecutorService scheduledExecutorService; - private final ConcurrentHashMap registeredConnections; - private final CompletableFuture terminationFuture; - - private final ScheduledExecutor scheduledExecutorServiceAdapter; - - public TestingSerialRpcService() { - executorService = new DirectExecutorService(); - scheduledExecutorService = new ScheduledThreadPoolExecutor(1); - this.registeredConnections = new ConcurrentHashMap<>(16); - this.terminationFuture = new CompletableFuture<>(); - - this.scheduledExecutorServiceAdapter = new ScheduledExecutorServiceAdapter(scheduledExecutorService); - } - - @Override - public ScheduledFuture scheduleRunnable(final Runnable runnable, final long delay, final TimeUnit unit) { - try { - unit.sleep(delay); - runnable.run(); - - return new DoneScheduledFuture(null); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - - @Override - public void execute(Runnable runnable) { - runnable.run(); - } - - @Override - public CompletableFuture execute(Callable callable) { - try { - T result = callable.call(); - - return CompletableFuture.completedFuture(result); - } catch (Exception e) { - return FutureUtils.completedExceptionally(e); - } - } - - @Override - public Executor getExecutor() { - return executorService; - } - - public ScheduledExecutor getScheduledExecutor() { - return scheduledExecutorServiceAdapter; - } - - @Override - public void stopService() { - executorService.shutdown(); - - scheduledExecutorService.shutdown(); - - boolean terminated = false; - - try { - terminated = scheduledExecutorService.awaitTermination(1, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - - if (!terminated) { - List runnables = scheduledExecutorService.shutdownNow(); - - for (Runnable runnable : runnables) { - runnable.run(); - } - } - - registeredConnections.clear(); - terminationFuture.complete(null); - } - - @Override - public CompletableFuture getTerminationFuture() { - return terminationFuture; - } - - @Override - public void stopServer(RpcServer selfGateway) { - registeredConnections.remove(selfGateway.getAddress()); - } - - @Override - public RpcServer startServer(S rpcEndpoint) { - final String address = UUID.randomUUID().toString(); - - InvocationHandler akkaInvocationHandler = new TestingSerialRpcService.TestingSerialInvocationHandler<>(address, rpcEndpoint); - ClassLoader classLoader = getClass().getClassLoader(); - - Set> implementedRpcGateways = RpcUtils.extractImplementedRpcGateways(rpcEndpoint.getClass()); - - implementedRpcGateways.add(RpcServer.class); - - - @SuppressWarnings("unchecked") - RpcServer rpcServer = (RpcServer) Proxy.newProxyInstance( - classLoader, - implementedRpcGateways.toArray(new Class[implementedRpcGateways.size()]), - akkaInvocationHandler); - - // register self - registeredConnections.putIfAbsent(rpcServer.getAddress(), rpcServer); - - return rpcServer; - } - - @Override - public String getAddress() { - return ""; - } - - @Override - public int getPort() { - return -1; - } - - @Override - public CompletableFuture connect(String address, Class clazz) { - RpcGateway gateway = registeredConnections.get(address); - - if (gateway != null) { - if (clazz.isAssignableFrom(gateway.getClass())) { - @SuppressWarnings("unchecked") - C typedGateway = (C) gateway; - return CompletableFuture.completedFuture(typedGateway); - } else { - return FutureUtils.completedExceptionally(new Exception("Gateway registered under " + address + " is not of type " + clazz)); - } - } else { - return FutureUtils.completedExceptionally(new Exception("No gateway registered under " + address + '.')); - } - } - - // ------------------------------------------------------------------------ - // connections - // ------------------------------------------------------------------------ - - public void registerGateway(String address, RpcGateway gateway) { - checkNotNull(address); - checkNotNull(gateway); - - if (registeredConnections.putIfAbsent(address, gateway) != null) { - throw new IllegalStateException("a gateway is already registered under " + address); - } - } - - public void clearGateways() { - registeredConnections.clear(); - } - - private static final class TestingSerialInvocationHandler implements InvocationHandler, RpcGateway, MainThreadExecutable, StartStoppable { - - private final T rpcEndpoint; - - /** default timeout for asks */ - private final Time timeout; - - private final String address; - - private TestingSerialInvocationHandler(String address, T rpcEndpoint) { - this(address, rpcEndpoint, Time.seconds(10)); - } - - private TestingSerialInvocationHandler(String address, T rpcEndpoint, Time timeout) { - this.rpcEndpoint = rpcEndpoint; - this.timeout = timeout; - this.address = address; - } - - @Override - public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - Class declaringClass = method.getDeclaringClass(); - if (declaringClass.equals(MainThreadExecutable.class) || - declaringClass.equals(Object.class) || - declaringClass.equals(StartStoppable.class) || - declaringClass.equals(RpcServer.class) || - declaringClass.equals(RpcGateway.class)) { - return method.invoke(this, args); - } else { - final String methodName = method.getName(); - Class[] parameterTypes = method.getParameterTypes(); - Annotation[][] parameterAnnotations = method.getParameterAnnotations(); - Time futureTimeout = extractRpcTimeout(parameterAnnotations, args, timeout); - - Class returnType = method.getReturnType(); - - if (returnType.equals(CompletableFuture.class)) { - try { - Object result = handleRpcInvocationSync(methodName, parameterTypes, args, futureTimeout); - return CompletableFuture.completedFuture(result); - } catch (Throwable e) { - return FutureUtils.completedExceptionally(e); - } - } else { - return handleRpcInvocationSync(methodName, parameterTypes, args, futureTimeout); - } - } - } - - /** - * Handle rpc invocations by looking up the rpc method on the rpc endpoint and calling this - * method with the provided method arguments. If the method has a return value, it is returned - * to the sender of the call. - */ - private Object handleRpcInvocationSync(final String methodName, - final Class[] parameterTypes, - final Object[] args, - final Time futureTimeout) throws Exception { - final Method rpcMethod = lookupRpcMethod(methodName, parameterTypes); - Object result = rpcMethod.invoke(rpcEndpoint, args); - - if (result instanceof Future) { - Future future = (Future) result; - return future.get(futureTimeout.getSize(), futureTimeout.getUnit()); - } else { - return result; - } - } - - @Override - public void runAsync(Runnable runnable) { - runnable.run(); - } - - @Override - public CompletableFuture callAsync(Callable callable, Time callTimeout) { - try { - return CompletableFuture.completedFuture(callable.call()); - } catch (Throwable e) { - return FutureUtils.completedExceptionally(e); - } - } - - @Override - public void scheduleRunAsync(final Runnable runnable, final long delay) { - try { - TimeUnit.MILLISECONDS.sleep(delay); - runnable.run(); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - - @Override - public String getAddress() { - return address; - } - - // this is not a real hostname but the address above is also not a real akka RPC address - // and we keep it that way until actually needed by a test case - @Override - public String getHostname() { - return address; - } - - @Override - public void start() { - // do nothing - } - - @Override - public void stop() { - // do nothing - } - - /** - * Look up the rpc method on the given {@link RpcEndpoint} instance. - * - * @param methodName Name of the method - * @param parameterTypes Parameter types of the method - * @return Method of the rpc endpoint - * @throws NoSuchMethodException Thrown if the method with the given name and parameter types - * cannot be found at the rpc endpoint - */ - private Method lookupRpcMethod(final String methodName, - final Class[] parameterTypes) throws NoSuchMethodException { - return rpcEndpoint.getClass().getMethod(methodName, parameterTypes); - } - - // ------------------------------------------------------------------------ - // Helper methods - // ------------------------------------------------------------------------ - - /** - * Extracts the {@link RpcTimeout} annotated rpc timeout value from the list of given method - * arguments. If no {@link RpcTimeout} annotated parameter could be found, then the default - * timeout is returned. - * - * @param parameterAnnotations Parameter annotations - * @param args Array of arguments - * @param defaultTimeout Default timeout to return if no {@link RpcTimeout} annotated parameter - * has been found - * @return Timeout extracted from the array of arguments or the default timeout - */ - private static Time extractRpcTimeout(Annotation[][] parameterAnnotations, Object[] args, - Time defaultTimeout) { - if (args != null) { - Preconditions.checkArgument(parameterAnnotations.length == args.length); - - for (int i = 0; i < parameterAnnotations.length; i++) { - if (isRpcTimeout(parameterAnnotations[i])) { - if (args[i] instanceof Time) { - return (Time) args[i]; - } else { - throw new RuntimeException("The rpc timeout parameter must be of type " + - Time.class.getName() + ". The type " + args[i].getClass().getName() + - " is not supported."); - } - } - } - } - - return defaultTimeout; - } - - /** - * Checks whether any of the annotations is of type {@link RpcTimeout} - * - * @param annotations Array of annotations - * @return True if {@link RpcTimeout} was found; otherwise false - */ - private static boolean isRpcTimeout(Annotation[] annotations) { - for (Annotation annotation : annotations) { - if (annotation.annotationType().equals(RpcTimeout.class)) { - return true; - } - } - - return false; - } - - } - - private static class DoneScheduledFuture implements ScheduledFuture { - - private final V value; - - private DoneScheduledFuture(V value) { - this.value = value; - } - - @Override - public long getDelay(TimeUnit unit) { - return 0L; - } - - @Override - public int compareTo(Delayed o) { - return 0; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return true; - } - - @Override - public V get() throws InterruptedException, ExecutionException { - return value; - } - - @Override - public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return value; - } - } - -} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java index c1b3ccdfa34e1..9f6f88ec42eb8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java @@ -19,12 +19,15 @@ package org.apache.flink.runtime.state; import org.apache.flink.runtime.checkpoint.savepoint.CheckpointTestUtils; + import org.junit.Test; import java.util.Map; import java.util.Random; import java.util.UUID; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.spy; @@ -59,8 +62,6 @@ public void testUnregisteredDiscarding() throws Exception { @Test public void testSharedStateDeRegistration() throws Exception { - Random rnd = new Random(42); - SharedStateRegistry registry = spy(new SharedStateRegistry()); // Create two state handles with overlapping shared state @@ -186,6 +187,76 @@ public void testSharedStateDeRegistration() throws Exception { verify(stateHandle2.getMetaStateHandle(), times(1)).discardState(); } + /** + * This tests that re-registration of shared state with another registry works as expected. This simulates a + * recovery from a checkpoint, when the checkpoint coordinator creates a new shared state registry and re-registers + * all live checkpoint states. + */ + @Test + public void testSharedStateReRegistration() throws Exception { + + SharedStateRegistry stateRegistryA = spy(new SharedStateRegistry()); + + IncrementalKeyedStateHandle stateHandleX = create(new Random(1)); + IncrementalKeyedStateHandle stateHandleY = create(new Random(2)); + IncrementalKeyedStateHandle stateHandleZ = create(new Random(3)); + + // Now we register first time ... + stateHandleX.registerSharedStates(stateRegistryA); + stateHandleY.registerSharedStates(stateRegistryA); + stateHandleZ.registerSharedStates(stateRegistryA); + + try { + // Second attempt should fail + stateHandleX.registerSharedStates(stateRegistryA); + fail("Should not be able to register twice with the same registry."); + } catch (IllegalStateException ignore) { + } + + // Everything should be discarded for this handle + stateHandleZ.discardState(); + verify(stateHandleZ.getMetaStateHandle(), times(1)).discardState(); + for (StreamStateHandle stateHandle : stateHandleZ.getSharedState().values()) { + verify(stateHandle, times(1)).discardState(); + } + + // Close the first registry + stateRegistryA.close(); + + // Attempt to register to closed registry should trigger exception + try { + create(new Random(4)).registerSharedStates(stateRegistryA); + fail("Should not be able to register new state to closed registry."); + } catch (IllegalStateException ignore) { + } + + // All state should still get discarded + stateHandleY.discardState(); + verify(stateHandleY.getMetaStateHandle(), times(1)).discardState(); + for (StreamStateHandle stateHandle : stateHandleY.getSharedState().values()) { + verify(stateHandle, times(1)).discardState(); + } + + // This should still be unaffected + verify(stateHandleX.getMetaStateHandle(), never()).discardState(); + for (StreamStateHandle stateHandle : stateHandleX.getSharedState().values()) { + verify(stateHandle, never()).discardState(); + } + + // We re-register the handle with a new registry + SharedStateRegistry sharedStateRegistryB = spy(new SharedStateRegistry()); + stateHandleX.registerSharedStates(sharedStateRegistryB); + stateHandleX.discardState(); + + // Should be completely discarded because it is tracked through the new registry + verify(stateHandleX.getMetaStateHandle(), times(1)).discardState(); + for (StreamStateHandle stateHandle : stateHandleX.getSharedState().values()) { + verify(stateHandle, times(1)).discardState(); + } + + sharedStateRegistryB.close(); + } + private static IncrementalKeyedStateHandle create(Random rnd) { return new IncrementalKeyedStateHandle( UUID.nameUUIDFromBytes("test".getBytes()), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index f08ad2d98be0b..f6f73f20b789a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -18,11 +18,6 @@ package org.apache.flink.runtime.state; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.base.Joiner; -import org.apache.commons.io.output.ByteArrayOutputStream; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.ReduceFunction; @@ -65,11 +60,17 @@ import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.state.internal.InternalValueState; import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; +import org.apache.flink.shaded.guava18.com.google.common.base.Joiner; import org.apache.flink.types.IntValue; import org.apache.flink.util.FutureUtil; import org.apache.flink.util.IOUtils; import org.apache.flink.util.StateMigrationException; import org.apache.flink.util.TestLogger; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.apache.commons.io.output.ByteArrayOutputStream; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -175,17 +176,15 @@ protected AbstractKeyedStateBackend restoreKeyedBackend( Environment env) throws Exception { AbstractKeyedStateBackend backend = getStateBackend().createKeyedStateBackend( - env, - new JobID(), - "test_op", - keySerializer, - numberOfKeyGroups, - keyGroupRange, - env.getTaskKvStateRegistry()); + env, + new JobID(), + "test_op", + keySerializer, + numberOfKeyGroups, + keyGroupRange, + env.getTaskKvStateRegistry()); - if (null != state) { - backend.restore(state); - } + backend.restore(state); return backend; } @@ -242,6 +241,7 @@ public void testBackendUsesRegisteredKryoDefaultSerializer() throws Exception { } assertEquals("Didn't see the expected Kryo exception.", 1, numExceptions); + backend.dispose(); } @Test @@ -301,6 +301,7 @@ public void testBackendUsesRegisteredKryoDefaultSerializerUsingGetOrCreate() thr } assertEquals("Didn't see the expected Kryo exception.", 1, numExceptions); + backend.dispose(); } @Test @@ -354,6 +355,7 @@ public void testBackendUsesRegisteredKryoSerializer() throws Exception { } assertEquals("Didn't see the expected Kryo exception.", 1, numExceptions); + backend.dispose(); } @Test @@ -409,6 +411,7 @@ public void testBackendUsesRegisteredKryoSerializerUsingGetOrCreate() throws Exc } assertEquals("Didn't see the expected Kryo exception.", 1, numExceptions); + backend.dispose(); } @@ -486,81 +489,91 @@ public void testKryoRegisteringRestoreResilienceWithDefaultSerializer() throws E CheckpointStreamFactory streamFactory = createStreamFactory(); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); Environment env = new DummyEnvironment("test", 1, 0); - AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE, env); + AbstractKeyedStateBackend backend = null; - TypeInformation pojoType = new GenericTypeInfo<>(TestPojo.class); + try { + backend = createKeyedBackend(IntSerializer.INSTANCE, env); - // make sure that we are in fact using the KryoSerializer - assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer); + TypeInformation pojoType = new GenericTypeInfo<>(TestPojo.class); - ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", pojoType); + // make sure that we are in fact using the KryoSerializer + assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer); - ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", pojoType); - // ============== create snapshot - no Kryo registration or specific / default serializers ============== + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - // make some more modifications - backend.setCurrentKey(1); - state.update(new TestPojo("u1", 1)); + // ============== create snapshot - no Kryo registration or specific / default serializers ============== - backend.setCurrentKey(2); - state.update(new TestPojo("u2", 2)); + // make some more modifications + backend.setCurrentKey(1); + state.update(new TestPojo("u1", 1)); - KeyedStateHandle snapshot = runSnapshot(backend.snapshot( + backend.setCurrentKey(2); + state.update(new TestPojo("u2", 2)); + + KeyedStateHandle snapshot = runSnapshot(backend.snapshot( 682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint())); - snapshot.registerSharedStates(sharedStateRegistry); - backend.dispose(); + snapshot.registerSharedStates(sharedStateRegistry); + backend.dispose(); - // ========== restore snapshot - should use default serializer (ONLY SERIALIZATION) ========== + // ========== restore snapshot - should use default serializer (ONLY SERIALIZATION) ========== - // cast because our test serializer is not typed to TestPojo - env.getExecutionConfig().addDefaultKryoSerializer(TestPojo.class, (Class) CustomKryoTestSerializer.class); + // cast because our test serializer is not typed to TestPojo + env.getExecutionConfig().addDefaultKryoSerializer(TestPojo.class, (Class) CustomKryoTestSerializer.class); - backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env); + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env); - // re-initialize to ensure that we create the KryoSerializer from scratch, otherwise - // initializeSerializerUnlessSet would not pick up our new config - kvId = new ValueStateDescriptor<>("id", pojoType); - state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + // re-initialize to ensure that we create the KryoSerializer from scratch, otherwise + // initializeSerializerUnlessSet would not pick up our new config + kvId = new ValueStateDescriptor<>("id", pojoType); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - backend.setCurrentKey(1); + backend.setCurrentKey(1); - // update to test state backends that eagerly serialize, such as RocksDB - state.update(new TestPojo("u1", 11)); + // update to test state backends that eagerly serialize, such as RocksDB + state.update(new TestPojo("u1", 11)); - KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot( + KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot( 682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint())); - snapshot2.registerSharedStates(sharedStateRegistry); + snapshot2.registerSharedStates(sharedStateRegistry); + snapshot.discardState(); - snapshot.discardState(); + backend.dispose(); - backend.dispose(); + // ========= restore snapshot - should use default serializer (FAIL ON DESERIALIZATION) ========= - // ========= restore snapshot - should use default serializer (FAIL ON DESERIALIZATION) ========= + // cast because our test serializer is not typed to TestPojo + env.getExecutionConfig().addDefaultKryoSerializer(TestPojo.class, (Class) CustomKryoTestSerializer.class); - // cast because our test serializer is not typed to TestPojo - env.getExecutionConfig().addDefaultKryoSerializer(TestPojo.class, (Class) CustomKryoTestSerializer.class); + // on the second restore, since the custom serializer will be used for + // deserialization, we expect the deliberate failure to be thrown + expectedException.expect(ExpectedKryoTestException.class); - // on the second restore, since the custom serializer will be used for - // deserialization, we expect the deliberate failure to be thrown - expectedException.expect(ExpectedKryoTestException.class); + // state backends that eagerly deserializes (such as the memory state backend) will fail here + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2, env); - // state backends that eagerly deserializes (such as the memory state backend) will fail here - backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2, env); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + backend.setCurrentKey(1); + // state backends that lazily deserializes (such as RocksDB) will fail here + state.value(); - backend.setCurrentKey(1); - // state backends that lazily deserializes (such as RocksDB) will fail here - state.value(); + snapshot2.discardState(); + backend.dispose(); + } finally { + // ensure to release native resources even when we exit through exception + IOUtils.closeQuietly(backend); + backend.dispose(); + } } /** @@ -579,78 +592,89 @@ public void testKryoRegisteringRestoreResilienceWithRegisteredSerializer() throw SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); Environment env = new DummyEnvironment("test", 1, 0); - AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE, env); + AbstractKeyedStateBackend backend = null; - TypeInformation pojoType = new GenericTypeInfo<>(TestPojo.class); + try { + backend = createKeyedBackend(IntSerializer.INSTANCE, env); - // make sure that we are in fact using the KryoSerializer - assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer); + TypeInformation pojoType = new GenericTypeInfo<>(TestPojo.class); - ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", pojoType); - ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + // make sure that we are in fact using the KryoSerializer + assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer); - // ============== create snapshot - no Kryo registration or specific / default serializers ============== + ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", pojoType); + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - // make some more modifications - backend.setCurrentKey(1); - state.update(new TestPojo("u1", 1)); + // ============== create snapshot - no Kryo registration or specific / default serializers ============== - backend.setCurrentKey(2); - state.update(new TestPojo("u2", 2)); + // make some more modifications + backend.setCurrentKey(1); + state.update(new TestPojo("u1", 1)); - KeyedStateHandle snapshot = runSnapshot(backend.snapshot( + backend.setCurrentKey(2); + state.update(new TestPojo("u2", 2)); + + KeyedStateHandle snapshot = runSnapshot(backend.snapshot( 682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint())); - snapshot.registerSharedStates(sharedStateRegistry); - backend.dispose(); + snapshot.registerSharedStates(sharedStateRegistry); + backend.dispose(); - // ========== restore snapshot - should use specific serializer (ONLY SERIALIZATION) ========== + // ========== restore snapshot - should use specific serializer (ONLY SERIALIZATION) ========== - env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, CustomKryoTestSerializer.class); + env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, CustomKryoTestSerializer.class); - backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env); + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env); - // re-initialize to ensure that we create the KryoSerializer from scratch, otherwise - // initializeSerializerUnlessSet would not pick up our new config - kvId = new ValueStateDescriptor<>("id", pojoType); - state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + // re-initialize to ensure that we create the KryoSerializer from scratch, otherwise + // initializeSerializerUnlessSet would not pick up our new config + kvId = new ValueStateDescriptor<>("id", pojoType); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - backend.setCurrentKey(1); + backend.setCurrentKey(1); - // update to test state backends that eagerly serialize, such as RocksDB - state.update(new TestPojo("u1", 11)); + // update to test state backends that eagerly serialize, such as RocksDB + state.update(new TestPojo("u1", 11)); - KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot( + KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot( 682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint())); - snapshot2.registerSharedStates(sharedStateRegistry); + snapshot2.registerSharedStates(sharedStateRegistry); - snapshot.discardState(); + snapshot.discardState(); - backend.dispose(); + backend.dispose(); - // ========= restore snapshot - should use specific serializer (FAIL ON DESERIALIZATION) ========= + // ========= restore snapshot - should use specific serializer (FAIL ON DESERIALIZATION) ========= - env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, CustomKryoTestSerializer.class); + env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, CustomKryoTestSerializer.class); - // on the second restore, since the custom serializer will be used for - // deserialization, we expect the deliberate failure to be thrown - expectedException.expect(ExpectedKryoTestException.class); + // on the second restore, since the custom serializer will be used for + // deserialization, we expect the deliberate failure to be thrown + expectedException.expect(ExpectedKryoTestException.class); - // state backends that eagerly deserializes (such as the memory state backend) will fail here - backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2, env); + // state backends that eagerly deserializes (such as the memory state backend) will fail here + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2, env); - state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); - backend.setCurrentKey(1); - // state backends that lazily deserializes (such as RocksDB) will fail here - state.value(); + backend.setCurrentKey(1); + // state backends that lazily deserializes (such as RocksDB) will fail here + state.value(); + + backend.dispose(); + } finally { + // ensure that native resources are also released in case of exception + if (backend != null) { + backend.dispose(); + } + } } @Test @@ -1724,7 +1748,7 @@ public void testKeyGroupSnapshotRestore() throws Exception { final int MAX_PARALLELISM = 10; CheckpointStreamFactory streamFactory = createStreamFactory(); - AbstractKeyedStateBackend backend = createKeyedBackend( + final AbstractKeyedStateBackend backend = createKeyedBackend( IntSerializer.INSTANCE, MAX_PARALLELISM, new KeyGroupRange(0, MAX_PARALLELISM - 1), @@ -1768,7 +1792,7 @@ public void testKeyGroupSnapshotRestore() throws Exception { backend.dispose(); // backend for the first half of the key group range - AbstractKeyedStateBackend firstHalfBackend = restoreKeyedBackend( + final AbstractKeyedStateBackend firstHalfBackend = restoreKeyedBackend( IntSerializer.INSTANCE, MAX_PARALLELISM, new KeyGroupRange(0, 4), @@ -1776,7 +1800,7 @@ public void testKeyGroupSnapshotRestore() throws Exception { new DummyEnvironment("test", 1, 0)); // backend for the second half of the key group range - AbstractKeyedStateBackend secondHalfBackend = restoreKeyedBackend( + final AbstractKeyedStateBackend secondHalfBackend = restoreKeyedBackend( IntSerializer.INSTANCE, MAX_PARALLELISM, new KeyGroupRange(5, 9), @@ -2015,7 +2039,7 @@ public void testMapStateRestoreWithWrongSerializers() throws Exception { @Test public void testCopyDefaultValue() throws Exception { - AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + final AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1)); @@ -2042,7 +2066,7 @@ public void testCopyDefaultValue() throws Exception { */ @Test public void testRequireNonNullNamespace() throws Exception { - AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + final AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1)); @@ -2074,7 +2098,7 @@ public void testRequireNonNullNamespace() throws Exception { @SuppressWarnings("unchecked") protected void testConcurrentMapIfQueryable() throws Exception { final int numberOfKeyGroups = 1; - AbstractKeyedStateBackend backend = createKeyedBackend( + final AbstractKeyedStateBackend backend = createKeyedBackend( IntSerializer.INSTANCE, numberOfKeyGroups, new KeyGroupRange(0, 0), @@ -2382,9 +2406,9 @@ public void testAsyncSnapshotCancellation() throws Exception { streamFactory.setBlockerLatch(blocker); streamFactory.setAfterNumberInvocations(10); - AbstractKeyedStateBackend backend = null; + final AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + try { - backend = createKeyedBackend(IntSerializer.INSTANCE); if (!backend.supportsAsynchronousSnapshots()) { return; @@ -2411,14 +2435,11 @@ public void testAsyncSnapshotCancellation() throws Exception { waiter.await(); // close the backend to see if the close is propagated to the stream - backend.close(); + IOUtils.closeQuietly(backend); //unblock the stream so that it can run into the IOException blocker.trigger(); - //dispose the backend - backend.dispose(); - runner.join(); try { @@ -2429,10 +2450,7 @@ public void testAsyncSnapshotCancellation() throws Exception { } } finally { - if (null != backend) { - IOUtils.closeQuietly(backend); - backend.dispose(); - } + backend.dispose(); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java index 4c87671de2eb9..348dce647e023 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java @@ -33,10 +33,12 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.JobMasterRegistrationSuccess; import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.metrics.groups.TaskManagerMetricGroup; import org.apache.flink.runtime.registration.RegistrationResponse; @@ -47,7 +49,7 @@ import org.apache.flink.runtime.resourcemanager.SlotRequest; import org.apache.flink.runtime.resourcemanager.StandaloneResourceManager; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; +import org.apache.flink.runtime.rpc.TestingRpcService; import org.apache.flink.runtime.taskexecutor.slot.SlotOffer; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; import org.apache.flink.runtime.taskexecutor.slot.TimerService; @@ -57,6 +59,7 @@ import org.apache.flink.util.TestLogger; import org.hamcrest.Matchers; import org.junit.Test; +import org.mockito.Mockito; import java.net.InetAddress; import java.util.Arrays; @@ -76,6 +79,8 @@ public class TaskExecutorITCase extends TestLogger { + private final Time timeout = Time.seconds(10L); + @Test public void testSlotAllocation() throws Exception { TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); @@ -88,7 +93,7 @@ public void testSlotAllocation() throws Exception { final TestingLeaderRetrievalService rmLeaderRetrievalService = new TestingLeaderRetrievalService(null, null); final String rmAddress = "rm"; final String jmAddress = "jm"; - final UUID jmLeaderId = UUID.randomUUID(); + final JobMasterId jobMasterId = JobMasterId.generate(); final ResourceID rmResourceId = new ResourceID(rmAddress); final ResourceID jmResourceId = new ResourceID(jmAddress); final JobID jobId = new JobID(); @@ -96,9 +101,9 @@ public void testSlotAllocation() throws Exception { testingHAServices.setResourceManagerLeaderElectionService(rmLeaderElectionService); testingHAServices.setResourceManagerLeaderRetriever(rmLeaderRetrievalService); - testingHAServices.setJobMasterLeaderRetriever(jobId, new TestingLeaderRetrievalService(jmAddress, jmLeaderId)); + testingHAServices.setJobMasterLeaderRetriever(jobId, new TestingLeaderRetrievalService(jmAddress, jobMasterId.toUUID())); - TestingSerialRpcService rpcService = new TestingSerialRpcService(); + TestingRpcService rpcService = new TestingRpcService(); ResourceManagerConfiguration resourceManagerConfiguration = new ResourceManagerConfiguration( Time.milliseconds(500L), Time.milliseconds(500L)); @@ -158,18 +163,19 @@ public void testSlotAllocation() throws Exception { JobMasterGateway jmGateway = mock(JobMasterGateway.class); - when(jmGateway.registerTaskManager(any(String.class), any(TaskManagerLocation.class), eq(jmLeaderId), any(Time.class))) + when(jmGateway.registerTaskManager(any(String.class), any(TaskManagerLocation.class), any(Time.class))) .thenReturn(CompletableFuture.completedFuture(new JMTMRegistrationSuccess(taskManagerResourceId, 1234))); when(jmGateway.getHostname()).thenReturn(jmAddress); when(jmGateway.offerSlots( eq(taskManagerResourceId), any(Iterable.class), - eq(jmLeaderId), any(Time.class))).thenReturn(mock(CompletableFuture.class, RETURNS_MOCKS)); + when(jmGateway.getFencingToken()).thenReturn(jobMasterId); rpcService.registerGateway(rmAddress, resourceManager.getSelfGateway(ResourceManagerGateway.class)); rpcService.registerGateway(jmAddress, jmGateway); + rpcService.registerGateway(taskExecutor.getAddress(), taskExecutor.getSelfGateway(TaskExecutorGateway.class)); final AllocationID allocationId = new AllocationID(); final SlotRequest slotRequest = new SlotRequest(jobId, allocationId, resourceProfile, jmAddress); @@ -179,30 +185,33 @@ public void testSlotAllocation() throws Exception { resourceManager.start(); taskExecutor.start(); + final ResourceManagerGateway rmGateway = resourceManager.getSelfGateway(ResourceManagerGateway.class); + // notify the RM that it is the leader rmLeaderElectionService.isLeader(rmLeaderId); // notify the TM about the new RM leader rmLeaderRetrievalService.notifyListener(rmAddress, rmLeaderId); - CompletableFuture registrationResponseFuture = resourceManager.registerJobManager( - rmLeaderId, - jmLeaderId, + CompletableFuture registrationResponseFuture = rmGateway.registerJobManager( + jobMasterId, jmResourceId, jmAddress, jobId, - Time.milliseconds(0L)); + timeout); RegistrationResponse registrationResponse = registrationResponseFuture.get(); assertTrue(registrationResponse instanceof JobMasterRegistrationSuccess); - resourceManager.requestSlot(jmLeaderId, rmLeaderId, slotRequest, Time.milliseconds(0L)); + CompletableFuture slotAck = rmGateway.requestSlot(jobMasterId, slotRequest, timeout); + + slotAck.get(); - verify(jmGateway).offerSlots( + verify(jmGateway, Mockito.timeout(timeout.toMilliseconds())).offerSlots( eq(taskManagerResourceId), (Iterable)argThat(Matchers.contains(slotOffer)), - eq(jmLeaderId), any(Time.class)); + any(Time.class)); } finally { if (testingFatalErrorHandler.hasExceptionOccurred()) { testingFatalErrorHandler.rethrowError(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java index b5a3c8025a463..714644514f166 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -54,6 +55,7 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; +import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.memory.MemoryManager; @@ -64,26 +66,32 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway; +import org.apache.flink.runtime.resourcemanager.ResourceManagerId; +import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.TestingRpcService; -import org.apache.flink.runtime.rpc.TestingSerialRpcService; import org.apache.flink.runtime.taskexecutor.exceptions.SlotAllocationException; import org.apache.flink.runtime.taskexecutor.slot.SlotOffer; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; import org.apache.flink.runtime.taskexecutor.slot.TimerService; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; +import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.taskmanager.TaskManagerActions; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.TestingFatalErrorHandler; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; + +import org.junit.After; +import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; import org.mockito.ArgumentCaptor; import org.mockito.Matchers; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.slf4j.Logger; @@ -96,6 +104,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import static org.hamcrest.Matchers.contains; @@ -111,6 +120,21 @@ public class TaskExecutorTest extends TestLogger { private final Time timeout = Time.milliseconds(10000L); + private TestingRpcService rpc; + + @Before + public void setup() { + rpc = new TestingRpcService(); + } + + @After + public void teardown() { + if (rpc != null) { + rpc.stopService(); + rpc = null; + } + } + @Rule public TestName name = new TestName(); @@ -123,7 +147,6 @@ public void testHeartbeatTimeoutWithJobManager() throws Exception { final TaskManagerLocation taskManagerLocation = new TaskManagerLocation(tmResourceId, InetAddress.getLoopbackAddress(), 1234); final TaskSlotTable taskSlotTable = new TaskSlotTable(Arrays.asList(mock(ResourceProfile.class)), mock(TimerService.class)); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); final JobLeaderService jobLeaderService = new JobLeaderService(taskManagerLocation); final TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); final TestingLeaderRetrievalService rmLeaderRetrievalService = new TestingLeaderRetrievalService( @@ -168,31 +191,30 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation) thro when(jobMasterGateway.registerTaskManager( any(String.class), eq(taskManagerLocation), - eq(jmLeaderId), any(Time.class) )).thenReturn(CompletableFuture.completedFuture(new JMTMRegistrationSuccess(jmResourceId, blobPort))); when(jobMasterGateway.getAddress()).thenReturn(jobMasterAddress); when(jobMasterGateway.getHostname()).thenReturn("localhost"); - try { - final TaskExecutor taskManager = new TaskExecutor( - rpc, - tmConfig, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - heartbeatServices, - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - new JobManagerTable(), - jobLeaderService, - testingFatalErrorHandler); + final TaskExecutor taskManager = new TaskExecutor( + rpc, + tmConfig, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + heartbeatServices, + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + new JobManagerTable(), + jobLeaderService, + testingFatalErrorHandler); + try { taskManager.start(); rpc.registerGateway(jobMasterAddress, jobMasterGateway); @@ -205,8 +227,8 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation) thro jmLeaderRetrievalService.notifyListener(jobMasterAddress, jmLeaderId); // register task manager success will trigger monitoring heartbeat target between tm and jm - verify(jobMasterGateway).registerTaskManager( - eq(taskManager.getAddress()), eq(taskManagerLocation), eq(jmLeaderId), any(Time.class)); + verify(jobMasterGateway, Mockito.timeout(timeout.toMilliseconds())).registerTaskManager( + eq(taskManager.getAddress()), eq(taskManagerLocation), any(Time.class)); // the timeout should trigger disconnecting from the JobManager verify(jobMasterGateway, timeout(heartbeatTimeout * 50L)).disconnectTaskManager(eq(taskManagerLocation.getResourceID()), any(TimeoutException.class)); @@ -215,7 +237,8 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation) thro testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -230,7 +253,7 @@ public void testHeartbeatTimeoutWithResourceManager() throws Exception { // register the mock resource manager gateway ResourceManagerGateway rmGateway = mock(ResourceManagerGateway.class); when(rmGateway.registerTaskExecutor( - any(UUID.class), anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) + anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) .thenReturn( CompletableFuture.completedFuture( new TaskExecutorRegistrationSuccess( @@ -238,7 +261,6 @@ public void testHeartbeatTimeoutWithResourceManager() throws Exception { rmResourceId, 10L))); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); rpc.registerGateway(rmAddress, rmGateway); final TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( @@ -280,33 +302,33 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation } ); - try { - final TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - heartbeatServices, - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - mock(JobManagerTable.class), - mock(JobLeaderService.class), - testingFatalErrorHandler); + final TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + heartbeatServices, + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + mock(JobManagerTable.class), + mock(JobLeaderService.class), + testingFatalErrorHandler); + try { taskManager.start(); // define a leader and see that a registration happens testLeaderService.notifyListener(rmAddress, rmLeaderId); // register resource manager success will trigger monitoring heartbeat target between tm and rm - verify(rmGateway, atLeast(1)).registerTaskExecutor( - eq(rmLeaderId), eq(taskManager.getAddress()), eq(tmResourceId), any(SlotReport.class), any(Time.class)); + verify(rmGateway, Mockito.timeout(timeout.toMilliseconds()).atLeast(1)).registerTaskExecutor( + eq(taskManager.getAddress()), eq(tmResourceId), any(SlotReport.class), any(Time.class)); // heartbeat timeout should trigger disconnect TaskManager from ResourceManager verify(rmGateway, timeout(heartbeatTimeout * 50L)).disconnectTaskManager(eq(taskManagerLocation.getResourceID()), any(TimeoutException.class)); @@ -315,7 +337,8 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -324,7 +347,7 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation */ @Test public void testHeartbeatSlotReporting() throws Exception { - final long timeout = 1000L; + final long verificationTimeout = 1000L; final String rmAddress = "rm"; final String tmAddress = "tm"; final ResourceID rmResourceId = new ResourceID(rmAddress); @@ -334,7 +357,7 @@ public void testHeartbeatSlotReporting() throws Exception { // register the mock resource manager gateway ResourceManagerGateway rmGateway = mock(ResourceManagerGateway.class); when(rmGateway.registerTaskExecutor( - any(UUID.class), anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) + anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) .thenReturn( CompletableFuture.completedFuture( new TaskExecutorRegistrationSuccess( @@ -342,7 +365,6 @@ public void testHeartbeatSlotReporting() throws Exception { rmResourceId, 10L))); - final TestingRpcService rpc = new TestingRpcService(); rpc.registerGateway(rmAddress, rmGateway); final TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( @@ -398,25 +420,25 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation } ); - try { - final TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - heartbeatServices, - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - mock(JobManagerTable.class), - mock(JobLeaderService.class), - testingFatalErrorHandler); + final TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + heartbeatServices, + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + mock(JobManagerTable.class), + mock(JobLeaderService.class), + testingFatalErrorHandler); + try { taskManager.start(); // wait for spied heartbeat manager instance @@ -426,10 +448,10 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation testLeaderService.notifyListener(rmAddress, rmLeaderId); // register resource manager success will trigger monitoring heartbeat target between tm and rm - verify(rmGateway, timeout(timeout).atLeast(1)).registerTaskExecutor( - eq(rmLeaderId), eq(taskManager.getAddress()), eq(tmResourceId), eq(slotReport1), any(Time.class)); + verify(rmGateway, timeout(verificationTimeout).atLeast(1)).registerTaskExecutor( + eq(taskManager.getAddress()), eq(tmResourceId), eq(slotReport1), any(Time.class)); - verify(heartbeatManager, timeout(timeout)).monitorTarget(any(ResourceID.class), any(HeartbeatTarget.class)); + verify(heartbeatManager, timeout(verificationTimeout)).monitorTarget(any(ResourceID.class), any(HeartbeatTarget.class)); TaskExecutorGateway taskExecutorGateway = taskManager.getSelfGateway(TaskExecutorGateway.class); @@ -439,7 +461,7 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation ArgumentCaptor slotReportArgumentCaptor = ArgumentCaptor.forClass(SlotReport.class); // wait for heartbeat response - verify(rmGateway, timeout(timeout)).heartbeatFromTaskManager( + verify(rmGateway, timeout(verificationTimeout)).heartbeatFromTaskManager( eq(taskManagerLocation.getResourceID()), slotReportArgumentCaptor.capture()); @@ -452,7 +474,8 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -461,64 +484,66 @@ public void testImmediatelyRegistersIfLeaderIsKnown() throws Exception { final ResourceID resourceID = ResourceID.generate(); final String resourceManagerAddress = "/resource/manager/address/one"; final ResourceID resourceManagerResourceId = new ResourceID(resourceManagerAddress); + final String dispatcherAddress = "localhost"; final String jobManagerAddress = "localhost"; - final TestingSerialRpcService rpc = new TestingSerialRpcService(); - try { - // register a mock resource manager gateway - ResourceManagerGateway rmGateway = mock(ResourceManagerGateway.class); - when(rmGateway.registerTaskExecutor( - any(UUID.class), anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) - .thenReturn(CompletableFuture.completedFuture(new TaskExecutorRegistrationSuccess( - new InstanceID(), resourceManagerResourceId, 10L))); - - TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); - when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); - - rpc.registerGateway(resourceManagerAddress, rmGateway); - - TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); - when(taskManagerLocation.getResourceID()).thenReturn(resourceID); - - StandaloneHaServices haServices = new StandaloneHaServices( - resourceManagerAddress, - jobManagerAddress); - - final TaskSlotTable taskSlotTable = mock(TaskSlotTable.class); - final SlotReport slotReport = new SlotReport(); - when(taskSlotTable.createSlotReport(any(ResourceID.class))).thenReturn(slotReport); - - final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerServicesConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - mock(JobManagerTable.class), - mock(JobLeaderService.class), - testingFatalErrorHandler); + // register a mock resource manager gateway + ResourceManagerGateway rmGateway = mock(ResourceManagerGateway.class); + when(rmGateway.registerTaskExecutor( + anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) + .thenReturn(CompletableFuture.completedFuture(new TaskExecutorRegistrationSuccess( + new InstanceID(), resourceManagerResourceId, 10L))); + + TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); + when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); + + rpc.registerGateway(resourceManagerAddress, rmGateway); + TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); + when(taskManagerLocation.getResourceID()).thenReturn(resourceID); + + StandaloneHaServices haServices = new StandaloneHaServices( + resourceManagerAddress, + dispatcherAddress, + jobManagerAddress); + + final TaskSlotTable taskSlotTable = mock(TaskSlotTable.class); + final SlotReport slotReport = new SlotReport(); + when(taskSlotTable.createSlotReport(any(ResourceID.class))).thenReturn(slotReport); + + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerServicesConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + mock(JobManagerTable.class), + mock(JobLeaderService.class), + testingFatalErrorHandler); + + try { taskManager.start(); String taskManagerAddress = taskManager.getAddress(); - verify(rmGateway).registerTaskExecutor( - any(UUID.class), eq(taskManagerAddress), eq(resourceID), eq(slotReport), any(Time.class)); + verify(rmGateway, Mockito.timeout(timeout.toMilliseconds())).registerTaskExecutor( + eq(taskManagerAddress), eq(resourceID), eq(slotReport), any(Time.class)); // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -533,64 +558,63 @@ public void testTriggerRegistrationOnLeaderChange() throws Exception { final ResourceID rmResourceId1 = new ResourceID(address1); final ResourceID rmResourceId2 = new ResourceID(address2); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); - try { - // register the mock resource manager gateways - ResourceManagerGateway rmGateway1 = mock(ResourceManagerGateway.class); - ResourceManagerGateway rmGateway2 = mock(ResourceManagerGateway.class); - - when(rmGateway1.registerTaskExecutor( - any(UUID.class), anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) - .thenReturn(CompletableFuture.completedFuture( - new TaskExecutorRegistrationSuccess(new InstanceID(), rmResourceId1, 10L))); - when(rmGateway2.registerTaskExecutor( - any(UUID.class), anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) - .thenReturn(CompletableFuture.completedFuture( - new TaskExecutorRegistrationSuccess(new InstanceID(), rmResourceId2, 10L))); - - rpc.registerGateway(address1, rmGateway1); - rpc.registerGateway(address2, rmGateway2); - - TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( - null, - null); - - TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); - haServices.setResourceManagerLeaderRetriever(testLeaderService); - - TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); - when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); - when(taskManagerServicesConfiguration.getConfiguration()).thenReturn(new Configuration()); - when(taskManagerServicesConfiguration.getTmpDirectories()).thenReturn(new String[1]); - - TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); - when(taskManagerLocation.getResourceID()).thenReturn(tmResourceID); - when(taskManagerLocation.getHostname()).thenReturn("foobar"); - - final TaskSlotTable taskSlotTable = mock(TaskSlotTable.class); - final SlotReport slotReport = new SlotReport(); - when(taskSlotTable.createSlotReport(any(ResourceID.class))).thenReturn(slotReport); - - final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerServicesConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - mock(JobManagerTable.class), - mock(JobLeaderService.class), - testingFatalErrorHandler); + // register the mock resource manager gateways + ResourceManagerGateway rmGateway1 = mock(ResourceManagerGateway.class); + ResourceManagerGateway rmGateway2 = mock(ResourceManagerGateway.class); + + when(rmGateway1.registerTaskExecutor( + anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) + .thenReturn(CompletableFuture.completedFuture( + new TaskExecutorRegistrationSuccess(new InstanceID(), rmResourceId1, 10L))); + when(rmGateway2.registerTaskExecutor( + anyString(), any(ResourceID.class), any(SlotReport.class), any(Time.class))) + .thenReturn(CompletableFuture.completedFuture( + new TaskExecutorRegistrationSuccess(new InstanceID(), rmResourceId2, 10L))); + + rpc.registerGateway(address1, rmGateway1); + rpc.registerGateway(address2, rmGateway2); + + TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( + null, + null); + + TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); + haServices.setResourceManagerLeaderRetriever(testLeaderService); + + TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); + when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); + when(taskManagerServicesConfiguration.getConfiguration()).thenReturn(new Configuration()); + when(taskManagerServicesConfiguration.getTmpDirectories()).thenReturn(new String[1]); + + TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); + when(taskManagerLocation.getResourceID()).thenReturn(tmResourceID); + when(taskManagerLocation.getHostname()).thenReturn("foobar"); + + final TaskSlotTable taskSlotTable = mock(TaskSlotTable.class); + final SlotReport slotReport = new SlotReport(); + when(taskSlotTable.createSlotReport(any(ResourceID.class))).thenReturn(slotReport); + + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerServicesConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + mock(JobManagerTable.class), + mock(JobLeaderService.class), + testingFatalErrorHandler); + try { taskManager.start(); String taskManagerAddress = taskManager.getAddress(); @@ -600,8 +624,8 @@ public void testTriggerRegistrationOnLeaderChange() throws Exception { // define a leader and see that a registration happens testLeaderService.notifyListener(address1, leaderId1); - verify(rmGateway1).registerTaskExecutor( - eq(leaderId1), eq(taskManagerAddress), eq(tmResourceID), any(SlotReport.class), any(Time.class)); + verify(rmGateway1, Mockito.timeout(timeout.toMilliseconds())).registerTaskExecutor( + eq(taskManagerAddress), eq(tmResourceID), any(SlotReport.class), any(Time.class)); assertNotNull(taskManager.getResourceManagerConnection()); // cancel the leader @@ -610,30 +634,30 @@ public void testTriggerRegistrationOnLeaderChange() throws Exception { // set a new leader, see that a registration happens testLeaderService.notifyListener(address2, leaderId2); - verify(rmGateway2).registerTaskExecutor( - eq(leaderId2), eq(taskManagerAddress), eq(tmResourceID), eq(slotReport), any(Time.class)); + verify(rmGateway2, Mockito.timeout(timeout.toMilliseconds())).registerTaskExecutor( + eq(taskManagerAddress), eq(tmResourceID), eq(slotReport), any(Time.class)); assertNotNull(taskManager.getResourceManagerConnection()); // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } /** * Tests that we can submit a task to the TaskManager given that we've allocated a slot there. */ - @Test(timeout = 1000L) + @Test(timeout = 10000L) public void testTaskSubmission() throws Exception { final Configuration configuration = new Configuration(); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); final TaskManagerConfiguration taskManagerConfiguration = TaskManagerConfiguration.fromConfiguration(configuration); final JobID jobId = new JobID(); final AllocationID allocationId = new AllocationID(); - final UUID jobManagerLeaderId = UUID.randomUUID(); + final JobMasterId jobMasterId = JobMasterId.generate(); final JobVertexID jobVertexId = new JobVertexID(); JobInformation jobInformation = new JobInformation( @@ -668,15 +692,18 @@ public void testTaskSubmission() throws Exception { Collections.emptyList()); final LibraryCacheManager libraryCacheManager = mock(LibraryCacheManager.class); - when(libraryCacheManager.getClassLoader(eq(jobId))).thenReturn(getClass().getClassLoader()); + when(libraryCacheManager.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader()); + + final JobMasterGateway jobMasterGateway = mock(JobMasterGateway.class); + when(jobMasterGateway.getFencingToken()).thenReturn(jobMasterId); final JobManagerConnection jobManagerConnection = new JobManagerConnection( jobId, ResourceID.generate(), - mock(JobMasterGateway.class), - jobManagerLeaderId, + jobMasterGateway, mock(TaskManagerActions.class), mock(CheckpointResponder.class), + mock(BlobCache.class), libraryCacheManager, mock(ResultPartitionConsumableNotifier.class), mock(PartitionProducerStateChecker.class)); @@ -705,30 +732,32 @@ public void testTaskSubmission() throws Exception { final HighAvailabilityServices haServices = mock(HighAvailabilityServices.class); when(haServices.getResourceManagerLeaderRetriever()).thenReturn(mock(LeaderRetrievalService.class)); - try { - final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - mock(TaskManagerLocation.class), - mock(MemoryManager.class), - mock(IOManager.class), - networkEnvironment, - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - taskManagerMetricGroup, - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - jobManagerTable, - mock(JobLeaderService.class), - testingFatalErrorHandler); + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + mock(TaskManagerLocation.class), + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + taskManagerMetricGroup, + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + jobManagerTable, + mock(JobLeaderService.class), + testingFatalErrorHandler); + try { taskManager.start(); - taskManager.submitTask(tdd, jobManagerLeaderId, timeout); + final TaskExecutorGateway tmGateway = taskManager.getSelfGateway(TaskExecutorGateway.class); + + tmGateway.submitTask(tdd, jobMasterId, timeout); CompletableFuture completionFuture = TestInvokable.completableFuture; @@ -737,7 +766,8 @@ public void testTaskSubmission() throws Exception { // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -762,7 +792,6 @@ public void invoke() throws Exception { public void testJobLeaderDetection() throws Exception { final JobID jobId = new JobID(); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); final Configuration configuration = new Configuration(); final TaskManagerConfiguration taskManagerConfiguration = TaskManagerConfiguration.fromConfiguration(configuration); final ResourceID resourceId = new ResourceID("foobar"); @@ -784,14 +813,13 @@ public void testJobLeaderDetection() throws Exception { haServices.setJobMasterLeaderRetriever(jobId, jobManagerLeaderRetrievalService); final String resourceManagerAddress = "rm"; - final UUID resourceManagerLeaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerLeaderId = ResourceManagerId.generate(); final ResourceID resourceManagerResourceId = new ResourceID(resourceManagerAddress); final ResourceManagerGateway resourceManagerGateway = mock(ResourceManagerGateway.class); final InstanceID registrationId = new InstanceID(); when(resourceManagerGateway.registerTaskExecutor( - eq(resourceManagerLeaderId), any(String.class), eq(resourceId), any(SlotReport.class), @@ -807,14 +835,12 @@ public void testJobLeaderDetection() throws Exception { when(jobMasterGateway.registerTaskManager( any(String.class), eq(taskManagerLocation), - eq(jobManagerLeaderId), any(Time.class) )).thenReturn(CompletableFuture.completedFuture(new JMTMRegistrationSuccess(jmResourceId, blobPort))); when(jobMasterGateway.getHostname()).thenReturn(jobManagerAddress); when(jobMasterGateway.offerSlots( any(ResourceID.class), any(Iterable.class), - any(UUID.class), any(Time.class))).thenReturn(mock(CompletableFuture.class, RETURNS_MOCKS)); rpc.registerGateway(resourceManagerAddress, resourceManagerGateway); @@ -824,47 +850,57 @@ public void testJobLeaderDetection() throws Exception { final SlotID slotId = new SlotID(resourceId, 0); final SlotOffer slotOffer = new SlotOffer(allocationId, 0, ResourceProfile.UNKNOWN); - try { - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - jobManagerTable, - jobLeaderService, - testingFatalErrorHandler); + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + jobManagerTable, + jobLeaderService, + testingFatalErrorHandler); + try { taskManager.start(); + final TaskExecutorGateway tmGateway = taskManager.getSelfGateway(TaskExecutorGateway.class); + // tell the task manager about the rm leader - resourceManagerLeaderRetrievalService.notifyListener(resourceManagerAddress, resourceManagerLeaderId); + resourceManagerLeaderRetrievalService.notifyListener(resourceManagerAddress, resourceManagerLeaderId.toUUID()); // request slots from the task manager under the given allocation id - taskManager.requestSlot(slotId, jobId, allocationId, jobManagerAddress, resourceManagerLeaderId, timeout); + CompletableFuture slotRequestAck = tmGateway.requestSlot( + slotId, + jobId, + allocationId, + jobManagerAddress, + resourceManagerLeaderId, + timeout); + + slotRequestAck.get(); // now inform the task manager about the new job leader jobManagerLeaderRetrievalService.notifyListener(jobManagerAddress, jobManagerLeaderId); // the job leader should get the allocation id offered - verify(jobMasterGateway).offerSlots( + verify(jobMasterGateway, Mockito.timeout(timeout.toMilliseconds())).offerSlots( any(ResourceID.class), (Iterable)Matchers.argThat(contains(slotOffer)), - eq(jobManagerLeaderId), any(Time.class)); // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -876,7 +912,6 @@ public void testJobLeaderDetection() throws Exception { public void testSlotAcceptance() throws Exception { final JobID jobId = new JobID(); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); final Configuration configuration = new Configuration(); final TaskManagerConfiguration taskManagerConfiguration = TaskManagerConfiguration.fromConfiguration(configuration); final ResourceID resourceId = new ResourceID("foobar"); @@ -904,7 +939,6 @@ public void testSlotAcceptance() throws Exception { final InstanceID registrationId = new InstanceID(); when(resourceManagerGateway.registerTaskExecutor( - eq(resourceManagerLeaderId), any(String.class), eq(resourceId), any(SlotReport.class), @@ -923,37 +957,36 @@ public void testSlotAcceptance() throws Exception { when(jobMasterGateway.registerTaskManager( any(String.class), eq(taskManagerLocation), - eq(jobManagerLeaderId), any(Time.class) )).thenReturn(CompletableFuture.completedFuture(new JMTMRegistrationSuccess(jmResourceId, blobPort))); when(jobMasterGateway.getHostname()).thenReturn(jobManagerAddress); when(jobMasterGateway.offerSlots( - any(ResourceID.class), any(Iterable.class), eq(jobManagerLeaderId), any(Time.class))) + any(ResourceID.class), any(Iterable.class), any(Time.class))) .thenReturn(CompletableFuture.completedFuture((Collection)Collections.singleton(offer1))); rpc.registerGateway(resourceManagerAddress, resourceManagerGateway); rpc.registerGateway(jobManagerAddress, jobMasterGateway); - try { - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - jobManagerTable, - jobLeaderService, - testingFatalErrorHandler); + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + jobManagerTable, + jobLeaderService, + testingFatalErrorHandler); + try { taskManager.start(); taskSlotTable.allocateSlot(0, jobId, allocationId1, Time.milliseconds(10000L)); @@ -963,8 +996,7 @@ public void testSlotAcceptance() throws Exception { // been properly started. jobLeaderService.addJob(jobId, jobManagerAddress); - verify(resourceManagerGateway).notifySlotAvailable( - eq(resourceManagerLeaderId), + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).notifySlotAvailable( eq(registrationId), eq(new SlotID(resourceId, 1)), eq(allocationId2)); @@ -976,7 +1008,8 @@ public void testSlotAcceptance() throws Exception { // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -992,76 +1025,84 @@ public void testRejectAllocationRequestsForOutOfSyncSlots() throws Exception { final ResourceID resourceID = ResourceID.generate(); final String address1 = "/resource/manager/address/one"; - final UUID leaderId = UUID.randomUUID(); + final ResourceManagerId resourceManagerId = ResourceManagerId.generate(); final JobID jobId = new JobID(); final String jobManagerAddress = "foobar"; - final TestingSerialRpcService rpc = new TestingSerialRpcService(); - try { - // register the mock resource manager gateways - ResourceManagerGateway rmGateway1 = mock(ResourceManagerGateway.class); - rpc.registerGateway(address1, rmGateway1); - - TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( - "localhost", - HighAvailabilityServices.DEFAULT_LEADER_ID); - - TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); - haServices.setResourceManagerLeaderRetriever(testLeaderService); - - TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); - when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); - - TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); - when(taskManagerLocation.getResourceID()).thenReturn(resourceID); - - final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); - - TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerServicesConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - mock(TaskManagerMetricGroup.class), - mock(BroadcastVariableManager.class), - mock(FileCache.class), - mock(TaskSlotTable.class), - mock(JobManagerTable.class), - mock(JobLeaderService.class), - testingFatalErrorHandler); + // register the mock resource manager gateways + ResourceManagerGateway rmGateway1 = mock(ResourceManagerGateway.class); + rpc.registerGateway(address1, rmGateway1); + + TestingLeaderRetrievalService testLeaderService = new TestingLeaderRetrievalService( + address1, + HighAvailabilityServices.DEFAULT_LEADER_ID); + + TestingHighAvailabilityServices haServices = new TestingHighAvailabilityServices(); + haServices.setResourceManagerLeaderRetriever(testLeaderService); + + TaskManagerConfiguration taskManagerServicesConfiguration = mock(TaskManagerConfiguration.class); + when(taskManagerServicesConfiguration.getNumberSlots()).thenReturn(1); + + TaskManagerLocation taskManagerLocation = mock(TaskManagerLocation.class); + when(taskManagerLocation.getResourceID()).thenReturn(resourceID); + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + final TaskSlotTable taskSlotTable = mock(TaskSlotTable.class); + when(taskSlotTable.createSlotReport(any(ResourceID.class))).thenReturn(new SlotReport()); + when(taskSlotTable.getCurrentAllocation(1)).thenReturn(new AllocationID()); + + when(rmGateway1.registerTaskExecutor(anyString(), eq(resourceID), any(SlotReport.class), any(Time.class))).thenReturn( + CompletableFuture.completedFuture(new TaskExecutorRegistrationSuccess(new InstanceID(), ResourceID.generate(), 1000L))); + + TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerServicesConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + mock(JobManagerTable.class), + mock(JobLeaderService.class), + testingFatalErrorHandler); + + try { taskManager.start(); - String taskManagerAddress = taskManager.getAddress(); + + final TaskExecutorGateway tmGateway = taskManager.getSelfGateway(TaskExecutorGateway.class); + + String taskManagerAddress = tmGateway.getAddress(); // no connection initially, since there is no leader assertNull(taskManager.getResourceManagerConnection()); // define a leader and see that a registration happens - testLeaderService.notifyListener(address1, leaderId); + testLeaderService.notifyListener(address1, resourceManagerId.toUUID()); - verify(rmGateway1).registerTaskExecutor( - eq(leaderId), eq(taskManagerAddress), eq(resourceID), any(SlotReport.class), any(Time.class)); + verify(rmGateway1, Mockito.timeout(timeout.toMilliseconds())).registerTaskExecutor( + eq(taskManagerAddress), eq(resourceID), any(SlotReport.class), any(Time.class)); assertNotNull(taskManager.getResourceManagerConnection()); // test that allocating a slot works final SlotID slotID = new SlotID(resourceID, 0); - taskManager.requestSlot(slotID, jobId, new AllocationID(), jobManagerAddress, leaderId, timeout); + tmGateway.requestSlot(slotID, jobId, new AllocationID(), jobManagerAddress, resourceManagerId, timeout); // TODO: Figure out the concrete allocation behaviour between RM and TM. Maybe we don't need the SlotID... // test that we can't allocate slots which are blacklisted due to pending confirmation of the RM final SlotID unconfirmedFreeSlotID = new SlotID(resourceID, 1); - CompletableFuture requestSlotFuture = taskManager.requestSlot( + CompletableFuture requestSlotFuture = tmGateway.requestSlot( unconfirmedFreeSlotID, jobId, new AllocationID(), jobManagerAddress, - leaderId, + resourceManagerId, timeout); try { @@ -1069,29 +1110,30 @@ public void testRejectAllocationRequestsForOutOfSyncSlots() throws Exception { fail("The slot request should have failed."); } catch (Exception e) { - assertTrue(ExceptionUtils.containsThrowable(e, SlotAllocationException.class)); + assertTrue(ExceptionUtils.findThrowable(e, SlotAllocationException.class).isPresent()); } // re-register - verify(rmGateway1).registerTaskExecutor( - eq(leaderId), eq(taskManagerAddress), eq(resourceID), any(SlotReport.class), any(Time.class)); - testLeaderService.notifyListener(address1, leaderId); + verify(rmGateway1, Mockito.timeout(timeout.toMilliseconds())).registerTaskExecutor( + eq(taskManagerAddress), eq(resourceID), any(SlotReport.class), any(Time.class)); + testLeaderService.notifyListener(address1, resourceManagerId.toUUID()); // now we should be successful because the slots status has been synced // test that we can't allocate slots which are blacklisted due to pending confirmation of the RM - taskManager.requestSlot( + tmGateway.requestSlot( unconfirmedFreeSlotID, jobId, new AllocationID(), jobManagerAddress, - leaderId, + resourceManagerId, timeout); // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } } @@ -1103,7 +1145,6 @@ public void testRejectAllocationRequestsForOutOfSyncSlots() throws Exception { public void testSubmitTaskBeforeAcceptSlot() throws Exception { final JobID jobId = new JobID(); - final TestingSerialRpcService rpc = new TestingSerialRpcService(); final Configuration configuration = new Configuration(); final TaskManagerConfiguration taskManagerConfiguration = TaskManagerConfiguration.fromConfiguration(configuration); final ResourceID resourceId = new ResourceID("foobar"); @@ -1120,10 +1161,10 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { final ResourceID resourceManagerResourceId = new ResourceID(resourceManagerAddress); final String jobManagerAddress = "jm"; - final UUID jobManagerLeaderId = UUID.randomUUID(); + final JobMasterId jobMasterId = JobMasterId.generate(); final LeaderRetrievalService resourceManagerLeaderRetrievalService = new TestingLeaderRetrievalService(resourceManagerAddress, resourceManagerLeaderId); - final LeaderRetrievalService jobManagerLeaderRetrievalService = new TestingLeaderRetrievalService(jobManagerAddress, jobManagerLeaderId); + final LeaderRetrievalService jobManagerLeaderRetrievalService = new TestingLeaderRetrievalService(jobManagerAddress, jobMasterId.toUUID()); haServices.setResourceManagerLeaderRetriever(resourceManagerLeaderRetrievalService); haServices.setJobMasterLeaderRetriever(jobId, jobManagerLeaderRetrievalService); @@ -1131,7 +1172,6 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { final InstanceID registrationId = new InstanceID(); when(resourceManagerGateway.registerTaskExecutor( - eq(resourceManagerLeaderId), any(String.class), eq(resourceId), any(SlotReport.class), @@ -1151,10 +1191,10 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { when(jobMasterGateway.registerTaskManager( any(String.class), eq(taskManagerLocation), - eq(jobManagerLeaderId), any(Time.class) )).thenReturn(CompletableFuture.completedFuture(new JMTMRegistrationSuccess(jmResourceId, blobPort))); when(jobMasterGateway.getHostname()).thenReturn(jobManagerAddress); + when(jobMasterGateway.updateTaskExecutionState(any(TaskExecutionState.class))).thenReturn(CompletableFuture.completedFuture(Acknowledge.get())); rpc.registerGateway(resourceManagerAddress, resourceManagerGateway); @@ -1167,43 +1207,47 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { jobId, jmResourceId, jobMasterGateway, - jobManagerLeaderId, mock(TaskManagerActions.class), mock(CheckpointResponder.class), + mock(BlobCache.class), libraryCacheManager, mock(ResultPartitionConsumableNotifier.class), mock(PartitionProducerStateChecker.class)); - jobManagerTable.put(jobId, jobManagerConnection); + final TaskManagerMetricGroup taskManagerMetricGroup = mock(TaskManagerMetricGroup.class); + TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); + when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); + + when(taskManagerMetricGroup.addTaskForJob( + any(JobID.class), anyString(), any(JobVertexID.class), any(ExecutionAttemptID.class), + anyString(), anyInt(), anyInt()) + ).thenReturn(taskMetricGroup); + + final NetworkEnvironment networkMock = mock(NetworkEnvironment.class, Mockito.RETURNS_MOCKS); + + final TaskExecutor taskManager = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + networkMock, + haServices, + mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(MetricRegistry.class), + taskManagerMetricGroup, + mock(BroadcastVariableManager.class), + mock(FileCache.class), + taskSlotTable, + jobManagerTable, + jobLeaderService, + testingFatalErrorHandler); try { - final TaskManagerMetricGroup taskManagerMetricGroup = mock(TaskManagerMetricGroup.class); - TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); - when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); + taskManager.start(); - when(taskManagerMetricGroup.addTaskForJob( - any(JobID.class), anyString(), any(JobVertexID.class), any(ExecutionAttemptID.class), - anyString(), anyInt(), anyInt()) - ).thenReturn(taskMetricGroup); + final TaskExecutorGateway tmGateway = taskManager.getSelfGateway(TaskExecutorGateway.class); - final TaskExecutor taskManager = new TaskExecutor( - rpc, - taskManagerConfiguration, - taskManagerLocation, - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - haServices, - mock(HeartbeatServices.class, RETURNS_MOCKS), - mock(MetricRegistry.class), - taskManagerMetricGroup, - mock(BroadcastVariableManager.class), - mock(FileCache.class), - taskSlotTable, - jobManagerTable, - jobLeaderService, - testingFatalErrorHandler); - taskManager.start(); taskSlotTable.allocateSlot(0, jobId, allocationId1, Time.milliseconds(10000L)); taskSlotTable.allocateSlot(1, jobId, allocationId2, Time.milliseconds(10000L)); @@ -1247,7 +1291,6 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { jobMasterGateway.offerSlots( any(ResourceID.class), any(Iterable.class), - eq(jobManagerLeaderId), any(Time.class))) .thenReturn(offerResultFuture); @@ -1255,16 +1298,15 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { // been properly started. This will also offer the slots to the job master jobLeaderService.addJob(jobId, jobManagerAddress); - verify(jobMasterGateway).offerSlots(any(ResourceID.class), any(Iterable.class), eq(jobManagerLeaderId), any(Time.class)); + verify(jobMasterGateway, Mockito.timeout(timeout.toMilliseconds())).offerSlots(any(ResourceID.class), any(Iterable.class), any(Time.class)); // submit the task without having acknowledge the offered slots - taskManager.submitTask(tdd, jobManagerLeaderId, timeout); + tmGateway.submitTask(tdd, jobMasterId, timeout); // acknowledge the offered slots offerResultFuture.complete(Collections.singleton(offer1)); - verify(resourceManagerGateway).notifySlotAvailable( - eq(resourceManagerLeaderId), + verify(resourceManagerGateway, Mockito.timeout(timeout.toMilliseconds())).notifySlotAvailable( eq(registrationId), eq(new SlotID(resourceId, 1)), any(AllocationID.class)); @@ -1276,8 +1318,79 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { // check if a concurrent error occurred testingFatalErrorHandler.rethrowError(); } finally { - rpc.stopService(); + taskManager.shutDown(); + taskManager.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); } + } + /** + * This tests makes sure that duplicate JobMaster gained leadership messages are filtered out + * by the TaskExecutor. + * + * See FLINK-7526 + */ + @Test + public void testFilterOutDuplicateJobMasterRegistrations() throws Exception { + final long verificationTimeout = 500L; + final Configuration configuration = new Configuration(); + final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); + final JobLeaderService jobLeaderService = mock(JobLeaderService.class); + final TaskManagerConfiguration taskManagerConfiguration = TaskManagerConfiguration.fromConfiguration(configuration); + final TaskManagerLocation taskManagerLocation = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLocalHost(), 1234); + + final HighAvailabilityServices haServicesMock = mock(HighAvailabilityServices.class, Mockito.RETURNS_MOCKS); + final HeartbeatServices heartbeatServicesMock = mock(HeartbeatServices.class, Mockito.RETURNS_MOCKS); + + final JobID jobId = new JobID(); + final JobMasterGateway jobMasterGateway = mock(JobMasterGateway.class); + when(jobMasterGateway.getHostname()).thenReturn("localhost"); + final JMTMRegistrationSuccess registrationMessage = new JMTMRegistrationSuccess(ResourceID.generate(), 1); + final JobManagerTable jobManagerTableMock = spy(new JobManagerTable()); + + final TaskExecutor taskExecutor = new TaskExecutor( + rpc, + taskManagerConfiguration, + taskManagerLocation, + mock(MemoryManager.class), + mock(IOManager.class), + mock(NetworkEnvironment.class), + haServicesMock, + heartbeatServicesMock, + mock(MetricRegistry.class), + mock(TaskManagerMetricGroup.class), + mock(BroadcastVariableManager.class), + mock(FileCache.class), + mock(TaskSlotTable.class), + jobManagerTableMock, + jobLeaderService, + testingFatalErrorHandler); + + try { + taskExecutor.start(); + + ArgumentCaptor jobLeaderListenerArgumentCaptor = ArgumentCaptor.forClass(JobLeaderListener.class); + + verify(jobLeaderService).start(anyString(), any(RpcService.class), any(HighAvailabilityServices.class), jobLeaderListenerArgumentCaptor.capture()); + + JobLeaderListener taskExecutorListener = jobLeaderListenerArgumentCaptor.getValue(); + + taskExecutorListener.jobManagerGainedLeadership(jobId, jobMasterGateway, registrationMessage); + + // duplicate job manager gained leadership message + taskExecutorListener.jobManagerGainedLeadership(jobId, jobMasterGateway, registrationMessage); + + ArgumentCaptor jobManagerConnectionArgumentCaptor = ArgumentCaptor.forClass(JobManagerConnection.class); + + verify(jobManagerTableMock, Mockito.timeout(verificationTimeout).times(1)).put(eq(jobId), jobManagerConnectionArgumentCaptor.capture()); + + JobManagerConnection jobManagerConnection = jobManagerConnectionArgumentCaptor.getValue(); + + assertEquals(jobMasterGateway, jobManagerConnection.getJobManagerGateway()); + + testingFatalErrorHandler.rethrowError(); + } finally { + taskExecutor.shutDown(); + taskExecutor.getTerminationFuture().get(timeout.toMilliseconds(), TimeUnit.MILLISECONDS); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index c6d2fec2f0daf..392dc29bf3ca8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -22,11 +22,13 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -49,7 +51,6 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; @@ -145,6 +146,7 @@ public void testMixedAsyncCallsInOrder() { } private static Task createTask() throws Exception { + BlobCache blobCache = mock(BlobCache.class); LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader()); @@ -187,7 +189,7 @@ private static Task createTask() throws Exception { Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateHandles(), + new TaskStateSnapshot(), mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, @@ -195,6 +197,7 @@ private static Task createTask() throws Exception { mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + blobCache, libCache, mock(FileCache.class), new TestingTaskManagerRuntimeInfo(), @@ -228,7 +231,7 @@ public void invoke() throws Exception { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {} + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {} @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java index 40678de125424..ac0df3663f4dc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java @@ -20,39 +20,42 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.blob.BlobCache; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; -import org.apache.flink.runtime.executiongraph.JobInformation; -import org.apache.flink.runtime.executiongraph.TaskInformation; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; -import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.JobInformation; +import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.StoppableTask; import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; +import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; + import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; -import scala.concurrent.duration.FiniteDuration; import java.lang.reflect.Field; import java.util.Collections; import java.util.concurrent.Executor; +import scala.concurrent.duration.FiniteDuration; + import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -88,7 +91,7 @@ public void doMocking(AbstractInvokable taskMock) throws Exception { Collections.emptyList(), Collections.emptyList(), 0, - mock(TaskStateHandles.class), + mock(TaskStateSnapshot.class), mock(MemoryManager.class), mock(IOManager.class), mock(NetworkEnvironment.class), @@ -96,6 +99,7 @@ public void doMocking(AbstractInvokable taskMock) throws Exception { mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + mock(BlobCache.class), mock(LibraryCacheManager.class), mock(FileCache.class), tmRuntimeInfo, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index ba3e8201f0ce2..d4cd0cfcf64d5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -227,7 +228,8 @@ public void testFailExternallyRightAway() { @Test public void testLibraryCacheRegistrationFailed() { try { - Task task = createTask(TestInvokableCorrect.class, mock(LibraryCacheManager.class)); + Task task = createTask(TestInvokableCorrect.class, mock(BlobCache.class), + mock(LibraryCacheManager.class)); // task should be new and perfect assertEquals(ExecutionState.CREATED, task.getExecutionState()); @@ -260,6 +262,7 @@ public void testLibraryCacheRegistrationFailed() { @Test public void testExecutionFailsInNetworkRegistration() { try { + BlobCache blobCache = mock(BlobCache.class); // mock a working library cache LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); @@ -274,7 +277,7 @@ public void testExecutionFailsInNetworkRegistration() { when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); - Task task = createTask(TestInvokableCorrect.class, libCache, network, consumableNotifier, partitionProducerStateChecker, executor); + Task task = createTask(TestInvokableCorrect.class, blobCache, libCache, network, consumableNotifier, partitionProducerStateChecker, executor); task.registerExecutionListener(listener); @@ -617,6 +620,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { IntermediateDataSetID resultId = new IntermediateDataSetID(); ResultPartitionID partitionId = new ResultPartitionID(); + BlobCache blobCache = mock(BlobCache.class); LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); @@ -629,7 +633,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); - createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); + createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); // Test all branches of trigger partition state check @@ -638,7 +642,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { createQueuesAndActors(); // PartitionProducerDisposedException - Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); + Task task = createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); @@ -654,7 +658,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { createQueuesAndActors(); // Any other exception - Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); + Task task = createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); @@ -671,7 +675,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { createQueuesAndActors(); // TimeoutException handled special => retry - Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); + Task task = createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); @@ -702,7 +706,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { createQueuesAndActors(); // Success - Task task = createTask(InvokableBlockingInInvoke.class, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); + Task task = createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); @@ -882,26 +886,30 @@ private Task createTask(Class invokable) throws IOE } private Task createTask(Class invokable, Configuration config) throws IOException { + BlobCache blobCache = mock(BlobCache.class); LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, libCache, config, new ExecutionConfig()); + return createTask(invokable, blobCache,libCache, config, new ExecutionConfig()); } private Task createTask(Class invokable, Configuration config, ExecutionConfig execConfig) throws IOException { + BlobCache blobCache = mock(BlobCache.class); LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, libCache, config, execConfig); + return createTask(invokable, blobCache,libCache, config, execConfig); } private Task createTask( Class invokable, + BlobCache blobCache, LibraryCacheManager libCache) throws IOException { - return createTask(invokable, libCache, new Configuration(), new ExecutionConfig()); + return createTask(invokable, blobCache,libCache, new Configuration(), new ExecutionConfig()); } private Task createTask( Class invokable, + BlobCache blobCache, LibraryCacheManager libCache, Configuration config, ExecutionConfig execConfig) throws IOException { @@ -916,21 +924,23 @@ private Task createTask( when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); - return createTask(invokable, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig); + return createTask(invokable, blobCache, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig); } private Task createTask( Class invokable, + BlobCache blobCache, LibraryCacheManager libCache, NetworkEnvironment networkEnvironment, ResultPartitionConsumableNotifier consumableNotifier, PartitionProducerStateChecker partitionProducerStateChecker, Executor executor) throws IOException { - return createTask(invokable, libCache, networkEnvironment, consumableNotifier, partitionProducerStateChecker, executor, new Configuration(), new ExecutionConfig()); + return createTask(invokable, blobCache, libCache, networkEnvironment, consumableNotifier, partitionProducerStateChecker, executor, new Configuration(), new ExecutionConfig()); } private Task createTask( Class invokable, + BlobCache blobCache, LibraryCacheManager libCache, NetworkEnvironment networkEnvironment, ResultPartitionConsumableNotifier consumableNotifier, @@ -991,6 +1001,7 @@ private Task createTask( taskManagerConnection, inputSplitProvider, checkpointResponder, + blobCache, libCache, mock(FileCache.class), new TestingTaskManagerRuntimeInfo(taskManagerConfig), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/FailingBlockingInvokable.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/FailingBlockingInvokable.java new file mode 100644 index 0000000000000..37c141d6b9f02 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/FailingBlockingInvokable.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.testtasks; + +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; + +/** + * Task which blocks until the (static) {@link #unblock()} method is called and then fails with an + * exception. + */ +public class FailingBlockingInvokable extends AbstractInvokable { + private static volatile boolean blocking = true; + private static final Object lock = new Object(); + + @Override + public void invoke() throws Exception { + while (blocking) { + synchronized (lock) { + lock.wait(); + } + } + throw new RuntimeException("This exception is expected."); + } + + public static void unblock() { + blocking = false; + + synchronized (lock) { + lock.notifyAll(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java index a0c441247ae56..037ecd17cf268 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -21,7 +21,8 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; import org.apache.flink.runtime.jobgraph.JobStatus; -import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,14 +42,21 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS private final ArrayDeque suspended = new ArrayDeque<>(2); + private final int maxRetainedCheckpoints; + + public RecoverableCompletedCheckpointStore() { + this(1); + } + + public RecoverableCompletedCheckpointStore(int maxRetainedCheckpoints) { + Preconditions.checkArgument(maxRetainedCheckpoints > 0); + this.maxRetainedCheckpoints = maxRetainedCheckpoints; + } + @Override - public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { + public void recover() throws Exception { checkpoints.addAll(suspended); suspended.clear(); - - for (CompletedCheckpoint checkpoint : checkpoints) { - checkpoint.registerSharedStatesAfterRestored(sharedStateRegistry); - } } @Override @@ -56,13 +64,16 @@ public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { checkpoints.addLast(checkpoint); - - if (checkpoints.size() > 1) { - CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); - checkpointToSubsume.discardOnSubsume(); + if (checkpoints.size() > maxRetainedCheckpoints) { + removeOldestCheckpoint(); } } + public void removeOldestCheckpoint() throws Exception { + CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); + checkpointToSubsume.discardOnSubsume(); + } + @Override public CompletedCheckpoint getLatestCheckpoint() throws Exception { return checkpoints.isEmpty() ? null : checkpoints.getLast(); @@ -96,7 +107,7 @@ public int getNumberOfRetainedCheckpoints() { @Override public int getMaxNumberOfRetainedCheckpoints() { - return 1; + return maxRetainedCheckpoints; } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java index f262bf2c0b631..229f1eb08b859 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java @@ -24,10 +24,11 @@ import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.io.InputSplit; import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -70,7 +71,8 @@ import java.util.concurrent.Executors; import static org.junit.Assume.assumeTrue; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Test that verifies the behavior of blocking shutdown hooks and of the @@ -177,6 +179,7 @@ public static void main(String[] args) throws Exception { new NoOpTaskManagerActions(), new NoOpInputSplitProvider(), new NoOpCheckpointResponder(), + mock(BlobCache.class), new FallbackLibraryCacheManager(), new FileCache(tmInfo.getTmpDirectories()), tmInfo, @@ -232,7 +235,7 @@ public InputSplit getNextInputSplit(ClassLoader userCodeClassLoader) { private static final class NoOpCheckpointResponder implements CheckpointResponder { @Override - public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, SubtaskState s) {} + public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, TaskStateSnapshot s) {} @Override public void declineCheckpoint(JobID j, ExecutionAttemptID e, long l, Throwable t) {} diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/JobManagerRegistrationTest.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/JobManagerRegistrationTest.scala index 1b9ee48646c73..95da9814bdbaa 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/JobManagerRegistrationTest.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/JobManagerRegistrationTest.scala @@ -264,14 +264,15 @@ ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll with Befor components._1, components._2, components._3, - ActorRef.noSender, components._4, + ActorRef.noSender, components._5, + components._6, highAvailabilityServices.getJobManagerLeaderElectionService( HighAvailabilityServices.DEFAULT_JOB_ID), highAvailabilityServices.getSubmittedJobGraphStore(), highAvailabilityServices.getCheckpointRecoveryFactory(), - components._8, + components._9, None) _system.actorOf(props) diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala index e5655bb5f5ab4..87f80882730d1 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala @@ -28,6 +28,7 @@ import akka.testkit.CallingThreadDispatcher import org.apache.flink.api.common.JobID import org.apache.flink.configuration.{Configuration, JobManagerOptions} import org.apache.flink.runtime.akka.AkkaUtils +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.savepoint.Savepoint import org.apache.flink.runtime.checkpoint.{CheckpointOptions, CheckpointRecoveryFactory} import org.apache.flink.runtime.clusterframework.FlinkResourceManager @@ -110,6 +111,7 @@ class TestingCluster( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: Scheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -127,6 +129,7 @@ class TestingCluster( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala index f50a832f6b92a..8b9ce15b142a6 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{Executor, ScheduledExecutorService} import akka.actor.ActorRef import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager import org.apache.flink.runtime.executiongraph.restart.RestartStrategyFactory @@ -34,15 +35,16 @@ import org.apache.flink.runtime.metrics.MetricRegistry import scala.concurrent.duration._ import scala.language.postfixOps -/** JobManager implementation extended by testing messages - * - */ +/** + * JobManager implementation extended by testing messages + */ class TestingJobManager( flinkConfiguration: Configuration, futureExecutor: ScheduledExecutorService, ioExecutor: Executor, instanceManager: InstanceManager, scheduler: Scheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -58,6 +60,7 @@ class TestingJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingUtils.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingUtils.scala index ddbb82dc16630..02f83fd7fdd29 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingUtils.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingUtils.scala @@ -24,7 +24,6 @@ import java.util.concurrent._ import akka.actor.{ActorRef, ActorSystem, Kill, Props} import akka.pattern.{Patterns, ask} -import com.google.common.util.concurrent.MoreExecutors import com.typesafe.config.ConfigFactory import grizzled.slf4j.Logger import org.apache.flink.api.common.time.Time @@ -130,7 +129,8 @@ object TestingUtils { * * @return Direct [[ExecutionContext]] which executes runnables directly */ - def directExecutionContext = ExecutionContext.fromExecutor(MoreExecutors.directExecutor()) + def directExecutionContext = ExecutionContext + .fromExecutor(org.apache.flink.runtime.concurrent.Executors.directExecutor()) /** @return A new [[QueuedActionExecutionContext]] */ def queuedActionExecutionContext = { diff --git a/flink-shaded-curator/flink-shaded-curator-recipes/pom.xml b/flink-shaded-curator/flink-shaded-curator-recipes/pom.xml index b539f96a5018c..61897d1ca12fb 100644 --- a/flink-shaded-curator/flink-shaded-curator-recipes/pom.xml +++ b/flink-shaded-curator/flink-shaded-curator-recipes/pom.xml @@ -41,13 +41,6 @@ under the License. curator-recipes ${curator.version} - - - - com.google.guava - guava - ${guava.version} - @@ -69,6 +62,16 @@ under the License. org.apache.curator:* + + + com.google + org.apache.flink.curator.shaded.com.google + + com.google.protobuf.** + com.google.inject.** + + + diff --git a/flink-shaded-curator/flink-shaded-curator-test/pom.xml b/flink-shaded-curator/flink-shaded-curator-test/pom.xml index 751b590efe0c8..2a181621e1628 100644 --- a/flink-shaded-curator/flink-shaded-curator-test/pom.xml +++ b/flink-shaded-curator/flink-shaded-curator-test/pom.xml @@ -69,6 +69,7 @@ org.apache.curator:curator-test + com.google.guava:guava @@ -76,6 +77,14 @@ org.apache.curator org.apache.flink.shaded.org.apache.curator + + com.google + org.apache.flink.curator.shaded.com.google + + com.google.protobuf.** + com.google.inject.** + + diff --git a/flink-streaming-java/pom.xml b/flink-streaming-java/pom.xml index aefee5d3d425b..2683546274e2d 100644 --- a/flink-streaming-java/pom.xml +++ b/flink-streaming-java/pom.xml @@ -56,6 +56,11 @@ under the License. ${project.version} + + org.apache.flink + flink-shaded-guava + + org.apache.commons commons-math3 @@ -68,12 +73,6 @@ under the License. 2.0.6 - - com.google.guava - guava - ${guava.version} - - diff --git a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java b/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java deleted file mode 100644 index cfaa4b12185fd..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.streaming.api.graph; - -import org.apache.flink.streaming.api.graph.StreamEdge; -import org.apache.flink.streaming.api.graph.StreamGraph; -import org.apache.flink.streaming.api.graph.StreamGraphHasher; -import org.apache.flink.streaming.api.graph.StreamNode; -import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; -import org.apache.flink.streaming.api.operators.ChainingStrategy; -import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; - -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.nio.charset.Charset; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.Set; - -import static org.apache.flink.util.StringUtils.byteToHexString; - -/** - * StreamGraphHasher from Flink 1.1. This contains duplicated code to ensure that the algorithm does not change with - * future Flink versions. - * - *

DO NOT MODIFY THIS CLASS - */ -public class StreamGraphHasherV1 implements StreamGraphHasher { - - private static final Logger LOG = LoggerFactory.getLogger(StreamGraphHasherV1.class); - - @Override - public Map traverseStreamGraphAndGenerateHashes(StreamGraph streamGraph) { - // The hash function used to generate the hash - final HashFunction hashFunction = Hashing.murmur3_128(0); - final Map hashes = new HashMap<>(); - - Set visited = new HashSet<>(); - Queue remaining = new ArrayDeque<>(); - - // We need to make the source order deterministic. The source IDs are - // not returned in the same order, which means that submitting the same - // program twice might result in different traversal, which breaks the - // deterministic hash assignment. - List sources = new ArrayList<>(); - for (Integer sourceNodeId : streamGraph.getSourceIDs()) { - sources.add(sourceNodeId); - } - Collections.sort(sources); - - // - // Traverse the graph in a breadth-first manner. Keep in mind that - // the graph is not a tree and multiple paths to nodes can exist. - // - - // Start with source nodes - for (Integer sourceNodeId : sources) { - remaining.add(streamGraph.getStreamNode(sourceNodeId)); - visited.add(sourceNodeId); - } - - StreamNode currentNode; - while ((currentNode = remaining.poll()) != null) { - // Generate the hash code. Because multiple path exist to each - // node, we might not have all required inputs available to - // generate the hash code. - if (generateNodeHash(currentNode, hashFunction, hashes, streamGraph.isChainingEnabled())) { - // Add the child nodes - for (StreamEdge outEdge : currentNode.getOutEdges()) { - StreamNode child = outEdge.getTargetVertex(); - - if (!visited.contains(child.getId())) { - remaining.add(child); - visited.add(child.getId()); - } - } - } else { - // We will revisit this later. - visited.remove(currentNode.getId()); - } - } - - return hashes; - } - - /** - * Generates a hash for the node and returns whether the operation was - * successful. - * - * @param node The node to generate the hash for - * @param hashFunction The hash function to use - * @param hashes The current state of generated hashes - * @return true if the node hash has been generated. - * false, otherwise. If the operation is not successful, the - * hash needs be generated at a later point when all input is available. - * @throws IllegalStateException If node has user-specified hash and is - * intermediate node of a chain - */ - private boolean generateNodeHash( - StreamNode node, - HashFunction hashFunction, - Map hashes, - boolean isChainingEnabled) { - - // Check for user-specified ID - String userSpecifiedHash = node.getTransformationUID(); - - if (userSpecifiedHash == null) { - // Check that all input nodes have their hashes computed - for (StreamEdge inEdge : node.getInEdges()) { - // If the input node has not been visited yet, the current - // node will be visited again at a later point when all input - // nodes have been visited and their hashes set. - if (!hashes.containsKey(inEdge.getSourceId())) { - return false; - } - } - - Hasher hasher = hashFunction.newHasher(); - byte[] hash = generateDeterministicHash(node, hasher, hashes, isChainingEnabled); - - if (hashes.put(node.getId(), hash) != null) { - // Sanity check - throw new IllegalStateException("Unexpected state. Tried to add node hash " + - "twice. This is probably a bug in the JobGraph generator."); - } - - return true; - } else { - Hasher hasher = hashFunction.newHasher(); - byte[] hash = generateUserSpecifiedHash(node, hasher); - - for (byte[] previousHash : hashes.values()) { - if (Arrays.equals(previousHash, hash)) { - throw new IllegalArgumentException("Hash collision on user-specified ID. " + - "Most likely cause is a non-unique ID. Please check that all IDs " + - "specified via `uid(String)` are unique."); - } - } - - if (hashes.put(node.getId(), hash) != null) { - // Sanity check - throw new IllegalStateException("Unexpected state. Tried to add node hash " + - "twice. This is probably a bug in the JobGraph generator."); - } - - return true; - } - } - - /** - * Generates a hash from a user-specified ID. - */ - private byte[] generateUserSpecifiedHash(StreamNode node, Hasher hasher) { - hasher.putString(node.getTransformationUID(), Charset.forName("UTF-8")); - - return hasher.hash().asBytes(); - } - - /** - * Generates a deterministic hash from node-local properties and input and - * output edges. - */ - private byte[] generateDeterministicHash( - StreamNode node, - Hasher hasher, - Map hashes, - boolean isChainingEnabled) { - - // Include stream node to hash. We use the current size of the computed - // hashes as the ID. We cannot use the node's ID, because it is - // assigned from a static counter. This will result in two identical - // programs having different hashes. - generateNodeLocalHash(node, hasher, hashes.size()); - - // Include chained nodes to hash - for (StreamEdge outEdge : node.getOutEdges()) { - if (isChainable(outEdge, isChainingEnabled)) { - StreamNode chainedNode = outEdge.getTargetVertex(); - - // Use the hash size again, because the nodes are chained to - // this node. This does not add a hash for the chained nodes. - generateNodeLocalHash(chainedNode, hasher, hashes.size()); - } - } - - byte[] hash = hasher.hash().asBytes(); - - // Make sure that all input nodes have their hash set before entering - // this loop (calling this method). - for (StreamEdge inEdge : node.getInEdges()) { - byte[] otherHash = hashes.get(inEdge.getSourceId()); - - // Sanity check - if (otherHash == null) { - throw new IllegalStateException("Missing hash for input node " - + inEdge.getSourceVertex() + ". Cannot generate hash for " - + node + "."); - } - - for (int j = 0; j < hash.length; j++) { - hash[j] = (byte) (hash[j] * 37 ^ otherHash[j]); - } - } - - if (LOG.isDebugEnabled()) { - String udfClassName = ""; - if (node.getOperator() instanceof AbstractUdfStreamOperator) { - udfClassName = ((AbstractUdfStreamOperator) node.getOperator()) - .getUserFunction().getClass().getName(); - } - - LOG.debug("Generated hash '" + byteToHexString(hash) + "' for node " + - "'" + node.toString() + "' {id: " + node.getId() + ", " + - "parallelism: " + node.getParallelism() + ", " + - "user function: " + udfClassName + "}"); - } - - return hash; - } - - private boolean isChainable(StreamEdge edge, boolean isChainingEnabled) { - StreamNode upStreamVertex = edge.getSourceVertex(); - StreamNode downStreamVertex = edge.getTargetVertex(); - - StreamOperator headOperator = upStreamVertex.getOperator(); - StreamOperator outOperator = downStreamVertex.getOperator(); - - return downStreamVertex.getInEdges().size() == 1 - && outOperator != null - && headOperator != null - && upStreamVertex.isSameSlotSharingGroup(downStreamVertex) - && outOperator.getChainingStrategy() == ChainingStrategy.ALWAYS - && (headOperator.getChainingStrategy() == ChainingStrategy.HEAD || - headOperator.getChainingStrategy() == ChainingStrategy.ALWAYS) - && (edge.getPartitioner() instanceof ForwardPartitioner) - && upStreamVertex.getParallelism() == downStreamVertex.getParallelism() - && isChainingEnabled; - } - - private void generateNodeLocalHash(StreamNode node, Hasher hasher, int id) { - hasher.putInt(id); - - hasher.putInt(node.getParallelism()); - - if (node.getOperator() instanceof AbstractUdfStreamOperator) { - String udfClassName = ((AbstractUdfStreamOperator) node.getOperator()) - .getUserFunction().getClass().getName(); - - hasher.putString(udfClassName, Charset.forName("UTF-8")); - } - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/MultiplexingStreamRecordSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/MultiplexingStreamRecordSerializer.java deleted file mode 100644 index b1471b233e4ba..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/MultiplexingStreamRecordSerializer.java +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.streaming.runtime.streamrecord; - -import org.apache.flink.api.common.typeutils.CompatibilityResult; -import org.apache.flink.api.common.typeutils.CompatibilityUtil; -import org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeDeserializerAdapter; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.streamrecord.StreamElement; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; - -import java.io.IOException; - -import static java.util.Objects.requireNonNull; - -/** - * Legacy multiplexing {@link TypeSerializer} for stream records, watermarks and other stream - * elements. - */ -public class MultiplexingStreamRecordSerializer extends TypeSerializer { - - - private static final long serialVersionUID = 1L; - - private static final int TAG_REC_WITH_TIMESTAMP = 0; - private static final int TAG_REC_WITHOUT_TIMESTAMP = 1; - private static final int TAG_WATERMARK = 2; - - - private final TypeSerializer typeSerializer; - - public MultiplexingStreamRecordSerializer(TypeSerializer serializer) { - if (serializer instanceof MultiplexingStreamRecordSerializer || serializer instanceof StreamRecordSerializer) { - throw new RuntimeException("StreamRecordSerializer given to StreamRecordSerializer as value TypeSerializer: " + serializer); - } - this.typeSerializer = requireNonNull(serializer); - } - - public TypeSerializer getContainedTypeSerializer() { - return this.typeSerializer; - } - - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public MultiplexingStreamRecordSerializer duplicate() { - TypeSerializer copy = typeSerializer.duplicate(); - return (copy == typeSerializer) ? this : new MultiplexingStreamRecordSerializer(copy); - } - - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - @Override - public StreamRecord createInstance() { - return new StreamRecord(typeSerializer.createInstance()); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public StreamElement copy(StreamElement from) { - // we can reuse the timestamp since Instant is immutable - if (from.isRecord()) { - StreamRecord fromRecord = from.asRecord(); - return fromRecord.copy(typeSerializer.copy(fromRecord.getValue())); - } - else if (from.isWatermark()) { - // is immutable - return from; - } - else { - throw new RuntimeException(); - } - } - - @Override - public StreamElement copy(StreamElement from, StreamElement reuse) { - if (from.isRecord() && reuse.isRecord()) { - StreamRecord fromRecord = from.asRecord(); - StreamRecord reuseRecord = reuse.asRecord(); - - T valueCopy = typeSerializer.copy(fromRecord.getValue(), reuseRecord.getValue()); - fromRecord.copyTo(valueCopy, reuseRecord); - return reuse; - } - else if (from.isWatermark()) { - // is immutable - return from; - } - else { - throw new RuntimeException("Cannot copy " + from + " -> " + reuse); - } - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - int tag = source.readByte(); - target.write(tag); - - if (tag == TAG_REC_WITH_TIMESTAMP) { - // move timestamp - target.writeLong(source.readLong()); - typeSerializer.copy(source, target); - } - else if (tag == TAG_REC_WITHOUT_TIMESTAMP) { - typeSerializer.copy(source, target); - } - else if (tag == TAG_WATERMARK) { - target.writeLong(source.readLong()); - } - else { - throw new IOException("Corrupt stream, found tag: " + tag); - } - } - - @Override - public void serialize(StreamElement value, DataOutputView target) throws IOException { - if (value.isRecord()) { - StreamRecord record = value.asRecord(); - - if (record.hasTimestamp()) { - target.write(TAG_REC_WITH_TIMESTAMP); - target.writeLong(record.getTimestamp()); - } else { - target.write(TAG_REC_WITHOUT_TIMESTAMP); - } - typeSerializer.serialize(record.getValue(), target); - } - else if (value.isWatermark()) { - target.write(TAG_WATERMARK); - target.writeLong(value.asWatermark().getTimestamp()); - } - else { - throw new RuntimeException(); - } - } - - @Override - public StreamElement deserialize(DataInputView source) throws IOException { - int tag = source.readByte(); - if (tag == TAG_REC_WITH_TIMESTAMP) { - long timestamp = source.readLong(); - return new StreamRecord(typeSerializer.deserialize(source), timestamp); - } - else if (tag == TAG_REC_WITHOUT_TIMESTAMP) { - return new StreamRecord(typeSerializer.deserialize(source)); - } - else if (tag == TAG_WATERMARK) { - return new Watermark(source.readLong()); - } - else { - throw new IOException("Corrupt stream, found tag: " + tag); - } - } - - @Override - public StreamElement deserialize(StreamElement reuse, DataInputView source) throws IOException { - int tag = source.readByte(); - if (tag == TAG_REC_WITH_TIMESTAMP) { - long timestamp = source.readLong(); - T value = typeSerializer.deserialize(source); - StreamRecord reuseRecord = reuse.asRecord(); - reuseRecord.replace(value, timestamp); - return reuseRecord; - } - else if (tag == TAG_REC_WITHOUT_TIMESTAMP) { - T value = typeSerializer.deserialize(source); - StreamRecord reuseRecord = reuse.asRecord(); - reuseRecord.replace(value); - return reuseRecord; - } - else if (tag == TAG_WATERMARK) { - return new Watermark(source.readLong()); - } - else { - throw new IOException("Corrupt stream, found tag: " + tag); - } - } - - // -------------------------------------------------------------------------------------------- - // Serializer configuration snapshotting & compatibility - // -------------------------------------------------------------------------------------------- - - @Override - public MultiplexingStreamRecordSerializerConfigSnapshot snapshotConfiguration() { - return new MultiplexingStreamRecordSerializerConfigSnapshot<>(typeSerializer); - } - - @Override - public CompatibilityResult ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { - if (configSnapshot instanceof MultiplexingStreamRecordSerializerConfigSnapshot) { - Tuple2, TypeSerializerConfigSnapshot> previousTypeSerializerAndConfig = - ((MultiplexingStreamRecordSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); - - CompatibilityResult compatResult = CompatibilityUtil.resolveCompatibilityResult( - previousTypeSerializerAndConfig.f0, - UnloadableDummyTypeSerializer.class, - previousTypeSerializerAndConfig.f1, - typeSerializer); - - if (!compatResult.isRequiresMigration()) { - return CompatibilityResult.compatible(); - } else if (compatResult.getConvertDeserializer() != null) { - return CompatibilityResult.requiresMigration( - new MultiplexingStreamRecordSerializer<>( - new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer()))); - } - } - - return CompatibilityResult.requiresMigration(); - } - - /** - * Configuration snapshot specific to the {@link MultiplexingStreamRecordSerializer}. - */ - public static final class MultiplexingStreamRecordSerializerConfigSnapshot - extends CompositeTypeSerializerConfigSnapshot { - - private static final int VERSION = 1; - - /** This empty nullary constructor is required for deserializing the configuration. */ - public MultiplexingStreamRecordSerializerConfigSnapshot() {} - - public MultiplexingStreamRecordSerializerConfigSnapshot(TypeSerializer typeSerializer) { - super(typeSerializer); - } - - @Override - public int getVersion() { - return VERSION; - } - } - - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - @Override - public boolean equals(Object obj) { - if (obj instanceof MultiplexingStreamRecordSerializer) { - MultiplexingStreamRecordSerializer other = (MultiplexingStreamRecordSerializer) obj; - - return other.canEqual(this) && typeSerializer.equals(other.typeSerializer); - } else { - return false; - } - } - - @Override - public boolean canEqual(Object obj) { - return obj instanceof MultiplexingStreamRecordSerializer; - } - - @Override - public int hashCode() { - return typeSerializer.hashCode(); - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/StreamRecordSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/StreamRecordSerializer.java deleted file mode 100644 index e018ba0ec4595..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/runtime/streamrecord/StreamRecordSerializer.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUStreamRecordWARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.migration.streaming.runtime.streamrecord; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.CompatibilityResult; -import org.apache.flink.api.common.typeutils.CompatibilityUtil; -import org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeDeserializerAdapter; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.util.Preconditions; - -import java.io.IOException; - -/** - * Serializer for {@link StreamRecord}. This version ignores timestamps and only deals with - * the element. - * - *

{@link MultiplexingStreamRecordSerializer} is a version that deals with timestamps and also - * multiplexes {@link org.apache.flink.streaming.api.watermark.Watermark Watermarks} in the same - * stream with {@link StreamRecord StreamRecords}. - * - * @see MultiplexingStreamRecordSerializer - * - * @param The type of value in the {@link StreamRecord} - */ -@Internal -public final class StreamRecordSerializer extends TypeSerializer> { - - private static final long serialVersionUID = 1L; - - private final TypeSerializer typeSerializer; - - public StreamRecordSerializer(TypeSerializer serializer) { - if (serializer instanceof StreamRecordSerializer) { - throw new RuntimeException("StreamRecordSerializer given to StreamRecordSerializer as value TypeSerializer: " + serializer); - } - this.typeSerializer = Preconditions.checkNotNull(serializer); - } - - public TypeSerializer getContainedTypeSerializer() { - return this.typeSerializer; - } - - // ------------------------------------------------------------------------ - // General serializer and type utils - // ------------------------------------------------------------------------ - - @Override - public StreamRecordSerializer duplicate() { - TypeSerializer serializerCopy = typeSerializer.duplicate(); - return serializerCopy == typeSerializer ? this : new StreamRecordSerializer(serializerCopy); - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public int getLength() { - return typeSerializer.getLength(); - } - - // ------------------------------------------------------------------------ - // Type serialization, copying, instantiation - // ------------------------------------------------------------------------ - - @Override - public StreamRecord createInstance() { - try { - return new StreamRecord(typeSerializer.createInstance()); - } catch (Exception e) { - throw new RuntimeException("Cannot instantiate StreamRecord.", e); - } - } - - @Override - public StreamRecord copy(StreamRecord from) { - return from.copy(typeSerializer.copy(from.getValue())); - } - - @Override - public StreamRecord copy(StreamRecord from, StreamRecord reuse) { - from.copyTo(typeSerializer.copy(from.getValue(), reuse.getValue()), reuse); - return reuse; - } - - @Override - public void serialize(StreamRecord value, DataOutputView target) throws IOException { - typeSerializer.serialize(value.getValue(), target); - } - - @Override - public StreamRecord deserialize(DataInputView source) throws IOException { - return new StreamRecord(typeSerializer.deserialize(source)); - } - - @Override - public StreamRecord deserialize(StreamRecord reuse, DataInputView source) throws IOException { - T element = typeSerializer.deserialize(reuse.getValue(), source); - reuse.replace(element); - return reuse; - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - typeSerializer.copy(source, target); - } - - // ------------------------------------------------------------------------ - - @Override - public boolean equals(Object obj) { - if (obj instanceof StreamRecordSerializer) { - StreamRecordSerializer other = (StreamRecordSerializer) obj; - - return other.canEqual(this) && typeSerializer.equals(other.typeSerializer); - } else { - return false; - } - } - - @Override - public boolean canEqual(Object obj) { - return obj instanceof StreamRecordSerializer; - } - - @Override - public int hashCode() { - return typeSerializer.hashCode(); - } - - // -------------------------------------------------------------------------------------------- - // Serializer configuration snapshotting & compatibility - // -------------------------------------------------------------------------------------------- - - @Override - public StreamRecordSerializerConfigSnapshot snapshotConfiguration() { - return new StreamRecordSerializerConfigSnapshot<>(typeSerializer); - } - - @Override - public CompatibilityResult> ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) { - if (configSnapshot instanceof StreamRecordSerializerConfigSnapshot) { - Tuple2, TypeSerializerConfigSnapshot> previousTypeSerializerAndConfig = - ((StreamRecordSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); - - CompatibilityResult compatResult = CompatibilityUtil.resolveCompatibilityResult( - previousTypeSerializerAndConfig.f0, - UnloadableDummyTypeSerializer.class, - previousTypeSerializerAndConfig.f1, - typeSerializer); - - if (!compatResult.isRequiresMigration()) { - return CompatibilityResult.compatible(); - } else if (compatResult.getConvertDeserializer() != null) { - return CompatibilityResult.requiresMigration( - new StreamRecordSerializer<>( - new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer()))); - } - } - - return CompatibilityResult.requiresMigration(); - } - - /** - * Configuration snapshot specific to the {@link StreamRecordSerializer}. - */ - public static final class StreamRecordSerializerConfigSnapshot extends CompositeTypeSerializerConfigSnapshot { - - private static final int VERSION = 1; - - /** This empty nullary constructor is required for deserializing the configuration. */ - public StreamRecordSerializerConfigSnapshot() {} - - public StreamRecordSerializerConfigSnapshot(TypeSerializer typeSerializer) { - super(typeSerializer); - } - - @Override - public int getVersion() { - return VERSION; - } - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java deleted file mode 100644 index cb3c7cce2943e..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.checkpoint; - -import org.apache.flink.annotation.PublicEvolving; - -import java.io.Serializable; - -/** - * This method must be implemented by functions that have state that needs to be - * checkpointed. The functions get a call whenever a checkpoint should take place - * and return a snapshot of their state, which will be checkpointed. - * - *

Deprecation and Replacement

- * The short cut replacement for this interface is via {@link ListCheckpointed} and works - * as shown in the example below. The {@code ListCheckpointed} interface returns a list of - * elements ( - * - *

{@code
- * public class ExampleFunction implements MapFunction, ListCheckpointed {
- *
- *     private int count;
- *
- *     public List snapshotState(long checkpointId, long timestamp) throws Exception {
- *         return Collections.singletonList(this.count);
- *     }
- *
- *     public void restoreState(List state) throws Exception {
- *         this.value = state.isEmpty() ? 0 : state.get(0);
- *     }
- *
- *     public T map(T value) {
- *         count++;
- *         return value;
- *     }
- * }
- * }
- * - * @param The type of the operator state. - * - * @deprecated Please use {@link ListCheckpointed} as illustrated above, or - * {@link CheckpointedFunction} for more control over the checkpointing process. - */ -@Deprecated -@PublicEvolving -public interface Checkpointed extends CheckpointedRestoring { - - /** - * Gets the current state of the function of operator. The state must reflect the result of all - * prior invocations to this function. - * - * @param checkpointId The ID of the checkpoint. - * @param checkpointTimestamp The timestamp of the checkpoint, as derived by - * System.currentTimeMillis() on the JobManager. - * - * @return A snapshot of the operator state. - * - * @throws Exception Thrown if the creation of the state object failed. This causes the - * checkpoint to fail. The system may decide to fail the operation (and trigger - * recovery), or to discard this checkpoint attempt and to continue running - * and to try again with the next checkpoint attempt. - */ - T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception; -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedAsynchronously.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedAsynchronously.java deleted file mode 100644 index 5138b49c4e36b..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedAsynchronously.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.checkpoint; - -import org.apache.flink.annotation.PublicEvolving; - -import java.io.Serializable; - -/** - * This interface marks a function/operator as checkpointed similar to the - * {@link Checkpointed} interface, but gives the Flink framework the option to - * perform the checkpoint asynchronously. Note that asynchronous checkpointing for - * this interface has not been implemented. - * - *

Deprecation and Replacement

- * The shortcut replacement for this interface is via {@link ListCheckpointed} and works - * as shown in the example below. Please refer to the JavaDocs of {@link ListCheckpointed} for - * a more detailed description of how to use the new interface. - * - *

{@code
- * public class ExampleFunction implements MapFunction, ListCheckpointed {
- *
- *     private int count;
- *
- *     public List snapshotState(long checkpointId, long timestamp) throws Exception {
- *         return Collections.singletonList(this.count);
- *     }
- *
- *     public void restoreState(List state) throws Exception {
- *         this.value = state.isEmpty() ? 0 : state.get(0);
- *     }
- *
- *     public T map(T value) {
- *         count++;
- *         return value;
- *     }
- * }
- * }
- * - * @deprecated Please use {@link ListCheckpointed} and {@link CheckpointedFunction} instead, - * as illustrated in the example above. - */ -@Deprecated -@PublicEvolving -public interface CheckpointedAsynchronously extends Checkpointed {} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java deleted file mode 100644 index cfaa505f0847b..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.checkpoint; - -import org.apache.flink.annotation.PublicEvolving; - -import java.io.Serializable; - -/** - * This deprecated interface contains the methods for restoring from the legacy checkpointing mechanism of state. - * @param type of the restored state. - * - * @deprecated Please use {@link CheckpointedFunction} or {@link ListCheckpointed} after restoring your legacy state. - */ -@Deprecated -@PublicEvolving -public interface CheckpointedRestoring { - /** - * Restores the state of the function or operator to that of a previous checkpoint. - * This method is invoked when a function is executed as part of a recovery run. - * - *

Note that restoreState() is called before open(). - * - * @param state The state to be restored. - */ - void restoreState(T state) throws Exception; -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/LegacyWindowOperatorType.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/LegacyWindowOperatorType.java deleted file mode 100644 index bb6e4bc7b68eb..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/LegacyWindowOperatorType.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.datastream; - -/** - * For specifying what type of window operator was used to create the state - * that a {@link org.apache.flink.streaming.runtime.operators.windowing.WindowOperator} - * is restoring from. This is used to signal that state written using an aligned processing-time - * window operator should be restored. - */ -public enum LegacyWindowOperatorType { - - FAST_ACCUMULATING(true, false), - - FAST_AGGREGATING(false, true), - - NONE(false, false); - - // ------------------------------------------------------------------------ - - private final boolean fastAccumulating; - private final boolean fastAggregating; - - LegacyWindowOperatorType(boolean fastAccumulating, boolean fastAggregating) { - this.fastAccumulating = fastAccumulating; - this.fastAggregating = fastAggregating; - } - - public boolean isFastAccumulating() { - return fastAccumulating; - } - - public boolean isFastAggregating() { - return fastAggregating; - } - - @Override - public String toString() { - if (fastAccumulating) { - return "AccumulatingProcessingTimeWindowOperator"; - } else if (fastAggregating) { - return "AggregatingProcessingTimeWindowOperator"; - } else { - return "WindowOperator"; - } - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SplitStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SplitStream.java index 4be6b6e20d967..0beae32435d4c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SplitStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/SplitStream.java @@ -22,7 +22,7 @@ import org.apache.flink.streaming.api.transformations.SelectTransformation; import org.apache.flink.streaming.api.transformations.SplitTransformation; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; /** * The SplitStream represents an operator that has been split using an diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java index 348861f221a6e..f904a10fd7ac0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/WindowedStream.java @@ -23,7 +23,6 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.FoldFunction; -import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichFunction; import org.apache.flink.api.common.state.AggregatingStateDescriptor; @@ -50,19 +49,11 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.windowing.assigners.BaseAlignedWindowAssigner; import org.apache.flink.streaming.api.windowing.assigners.MergingWindowAssigner; -import org.apache.flink.streaming.api.windowing.assigners.SlidingAlignedProcessingTimeWindows; -import org.apache.flink.streaming.api.windowing.assigners.SlidingProcessingTimeWindows; -import org.apache.flink.streaming.api.windowing.assigners.TumblingAlignedProcessingTimeWindows; -import org.apache.flink.streaming.api.windowing.assigners.TumblingProcessingTimeWindows; import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner; import org.apache.flink.streaming.api.windowing.evictors.Evictor; import org.apache.flink.streaming.api.windowing.time.Time; -import org.apache.flink.streaming.api.windowing.triggers.ProcessingTimeTrigger; import org.apache.flink.streaming.api.windowing.triggers.Trigger; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.api.windowing.windows.Window; -import org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator; -import org.apache.flink.streaming.runtime.operators.windowing.AggregatingProcessingTimeWindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.EvictingWindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.WindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalAggregateProcessWindowFunction; @@ -227,33 +218,7 @@ public SingleOutputStreamOperator reduce(ReduceFunction function) { //clean the closure function = input.getExecutionEnvironment().clean(function); - - String callLocation = Utils.getCallLocationName(); - String udfName = "WindowedStream." + callLocation; - - SingleOutputStreamOperator result = createFastTimeOperatorIfValid(function, input.getType(), udfName); - if (result != null) { - return result; - } - - LegacyWindowOperatorType legacyOpType = getLegacyWindowType(function); - return reduce(function, new PassThroughWindowFunction(), legacyOpType); - } - - /** - * Applies the given window function to each window. The window function is called for each - * evaluation of the window for each key individually. The output of the window function is - * interpreted as a regular non-windowed stream. - * - *

Arriving data is incrementally aggregated using the given reducer. - * - * @param reduceFunction The reduce function that is used for incremental aggregation. - * @param function The window function. - * @return The data stream that is the result of applying the window function to the window. - */ - @PublicEvolving - public SingleOutputStreamOperator reduce(ReduceFunction reduceFunction, WindowFunction function) { - return reduce(reduceFunction, function, LegacyWindowOperatorType.NONE); + return reduce(function, new PassThroughWindowFunction()); } /** @@ -265,39 +230,15 @@ public SingleOutputStreamOperator reduce(ReduceFunction reduceFunction * * @param reduceFunction The reduce function that is used for incremental aggregation. * @param function The window function. - * @param resultType Type information for the result type of the window function * @return The data stream that is the result of applying the window function to the window. */ - @PublicEvolving public SingleOutputStreamOperator reduce( - ReduceFunction reduceFunction, - WindowFunction function, - TypeInformation resultType) { - return reduce(reduceFunction, function, resultType, LegacyWindowOperatorType.NONE); - } - - /** - * Applies the given window function to each window. The window function is called for each - * evaluation of the window for each key individually. The output of the window function is - * interpreted as a regular non-windowed stream. - * - *

Arriving data is incrementally aggregated using the given reducer. - * - * @param reduceFunction The reduce function that is used for incremental aggregation. - * @param function The window function. - * @param legacyWindowOpType When migrating from an older Flink version, this flag indicates - * the type of the previous operator whose state we inherit. - * @return The data stream that is the result of applying the window function to the window. - */ - private SingleOutputStreamOperator reduce( ReduceFunction reduceFunction, - WindowFunction function, - LegacyWindowOperatorType legacyWindowOpType) { + WindowFunction function) { TypeInformation inType = input.getType(); TypeInformation resultType = getWindowFunctionReturnType(function, inType); - - return reduce(reduceFunction, function, resultType, legacyWindowOpType); + return reduce(reduceFunction, function, resultType); } /** @@ -310,15 +251,12 @@ private SingleOutputStreamOperator reduce( * @param reduceFunction The reduce function that is used for incremental aggregation. * @param function The window function. * @param resultType Type information for the result type of the window function. - * @param legacyWindowOpType When migrating from an older Flink version, this flag indicates - * the type of the previous operator whose state we inherit. * @return The data stream that is the result of applying the window function to the window. */ - private SingleOutputStreamOperator reduce( + public SingleOutputStreamOperator reduce( ReduceFunction reduceFunction, WindowFunction function, - TypeInformation resultType, - LegacyWindowOperatorType legacyWindowOpType) { + TypeInformation resultType) { if (reduceFunction instanceof RichFunction) { throw new UnsupportedOperationException("ReduceFunction of reduce can not be a RichFunction."); @@ -374,8 +312,7 @@ private SingleOutputStreamOperator reduce( new InternalSingleValueWindowFunction<>(function), trigger, allowedLateness, - lateDataOutputTag, - legacyWindowOpType); + lateDataOutputTag); } return input.transform(opName, resultType, operator); @@ -1183,12 +1120,6 @@ private SingleOutputStreamOperator apply(InternalWindowFunction result = createFastTimeOperatorIfValid(function, resultType, udfName); - if (result != null) { - return result; - } - - LegacyWindowOperatorType legacyWindowOpType = getLegacyWindowType(function); String opName; KeySelector keySel = input.getKeySelector(); @@ -1231,8 +1162,7 @@ private SingleOutputStreamOperator apply(InternalWindowFunction aggregate(AggregationFunction aggregato return reduce(aggregator); } - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - private LegacyWindowOperatorType getLegacyWindowType(Function function) { - if (windowAssigner instanceof SlidingProcessingTimeWindows && trigger instanceof ProcessingTimeTrigger && evictor == null) { - if (function instanceof ReduceFunction) { - return LegacyWindowOperatorType.FAST_AGGREGATING; - } else if (function instanceof WindowFunction) { - return LegacyWindowOperatorType.FAST_ACCUMULATING; - } - } else if (windowAssigner instanceof TumblingProcessingTimeWindows && trigger instanceof ProcessingTimeTrigger && evictor == null) { - if (function instanceof ReduceFunction) { - return LegacyWindowOperatorType.FAST_AGGREGATING; - } else if (function instanceof WindowFunction) { - return LegacyWindowOperatorType.FAST_ACCUMULATING; - } - } - return LegacyWindowOperatorType.NONE; - } - - private SingleOutputStreamOperator createFastTimeOperatorIfValid( - ReduceFunction function, - TypeInformation resultType, - String functionName) { - - if (windowAssigner.getClass() == SlidingAlignedProcessingTimeWindows.class && trigger == null && evictor == null) { - SlidingAlignedProcessingTimeWindows timeWindows = (SlidingAlignedProcessingTimeWindows) windowAssigner; - final long windowLength = timeWindows.getSize(); - final long windowSlide = timeWindows.getSlide(); - - String opName = "Fast " + timeWindows + " of " + functionName; - - @SuppressWarnings("unchecked") - ReduceFunction reducer = (ReduceFunction) function; - - @SuppressWarnings("unchecked") - OneInputStreamOperator op = (OneInputStreamOperator) - new AggregatingProcessingTimeWindowOperator<>( - reducer, input.getKeySelector(), - input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), - input.getType().createSerializer(getExecutionEnvironment().getConfig()), - windowLength, windowSlide); - return input.transform(opName, resultType, op); - - } else if (windowAssigner.getClass() == TumblingAlignedProcessingTimeWindows.class && trigger == null && evictor == null) { - TumblingAlignedProcessingTimeWindows timeWindows = (TumblingAlignedProcessingTimeWindows) windowAssigner; - final long windowLength = timeWindows.getSize(); - final long windowSlide = timeWindows.getSize(); - - String opName = "Fast " + timeWindows + " of " + functionName; - - @SuppressWarnings("unchecked") - ReduceFunction reducer = (ReduceFunction) function; - - @SuppressWarnings("unchecked") - OneInputStreamOperator op = (OneInputStreamOperator) - new AggregatingProcessingTimeWindowOperator<>( - reducer, - input.getKeySelector(), - input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), - input.getType().createSerializer(getExecutionEnvironment().getConfig()), - windowLength, windowSlide); - return input.transform(opName, resultType, op); - } - - return null; - } - - private SingleOutputStreamOperator createFastTimeOperatorIfValid( - InternalWindowFunction, R, K, W> function, - TypeInformation resultType, - String functionName) { - - if (windowAssigner.getClass() == SlidingAlignedProcessingTimeWindows.class && trigger == null && evictor == null) { - SlidingAlignedProcessingTimeWindows timeWindows = (SlidingAlignedProcessingTimeWindows) windowAssigner; - final long windowLength = timeWindows.getSize(); - final long windowSlide = timeWindows.getSlide(); - - String opName = "Fast " + timeWindows + " of " + functionName; - - @SuppressWarnings("unchecked") - InternalWindowFunction, R, K, TimeWindow> timeWindowFunction = - (InternalWindowFunction, R, K, TimeWindow>) function; - - OneInputStreamOperator op = new AccumulatingProcessingTimeWindowOperator<>( - timeWindowFunction, input.getKeySelector(), - input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), - input.getType().createSerializer(getExecutionEnvironment().getConfig()), - windowLength, windowSlide); - return input.transform(opName, resultType, op); - } else if (windowAssigner.getClass() == TumblingAlignedProcessingTimeWindows.class && trigger == null && evictor == null) { - TumblingAlignedProcessingTimeWindows timeWindows = (TumblingAlignedProcessingTimeWindows) windowAssigner; - final long windowLength = timeWindows.getSize(); - final long windowSlide = timeWindows.getSize(); - - String opName = "Fast " + timeWindows + " of " + functionName; - - @SuppressWarnings("unchecked") - InternalWindowFunction, R, K, TimeWindow> timeWindowFunction = - (InternalWindowFunction, R, K, TimeWindow>) function; - - OneInputStreamOperator op = new AccumulatingProcessingTimeWindowOperator<>( - timeWindowFunction, input.getKeySelector(), - input.getKeyType().createSerializer(getExecutionEnvironment().getConfig()), - input.getType().createSerializer(getExecutionEnvironment().getConfig()), - windowLength, windowSlide); - return input.transform(opName, resultType, op); - } - - return null; - } - public StreamExecutionEnvironment getExecutionEnvironment() { return input.getExecutionEnvironment(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncFunction.java index 5bb4459ac11e8..2ac218dc4e3ed 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncFunction.java @@ -20,7 +20,6 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.functions.Function; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import java.io.Serializable; @@ -28,21 +27,21 @@ * A function to trigger Async I/O operation. * *

For each #asyncInvoke, an async io operation can be triggered, and once it has been done, - * the result can be collected by calling {@link AsyncCollector#collect}. For each async + * the result can be collected by calling {@link ResultFuture#complete}. For each async * operation, its context is stored in the operator immediately after invoking * #asyncInvoke, avoiding blocking for each stream input as long as the internal buffer is not full. * - *

{@link AsyncCollector} can be passed into callbacks or futures to collect the result data. + *

{@link ResultFuture} can be passed into callbacks or futures to collect the result data. * An error can also be propagate to the async IO operator by - * {@link AsyncCollector#collect(Throwable)}. + * {@link ResultFuture#completeExceptionally(Throwable)}. * *

Callback example usage: * *

{@code
  * public class HBaseAsyncFunc implements AsyncFunction {
  *
- *   public void asyncInvoke(String row, AsyncCollector collector) throws Exception {
- *     HBaseCallback cb = new HBaseCallback(collector);
+ *   public void asyncInvoke(String row, ResultFuture result) throws Exception {
+ *     HBaseCallback cb = new HBaseCallback(result);
  *     Get get = new Get(Bytes.toBytes(row));
  *     hbase.asyncGet(get, cb);
  *   }
@@ -54,16 +53,16 @@
  * 
{@code
  * public class HBaseAsyncFunc implements AsyncFunction {
  *
- *   public void asyncInvoke(String row, final AsyncCollector collector) throws Exception {
+ *   public void asyncInvoke(String row, final ResultFuture result) throws Exception {
  *     Get get = new Get(Bytes.toBytes(row));
  *     ListenableFuture future = hbase.asyncGet(get);
  *     Futures.addCallback(future, new FutureCallback() {
  *       public void onSuccess(Result result) {
  *         List ret = process(result);
- *         collector.collect(ret);
+ *         result.complete(ret);
  *       }
  *       public void onFailure(Throwable thrown) {
- *         collector.collect(thrown);
+ *         result.completeExceptionally(thrown);
  *       }
  *     });
  *   }
@@ -80,9 +79,9 @@ public interface AsyncFunction extends Function, Serializable {
 	 * Trigger async operation for each stream input.
 	 *
 	 * @param input element coming from an upstream task
-	 * @param collector to collect the result data
+	 * @param resultFuture to be completed with the result data
 	 * @exception Exception in case of a user code error. An exception will make the task fail and
 	 * trigger fail-over process.
 	 */
-	void asyncInvoke(IN input, AsyncCollector collector) throws Exception;
+	void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception;
 }
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/collector/AsyncCollector.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/ResultFuture.java
similarity index 76%
rename from flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/collector/AsyncCollector.java
rename to flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/ResultFuture.java
index 964c13ab4e8b2..934341e973606 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/collector/AsyncCollector.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/ResultFuture.java
@@ -16,21 +16,21 @@
  * limitations under the License.
  */
 
-package org.apache.flink.streaming.api.functions.async.collector;
+package org.apache.flink.streaming.api.functions.async;
 
 import org.apache.flink.annotation.PublicEvolving;
 
 import java.util.Collection;
 
 /**
- * {@link AsyncCollector} collects data / error in user codes while processing async i/o.
+ * {@link ResultFuture} collects data / error in user codes while processing async i/o.
  *
  * @param  Output type
  */
 @PublicEvolving
-public interface AsyncCollector {
+public interface ResultFuture {
 	/**
-	 * Set result.
+	 * Completes the result future with a collection of result objects.
 	 *
 	 * 

Note that it should be called for exactly one time in the user code. * Calling this function for multiple times will cause data lose. @@ -39,12 +39,12 @@ public interface AsyncCollector { * * @param result A list of results. */ - void collect(Collection result); + void complete(Collection result); /** - * Set error. + * Completes the result future exceptionally with an exception. * * @param error A Throwable object. */ - void collect(Throwable error); + void completeExceptionally(Throwable error); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java index 84f9e5338012f..b6ce862e126dc 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java @@ -43,7 +43,6 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import org.apache.flink.types.Value; import org.apache.flink.util.Preconditions; @@ -85,7 +84,7 @@ public void setRuntimeContext(RuntimeContext runtimeContext) { } @Override - public abstract void asyncInvoke(IN input, AsyncCollector collector) throws Exception; + public abstract void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception; // ----------------------------------------------------------------------------------------- // Wrapper classes diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java index 85ddc5c40b0b3..60409793ed4d7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java @@ -22,7 +22,6 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; @@ -38,6 +37,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static java.util.Objects.requireNonNull; import static org.apache.flink.util.Preconditions.checkState; @@ -50,23 +50,25 @@ * * @param Input type for {@link SinkFunction}. * @param Transaction to store all of the information required to handle a transaction. + * @param Context that will be shared across all invocations for the given {@link TwoPhaseCommitSinkFunction} + * instance. Context is created once */ @PublicEvolving -public abstract class TwoPhaseCommitSinkFunction +public abstract class TwoPhaseCommitSinkFunction extends RichSinkFunction implements CheckpointedFunction, CheckpointListener { private static final Logger LOG = LoggerFactory.getLogger(TwoPhaseCommitSinkFunction.class); - protected final ListStateDescriptor>> pendingCommitTransactionsDescriptor; - protected final ListStateDescriptor pendingTransactionsDescriptor; + protected final ListStateDescriptor> stateDescriptor; protected final LinkedHashMap pendingCommitTransactions = new LinkedHashMap<>(); @Nullable protected TXN currentTransaction; - protected ListState pendingTransactionsState; - protected ListState>> pendingCommitTransactionsState; + protected Optional userContext; + + protected ListState> state; /** * Use default {@link ListStateDescriptor} for internal state serialization. Helpful utilities for using this @@ -74,32 +76,30 @@ public abstract class TwoPhaseCommitSinkFunction * {@link TypeInformation#of(TypeHint)}. Example: *

 	 * {@code
-	 * TwoPhaseCommitSinkFunction(
-	 *     TypeInformation.of(TXN.class),
-	 *     TypeInformation.of(new TypeHint>() {}));
+	 * TwoPhaseCommitSinkFunction(TypeInformation.of(new TypeHint>() {}));
 	 * }
 	 * 
- * @param txnTypeInformation {@link TypeInformation} for transaction POJO. - * @param checkpointToTxnTypeInformation {@link TypeInformation} for mapping between checkpointId and transaction. + * @param stateTypeInformation {@link TypeInformation} for POJO holding state of opened transactions. */ - public TwoPhaseCommitSinkFunction( - TypeInformation txnTypeInformation, - TypeInformation>> checkpointToTxnTypeInformation) { - this(new ListStateDescriptor<>("pendingTransactions", txnTypeInformation), - new ListStateDescriptor<>("pendingCommitTransactions", checkpointToTxnTypeInformation)); + public TwoPhaseCommitSinkFunction(TypeInformation> stateTypeInformation) { + this(new ListStateDescriptor>("state", stateTypeInformation)); } /** * Instantiate {@link TwoPhaseCommitSinkFunction} with custom state descriptors. * - * @param pendingTransactionsDescriptor descriptor for transaction POJO. - * @param pendingCommitTransactionsDescriptor descriptor for mapping between checkpointId and transaction POJO. + * @param stateDescriptor descriptor for transactions POJO. */ - public TwoPhaseCommitSinkFunction( - ListStateDescriptor pendingTransactionsDescriptor, - ListStateDescriptor>> pendingCommitTransactionsDescriptor) { - this.pendingTransactionsDescriptor = requireNonNull(pendingTransactionsDescriptor, "pendingTransactionsDescriptor is null"); - this.pendingCommitTransactionsDescriptor = requireNonNull(pendingCommitTransactionsDescriptor, "pendingCommitTransactionsDescriptor is null"); + public TwoPhaseCommitSinkFunction(ListStateDescriptor> stateDescriptor) { + this.stateDescriptor = requireNonNull(stateDescriptor, "stateDescriptor is null"); + } + + protected Optional initializeUserContext() { + return Optional.empty(); + } + + protected Optional getUserContext() { + return userContext; } // ------ methods that should be implemented in child class to support two phase commit algorithm ------ @@ -154,6 +154,9 @@ protected void recoverAndAbort(TXN transaction) { abort(transaction); } + protected void finishRecoveringContext() { + } + // ------ entry points for above methods implementing {@CheckPointedFunction} and {@CheckpointListener} ------ @Override @@ -196,11 +199,11 @@ public final void notifyCheckpointComplete(long checkpointId) throws Exception { // ==> There should never be a case where we have no pending transaction here // - Iterator> pendingTransactionsIterator = pendingCommitTransactions.entrySet().iterator(); - checkState(pendingTransactionsIterator.hasNext(), "checkpoint completed, but no transaction pending"); + Iterator> pendingTransactionIterator = pendingCommitTransactions.entrySet().iterator(); + checkState(pendingTransactionIterator.hasNext(), "checkpoint completed, but no transaction pending"); - while (pendingTransactionsIterator.hasNext()) { - Map.Entry entry = pendingTransactionsIterator.next(); + while (pendingTransactionIterator.hasNext()) { + Map.Entry entry = pendingTransactionIterator.next(); Long pendingTransactionCheckpointId = entry.getKey(); TXN pendingTransaction = entry.getValue(); if (pendingTransactionCheckpointId > checkpointId) { @@ -214,12 +217,12 @@ public final void notifyCheckpointComplete(long checkpointId) throws Exception { LOG.debug("{} - committed checkpoint transaction {}", name(), pendingTransaction); - pendingTransactionsIterator.remove(); + pendingTransactionIterator.remove(); } } @Override - public final void snapshotState(FunctionSnapshotContext context) throws Exception { + public void snapshotState(FunctionSnapshotContext context) throws Exception { // this is like the pre-commit of a 2-phase-commit transaction // we are ready to commit and remember the transaction @@ -235,17 +238,15 @@ public final void snapshotState(FunctionSnapshotContext context) throws Exceptio currentTransaction = beginTransaction(); LOG.debug("{} - started new transaction '{}'", name(), currentTransaction); - pendingCommitTransactionsState.clear(); - pendingCommitTransactionsState.add(toTuple2List(pendingCommitTransactions)); - - pendingTransactionsState.clear(); - // in case of failure we might not be able to abort currentTransaction. Let's store it into the state - // so it can be aborted after a restart/crash - pendingTransactionsState.add(currentTransaction); + state.clear(); + state.add(new State<>( + this.currentTransaction, + new ArrayList<>(pendingCommitTransactions.values()), + userContext)); } @Override - public final void initializeState(FunctionInitializationContext context) throws Exception { + public void initializeState(FunctionInitializationContext context) throws Exception { // when we are restoring state with pendingCommitTransactions, we don't really know whether the // transactions were already committed, or whether there was a failure between // completing the checkpoint on the master, and notifying the writer here. @@ -260,27 +261,33 @@ public final void initializeState(FunctionInitializationContext context) throws // we can have more than one transaction to check in case of a scale-in event, or // for the reasons discussed in the 'notifyCheckpointComplete()' method. - pendingTransactionsState = context.getOperatorStateStore().getListState(pendingTransactionsDescriptor); - pendingCommitTransactionsState = context.getOperatorStateStore().getListState(pendingCommitTransactionsDescriptor); + state = context.getOperatorStateStore().getListState(stateDescriptor); if (context.isRestored()) { LOG.info("{} - restoring state", name()); - for (List> recoveredTransactions : pendingCommitTransactionsState.get()) { - for (Tuple2 recoveredTransaction : recoveredTransactions) { + for (State operatorState : state.get()) { + userContext = operatorState.getContext(); + List recoveredTransactions = operatorState.getPendingCommitTransactions(); + for (TXN recoveredTransaction : recoveredTransactions) { // If this fails, there is actually a data loss - recoverAndCommit(recoveredTransaction.f1); + recoverAndCommit(recoveredTransaction); LOG.info("{} committed recovered transaction {}", name(), recoveredTransaction); } - } - // Explicitly abort transactions that could be not closed cleanly - for (TXN pendingTransaction : pendingTransactionsState.get()) { - recoverAndAbort(pendingTransaction); - LOG.info("{} aborted recovered transaction {}", name(), pendingTransaction); + recoverAndAbort(operatorState.getPendingTransaction()); + LOG.info("{} aborted recovered transaction {}", name(), operatorState.getPendingTransaction()); + + if (userContext.isPresent()) { + finishRecoveringContext(); + } } - } else { - LOG.info("{} - no state to restore {}", name()); + } + // if in restore we didn't get any userContext or we are initializing from scratch + if (userContext == null) { + LOG.info("{} - no state to restore", name()); + + userContext = initializeUserContext(); } this.pendingCommitTransactions.clear(); @@ -306,11 +313,45 @@ private String name() { getRuntimeContext().getNumberOfParallelSubtasks()); } - private List> toTuple2List(LinkedHashMap transactions) { - List> result = new ArrayList<>(); - for (Map.Entry entry : transactions.entrySet()) { - result.add(Tuple2.of(entry.getKey(), entry.getValue())); + /** + * State POJO class coupling pendingTransaction, context and pendingCommitTransactions. + */ + public static class State { + protected TXN pendingTransaction; + protected List pendingCommitTransactions = new ArrayList<>(); + protected Optional context; + + public State() { + } + + public State(TXN pendingTransaction, List pendingCommitTransactions, Optional context) { + this.context = requireNonNull(context, "context is null"); + this.pendingTransaction = requireNonNull(pendingTransaction, "pendingTransaction is null"); + this.pendingCommitTransactions = requireNonNull(pendingCommitTransactions, "pendingCommitTransactions is null"); + } + + public TXN getPendingTransaction() { + return pendingTransaction; + } + + public void setPendingTransaction(TXN pendingTransaction) { + this.pendingTransaction = pendingTransaction; + } + + public List getPendingCommitTransactions() { + return pendingCommitTransactions; + } + + public void setPendingCommitTransactions(List pendingCommitTransactions) { + this.pendingCommitTransactions = pendingCommitTransactions; + } + + public Optional getContext() { + return context; + } + + public void setContext(Optional context) { + this.context = context; } - return result; } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java index 3c4cfbd0c9194..fedd791fc292e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java @@ -32,7 +32,6 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -66,7 +65,7 @@ */ @Internal public class ContinuousFileMonitoringFunction - extends RichSourceFunction implements CheckpointedFunction, CheckpointedRestoring { + extends RichSourceFunction implements CheckpointedFunction { private static final long serialVersionUID = 1L; @@ -375,12 +374,4 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { LOG.debug("{} checkpointed {}.", getClass().getSimpleName(), globalModificationTime); } } - - @Override - public void restoreState(Long state) throws Exception { - this.globalModificationTime = state; - - LOG.info("{} (taskIdx={}) restored global modification time from an older Flink version: {}", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), globalModificationTime); - } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java index 3a9e8e1f60c9a..78e181a978b3c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java @@ -25,30 +25,23 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FileInputSplit; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.CheckpointedRestoringOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.OutputTypeConfigurable; import org.apache.flink.streaming.api.operators.StreamSourceContexts; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.Serializable; import java.util.ArrayList; -import java.util.LinkedList; import java.util.List; import java.util.PriorityQueue; import java.util.Queue; @@ -60,15 +53,15 @@ * The operator that reads the {@link TimestampedFileInputSplit splits} received from the preceding * {@link ContinuousFileMonitoringFunction}. Contrary to the {@link ContinuousFileMonitoringFunction} * which has a parallelism of 1, this operator can have DOP > 1. - *

- * As soon as a split descriptor is received, it is put in a queue, and have another + * + *

As soon as a split descriptor is received, it is put in a queue, and have another * thread read the actual data of the split. This architecture allows the separation of the * reading thread from the one emitting the checkpoint barriers, thus removing any potential * back-pressure. */ @Internal public class ContinuousFileReaderOperator extends AbstractStreamOperator - implements OneInputStreamOperator, OutputTypeConfigurable, CheckpointedRestoringOperator { + implements OneInputStreamOperator, OutputTypeConfigurable { private static final long serialVersionUID = 1L; @@ -203,11 +196,14 @@ public void dispose() throws Exception { public void close() throws Exception { super.close(); + // make sure that we hold the checkpointing lock + Thread.holdsLock(checkpointLock); + // close the reader to signal that no more splits will come. By doing this, // the reader will exit as soon as it finishes processing the already pending splits. // This method will wait until then. Further cleaning up is handled by the dispose(). - if (reader != null && reader.isAlive() && reader.isRunning()) { + while (reader != null && reader.isAlive() && reader.isRunning()) { reader.close(); checkpointLock.wait(); } @@ -422,83 +418,4 @@ public void snapshotState(StateSnapshotContext context) throws Exception { getClass().getSimpleName(), subtaskIdx, readerState.size(), readerState); } } - - // ------------------------------------------------------------------------ - // Restoring / Migrating from an older Flink version. - // ------------------------------------------------------------------------ - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - - LOG.info("{} (taskIdx={}) restoring state from an older Flink version.", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask()); - - // this is just to read the byte indicating if we have udf state or not - int hasUdfState = in.read(); - - Preconditions.checkArgument(hasUdfState == 0); - - final ObjectInputStream ois = new ObjectInputStream(in); - final DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(in); - - // read the split that was being read - FileInputSplit currSplit = (FileInputSplit) ois.readObject(); - - // read the pending splits list - List pendingSplits = new LinkedList<>(); - int noOfSplits = div.readInt(); - for (int i = 0; i < noOfSplits; i++) { - FileInputSplit split = (FileInputSplit) ois.readObject(); - pendingSplits.add(split); - } - - // read the state of the format - Serializable formatState = (Serializable) ois.readObject(); - - div.close(); - - if (restoredReaderState == null) { - restoredReaderState = new ArrayList<>(); - } - - // we do not know the modification time of the retrieved splits, so we assign them - // artificial ones, with the only constraint that they respect the relative order of the - // retrieved splits, because modification time is going to be used to sort the splits within - // the "pending splits" priority queue. - - long now = getProcessingTimeService().getCurrentProcessingTime(); - long runningModTime = Math.max(now, noOfSplits + 1); - - TimestampedFileInputSplit currentSplit = createTimestampedFileSplit(currSplit, --runningModTime, formatState); - restoredReaderState.add(currentSplit); - for (FileInputSplit split : pendingSplits) { - TimestampedFileInputSplit timestampedSplit = createTimestampedFileSplit(split, --runningModTime); - restoredReaderState.add(timestampedSplit); - } - - if (LOG.isDebugEnabled()) { - if (LOG.isDebugEnabled()) { - LOG.debug("{} (taskIdx={}) restored {} splits from legacy: {}.", - getClass().getSimpleName(), - getRuntimeContext().getIndexOfThisSubtask(), - restoredReaderState.size(), - restoredReaderState); - } - } - } - - private TimestampedFileInputSplit createTimestampedFileSplit(FileInputSplit split, long modificationTime) { - return createTimestampedFileSplit(split, modificationTime, null); - } - - private TimestampedFileInputSplit createTimestampedFileSplit(FileInputSplit split, long modificationTime, Serializable state) { - TimestampedFileInputSplit timestampedSplit = new TimestampedFileInputSplit( - modificationTime, split.getSplitNumber(), split.getPath(), - split.getStart(), split.getLength(), split.getHostnames()); - - if (state != null) { - timestampedSplit.setSplitState(state); - } - return timestampedSplit; - } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java index ab21586e440ba..604755d456df8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java @@ -96,13 +96,13 @@ public abstract class MessageAcknowledgingSourceBase private final TypeSerializer idSerializer; /** The list gathering the IDs of messages emitted during the current checkpoint. */ - private transient List idsForCurrentCheckpoint; + private transient Set idsForCurrentCheckpoint; /** * The list with IDs from checkpoints that were triggered, but not yet completed or notified of * completion. */ - protected transient ArrayDeque>> pendingCheckpoints; + protected transient ArrayDeque>> pendingCheckpoints; /** * Set which contain all processed ids. Ids are acknowledged after checkpoints. When restoring @@ -142,7 +142,7 @@ public void initializeState(FunctionInitializationContext context) throws Except .getOperatorStateStore() .getSerializableListState("message-acknowledging-source-state"); - this.idsForCurrentCheckpoint = new ArrayList<>(64); + this.idsForCurrentCheckpoint = new HashSet<>(64); this.pendingCheckpoints = new ArrayDeque<>(); this.idsProcessedButNotAcknowledged = new HashSet<>(); @@ -161,7 +161,7 @@ public void initializeState(FunctionInitializationContext context) throws Except pendingCheckpoints = SerializedCheckpointData.toDeque(retrievedStates.get(0), idSerializer); // build a set which contains all processed ids. It may be used to check if we have // already processed an incoming message. - for (Tuple2> checkpoint : pendingCheckpoints) { + for (Tuple2> checkpoint : pendingCheckpoints) { idsProcessedButNotAcknowledged.addAll(checkpoint.f1); } } else { @@ -185,7 +185,7 @@ public void close() throws Exception { * * @param uIds The list od IDs to acknowledge. */ - protected abstract void acknowledgeIDs(long checkpointId, List uIds); + protected abstract void acknowledgeIDs(long checkpointId, Set uIds); /** * Adds an ID to be stored with the current checkpoint. @@ -213,7 +213,7 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { } pendingCheckpoints.addLast(new Tuple2<>(context.getCheckpointId(), idsForCurrentCheckpoint)); - idsForCurrentCheckpoint = new ArrayList<>(64); + idsForCurrentCheckpoint = new HashSet<>(64); this.checkpointedState.clear(); this.checkpointedState.add(SerializedCheckpointData.fromDeque(pendingCheckpoints, idSerializer)); @@ -223,8 +223,8 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { public void notifyCheckpointComplete(long checkpointId) throws Exception { LOG.debug("Committing Messages externally for checkpoint {}", checkpointId); - for (Iterator>> iter = pendingCheckpoints.iterator(); iter.hasNext();) { - Tuple2> checkpoint = iter.next(); + for (Iterator>> iter = pendingCheckpoints.iterator(); iter.hasNext();) { + Tuple2> checkpoint = iter.next(); long id = checkpoint.f0; if (id <= checkpointId) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java index e7cdb99c24c3c..d0c0741ec74f7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java @@ -32,6 +32,7 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.Set; /** * Abstract base class for data sources that receive elements from a message queue and @@ -110,7 +111,8 @@ public void close() throws Exception { * means of de-duplicating messages when the acknowledgment after a checkpoint * fails. */ - protected final void acknowledgeIDs(long checkpointId, List uniqueIds) { + @Override + protected final void acknowledgeIDs(long checkpointId, Set uniqueIds) { LOG.debug("Acknowledging ids for checkpoint {}", checkpointId); Iterator>> iterator = sessionIdsPerSnapshot.iterator(); while (iterator.hasNext()) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java index 77caa34d11fbf..13100db01d837 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.util.CorruptConfigurationException; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.util.ClassLoaderUtil; @@ -76,6 +77,7 @@ public class StreamConfig implements Serializable { private static final String OUT_STREAM_EDGES = "outStreamEdges"; private static final String IN_STREAM_EDGES = "inStreamEdges"; private static final String OPERATOR_NAME = "operatorName"; + private static final String OPERATOR_ID = "operatorID"; private static final String CHAIN_END = "chainEnd"; private static final String CHECKPOINTING_ENABLED = "checkpointing"; @@ -213,7 +215,7 @@ public void setStreamOperator(StreamOperator operator) { } } - public T getStreamOperator(ClassLoader cl) { + public > T getStreamOperator(ClassLoader cl) { try { return InstantiationUtil.readObjectFromConfig(this.config, SERIALIZEDUDF, cl); } @@ -411,6 +413,15 @@ public Map getTransitiveChainedTaskConfigs(ClassLoader cl } } + public void setOperatorID(OperatorID operatorID) { + this.config.setBytes(OPERATOR_ID, operatorID.getBytes()); + } + + public OperatorID getOperatorID() { + byte[] operatorIDBytes = config.getBytes(OPERATOR_ID, null); + return new OperatorID(Preconditions.checkNotNull(operatorIDBytes)); + } + public void setOperatorName(String name) { this.config.setString(OPERATOR_NAME, name); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java index bb9e47b6c6a70..9bbcec0c09198 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java @@ -25,9 +25,10 @@ import org.apache.flink.streaming.api.transformations.StreamTransformation; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; +import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction; +import org.apache.flink.shaded.guava18.com.google.common.hash.Hasher; +import org.apache.flink.shaded.guava18.com.google.common.hash.Hashing; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index e70962b9b2623..884b899116378 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -26,7 +26,6 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; -import org.apache.flink.migration.streaming.api.graph.StreamGraphHasherV1; import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; @@ -116,7 +115,7 @@ public static JobGraph createJobGraph(StreamGraph streamGraph) { private StreamingJobGraphGenerator(StreamGraph streamGraph) { this.streamGraph = streamGraph; this.defaultStreamGraphHasher = new StreamGraphHasherV2(); - this.legacyStreamGraphHashers = Arrays.asList(new StreamGraphHasherV1(), new StreamGraphUserHashHasher()); + this.legacyStreamGraphHashers = Arrays.asList(new StreamGraphUserHashHasher()); this.jobVertices = new HashMap<>(); this.builtVertices = new HashSet<>(); @@ -241,12 +240,14 @@ private List createChain( createChain(nonChainable.getTargetId(), nonChainable.getTargetId(), hashes, legacyHashes, 0, chainedOperatorHashes); } - List> operatorHashes = chainedOperatorHashes.get(startNodeId); - if (operatorHashes == null) { - operatorHashes = new ArrayList<>(); - chainedOperatorHashes.put(startNodeId, operatorHashes); + List> operatorHashes = + chainedOperatorHashes.computeIfAbsent(startNodeId, k -> new ArrayList<>()); + + byte[] primaryHashBytes = hashes.get(currentNodeId); + + for (Map legacyHash : legacyHashes) { + operatorHashes.add(new Tuple2<>(primaryHashBytes, legacyHash.get(currentNodeId))); } - operatorHashes.add(new Tuple2<>(hashes.get(currentNodeId), legacyHashes.get(1).get(currentNodeId))); chainedNames.put(currentNodeId, createChainedName(currentNodeId, chainableOutputs)); chainedMinResources.put(currentNodeId, createChainedMinResources(currentNodeId, chainableOutputs)); @@ -280,13 +281,16 @@ private List createChain( chainedConfigs.put(startNodeId, new HashMap()); } config.setChainIndex(chainIndex); - config.setOperatorName(streamGraph.getStreamNode(currentNodeId).getOperatorName()); + StreamNode node = streamGraph.getStreamNode(currentNodeId); + config.setOperatorName(node.getOperatorName()); chainedConfigs.get(startNodeId).put(currentNodeId, config); } + + config.setOperatorID(new OperatorID(primaryHashBytes)); + if (chainableOutputs.isEmpty()) { config.setChainEnd(); } - return transitiveOutEdges; } else { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index d711518e49f4c..fc043a8045d5c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -28,7 +28,6 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.MetricOptions; -import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.metrics.Counter; @@ -36,6 +35,8 @@ import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; @@ -53,14 +54,12 @@ import org.apache.flink.runtime.state.StateInitializationContextImpl; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.OutputTag; @@ -161,7 +160,7 @@ public abstract class AbstractStreamOperator // ---------------- time handler ------------------ - private transient InternalTimeServiceManager timeServiceManager; + protected transient InternalTimeServiceManager timeServiceManager; // ---------------- two-input operator watermarks ------------------ @@ -179,7 +178,6 @@ public abstract class AbstractStreamOperator public void setup(StreamTask containingTask, StreamConfig config, Output> output) { this.container = containingTask; this.config = config; - this.metrics = container.getEnvironment().getMetricGroup().addOperator(config.getOperatorName()); this.output = new CountingOutput(output, ((OperatorMetricGroup) this.metrics).getIOMetricGroup().getNumRecordsOutCounter()); if (config.isChainStart()) { @@ -208,13 +206,13 @@ public MetricGroup getMetricGroup() { } @Override - public final void initializeState(OperatorStateHandles stateHandles) throws Exception { + public final void initializeState(OperatorSubtaskState stateHandles) throws Exception { Collection keyedStateHandlesRaw = null; Collection operatorStateHandlesRaw = null; Collection operatorStateHandlesBackend = null; - boolean restoring = null != stateHandles; + boolean restoring = (null != stateHandles); initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class @@ -251,42 +249,6 @@ public final void initializeState(OperatorStateHandles stateHandles) throws Exce getContainingTask().getCancelables()); // access to register streams for canceling initializeState(initializationContext); - - if (restoring) { - - // finally restore the legacy state in case we are - // migrating from a previous Flink version. - - restoreStreamCheckpointed(stateHandles); - } - } - - /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception { - StreamStateHandle state = stateHandles.getLegacyOperatorState(); - if (null != state) { - if (this instanceof CheckpointedRestoringOperator) { - - LOG.debug("Restore state of task {} in chain ({}).", - stateHandles.getOperatorChainIndex(), getContainingTask().getName()); - - FSDataInputStream is = state.openInputStream(); - try { - getContainingTask().getCancelables().registerClosable(is); - ((CheckpointedRestoringOperator) this).restoreState(is); - } finally { - getContainingTask().getCancelables().unregisterClosable(is); - is.close(); - } - } else { - throw new Exception( - "Found legacy operator state for operator that does not implement StreamCheckpointedOperator."); - } - } } /** @@ -450,35 +412,6 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } } - /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @SuppressWarnings("deprecation") - @Deprecated - @Override - public StreamStateHandle snapshotLegacyOperatorState(long checkpointId, long timestamp, CheckpointOptions checkpointOptions) throws Exception { - if (this instanceof StreamCheckpointedOperator) { - CheckpointStreamFactory factory = getCheckpointStreamFactory(checkpointOptions); - - final CheckpointStreamFactory.CheckpointStateOutputStream outStream = - factory.createCheckpointStateOutputStream(checkpointId, timestamp); - - getContainingTask().getCancelables().registerClosable(outStream); - - try { - ((StreamCheckpointedOperator) this).snapshotState(outStream, checkpointId, timestamp); - return outStream.closeAndGetHandle(); - } - finally { - getContainingTask().getCancelables().unregisterClosable(outStream); - outStream.close(); - } - } else { - return null; - } - } - /** * Stream operators with state which can be restored need to override this hook method. * @@ -973,6 +906,11 @@ public void processWatermark2(Watermark mark) throws Exception { } } + @Override + public OperatorID getOperatorID() { + return config.getOperatorID(); + } + @VisibleForTesting public int numProcessingTimeTimers() { return timeServiceManager == null ? 0 : diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index efbebf4056ee8..329ce183ce5d0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -24,23 +24,15 @@ import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.util.functions.StreamingFunctionUtils; -import org.apache.flink.util.InstantiationUtil; -import org.apache.flink.util.Migration; - -import java.io.Serializable; import static java.util.Objects.requireNonNull; @@ -57,8 +49,7 @@ @PublicEvolving public abstract class AbstractUdfStreamOperator extends AbstractStreamOperator - implements OutputTypeConfigurable, - StreamCheckpointedOperator { + implements OutputTypeConfigurable { private static final long serialVersionUID = 1L; @@ -131,59 +122,6 @@ public void dispose() throws Exception { // checkpointing and recovery // ------------------------------------------------------------------------ - @Override - public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - if (userFunction instanceof Checkpointed) { - @SuppressWarnings("unchecked") - Checkpointed chkFunction = (Checkpointed) userFunction; - - Serializable udfState; - try { - udfState = chkFunction.snapshotState(checkpointId, timestamp); - if (udfState != null) { - out.write(1); - InstantiationUtil.serializeObject(out, udfState); - } else { - out.write(0); - } - } catch (Exception e) { - throw new Exception("Failed to draw state snapshot from function: " + e.getMessage(), e); - } - } - } - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - boolean haveReadUdfStateFlag = false; - if (userFunction instanceof Checkpointed || - (userFunction instanceof CheckpointedRestoring)) { - @SuppressWarnings("unchecked") - CheckpointedRestoring chkFunction = (CheckpointedRestoring) userFunction; - - int hasUdfState = in.read(); - haveReadUdfStateFlag = true; - - if (hasUdfState == 1) { - Serializable functionState = InstantiationUtil.deserializeObject(in, getUserCodeClassloader()); - if (functionState != null) { - try { - chkFunction.restoreState(functionState); - } catch (Exception e) { - throw new Exception("Failed to restore state to function: " + e.getMessage(), e); - } - } - } - } - - if (in instanceof Migration && !haveReadUdfStateFlag) { - // absorb the introduced byte from the migration stream without too much further consequences - int hasUdfState = in.read(); - if (hasUdfState == 1) { - throw new Exception("Found UDF state but operator is not instance of CheckpointedRestoring"); - } - } - } - @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { super.notifyOfCompletedCheckpoint(checkpointId); @@ -219,23 +157,11 @@ public Configuration getUserFunctionParameters() { private void checkUdfCheckpointingPreconditions() { - boolean newCheckpointInferface = false; - - if (userFunction instanceof CheckpointedFunction) { - newCheckpointInferface = true; - } - - if (userFunction instanceof ListCheckpointed) { - if (newCheckpointInferface) { - throw new IllegalStateException("User functions are not allowed to implement " + - "CheckpointedFunction AND ListCheckpointed."); - } - newCheckpointInferface = true; - } + if (userFunction instanceof CheckpointedFunction + && userFunction instanceof ListCheckpointed) { - if (newCheckpointInferface && userFunction instanceof Checkpointed) { - throw new IllegalStateException("User functions are not allowed to implement Checkpointed AND " + - "CheckpointedFunction/ListCheckpointed."); + throw new IllegalStateException("User functions are not allowed to implement " + + "CheckpointedFunction AND ListCheckpointed."); } } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/CheckpointedRestoringOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/CheckpointedRestoringOperator.java deleted file mode 100644 index 33304e413be3b..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/CheckpointedRestoringOperator.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.operators; - -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; -import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.streaming.runtime.tasks.StreamTask; - -/** - * Interface for {@link StreamOperator StreamOperators} that can restore from a Flink 1.1 - * legacy snapshot that was done using the {@link StreamCheckpointedOperator} interface. - * - * @deprecated {@link Checkpointed} has been deprecated as well. This class can be - * removed when we remove that interface. - */ -@Deprecated -public interface CheckpointedRestoringOperator { - - /** - * Restores the operator state, if this operator's execution is recovering from a checkpoint. - * This method restores the operator state (if the operator is stateful) and the key/value state - * (if it had been used and was initialized when the snapshot occurred). - * - *

This method is called after {@link StreamOperator#setup(StreamTask, StreamConfig, Output)} - * and before {@link StreamOperator#open()}. - * - * @param in The stream from which we have to restore our state. - * - * @throws Exception Exceptions during state restore should be forwarded, so that the system can - * properly react to failed state restore and fail the execution attempt. - */ - void restoreState(FSDataInputStream in) throws Exception; -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java index 17af3aa4a3f18..7d5cb9188ec3e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java @@ -44,7 +44,7 @@ * @param The type of namespace used for the timers. */ @Internal -class InternalTimeServiceManager { +public class InternalTimeServiceManager { private final int totalKeyGroups; private final KeyGroupsList localKeyGroupRange; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java deleted file mode 100644 index 986e2b76930fc..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.operators; - -import org.apache.flink.core.fs.FSDataOutputStream; - -/** - * @deprecated This interface is deprecated without replacement. - * All operators are now checkpointed. - */ -@Deprecated -public interface StreamCheckpointedOperator extends CheckpointedRestoringOperator { - - /** - * Called to draw a state snapshot from the operator. This method snapshots the operator state - * (if the operator is stateful). - * - * @param out The stream to which we have to write our state. - * @param checkpointId The ID of the checkpoint. - * @param timestamp The timestamp of the checkpoint. - * - * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator - * and the key/value state. - */ - void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception; - -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 61578b23a6d51..38b4aeedb1b2c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,10 +20,10 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.StreamTask; import java.io.Serializable; @@ -103,27 +103,12 @@ OperatorSnapshotResult snapshotState( long timestamp, CheckpointOptions checkpointOptions) throws Exception; - /** - * Takes a snapshot of the legacy operator state defined via {@link StreamCheckpointedOperator}. - * - * @return The handle to the legacy operator state, or null, if no state was snapshotted. - * @throws Exception This method should forward any type of exception that happens during snapshotting. - * - * @deprecated This method will be removed as soon as no more operators use the legacy state code paths - */ - @SuppressWarnings("deprecation") - @Deprecated - StreamStateHandle snapshotLegacyOperatorState( - long checkpointId, - long timestamp, - CheckpointOptions checkpointOptions) throws Exception; - /** * Provides state handles to restore the operator state. * * @param stateHandles state handles to the operator state. */ - void initializeState(OperatorStateHandles stateHandles) throws Exception; + void initializeState(OperatorSubtaskState stateHandles) throws Exception; /** * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. @@ -149,4 +134,5 @@ StreamStateHandle snapshotLegacyOperatorState( MetricGroup getMetricGroup(); + OperatorID getOperatorID(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java index a0f626e254cfe..3dfa8aac41eda 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java @@ -27,7 +27,7 @@ import org.apache.flink.streaming.api.datastream.AsyncDataStream; import org.apache.flink.streaming.api.datastream.AsyncDataStream.OutputMode; import org.apache.flink.streaming.api.functions.async.AsyncFunction; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; @@ -57,7 +57,7 @@ /** * The {@link AsyncWaitOperator} allows to asynchronously process incoming stream records. For that - * the operator creates an {@link AsyncCollector} which is passed to an {@link AsyncFunction}. + * the operator creates an {@link ResultFuture} which is passed to an {@link AsyncFunction}. * Within the async function, the user can complete the async collector arbitrarily. Once the async * collector has been completed, the result is emitted by the operator's emitter to downstream * operators. @@ -209,7 +209,7 @@ public void processElement(StreamRecord element) throws Exception { new ProcessingTimeCallback() { @Override public void onProcessingTime(long timestamp) throws Exception { - streamRecordBufferEntry.collect( + streamRecordBufferEntry.completeExceptionally( new TimeoutException("Async function call has timed out.")); } }); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/Emitter.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/Emitter.java index 2204109c1a739..0a1a2dbaaab1b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/Emitter.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/Emitter.java @@ -68,7 +68,7 @@ public Emitter( this.checkpointLock = Preconditions.checkNotNull(checkpointLock, "checkpointLock"); this.output = Preconditions.checkNotNull(output, "output"); - this.streamElementQueue = Preconditions.checkNotNull(streamElementQueue, "asyncCollectorBuffer"); + this.streamElementQueue = Preconditions.checkNotNull(streamElementQueue, "streamElementQueue"); this.operatorActions = Preconditions.checkNotNull(operatorActions, "operatorActions"); this.timestampedCollector = new TimestampedCollector<>(this.output); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/queue/StreamRecordQueueEntry.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/queue/StreamRecordQueueEntry.java index 2aca10e25ab82..796b44f91012a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/queue/StreamRecordQueueEntry.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/queue/StreamRecordQueueEntry.java @@ -20,7 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.streaming.api.functions.async.AsyncFunction; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import java.util.Collection; @@ -28,14 +28,14 @@ /** * {@link StreamElementQueueEntry} implementation for {@link StreamRecord}. This class also acts - * as the {@link AsyncCollector} implementation which is given to the {@link AsyncFunction}. The + * as the {@link ResultFuture} implementation which is given to the {@link AsyncFunction}. The * async function completes this class with a collection of results. * * @param Type of the asynchronous collection result */ @Internal public class StreamRecordQueueEntry extends StreamElementQueueEntry> - implements AsyncCollectionResult, AsyncCollector { + implements AsyncCollectionResult, ResultFuture { /** Timestamp information. */ private final boolean hasTimestamp; @@ -74,12 +74,12 @@ protected CompletableFuture> getFuture() { } @Override - public void collect(Collection result) { + public void complete(Collection result) { resultFuture.complete(result); } @Override - public void collect(Throwable error) { + public void completeExceptionally(Throwable error) { resultFuture.completeExceptionally(error); } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java index b36ad22a7ee07..28496fc31be23 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/CoFeedbackTransformation.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.Collections; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java index e5d7c3ac297c1..03a4e52955743 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/FeedbackTransformation.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java index bc1be5b8ed3d4..c9362866eed92 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/OneInputTransformation.java @@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java index 942d019a30984..6f30e0f6afa76 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/PartitionTransformation.java @@ -22,7 +22,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SelectTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SelectTransformation.java index 6f47264cd5830..2f867cb72e187 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SelectTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SelectTransformation.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java index faa01f47cf188..faa033b106531 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SideOutputTransformation.java @@ -21,7 +21,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.util.OutputTag; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java index 5534cb994e8e7..30ef35eea6d8e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SinkTransformation.java @@ -25,7 +25,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.StreamSink; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SplitTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SplitTransformation.java index 148478a4e380f..d20276111c81e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SplitTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SplitTransformation.java @@ -22,7 +22,7 @@ import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java index 7f561c51061ea..5ee055c97e3b2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java @@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java index bc522e791cc96..2bca7571e2833 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/UnionTransformation.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.streaming.api.operators.ChainingStrategy; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import java.util.Collection; import java.util.List; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/assigners/TumblingAlignedProcessingTimeWindows.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/assigners/TumblingAlignedProcessingTimeWindows.java deleted file mode 100644 index 252f997d0b30b..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/assigners/TumblingAlignedProcessingTimeWindows.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.windowing.assigners; - -import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.streaming.api.windowing.time.Time; - -/** - * This is a special window assigner used to tell the system to use the - * "Fast Aligned Processing Time Window Operator" for windowing. - * - *

Prior Flink versions used that operator automatically for simple processing time - * windows (tumbling and sliding) when no custom trigger and no evictor was specified. - * In the current Flink version, that operator is only used when programs explicitly - * specify this window assigner. This is only intended for special cases where programs relied on - * the better performance of the fast aligned window operator, and are willing to accept the lack - * of support for various features as indicated below: - * - *

    - *
  • No custom state backend can be selected, the operator always stores data on the Java heap.
  • - *
  • The operator does not support key groups, meaning it cannot change the parallelism.
  • - *
  • Future versions of Flink may not be able to resume from checkpoints/savepoints taken by this - * operator.
  • - *
- * - *

Future implementation plans: We plan to add some of the optimizations used by this operator to - * the general window operator, so that future versions of Flink will not have the performance/functionality - * trade-off any more. - * - *

Note on implementation: The concrete operator instantiated by this assigner is either the - * {@link org.apache.flink.streaming.runtime.operators.windowing.AggregatingProcessingTimeWindowOperator} - * or {@link org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator}. - */ -@PublicEvolving -public final class TumblingAlignedProcessingTimeWindows extends BaseAlignedWindowAssigner { - - private static final long serialVersionUID = -6217477609512299842L; - - public TumblingAlignedProcessingTimeWindows(long size) { - super(size); - } - - /** - * Creates a new {@code TumblingAlignedProcessingTimeWindows} {@link WindowAssigner} that assigns - * elements to time windows based on the element timestamp. - * - * @param size The size of the generated windows. - */ - public static TumblingAlignedProcessingTimeWindows of(Time size) { - return new TumblingAlignedProcessingTimeWindows(size.toMilliseconds()); - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/evictors/DeltaEvictor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/evictors/DeltaEvictor.java index 57fec10711b7c..5eeaff8dde4ab 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/evictors/DeltaEvictor.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/windowing/evictors/DeltaEvictor.java @@ -23,7 +23,7 @@ import org.apache.flink.streaming.api.windowing.windows.Window; import org.apache.flink.streaming.runtime.operators.windowing.TimestampedValue; -import com.google.common.collect.Iterables; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; import java.util.Iterator; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java deleted file mode 100644 index 83a752869e5d1..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java +++ /dev/null @@ -1,331 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windowing; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.functions.Function; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.TimestampedCollector; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; -import org.apache.flink.util.MathUtils; - -import org.apache.commons.math3.util.ArithmeticUtils; - -import static java.util.Objects.requireNonNull; - -/** - * Base class for special window operator implementation for windows that fire at the same time for - * all keys. - * - * @deprecated Deprecated in favour of the generic {@link WindowOperator}. This was an - * optimized implementation used for aligned windows. - */ -@Internal -@Deprecated -public abstract class AbstractAlignedProcessingTimeWindowOperator - extends AbstractUdfStreamOperator - implements OneInputStreamOperator, ProcessingTimeCallback { - - private static final long serialVersionUID = 3245500864882459867L; - - private static final long MIN_SLIDE_TIME = 50; - - // ----- fields for operator parametrization ----- - - private final Function function; - private final KeySelector keySelector; - - private final TypeSerializer keySerializer; - private final TypeSerializer stateTypeSerializer; - - private final long windowSize; - private final long windowSlide; - private final long paneSize; - private final int numPanesPerWindow; - - // ----- fields for operator functionality ----- - - private transient AbstractKeyedTimePanes panes; - - private transient TimestampedCollector out; - - private transient RestoredState restoredState; - - private transient long nextEvaluationTime; - private transient long nextSlideTime; - - protected AbstractAlignedProcessingTimeWindowOperator( - F function, - KeySelector keySelector, - TypeSerializer keySerializer, - TypeSerializer stateTypeSerializer, - long windowLength, - long windowSlide) { - super(function); - - if (windowLength < MIN_SLIDE_TIME) { - throw new IllegalArgumentException("Window length must be at least " + MIN_SLIDE_TIME + " msecs"); - } - if (windowSlide < MIN_SLIDE_TIME) { - throw new IllegalArgumentException("Window slide must be at least " + MIN_SLIDE_TIME + " msecs"); - } - if (windowLength < windowSlide) { - throw new IllegalArgumentException("The window size must be larger than the window slide"); - } - - final long paneSlide = ArithmeticUtils.gcd(windowLength, windowSlide); - if (paneSlide < MIN_SLIDE_TIME) { - throw new IllegalArgumentException(String.format( - "Cannot compute window of size %d msecs sliding by %d msecs. " + - "The unit of grouping is too small: %d msecs", windowLength, windowSlide, paneSlide)); - } - - this.function = requireNonNull(function); - this.keySelector = requireNonNull(keySelector); - this.keySerializer = requireNonNull(keySerializer); - this.stateTypeSerializer = requireNonNull(stateTypeSerializer); - this.windowSize = windowLength; - this.windowSlide = windowSlide; - this.paneSize = paneSlide; - this.numPanesPerWindow = MathUtils.checkedDownCast(windowLength / paneSlide); - } - - protected abstract AbstractKeyedTimePanes createPanes( - KeySelector keySelector, Function function); - - // ------------------------------------------------------------------------ - // startup and shutdown - // ------------------------------------------------------------------------ - - @Override - public void open() throws Exception { - super.open(); - - out = new TimestampedCollector<>(output); - - // decide when to first compute the window and when to slide it - // the values should align with the start of time (that is, the UNIX epoch, not the big bang) - final long now = getProcessingTimeService().getCurrentProcessingTime(); - nextEvaluationTime = now + windowSlide - (now % windowSlide); - nextSlideTime = now + paneSize - (now % paneSize); - - final long firstTriggerTime = Math.min(nextEvaluationTime, nextSlideTime); - - // check if we restored state and if we need to fire some windows based on that restored state - if (restoredState == null) { - // initial empty state: create empty panes that gather the elements per slide - panes = createPanes(keySelector, function); - } - else { - // restored state - panes = restoredState.panes; - - long nextPastEvaluationTime = restoredState.nextEvaluationTime; - long nextPastSlideTime = restoredState.nextSlideTime; - long nextPastTriggerTime = Math.min(nextPastEvaluationTime, nextPastSlideTime); - int numPanesRestored = panes.getNumPanes(); - - // fire windows from the past as long as there are more panes with data and as long - // as the missed trigger times have not caught up with the presence - while (numPanesRestored > 0 && nextPastTriggerTime < firstTriggerTime) { - // evaluate the window from the past - if (nextPastTriggerTime == nextPastEvaluationTime) { - computeWindow(nextPastTriggerTime); - nextPastEvaluationTime += windowSlide; - } - - // evaluate slide from the past - if (nextPastTriggerTime == nextPastSlideTime) { - panes.slidePanes(numPanesPerWindow); - numPanesRestored--; - nextPastSlideTime += paneSize; - } - - nextPastTriggerTime = Math.min(nextPastEvaluationTime, nextPastSlideTime); - } - } - - // make sure the first window happens - getProcessingTimeService().registerTimer(firstTriggerTime, this); - } - - @Override - public void close() throws Exception { - super.close(); - - // early stop the triggering thread, so it does not attempt to return any more data - stopTriggers(); - } - - @Override - public void dispose() throws Exception { - super.dispose(); - - // acquire the lock during shutdown, to prevent trigger calls at the same time - // fail-safe stop of the triggering thread (in case of an error) - stopTriggers(); - - // release the panes. panes may still be null if dispose is called - // after open() failed - if (panes != null) { - panes.dispose(); - } - } - - private void stopTriggers() { - // reset the action timestamps. this makes sure any pending triggers will not evaluate - nextEvaluationTime = -1L; - nextSlideTime = -1L; - } - - // ------------------------------------------------------------------------ - // Receiving elements and triggers - // ------------------------------------------------------------------------ - - @Override - public void processElement(StreamRecord element) throws Exception { - panes.addElementToLatestPane(element.getValue()); - } - - @Override - public void onProcessingTime(long timestamp) throws Exception { - // first we check if we actually trigger the window function - if (timestamp == nextEvaluationTime) { - // compute and output the results - computeWindow(timestamp); - - nextEvaluationTime += windowSlide; - } - - // check if we slide the panes by one. this may happen in addition to the - // window computation, or just by itself - if (timestamp == nextSlideTime) { - panes.slidePanes(numPanesPerWindow); - nextSlideTime += paneSize; - } - - long nextTriggerTime = Math.min(nextEvaluationTime, nextSlideTime); - getProcessingTimeService().registerTimer(nextTriggerTime, this); - } - - private void computeWindow(long timestamp) throws Exception { - out.setAbsoluteTimestamp(timestamp); - panes.truncatePanes(numPanesPerWindow); - panes.evaluateWindow(out, new TimeWindow(timestamp - windowSize, timestamp), this); - } - - // ------------------------------------------------------------------------ - // Checkpointing - // ------------------------------------------------------------------------ - - @Override - public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - super.snapshotState(out, checkpointId, timestamp); - - // we write the panes with the key/value maps into the stream, as well as when this state - // should have triggered and slided - - DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(out); - - outView.writeLong(nextEvaluationTime); - outView.writeLong(nextSlideTime); - - panes.writeToOutput(outView, keySerializer, stateTypeSerializer); - - outView.flush(); - } - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); - - DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(in); - - final long nextEvaluationTime = inView.readLong(); - final long nextSlideTime = inView.readLong(); - - AbstractKeyedTimePanes panes = createPanes(keySelector, function); - - panes.readFromInput(inView, keySerializer, stateTypeSerializer); - - restoredState = new RestoredState<>(panes, nextEvaluationTime, nextSlideTime); - } - - // ------------------------------------------------------------------------ - // Property access (for testing) - // ------------------------------------------------------------------------ - - public long getWindowSize() { - return windowSize; - } - - public long getWindowSlide() { - return windowSlide; - } - - public long getPaneSize() { - return paneSize; - } - - public int getNumPanesPerWindow() { - return numPanesPerWindow; - } - - public long getNextEvaluationTime() { - return nextEvaluationTime; - } - - public long getNextSlideTime() { - return nextSlideTime; - } - - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - @Override - public String toString() { - return "Window (processing time) (length=" + windowSize + ", slide=" + windowSlide + ')'; - } - - // ------------------------------------------------------------------------ - // ------------------------------------------------------------------------ - - private static final class RestoredState { - - final AbstractKeyedTimePanes panes; - final long nextEvaluationTime; - final long nextSlideTime; - - RestoredState(AbstractKeyedTimePanes panes, long nextEvaluationTime, long nextSlideTime) { - this.panes = panes; - this.nextEvaluationTime = nextEvaluationTime; - this.nextSlideTime = nextSlideTime; - } - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingProcessingTimeWindowOperator.java deleted file mode 100644 index d67121ab6d752..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingProcessingTimeWindowOperator.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windowing; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.functions.Function; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.runtime.state.ArrayListSerializer; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.streaming.api.windowing.windows.Window; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalWindowFunction; - -import java.util.ArrayList; - -/** - * Special window operator implementation for windows that fire at the same time for all keys with - * accumulating windows. - * - * @deprecated Deprecated in favour of the generic {@link WindowOperator}. This was an - * optimized implementation used for aligned windows. - */ -@Internal -@Deprecated -public class AccumulatingProcessingTimeWindowOperator - extends AbstractAlignedProcessingTimeWindowOperator, InternalWindowFunction, OUT, KEY, TimeWindow>> { - - private static final long serialVersionUID = 7305948082830843475L; - - public AccumulatingProcessingTimeWindowOperator( - InternalWindowFunction, OUT, KEY, TimeWindow> function, - KeySelector keySelector, - TypeSerializer keySerializer, - TypeSerializer valueSerializer, - long windowLength, - long windowSlide) { - super(function, keySelector, keySerializer, - new ArrayListSerializer<>(valueSerializer), windowLength, windowSlide); - } - - @Override - protected AccumulatingKeyedTimePanes createPanes(KeySelector keySelector, Function function) { - @SuppressWarnings("unchecked") - InternalWindowFunction, OUT, KEY, Window> windowFunction = (InternalWindowFunction, OUT, KEY, Window>) function; - - return new AccumulatingKeyedTimePanes<>(keySelector, windowFunction); - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingProcessingTimeWindowOperator.java deleted file mode 100644 index 674738322dfa3..0000000000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingProcessingTimeWindowOperator.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windowing; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.functions.Function; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.functions.KeySelector; - -/** - * Special window operator implementation for windows that fire at the same time for all keys with - * aggregating windows. - * - * @deprecated Deprecated in favour of the generic {@link WindowOperator}. This was an - * optimized implementation used for aligned windows. - */ -@Internal -@Deprecated -public class AggregatingProcessingTimeWindowOperator - extends AbstractAlignedProcessingTimeWindowOperator> { - - private static final long serialVersionUID = 7305948082830843475L; - - public AggregatingProcessingTimeWindowOperator( - ReduceFunction function, - KeySelector keySelector, - TypeSerializer keySerializer, - TypeSerializer aggregateSerializer, - long windowLength, - long windowSlide) { - super(function, keySelector, keySerializer, aggregateSerializer, windowLength, windowSlide); - } - - @Override - protected AggregatingKeyedTimePanes createPanes(KeySelector keySelector, Function function) { - @SuppressWarnings("unchecked") - ReduceFunction windowFunction = (ReduceFunction) function; - - return new AggregatingKeyedTimePanes(keySelector, windowFunction); - } -} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java index d78de097441c5..29602af171b54 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java @@ -38,9 +38,9 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.OutputTag; -import com.google.common.base.Function; -import com.google.common.collect.FluentIterable; -import com.google.common.collect.Iterables; +import org.apache.flink.shaded.guava18.com.google.common.base.Function; +import org.apache.flink.shaded.guava18.com.google.common.collect.FluentIterable; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; import java.util.Collection; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index 880907dbf3251..b14739fed418d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -41,17 +41,12 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.runtime.state.ArrayListSerializer; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.internal.InternalAppendingState; import org.apache.flink.runtime.state.internal.InternalListState; import org.apache.flink.runtime.state.internal.InternalMergingState; -import org.apache.flink.streaming.api.datastream.LegacyWindowOperatorType; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.InternalTimer; @@ -61,8 +56,6 @@ import org.apache.flink.streaming.api.operators.Triggerable; import org.apache.flink.streaming.api.windowing.assigners.BaseAlignedWindowAssigner; import org.apache.flink.streaming.api.windowing.assigners.MergingWindowAssigner; -import org.apache.flink.streaming.api.windowing.assigners.SlidingProcessingTimeWindows; -import org.apache.flink.streaming.api.windowing.assigners.TumblingProcessingTimeWindows; import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner; import org.apache.flink.streaming.api.windowing.triggers.Trigger; import org.apache.flink.streaming.api.windowing.triggers.TriggerResult; @@ -70,16 +63,9 @@ import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalWindowFunction; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.OutputTag; -import org.apache.flink.util.Preconditions; -import org.apache.commons.math3.util.ArithmeticUtils; - -import java.io.IOException; import java.io.Serializable; import java.util.Collection; -import java.util.Comparator; -import java.util.List; -import java.util.PriorityQueue; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -180,34 +166,6 @@ public class WindowOperator protected transient InternalTimerService internalTimerService; - // ------------------------------------------------------------------------ - // State restored in case of migration from an older version (backwards compatibility) - // ------------------------------------------------------------------------ - - /** - * A flag indicating if we are migrating from a regular {@link WindowOperator} - * or one of the deprecated {@link AccumulatingProcessingTimeWindowOperator} and - * {@link AggregatingProcessingTimeWindowOperator}. - */ - private final LegacyWindowOperatorType legacyWindowOperatorType; - - /** - * The elements restored when migrating from an older, deprecated - * {@link AccumulatingProcessingTimeWindowOperator} or - * {@link AggregatingProcessingTimeWindowOperator}. */ - private transient PriorityQueue> restoredFromLegacyAlignedOpRecords; - - /** - * The restored processing time timers when migrating from an - * older version of the {@link WindowOperator}. - */ - private transient PriorityQueue> restoredFromLegacyProcessingTimeTimers; - - /** The restored event time timer when migrating from an - * older version of the {@link WindowOperator}. - */ - private transient PriorityQueue> restoredFromLegacyEventTimeTimers; - /** * Creates a new {@code WindowOperator} based on the given policies and user functions. */ @@ -222,25 +180,6 @@ public WindowOperator( long allowedLateness, OutputTag lateDataOutputTag) { - this(windowAssigner, windowSerializer, keySelector, keySerializer, - windowStateDescriptor, windowFunction, trigger, allowedLateness, lateDataOutputTag, LegacyWindowOperatorType.NONE); - } - - /** - * Creates a new {@code WindowOperator} based on the given policies and user functions. - */ - public WindowOperator( - WindowAssigner windowAssigner, - TypeSerializer windowSerializer, - KeySelector keySelector, - TypeSerializer keySerializer, - StateDescriptor, ?> windowStateDescriptor, - InternalWindowFunction windowFunction, - Trigger trigger, - long allowedLateness, - OutputTag lateDataOutputTag, - LegacyWindowOperatorType legacyWindowOperatorType) { - super(windowFunction); checkArgument(!(windowAssigner instanceof BaseAlignedWindowAssigner), @@ -261,7 +200,6 @@ public WindowOperator( this.trigger = checkNotNull(trigger); this.allowedLateness = allowedLateness; this.lateDataOutputTag = lateDataOutputTag; - this.legacyWindowOperatorType = legacyWindowOperatorType; setChainingStrategy(ChainingStrategy.ALWAYS); } @@ -321,8 +259,6 @@ public long getCurrentProcessingTime() { getOrCreateKeyedState(VoidNamespaceSerializer.INSTANCE, mergingSetsStateDescriptor); mergingSetsState.setCurrentNamespace(VoidNamespace.INSTANCE); } - - registerRestoredLegacyStateState(); } @Override @@ -1036,256 +972,6 @@ public String toString() { } } - // ------------------------------------------------------------------------ - // Restoring / Migrating from an older Flink version. - // ------------------------------------------------------------------------ - - private static final int BEGIN_OF_STATE_MAGIC_NUMBER = 0x0FF1CE42; - - private static final int BEGIN_OF_PANE_MAGIC_NUMBER = 0xBADF00D5; - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); - - LOG.info("{} (taskIdx={}) restoring {} state from an older Flink version.", - getClass().getSimpleName(), legacyWindowOperatorType, getRuntimeContext().getIndexOfThisSubtask()); - - DataInputViewStreamWrapper streamWrapper = new DataInputViewStreamWrapper(in); - - switch (legacyWindowOperatorType) { - case NONE: - restoreFromLegacyWindowOperator(streamWrapper); - break; - case FAST_ACCUMULATING: - case FAST_AGGREGATING: - restoreFromLegacyAlignedWindowOperator(streamWrapper); - break; - } - } - - public void registerRestoredLegacyStateState() throws Exception { - - switch (legacyWindowOperatorType) { - case NONE: - reregisterStateFromLegacyWindowOperator(); - break; - case FAST_ACCUMULATING: - case FAST_AGGREGATING: - reregisterStateFromLegacyAlignedWindowOperator(); - break; - } - } - - private void restoreFromLegacyAlignedWindowOperator(DataInputViewStreamWrapper in) throws IOException { - Preconditions.checkArgument(legacyWindowOperatorType != LegacyWindowOperatorType.NONE); - - final long nextEvaluationTime = in.readLong(); - final long nextSlideTime = in.readLong(); - - validateMagicNumber(BEGIN_OF_STATE_MAGIC_NUMBER, in.readInt()); - - restoredFromLegacyAlignedOpRecords = new PriorityQueue<>(42, - new Comparator>() { - @Override - public int compare(StreamRecord o1, StreamRecord o2) { - return Long.compare(o1.getTimestamp(), o2.getTimestamp()); - } - } - ); - - switch (legacyWindowOperatorType) { - case FAST_ACCUMULATING: - restoreElementsFromLegacyAccumulatingAlignedWindowOperator(in, nextSlideTime); - break; - case FAST_AGGREGATING: - restoreElementsFromLegacyAggregatingAlignedWindowOperator(in, nextSlideTime); - break; - } - - if (LOG.isDebugEnabled()) { - LOG.debug("{} (taskIdx={}) restored {} events from legacy {}.", - getClass().getSimpleName(), - getRuntimeContext().getIndexOfThisSubtask(), - restoredFromLegacyAlignedOpRecords.size(), - legacyWindowOperatorType); - } - } - - private void restoreElementsFromLegacyAccumulatingAlignedWindowOperator(DataInputView in, long nextSlideTime) throws IOException { - int numPanes = in.readInt(); - final long paneSize = getPaneSize(); - long nextElementTimestamp = nextSlideTime - (numPanes * paneSize); - - @SuppressWarnings("unchecked") - ArrayListSerializer ser = new ArrayListSerializer<>((TypeSerializer) getStateDescriptor().getSerializer()); - - while (numPanes > 0) { - validateMagicNumber(BEGIN_OF_PANE_MAGIC_NUMBER, in.readInt()); - - nextElementTimestamp += paneSize - 1; // the -1 is so that the elements fall into the correct time-frame - - final int numElementsInPane = in.readInt(); - for (int i = numElementsInPane - 1; i >= 0; i--) { - K key = keySerializer.deserialize(in); - - @SuppressWarnings("unchecked") - List valueList = ser.deserialize(in); - for (IN record: valueList) { - restoredFromLegacyAlignedOpRecords.add(new StreamRecord<>(record, nextElementTimestamp)); - } - } - numPanes--; - } - } - - private void restoreElementsFromLegacyAggregatingAlignedWindowOperator(DataInputView in, long nextSlideTime) throws IOException { - int numPanes = in.readInt(); - final long paneSize = getPaneSize(); - long nextElementTimestamp = nextSlideTime - (numPanes * paneSize); - - while (numPanes > 0) { - validateMagicNumber(BEGIN_OF_PANE_MAGIC_NUMBER, in.readInt()); - - nextElementTimestamp += paneSize - 1; // the -1 is so that the elements fall into the correct time-frame - - final int numElementsInPane = in.readInt(); - for (int i = numElementsInPane - 1; i >= 0; i--) { - K key = keySerializer.deserialize(in); - - @SuppressWarnings("unchecked") - IN value = (IN) getStateDescriptor().getSerializer().deserialize(in); - restoredFromLegacyAlignedOpRecords.add(new StreamRecord<>(value, nextElementTimestamp)); - } - numPanes--; - } - } - - private long getPaneSize() { - Preconditions.checkArgument( - legacyWindowOperatorType == LegacyWindowOperatorType.FAST_ACCUMULATING || - legacyWindowOperatorType == LegacyWindowOperatorType.FAST_AGGREGATING); - - final long paneSlide; - if (windowAssigner instanceof SlidingProcessingTimeWindows) { - SlidingProcessingTimeWindows timeWindows = (SlidingProcessingTimeWindows) windowAssigner; - paneSlide = ArithmeticUtils.gcd(timeWindows.getSize(), timeWindows.getSlide()); - } else { - TumblingProcessingTimeWindows timeWindows = (TumblingProcessingTimeWindows) windowAssigner; - paneSlide = timeWindows.getSize(); // this is valid as windowLength == windowSlide == timeWindows.getSize - } - return paneSlide; - } - - private static void validateMagicNumber(int expected, int found) throws IOException { - if (expected != found) { - throw new IOException("Corrupt state stream - wrong magic number. " + - "Expected '" + Integer.toHexString(expected) + - "', found '" + Integer.toHexString(found) + '\''); - } - } - - private void restoreFromLegacyWindowOperator(DataInputViewStreamWrapper in) throws IOException { - Preconditions.checkArgument(legacyWindowOperatorType == LegacyWindowOperatorType.NONE); - - int numWatermarkTimers = in.readInt(); - this.restoredFromLegacyEventTimeTimers = new PriorityQueue<>(Math.max(numWatermarkTimers, 1)); - - for (int i = 0; i < numWatermarkTimers; i++) { - K key = keySerializer.deserialize(in); - W window = windowSerializer.deserialize(in); - long timestamp = in.readLong(); - - Timer timer = new Timer<>(timestamp, key, window); - restoredFromLegacyEventTimeTimers.add(timer); - } - - int numProcessingTimeTimers = in.readInt(); - this.restoredFromLegacyProcessingTimeTimers = new PriorityQueue<>(Math.max(numProcessingTimeTimers, 1)); - - for (int i = 0; i < numProcessingTimeTimers; i++) { - K key = keySerializer.deserialize(in); - W window = windowSerializer.deserialize(in); - long timestamp = in.readLong(); - - Timer timer = new Timer<>(timestamp, key, window); - restoredFromLegacyProcessingTimeTimers.add(timer); - } - - // just to read all the rest, although we do not really use this information. - int numProcessingTimeTimerTimestamp = in.readInt(); - for (int i = 0; i < numProcessingTimeTimerTimestamp; i++) { - in.readLong(); - in.readInt(); - } - - if (LOG.isDebugEnabled()) { - int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask(); - - if (restoredFromLegacyEventTimeTimers != null && !restoredFromLegacyEventTimeTimers.isEmpty()) { - LOG.debug("{} (taskIdx={}) restored {} event time timers from an older Flink version: {}", - getClass().getSimpleName(), subtaskIdx, - restoredFromLegacyEventTimeTimers.size(), - restoredFromLegacyEventTimeTimers); - } - - if (restoredFromLegacyProcessingTimeTimers != null && !restoredFromLegacyProcessingTimeTimers.isEmpty()) { - LOG.debug("{} (taskIdx={}) restored {} processing time timers from an older Flink version: {}", - getClass().getSimpleName(), subtaskIdx, - restoredFromLegacyProcessingTimeTimers.size(), - restoredFromLegacyProcessingTimeTimers); - } - } - } - - public void reregisterStateFromLegacyWindowOperator() { - // if we restore from an older version, - // we have to re-register the recovered state. - - if (restoredFromLegacyEventTimeTimers != null && !restoredFromLegacyEventTimeTimers.isEmpty()) { - - LOG.info("{} (taskIdx={}) re-registering event-time timers from an older Flink version.", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask()); - - for (Timer timer : restoredFromLegacyEventTimeTimers) { - setCurrentKey(timer.key); - internalTimerService.registerEventTimeTimer(timer.window, timer.timestamp); - } - } - - if (restoredFromLegacyProcessingTimeTimers != null && !restoredFromLegacyProcessingTimeTimers.isEmpty()) { - - LOG.info("{} (taskIdx={}) re-registering processing-time timers from an older Flink version.", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask()); - - for (Timer timer : restoredFromLegacyProcessingTimeTimers) { - setCurrentKey(timer.key); - internalTimerService.registerProcessingTimeTimer(timer.window, timer.timestamp); - } - } - - // gc friendliness - restoredFromLegacyEventTimeTimers = null; - restoredFromLegacyProcessingTimeTimers = null; - } - - public void reregisterStateFromLegacyAlignedWindowOperator() throws Exception { - if (restoredFromLegacyAlignedOpRecords != null && !restoredFromLegacyAlignedOpRecords.isEmpty()) { - - LOG.info("{} (taskIdx={}) re-registering timers from legacy {} from an older Flink version.", - getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), legacyWindowOperatorType); - - while (!restoredFromLegacyAlignedOpRecords.isEmpty()) { - StreamRecord record = restoredFromLegacyAlignedOpRecords.poll(); - setCurrentKey(keySelector.getKey(record.getValue())); - processElement(record); - } - } - - // gc friendliness - restoredFromLegacyAlignedOpRecords = null; - } - // ------------------------------------------------------------------------ // Getters for testing // ------------------------------------------------------------------------ diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/streamrecord/StreamElementSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/streamrecord/StreamElementSerializer.java index 1dc0ee2500084..d0ab60a3ab347 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/streamrecord/StreamElementSerializer.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/streamrecord/StreamElementSerializer.java @@ -29,7 +29,6 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.migration.streaming.runtime.streamrecord.MultiplexingStreamRecordSerializer; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamstatus.StreamStatus; @@ -292,9 +291,6 @@ public CompatibilityResult ensureCompatibility(TypeSerializerConf if (configSnapshot instanceof StreamElementSerializerConfigSnapshot) { previousTypeSerializerAndConfig = ((StreamElementSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); - } else if (configSnapshot instanceof MultiplexingStreamRecordSerializer.MultiplexingStreamRecordSerializerConfigSnapshot) { - previousTypeSerializerAndConfig = - ((MultiplexingStreamRecordSerializer.MultiplexingStreamRecordSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); } else { return CompatibilityResult.requiresMigration(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java index 1a79f5429c206..0b03b7980f52e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java @@ -20,13 +20,9 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.CollectionUtil; -import org.apache.flink.util.Preconditions; import java.util.Collection; import java.util.List; @@ -40,8 +36,6 @@ public class OperatorStateHandles { private final int operatorChainIndex; - private final StreamStateHandle legacyOperatorState; - private final Collection managedKeyedState; private final Collection rawKeyedState; private final Collection managedOperatorState; @@ -49,40 +43,18 @@ public class OperatorStateHandles { public OperatorStateHandles( int operatorChainIndex, - StreamStateHandle legacyOperatorState, Collection managedKeyedState, Collection rawKeyedState, Collection managedOperatorState, Collection rawOperatorState) { this.operatorChainIndex = operatorChainIndex; - this.legacyOperatorState = legacyOperatorState; this.managedKeyedState = managedKeyedState; this.rawKeyedState = rawKeyedState; this.managedOperatorState = managedOperatorState; this.rawOperatorState = rawOperatorState; } - public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) { - Preconditions.checkNotNull(taskStateHandles); - - this.operatorChainIndex = operatorChainIndex; - - ChainedStateHandle legacyState = taskStateHandles.getLegacyOperatorState(); - this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ? - null : legacyState.get(operatorChainIndex); - - this.rawKeyedState = taskStateHandles.getRawKeyedState(); - this.managedKeyedState = taskStateHandles.getManagedKeyedState(); - - this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex); - this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex); - } - - public StreamStateHandle getLegacyOperatorState() { - return legacyOperatorState; - } - public Collection getManagedKeyedState() { return managedKeyedState; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index c35a6dc5b3684..6089240c8db51 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -18,6 +18,7 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.CloseableRegistry; @@ -25,26 +26,24 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackend; -import org.apache.flink.runtime.state.StateUtil; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -54,7 +53,6 @@ import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; -import org.apache.flink.util.CollectionUtil; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FutureUtil; import org.apache.flink.util.Preconditions; @@ -64,13 +62,11 @@ import java.io.Closeable; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.RunnableFuture; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicReference; @@ -158,7 +154,7 @@ public abstract class StreamTask> /** The map of user-defined accumulators of this task. */ private Map> accumulatorMap; - private TaskStateHandles restoreStateHandles; + private TaskStateSnapshot taskStateSnapshot; /** The currently active background materialization threads. */ private final CloseableRegistry cancelables = new CloseableRegistry(); @@ -281,10 +277,12 @@ public final void invoke() throws Exception { // we also need to make sure that no triggers fire concurrently with the close logic // at the same time, this makes sure that during any "regular" exit where still synchronized (lock) { - isRunning = false; - // this is part of the main logic, so if this fails, the task is considered failed closeAllOperators(); + + // only set the StreamTask to not running after all operators have been closed! + // See FLINK-7430 + isRunning = false; } LOG.debug("Closed operators for task {}", getName()); @@ -508,8 +506,8 @@ RecordWriterOutput[] getStreamOutputs() { // ------------------------------------------------------------------------ @Override - public void setInitialState(TaskStateHandles taskStateHandles) { - this.restoreStateHandles = taskStateHandles; + public void setInitialState(TaskStateSnapshot taskStateHandles) { + this.taskStateSnapshot = taskStateHandles; } @Override @@ -658,12 +656,11 @@ private void checkpointState( private void initializeState() throws Exception { - boolean restored = null != restoreStateHandles; + boolean restored = null != taskStateSnapshot; if (restored) { - checkRestorePreconditions(operatorChain.getChainLength()); initializeOperators(true); - restoreStateHandles = null; // free for GC + taskStateSnapshot = null; // free for GC } else { initializeOperators(false); } @@ -674,8 +671,8 @@ private void initializeOperators(boolean restored) throws Exception { for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) { StreamOperator operator = allOperators[chainIdx]; if (null != operator) { - if (restored && restoreStateHandles != null) { - operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx)); + if (restored && taskStateSnapshot != null) { + operator.initializeState(taskStateSnapshot.getSubtaskStateByOperatorID(operator.getOperatorID())); } else { operator.initializeState(null); } @@ -683,26 +680,6 @@ private void initializeOperators(boolean restored) throws Exception { } } - private void checkRestorePreconditions(int operatorChainLength) { - - ChainedStateHandle nonPartitionableOperatorStates = - restoreStateHandles.getLegacyOperatorState(); - List> operatorStates = - restoreStateHandles.getManagedOperatorState(); - - if (nonPartitionableOperatorStates != null) { - Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength, - "Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength() - + ". Expected: " + operatorChainLength); - } - - if (!CollectionUtil.isNullOrEmpty(operatorStates)) { - Preconditions.checkArgument(operatorStates.size() == operatorChainLength, - "Invalid number of operator states. Found :" + operatorStates.size() + - ". Expected: " + operatorChainLength); - } - } - // ------------------------------------------------------------------------ // State backend // ------------------------------------------------------------------------ @@ -768,8 +745,13 @@ public AbstractKeyedStateBackend createKeyedStateBackend( cancelables.registerClosable(keyedStateBackend); // restore if we have some old state - Collection restoreKeyedStateHandles = - restoreStateHandles == null ? null : restoreStateHandles.getManagedKeyedState(); + Collection restoreKeyedStateHandles = null; + + if (taskStateSnapshot != null) { + OperatorSubtaskState stateByOperatorID = + taskStateSnapshot.getSubtaskStateByOperatorID(headOperator.getOperatorID()); + restoreKeyedStateHandles = stateByOperatorID != null ? stateByOperatorID.getManagedKeyedState() : null; + } keyedStateBackend.restore(restoreKeyedStateHandles); @@ -798,9 +780,11 @@ public CheckpointStreamFactory createSavepointStreamFactory(StreamOperator op } private String createOperatorIdentifier(StreamOperator operator, int vertexId) { + + TaskInfo taskInfo = getEnvironment().getTaskInfo(); return operator.getClass().getSimpleName() + - "_" + vertexId + - "_" + getEnvironment().getTaskInfo().getIndexOfThisSubtask(); + "_" + operator.getOperatorID() + + "_(" + taskInfo.getIndexOfThisSubtask() + "/" + taskInfo.getNumberOfParallelSubtasks() + ")"; } /** @@ -850,12 +834,7 @@ private static final class AsyncCheckpointRunnable implements Runnable, Closeabl private final StreamTask owner; - private final List snapshotInProgressList; - - private RunnableFuture futureKeyedBackendStateHandles; - private RunnableFuture futureKeyedStreamStateHandles; - - private List nonPartitionedStateHandles; + private final Map operatorSnapshotsInProgress; private final CheckpointMetaData checkpointMetaData; private final CheckpointMetrics checkpointMetrics; @@ -867,86 +846,66 @@ private static final class AsyncCheckpointRunnable implements Runnable, Closeabl AsyncCheckpointRunnable( StreamTask owner, - List nonPartitionedStateHandles, - List snapshotInProgressList, + Map operatorSnapshotsInProgress, CheckpointMetaData checkpointMetaData, CheckpointMetrics checkpointMetrics, long asyncStartNanos) { this.owner = Preconditions.checkNotNull(owner); - this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList); + this.operatorSnapshotsInProgress = Preconditions.checkNotNull(operatorSnapshotsInProgress); this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData); this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics); - this.nonPartitionedStateHandles = nonPartitionedStateHandles; this.asyncStartNanos = asyncStartNanos; - - if (!snapshotInProgressList.isEmpty()) { - // TODO Currently only the head operator of a chain can have keyed state, so simply access it directly. - int headIndex = snapshotInProgressList.size() - 1; - OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex); - if (null != snapshotInProgress) { - this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture(); - this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture(); - } - } } @Override public void run() { FileSystemSafetyNet.initializeSafetyNetForThread(); try { - // Keyed state handle future, currently only one (the head) operator can have this - KeyedStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles); - KeyedStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles); - - List operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size()); - List operatorStatesStream = new ArrayList<>(snapshotInProgressList.size()); - - for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) { - if (null != snapshotInProgress) { - operatorStatesBackend.add( - FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture())); - operatorStatesStream.add( - FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture())); - } else { - operatorStatesBackend.add(null); - operatorStatesStream.add(null); - } - } + boolean hasState = false; + final TaskStateSnapshot taskOperatorSubtaskStates = + new TaskStateSnapshot(operatorSnapshotsInProgress.size()); - final long asyncEndNanos = System.nanoTime(); - final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; + for (Map.Entry entry : operatorSnapshotsInProgress.entrySet()) { - checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis); + OperatorID operatorID = entry.getKey(); + OperatorSnapshotResult snapshotInProgress = entry.getValue(); - ChainedStateHandle chainedNonPartitionedOperatorsState = - new ChainedStateHandle<>(nonPartitionedStateHandles); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture()) + ); - ChainedStateHandle chainedOperatorStateBackend = - new ChainedStateHandle<>(operatorStatesBackend); + hasState |= operatorSubtaskState.hasState(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + } - ChainedStateHandle chainedOperatorStateStream = - new ChainedStateHandle<>(operatorStatesStream); + final long asyncEndNanos = System.nanoTime(); + final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; - SubtaskState subtaskState = createSubtaskStateFromSnapshotStateHandles( - chainedNonPartitionedOperatorsState, - chainedOperatorStateBackend, - chainedOperatorStateStream, - keyedStateHandleBackend, - keyedStateHandleStream); + checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis); if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.COMPLETED)) { + TaskStateSnapshot acknowledgedState = hasState ? taskOperatorSubtaskStates : null; + + // we signal stateless tasks by reporting null, so that there are no attempts to assign empty state + // to stateless tasks on restore. This enables simple job modifications that only concern + // stateless without the need to assign them uids to match their (always empty) states. owner.getEnvironment().acknowledgeCheckpoint( checkpointMetaData.getCheckpointId(), checkpointMetrics, - subtaskState); + acknowledgedState); + + LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", + owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis); + + LOG.trace("{} - reported the following states in snapshot for checkpoint {}: {}.", + owner.getName(), checkpointMetaData.getCheckpointId(), acknowledgedState); - if (LOG.isDebugEnabled()) { - LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", - owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis); - } } else { LOG.debug("{} - asynchronous part of checkpoint {} could not be completed because it was closed before.", owner.getName(), @@ -988,38 +947,13 @@ public void close() { } } - private SubtaskState createSubtaskStateFromSnapshotStateHandles( - ChainedStateHandle chainedNonPartitionedOperatorsState, - ChainedStateHandle chainedOperatorStateBackend, - ChainedStateHandle chainedOperatorStateStream, - KeyedStateHandle keyedStateHandleBackend, - KeyedStateHandle keyedStateHandleStream) { - - boolean hasAnyState = keyedStateHandleBackend != null - || keyedStateHandleStream != null - || !chainedOperatorStateBackend.isEmpty() - || !chainedOperatorStateStream.isEmpty() - || !chainedNonPartitionedOperatorsState.isEmpty(); - - // we signal a stateless task by reporting null, so that there are no attempts to assign empty state to - // stateless tasks on restore. This allows for simple job modifications that only concern stateless without - // the need to assign them uids to match their (always empty) states. - return hasAnyState ? new SubtaskState( - chainedNonPartitionedOperatorsState, - chainedOperatorStateBackend, - chainedOperatorStateStream, - keyedStateHandleBackend, - keyedStateHandleStream) - : null; - } - private void cleanup() throws Exception { if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.DISCARDED)) { LOG.debug("Cleanup AsyncCheckpointRunnable for checkpoint {} of {}.", checkpointMetaData.getCheckpointId(), owner.getName()); Exception exception = null; // clean up ongoing operator snapshot results and non partitioned state handles - for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) { + for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) { if (operatorSnapshotResult != null) { try { operatorSnapshotResult.cancel(); @@ -1029,13 +963,6 @@ private void cleanup() throws Exception { } } - // discard non partitioned state handles - try { - StateUtil.bestEffortDiscardAllStateObjects(nonPartitionedStateHandles); - } catch (Exception discardException) { - exception = ExceptionUtils.firstOrSuppressed(discardException, exception); - } - if (null != exception) { throw exception; } @@ -1069,8 +996,7 @@ private static final class CheckpointingOperation { // ------------------------ - private final List nonPartitionedStates; - private final List snapshotInProgressList; + private final Map operatorSnapshotsInProgress; public CheckpointingOperation( StreamTask owner, @@ -1083,8 +1009,7 @@ public CheckpointingOperation( this.checkpointOptions = Preconditions.checkNotNull(checkpointOptions); this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics); this.allOperators = owner.operatorChain.getAllOperators(); - this.nonPartitionedStates = new ArrayList<>(allOperators.length); - this.snapshotInProgressList = new ArrayList<>(allOperators.length); + this.operatorSnapshotsInProgress = new HashMap<>(allOperators.length); } public void executeCheckpointing() throws Exception { @@ -1119,7 +1044,7 @@ public void executeCheckpointing() throws Exception { } finally { if (failed) { // Cleanup to release resources - for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) { + for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) { if (null != operatorSnapshotResult) { try { operatorSnapshotResult.cancel(); @@ -1129,18 +1054,6 @@ public void executeCheckpointing() throws Exception { } } - // Cleanup non partitioned state handles - for (StreamStateHandle nonPartitionedState : nonPartitionedStates) { - if (nonPartitionedState != null) { - try { - nonPartitionedState.discardState(); - } catch (Exception e) { - LOG.warn("Could not properly discard a non partitioned " + - "state. This might leave some orphaned files behind.", e); - } - } - } - if (LOG.isDebugEnabled()) { LOG.debug("{} - did NOT finish synchronous part of checkpoint {}." + "Alignment duration: {} ms, snapshot duration {} ms", @@ -1155,22 +1068,12 @@ public void executeCheckpointing() throws Exception { @SuppressWarnings("deprecation") private void checkpointStreamOperator(StreamOperator op) throws Exception { if (null != op) { - // first call the legacy checkpoint code paths - nonPartitionedStates.add(op.snapshotLegacyOperatorState( - checkpointMetaData.getCheckpointId(), - checkpointMetaData.getTimestamp(), - checkpointOptions)); OperatorSnapshotResult snapshotInProgress = op.snapshotState( checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp(), checkpointOptions); - - snapshotInProgressList.add(snapshotInProgress); - } else { - nonPartitionedStates.add(null); - OperatorSnapshotResult emptySnapshotInProgress = new OperatorSnapshotResult(); - snapshotInProgressList.add(emptySnapshotInProgress); + operatorSnapshotsInProgress.put(op.getOperatorID(), snapshotInProgress); } } @@ -1178,8 +1081,7 @@ public void runAsyncCheckpointingAndAcknowledge() throws IOException { AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable( owner, - nonPartitionedStates, - snapshotInProgressList, + operatorSnapshotsInProgress, checkpointMetaData, checkpointMetrics, startAsyncPartNano); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java index acb531dc8a284..db9622b4e48b9 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java @@ -34,7 +34,8 @@ import org.apache.flink.streaming.util.MockContext; import org.apache.flink.streaming.util.keys.KeySelectorUtil; -import com.google.common.collect.ImmutableList; +import org.apache.flink.shaded.guava18.com.google.common.collect.ImmutableList; + import org.junit.Test; import java.io.Serializable; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java index 224b3761a99d2..aba37dfed0545 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java @@ -31,7 +31,6 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import org.junit.Test; @@ -55,7 +54,7 @@ public void testIterationRuntimeContext() throws Exception { private static final long serialVersionUID = -2023923961609455894L; @Override - public void asyncInvoke(Integer input, AsyncCollector collector) throws Exception { + public void asyncInvoke(Integer input, ResultFuture resultFuture) throws Exception { // no op } }; @@ -94,7 +93,7 @@ public void testRuntimeContext() throws Exception { private static final long serialVersionUID = 1707630162838967972L; @Override - public void asyncInvoke(Integer input, AsyncCollector collector) throws Exception { + public void asyncInvoke(Integer input, ResultFuture resultFuture) throws Exception { // no op } }; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java index 7d3abc2c3ff91..4715c39fe384f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java @@ -20,11 +20,12 @@ import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.streaming.api.operators.StreamSink; import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import java.io.BufferedWriter; @@ -32,7 +33,6 @@ import java.io.FileNotFoundException; import java.io.FileWriter; import java.io.IOException; -import java.io.Writer; import java.nio.charset.Charset; import java.nio.file.Files; import java.util.ArrayList; @@ -44,56 +44,52 @@ import static java.nio.file.StandardCopyOption.ATOMIC_MOVE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Tests for {@link TwoPhaseCommitSinkFunction}. */ public class TwoPhaseCommitSinkFunctionTest { - @Test - public void testNotifyOfCompletedCheckpoint() throws Exception { - File tmpDirectory = Files.createTempDirectory(this.getClass().getSimpleName() + "_tmp").toFile(); - File targetDirectory = Files.createTempDirectory(this.getClass().getSimpleName() + "_target").toFile(); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(tmpDirectory, targetDirectory); - - testHarness.setup(); - testHarness.open(); - testHarness.processElement("42", 0); - testHarness.snapshot(0, 1); - testHarness.processElement("43", 2); - testHarness.snapshot(1, 3); - testHarness.processElement("44", 4); - testHarness.snapshot(2, 5); - testHarness.notifyOfCompletedCheckpoint(1); - - assertExactlyOnceForDirectory(targetDirectory, Arrays.asList("42", "43")); - assertEquals(2, tmpDirectory.listFiles().length); // one for checkpointId 2 and second for the currentTransaction - testHarness.close(); + TestContext context; + + @Before + public void setUp() throws Exception { + context = new TestContext(); } - public OneInputStreamOperatorTestHarness createTestHarness(File tmpDirectory, File targetDirectory) throws Exception { - tmpDirectory.deleteOnExit(); - targetDirectory.deleteOnExit(); - FileBasedSinkFunction sinkFunction = new FileBasedSinkFunction(tmpDirectory, targetDirectory); - return new OneInputStreamOperatorTestHarness<>(new StreamSink<>(sinkFunction), StringSerializer.INSTANCE); + @After + public void tearDown() throws Exception { + context.close(); + } + + @Test + public void testNotifyOfCompletedCheckpoint() throws Exception { + context.harness.open(); + context.harness.processElement("42", 0); + context.harness.snapshot(0, 1); + context.harness.processElement("43", 2); + context.harness.snapshot(1, 3); + context.harness.processElement("44", 4); + context.harness.snapshot(2, 5); + context.harness.notifyOfCompletedCheckpoint(1); + + assertExactlyOnceForDirectory(context.targetDirectory, Arrays.asList("42", "43")); + assertEquals(2, context.tmpDirectory.listFiles().length); // one for checkpointId 2 and second for the currentTransaction } @Test public void testFailBeforeNotify() throws Exception { - File tmpDirectory = Files.createTempDirectory(this.getClass().getSimpleName() + "_tmp").toFile(); - File targetDirectory = Files.createTempDirectory(this.getClass().getSimpleName() + "_target").toFile(); - OneInputStreamOperatorTestHarness testHarness = createTestHarness(tmpDirectory, targetDirectory); - - testHarness.setup(); - testHarness.open(); - testHarness.processElement("42", 0); - testHarness.snapshot(0, 1); - testHarness.processElement("43", 2); - OperatorStateHandles snapshot = testHarness.snapshot(1, 3); - - assertTrue(tmpDirectory.setWritable(false)); + context.harness.open(); + context.harness.processElement("42", 0); + context.harness.snapshot(0, 1); + context.harness.processElement("43", 2); + OperatorStateHandles snapshot = context.harness.snapshot(1, 3); + + assertTrue(context.tmpDirectory.setWritable(false)); try { - testHarness.processElement("44", 4); - testHarness.snapshot(2, 5); + context.harness.processElement("44", 4); + context.harness.snapshot(2, 5); + fail("something should fail"); } catch (Exception ex) { if (!(ex.getCause() instanceof FileNotFoundException)) { @@ -101,17 +97,17 @@ public void testFailBeforeNotify() throws Exception { } // ignore } - testHarness.close(); + context.close(); + + assertTrue(context.tmpDirectory.setWritable(true)); - assertTrue(tmpDirectory.setWritable(true)); + context.open(); + context.harness.initializeState(snapshot); - testHarness = createTestHarness(tmpDirectory, targetDirectory); - testHarness.setup(); - testHarness.initializeState(snapshot); - testHarness.close(); + assertExactlyOnceForDirectory(context.targetDirectory, Arrays.asList("42", "43")); + context.close(); - assertExactlyOnceForDirectory(targetDirectory, Arrays.asList("42", "43")); - assertEquals(0, tmpDirectory.listFiles().length); + assertEquals(0, context.tmpDirectory.listFiles().length); } private void assertExactlyOnceForDirectory(File targetDirectory, List expectedValues) throws IOException { @@ -124,14 +120,12 @@ private void assertExactlyOnceForDirectory(File targetDirectory, List ex assertEquals(expectedValues, actualValues); } - private static class FileBasedSinkFunction extends TwoPhaseCommitSinkFunction { + private static class FileBasedSinkFunction extends TwoPhaseCommitSinkFunction { private final File tmpDirectory; private final File targetDirectory; public FileBasedSinkFunction(File tmpDirectory, File targetDirectory) { - super( - TypeInformation.of(FileTransaction.class), - TypeInformation.of(new TypeHint>>() {})); + super(TypeInformation.of(new TypeHint>() {})); if (!tmpDirectory.isDirectory() || !targetDirectory.isDirectory()) { throw new IllegalArgumentException(); @@ -161,7 +155,10 @@ protected void preCommit(FileTransaction transaction) throws Exception { @Override protected void commit(FileTransaction transaction) { try { - Files.move(transaction.tmpFile.toPath(), new File(targetDirectory, transaction.tmpFile.getName()).toPath(), ATOMIC_MOVE); + Files.move( + transaction.tmpFile.toPath(), + new File(targetDirectory, transaction.tmpFile.getName()).toPath(), + ATOMIC_MOVE); } catch (IOException e) { throw new IllegalStateException(e); } @@ -185,11 +182,41 @@ protected void recoverAndAbort(FileTransaction transaction) { private static class FileTransaction { private final File tmpFile; - private final transient Writer writer; + private final transient BufferedWriter writer; public FileTransaction(File tmpFile) throws IOException { this.tmpFile = tmpFile; this.writer = new BufferedWriter(new FileWriter(tmpFile)); } + + @Override + public String toString() { + return String.format("FileTransaction[%s]", tmpFile.getName()); + } + } + + private static class TestContext implements AutoCloseable { + public final File tmpDirectory = Files.createTempDirectory(TwoPhaseCommitSinkFunctionTest.class.getSimpleName() + "_tmp").toFile(); + public final File targetDirectory = Files.createTempDirectory(TwoPhaseCommitSinkFunctionTest.class.getSimpleName() + "_target").toFile(); + + public FileBasedSinkFunction sinkFunction; + public OneInputStreamOperatorTestHarness harness; + + private TestContext() throws Exception { + tmpDirectory.deleteOnExit(); + targetDirectory.deleteOnExit(); + open(); + } + + @Override + public void close() throws Exception { + harness.close(); + } + + public void open() throws Exception { + sinkFunction = new FileBasedSinkFunction(tmpDirectory, targetDirectory); + harness = new OneInputStreamOperatorTestHarness<>(new StreamSink<>(sinkFunction), StringSerializer.INSTANCE); + harness.setup(); + } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java index e8b4c9e83c98e..4ed689d4e8626 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java @@ -25,9 +25,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; @@ -64,7 +64,6 @@ public class AbstractUdfStreamOperatorLifecycleTest { "UDF::open", "OPERATOR::run", "UDF::run", - "OPERATOR::snapshotLegacyOperatorState", "OPERATOR::snapshotState", "OPERATOR::close", "UDF::close", @@ -84,7 +83,7 @@ public class AbstractUdfStreamOperatorLifecycleTest { "UDF::close"); private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " + - "getMetricGroup[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " + + "getMetricGroup[], getOperatorID[], initializeState[class org.apache.flink.runtime.checkpoint.OperatorSubtaskState], " + "notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " + "org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " + "org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " + @@ -92,7 +91,6 @@ public class AbstractUdfStreamOperatorLifecycleTest { "setup[class org.apache.flink.streaming.runtime.tasks.StreamTask, class " + "org.apache.flink.streaming.api.graph.StreamConfig, interface " + "org.apache.flink.streaming.api.operators.Output], " + - "snapshotLegacyOperatorState[long, long, class org.apache.flink.runtime.checkpoint.CheckpointOptions], " + "snapshotState[long, long, class org.apache.flink.runtime.checkpoint.CheckpointOptions]]"; private static final String ALL_METHODS_RICH_FUNCTION = "[close[], getIterationRuntimeContext[], getRuntimeContext[]" + @@ -132,6 +130,7 @@ public void testLifeCycleFull() throws Exception { MockSourceFunction srcFun = new MockSourceFunction(); cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, true)); + cfg.setOperatorID(new OperatorID()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig); @@ -154,6 +153,7 @@ public void testLifeCycleCancel() throws Exception { StreamConfig cfg = new StreamConfig(new Configuration()); MockSourceFunction srcFun = new MockSourceFunction(); cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, false)); + cfg.setOperatorID(new OperatorID()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig); @@ -204,7 +204,7 @@ public void close() throws Exception { } private static class LifecycleTrackingStreamSource> - extends StreamSource implements Serializable, StreamCheckpointedOperator { + extends StreamSource implements Serializable { private static final long serialVersionUID = 2431488948886850562L; private transient Thread testCheckpointer; @@ -262,12 +262,6 @@ public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); } - @Override - public StreamStateHandle snapshotLegacyOperatorState(long checkpointId, long timestamp, CheckpointOptions checkpointOptions) throws Exception { - ACTUAL_ORDER_TRACKING.add("OPERATOR::snapshotLegacyOperatorState"); - return super.snapshotLegacyOperatorState(checkpointId, timestamp, checkpointOptions); - } - @Override public void initializeState(StateInitializationContext context) throws Exception { ACTUAL_ORDER_TRACKING.add("OPERATOR::initializeState"); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyProcessWindowFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyProcessWindowFunctionTest.java deleted file mode 100644 index 7dba4af9bfef0..0000000000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyProcessWindowFunctionTest.java +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.operators; - -import org.apache.flink.api.common.JobExecutionResult; -import org.apache.flink.api.common.functions.FoldFunction; -import org.apache.flink.api.common.functions.util.ListCollector; -import org.apache.flink.api.common.state.FoldingState; -import org.apache.flink.api.common.state.FoldingStateDescriptor; -import org.apache.flink.api.common.state.KeyedStateStore; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.state.ReducingState; -import org.apache.flink.api.common.state.ReducingStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.ByteSerializer; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.source.SourceFunction; -import org.apache.flink.streaming.api.functions.windowing.FoldApplyProcessAllWindowFunction; -import org.apache.flink.streaming.api.functions.windowing.FoldApplyProcessWindowFunction; -import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction; -import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; -import org.apache.flink.streaming.api.graph.StreamGraph; -import org.apache.flink.streaming.api.graph.StreamGraphGenerator; -import org.apache.flink.streaming.api.transformations.OneInputTransformation; -import org.apache.flink.streaming.api.transformations.SourceTransformation; -import org.apache.flink.streaming.api.transformations.StreamTransformation; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessAllWindowFunction; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction; -import org.apache.flink.util.Collector; - -import org.junit.Assert; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; - -/** - * Tests for {@link FoldApplyProcessWindowFunction}. - */ -public class FoldApplyProcessWindowFunctionTest { - - /** - * Tests that the FoldWindowFunction gets the output type serializer set by the - * StreamGraphGenerator and checks that the FoldWindowFunction computes the correct result. - */ - @Test - public void testFoldWindowFunctionOutputTypeConfigurable() throws Exception{ - StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); - - List> transformations = new ArrayList<>(); - - int initValue = 1; - - FoldApplyProcessWindowFunction foldWindowFunction = new FoldApplyProcessWindowFunction<>( - initValue, - new FoldFunction() { - @Override - public Integer fold(Integer accumulator, Integer value) throws Exception { - return accumulator + value; - } - - }, - new ProcessWindowFunction() { - @Override - public void process(Integer integer, - Context context, - Iterable input, - Collector out) throws Exception { - for (Integer in: input) { - out.collect(in); - } - } - }, - BasicTypeInfo.INT_TYPE_INFO - ); - - AccumulatingProcessingTimeWindowOperator windowOperator = new AccumulatingProcessingTimeWindowOperator<>( - new InternalIterableProcessWindowFunction<>(foldWindowFunction), - new KeySelector() { - private static final long serialVersionUID = -7951310554369722809L; - - @Override - public Integer getKey(Integer value) throws Exception { - return value; - } - }, - IntSerializer.INSTANCE, - IntSerializer.INSTANCE, - 3000, - 3000 - ); - - SourceFunction sourceFunction = new SourceFunction(){ - - private static final long serialVersionUID = 8297735565464653028L; - - @Override - public void run(SourceContext ctx) throws Exception { - - } - - @Override - public void cancel() { - - } - }; - - SourceTransformation source = new SourceTransformation<>("", new StreamSource<>(sourceFunction), BasicTypeInfo.INT_TYPE_INFO, 1); - - transformations.add(new OneInputTransformation<>(source, "test", windowOperator, BasicTypeInfo.INT_TYPE_INFO, 1)); - - StreamGraph streamGraph = StreamGraphGenerator.generate(env, transformations); - - List result = new ArrayList<>(); - List input = new ArrayList<>(); - List expected = new ArrayList<>(); - - input.add(1); - input.add(2); - input.add(3); - - for (int value : input) { - initValue += value; - } - - expected.add(initValue); - - FoldApplyProcessWindowFunction.Context ctx = foldWindowFunction.new Context() { - @Override - public TimeWindow window() { - return new TimeWindow(0, 1); - } - - @Override - public long currentProcessingTime() { - return 0; - } - - @Override - public long currentWatermark() { - return 0; - } - - @Override - public KeyedStateStore windowState() { - return new DummyKeyedStateStore(); - } - - @Override - public KeyedStateStore globalState() { - return new DummyKeyedStateStore(); - } - }; - - foldWindowFunction.open(new Configuration()); - - foldWindowFunction.process(0, ctx, input, new ListCollector<>(result)); - - Assert.assertEquals(expected, result); - } - - /** - * Tests that the FoldWindowFunction gets the output type serializer set by the - * StreamGraphGenerator and checks that the FoldWindowFunction computes the correct result. - */ - @Test - public void testFoldAllWindowFunctionOutputTypeConfigurable() throws Exception{ - StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); - - List> transformations = new ArrayList<>(); - - int initValue = 1; - - FoldApplyProcessAllWindowFunction foldWindowFunction = new FoldApplyProcessAllWindowFunction<>( - initValue, - new FoldFunction() { - @Override - public Integer fold(Integer accumulator, Integer value) throws Exception { - return accumulator + value; - } - - }, - new ProcessAllWindowFunction() { - @Override - public void process(Context context, - Iterable input, - Collector out) throws Exception { - for (Integer in: input) { - out.collect(in); - } - } - }, - BasicTypeInfo.INT_TYPE_INFO - ); - - AccumulatingProcessingTimeWindowOperator windowOperator = new AccumulatingProcessingTimeWindowOperator<>( - new InternalIterableProcessAllWindowFunction<>(foldWindowFunction), - new KeySelector() { - private static final long serialVersionUID = -7951310554369722809L; - - @Override - public Byte getKey(Integer value) throws Exception { - return 0; - } - }, - ByteSerializer.INSTANCE, - IntSerializer.INSTANCE, - 3000, - 3000 - ); - - SourceFunction sourceFunction = new SourceFunction(){ - - private static final long serialVersionUID = 8297735565464653028L; - - @Override - public void run(SourceContext ctx) throws Exception { - - } - - @Override - public void cancel() { - - } - }; - - SourceTransformation source = new SourceTransformation<>("", new StreamSource<>(sourceFunction), BasicTypeInfo.INT_TYPE_INFO, 1); - - transformations.add(new OneInputTransformation<>(source, "test", windowOperator, BasicTypeInfo.INT_TYPE_INFO, 1)); - - StreamGraph streamGraph = StreamGraphGenerator.generate(env, transformations); - - List result = new ArrayList<>(); - List input = new ArrayList<>(); - List expected = new ArrayList<>(); - - input.add(1); - input.add(2); - input.add(3); - - for (int value : input) { - initValue += value; - } - - expected.add(initValue); - - FoldApplyProcessAllWindowFunction.Context ctx = foldWindowFunction.new Context() { - @Override - public TimeWindow window() { - return new TimeWindow(0, 1); - } - - @Override - public KeyedStateStore windowState() { - return new DummyKeyedStateStore(); - } - - @Override - public KeyedStateStore globalState() { - return new DummyKeyedStateStore(); - } - }; - - foldWindowFunction.open(new Configuration()); - - foldWindowFunction.process(ctx, input, new ListCollector<>(result)); - - Assert.assertEquals(expected, result); - } - - private static class DummyKeyedStateStore implements KeyedStateStore { - - @Override - public ValueState getState(ValueStateDescriptor stateProperties) { - return null; - } - - @Override - public ListState getListState(ListStateDescriptor stateProperties) { - return null; - } - - @Override - public ReducingState getReducingState(ReducingStateDescriptor stateProperties) { - return null; - } - - @Override - public FoldingState getFoldingState(FoldingStateDescriptor stateProperties) { - return null; - } - - @Override - public MapState getMapState(MapStateDescriptor stateProperties) { - return null; - } - } - - private static class DummyStreamExecutionEnvironment extends StreamExecutionEnvironment { - - @Override - public JobExecutionResult execute(String jobName) throws Exception { - return null; - } - } -} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyWindowFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyWindowFunctionTest.java deleted file mode 100644 index 7cf18ddd1a131..0000000000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/FoldApplyWindowFunctionTest.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.api.operators; - -import org.apache.flink.api.common.JobExecutionResult; -import org.apache.flink.api.common.functions.FoldFunction; -import org.apache.flink.api.common.functions.util.ListCollector; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.source.SourceFunction; -import org.apache.flink.streaming.api.functions.windowing.FoldApplyWindowFunction; -import org.apache.flink.streaming.api.functions.windowing.WindowFunction; -import org.apache.flink.streaming.api.graph.StreamGraph; -import org.apache.flink.streaming.api.graph.StreamGraphGenerator; -import org.apache.flink.streaming.api.transformations.OneInputTransformation; -import org.apache.flink.streaming.api.transformations.SourceTransformation; -import org.apache.flink.streaming.api.transformations.StreamTransformation; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction; -import org.apache.flink.util.Collector; - -import org.junit.Assert; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; - -/** - * Tests for {@link FoldApplyWindowFunction}. - */ -public class FoldApplyWindowFunctionTest { - - /** - * Tests that the FoldWindowFunction gets the output type serializer set by the - * StreamGraphGenerator and checks that the FoldWindowFunction computes the correct result. - */ - @Test - public void testFoldWindowFunctionOutputTypeConfigurable() throws Exception{ - StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); - - List> transformations = new ArrayList<>(); - - int initValue = 1; - - FoldApplyWindowFunction foldWindowFunction = new FoldApplyWindowFunction<>( - initValue, - new FoldFunction() { - private static final long serialVersionUID = -4849549768529720587L; - - @Override - public Integer fold(Integer accumulator, Integer value) throws Exception { - return accumulator + value; - } - }, - new WindowFunction() { - @Override - public void apply(Integer integer, - TimeWindow window, - Iterable input, - Collector out) throws Exception { - for (Integer in: input) { - out.collect(in); - } - } - }, - BasicTypeInfo.INT_TYPE_INFO - ); - - AccumulatingProcessingTimeWindowOperator windowOperator = new AccumulatingProcessingTimeWindowOperator<>( - new InternalIterableWindowFunction<>( - foldWindowFunction), - new KeySelector() { - private static final long serialVersionUID = -7951310554369722809L; - - @Override - public Integer getKey(Integer value) throws Exception { - return value; - } - }, - IntSerializer.INSTANCE, - IntSerializer.INSTANCE, - 3000, - 3000 - ); - - SourceFunction sourceFunction = new SourceFunction(){ - - private static final long serialVersionUID = 8297735565464653028L; - - @Override - public void run(SourceContext ctx) throws Exception { - - } - - @Override - public void cancel() { - - } - }; - - SourceTransformation source = new SourceTransformation<>("", new StreamSource<>(sourceFunction), BasicTypeInfo.INT_TYPE_INFO, 1); - - transformations.add(new OneInputTransformation<>(source, "test", windowOperator, BasicTypeInfo.INT_TYPE_INFO, 1)); - - StreamGraph streamGraph = StreamGraphGenerator.generate(env, transformations); - - List result = new ArrayList<>(); - List input = new ArrayList<>(); - List expected = new ArrayList<>(); - - input.add(1); - input.add(2); - input.add(3); - - for (int value : input) { - initValue += value; - } - - expected.add(initValue); - - foldWindowFunction.apply(0, new TimeWindow(0, 1), input, new ListCollector(result)); - - Assert.assertEquals(expected, result); - } - - private static class DummyStreamExecutionEnvironment extends StreamExecutionEnvironment { - - @Override - public JobExecutionResult execute(String jobName) throws Exception { - return null; - } - } -} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java index f9a1cd00ed091..a3cb6304f60cb 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java @@ -29,23 +29,23 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.datastream.AsyncDataStream; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; @@ -166,11 +166,11 @@ private void freeExecutor() { } @Override - public void asyncInvoke(final Integer input, final AsyncCollector collector) throws Exception { + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) throws Exception { executorService.submit(new Runnable() { @Override public void run() { - collector.collect(Collections.singletonList(input * 2)); + resultFuture.complete(Collections.singletonList(input * 2)); } }); } @@ -178,7 +178,7 @@ public void run() { /** * A special {@link AsyncFunction} without issuing - * {@link AsyncCollector#collect} until the latch counts to zero. + * {@link ResultFuture#complete} until the latch counts to zero. * This function is used in the testStateSnapshotAndRestore, ensuring * that {@link StreamElementQueueEntry} can stay * in the {@link StreamElementQueue} to be @@ -194,7 +194,7 @@ public LazyAsyncFunction() { } @Override - public void asyncInvoke(final Integer input, final AsyncCollector collector) throws Exception { + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) throws Exception { this.executorService.submit(new Runnable() { @Override public void run() { @@ -205,7 +205,7 @@ public void run() { // do nothing } - collector.collect(Collections.singletonList(input)); + resultFuture.complete(Collections.singletonList(input)); } }); } @@ -500,7 +500,9 @@ public void testStateSnapshotAndRestore() throws Exception { AsyncDataStream.OutputMode.ORDERED); final StreamConfig streamConfig = testHarness.getStreamConfig(); + OperatorID operatorID = new OperatorID(42L, 4711L); streamConfig.setStreamOperator(operator); + streamConfig.setOperatorID(operatorID); final AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( testHarness.jobConfig, @@ -540,7 +542,8 @@ public void testStateSnapshotAndRestore() throws Exception { // set the operator state from previous attempt into the restored one final OneInputStreamTask restoredTask = new OneInputStreamTask<>(); - restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles())); + TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles(); + restoredTask.setInitialState(subtaskStates); final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness<>(restoredTask, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); @@ -553,6 +556,7 @@ public void testStateSnapshotAndRestore() throws Exception { AsyncDataStream.OutputMode.ORDERED); restoredTaskHarness.getStreamConfig().setStreamOperator(restoredOperator); + restoredTaskHarness.getStreamConfig().setOperatorID(operatorID); restoredTaskHarness.invoke(); restoredTaskHarness.waitForTaskRunning(); @@ -595,7 +599,7 @@ public void testStateSnapshotAndRestore() throws Exception { private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment { private volatile long checkpointId; - private volatile SubtaskState checkpointStateHandles; + private volatile TaskStateSnapshot checkpointStateHandles; private final OneShotLatch checkpointLatch = new OneShotLatch(); @@ -614,7 +618,7 @@ public long getCheckpointId() { public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { this.checkpointId = checkpointId; this.checkpointStateHandles = checkpointStateHandles; @@ -625,7 +629,7 @@ public OneShotLatch getCheckpointLatch() { return checkpointLatch; } - public SubtaskState getCheckpointStateHandles() { + public TaskStateSnapshot getCheckpointStateHandles() { return checkpointStateHandles; } } @@ -850,8 +854,8 @@ public void testTimeoutCleanup() throws Exception { private static final long serialVersionUID = -3718276118074877073L; @Override - public void asyncInvoke(Integer input, AsyncCollector collector) throws Exception { - collector.collect(Collections.singletonList(input)); + public void asyncInvoke(Integer input, ResultFuture resultFuture) throws Exception { + resultFuture.complete(Collections.singletonList(input)); } }, timeout, @@ -945,8 +949,8 @@ private static class UserExceptionAsyncFunction implements AsyncFunction collector) throws Exception { - collector.collect(new Exception("Test exception")); + public void asyncInvoke(Integer input, ResultFuture resultFuture) throws Exception { + resultFuture.completeExceptionally(new Exception("Test exception")); } } @@ -1008,7 +1012,7 @@ private static class NoOpAsyncFunction implements AsyncFunction collector) throws Exception { + public void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception { // no op } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/EmitterTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/EmitterTest.java index da2d7959c83d4..7dedd14e6174d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/EmitterTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/EmitterTest.java @@ -18,7 +18,7 @@ package org.apache.flink.streaming.api.operators.async; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.async.queue.OrderedStreamElementQueue; import org.apache.flink.streaming.api.operators.async.queue.StreamElementQueue; @@ -115,9 +115,9 @@ public void testEmitterWithOrderedQueue() throws Exception { queue.put(watermark1); queue.put(record3); - record2.collect(Arrays.asList(3, 4)); - record1.collect(Arrays.asList(1, 2)); - record3.collect(Arrays.asList(5, 6)); + record2.complete(Arrays.asList(3, 4)); + record1.complete(Arrays.asList(1, 2)); + record3.complete(Arrays.asList(5, 6)); synchronized (lock) { while (!queue.isEmpty()) { @@ -133,7 +133,7 @@ public void testEmitterWithOrderedQueue() throws Exception { } /** - * Tests that the emitter handles exceptions occurring in the {@link AsyncCollector} correctly. + * Tests that the emitter handles exceptions occurring in the {@link ResultFuture} correctly. */ @Test public void testEmitterWithExceptions() throws Exception { @@ -167,8 +167,8 @@ public void testEmitterWithExceptions() throws Exception { queue.put(record2); queue.put(watermark1); - record2.collect(testException); - record1.collect(Arrays.asList(1)); + record2.completeExceptionally(testException); + record1.complete(Arrays.asList(1)); synchronized (lock) { while (!queue.isEmpty()) { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/OrderedStreamElementQueueTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/OrderedStreamElementQueueTest.java index f3b68c4b36155..c7b811a28ba67 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/OrderedStreamElementQueueTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/OrderedStreamElementQueueTest.java @@ -110,15 +110,15 @@ public void testCompletionOrder() throws Exception { Assert.assertFalse(pollOperation.isDone()); - entry2.collect(Collections.emptyList()); + entry2.complete(Collections.emptyList()); - entry4.collect(Collections.emptyList()); + entry4.complete(Collections.emptyList()); Thread.sleep(10L); Assert.assertEquals(4, queue.size()); - entry1.collect(Collections.emptyList()); + entry1.complete(Collections.emptyList()); Assert.assertEquals(expected, pollOperation.get()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/StreamElementQueueTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/StreamElementQueueTest.java index d3967567b7efd..7315f65d7dde4 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/StreamElementQueueTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/StreamElementQueueTest.java @@ -150,7 +150,7 @@ public void testPoll() throws InterruptedException { Assert.assertEquals(watermarkQueueEntry, queue.poll()); Assert.assertEquals(1, queue.size()); - streamRecordQueueEntry.collect(Collections.emptyList()); + streamRecordQueueEntry.complete(Collections.emptyList()); Assert.assertEquals(streamRecordQueueEntry, queue.poll()); @@ -191,7 +191,7 @@ public void testBlockingPut() throws Exception { // but it shouldn't ;-) Assert.assertFalse(putOperation.isDone()); - streamRecordQueueEntry.collect(Collections.emptyList()); + streamRecordQueueEntry.complete(Collections.emptyList()); // polling the completed head element frees the queue again Assert.assertEquals(streamRecordQueueEntry, queue.poll()); @@ -259,7 +259,7 @@ public void testBlockingPoll() throws Exception { Assert.assertFalse(pollOperation.isDone()); - streamRecordQueueEntry.collect(Collections.emptyList()); + streamRecordQueueEntry.complete(Collections.emptyList()); Assert.assertEquals(streamRecordQueueEntry, pollOperation.get()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/UnorderedStreamElementQueueTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/UnorderedStreamElementQueueTest.java index cc0bc309a35df..acc6b8ea8aa62 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/UnorderedStreamElementQueueTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/queue/UnorderedStreamElementQueueTest.java @@ -110,13 +110,13 @@ public void testCompletionOrder() throws Exception { executor); // this should not fulfill the poll, because R3 is behind W1 - record3.collect(Collections.emptyList()); + record3.complete(Collections.emptyList()); Thread.sleep(10L); Assert.assertFalse(firstPoll.isDone()); - record2.collect(Collections.emptyList()); + record2.complete(Collections.emptyList()); Assert.assertEquals(record2, firstPoll.get()); @@ -130,15 +130,15 @@ public void testCompletionOrder() throws Exception { }, executor); - record6.collect(Collections.emptyList()); - record4.collect(Collections.emptyList()); + record6.complete(Collections.emptyList()); + record4.complete(Collections.emptyList()); Thread.sleep(10L); // The future should not be completed because R1 has not been completed yet Assert.assertFalse(secondPoll.isDone()); - record1.collect(Collections.emptyList()); + record1.complete(Collections.emptyList()); Assert.assertEquals(record1, secondPoll.get()); @@ -180,7 +180,7 @@ public void testCompletionOrder() throws Exception { Assert.assertFalse(thirdPoll.isDone()); - record5.collect(Collections.emptyList()); + record5.complete(Collections.emptyList()); Assert.assertEquals(record5, thirdPoll.get()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java index 58898d8c740b1..6dd08f6d53893 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java @@ -437,7 +437,8 @@ public void testUserProvidedHashing() { StreamGraph streamGraph = env.getStreamGraph(); int idx = 1; for (JobVertex jobVertex : streamGraph.getJobGraph().getVertices()) { - Assert.assertEquals(jobVertex.getIdAlternatives().get(1).toString(), userHashes.get(idx)); + List idAlternatives = jobVertex.getIdAlternatives(); + Assert.assertEquals(idAlternatives.get(idAlternatives.size() - 1).toString(), userHashes.get(idx)); --idx; } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java index c2cf7f3f91ed5..491b23d17b057 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineSubsumedException; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -34,7 +35,6 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -1484,7 +1484,7 @@ long getLastReportedBytesBufferedInAlignment() { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception { + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { throw new UnsupportedOperationException("should never be called"); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java index 847db5cec006f..cde90104b91ac 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java @@ -22,13 +22,13 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.junit.Test; @@ -498,7 +498,7 @@ private CheckpointSequenceValidator(long... checkpointIDs) { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception { + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { throw new UnsupportedOperationException("should never be called"); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java index 6e3be0365fc33..65e59f8ac756c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; @@ -53,6 +54,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamMap mapOperator = new StreamMap<>(new DummyMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); testHarness.waitForTaskRunning(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java index 675ffa3570ba7..d621b0bb12adb 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.operators; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.runtime.tasks.AsyncExceptionHandler; @@ -53,6 +54,7 @@ public void testCustomTimeServiceProvider() throws Throwable { StreamMap mapOperator = new StreamMap<>(new StreamTaskTimerTest.DummyMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java deleted file mode 100644 index a57dcf197a38c..0000000000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java +++ /dev/null @@ -1,1116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windowing; - -import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.TaskInfo; -import org.apache.flink.api.common.accumulators.Accumulator; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.java.ClosureCleaner; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.configuration.CoreOptions; -import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; -import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; -import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; -import org.apache.flink.streaming.api.functions.windowing.WindowFunction; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction; -import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableWindowFunction; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; -import org.apache.flink.streaming.runtime.tasks.StreamTask; -import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.util.Collector; - -import org.junit.After; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -/** - * Tests for {@link AccumulatingProcessingTimeWindowOperator}. - */ -@SuppressWarnings({"serial"}) -@PrepareForTest(InternalIterableWindowFunction.class) -@RunWith(PowerMockRunner.class) -public class AccumulatingAlignedProcessingTimeWindowOperatorTest { - - @SuppressWarnings("unchecked") - private final InternalIterableWindowFunction mockFunction = mock(InternalIterableWindowFunction.class); - - @SuppressWarnings("unchecked") - private final KeySelector mockKeySelector = mock(KeySelector.class); - - private final KeySelector identitySelector = new KeySelector() { - @Override - public Integer getKey(Integer value) { - return value; - } - }; - - private final InternalIterableWindowFunction validatingIdentityFunction = - new InternalIterableWindowFunction<>(new WindowFunction() { - @Override - public void apply(Integer key, TimeWindow window, Iterable values, Collector out) throws Exception { - for (Integer val : values) { - assertEquals(key, val); - out.collect(val); - } - } - }); - - private final InternalIterableProcessWindowFunction validatingIdentityProcessFunction = - new InternalIterableProcessWindowFunction<>(new ProcessWindowFunction() { - @Override - public void process(Integer key, Context context, Iterable values, Collector out) throws Exception { - for (Integer val : values) { - assertEquals(key, val); - out.collect(val); - } - } - }); - - // ------------------------------------------------------------------------ - - public AccumulatingAlignedProcessingTimeWindowOperatorTest() { - ClosureCleaner.clean(identitySelector, false); - ClosureCleaner.clean(validatingIdentityFunction, false); - ClosureCleaner.clean(validatingIdentityProcessFunction, false); - } - - // ------------------------------------------------------------------------ - - @After - public void checkNoTriggerThreadsRunning() { - // make sure that all the threads we trigger are shut down - long deadline = System.currentTimeMillis() + 5000; - while (StreamTask.TRIGGER_THREAD_GROUP.activeCount() > 0 && System.currentTimeMillis() < deadline) { - try { - Thread.sleep(10); - } - catch (InterruptedException ignored) {} - } - - assertTrue("Not all trigger threads where properly shut down", - StreamTask.TRIGGER_THREAD_GROUP.activeCount() == 0); - } - - // ------------------------------------------------------------------------ - - @Test - public void testInvalidParameters() { - try { - assertInvalidParameter(-1L, -1L); - assertInvalidParameter(10000L, -1L); - assertInvalidParameter(-1L, 1000L); - assertInvalidParameter(1000L, 2000L); - - // actual internal slide is too low here: - assertInvalidParameter(1000L, 999L); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testWindowSizeAndSlide() { - try { - AccumulatingProcessingTimeWindowOperator op; - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000); - assertEquals(5000, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(1000, op.getPaneSize()); - assertEquals(5, op.getNumPanesPerWindow()); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000); - assertEquals(1000, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(1000, op.getPaneSize()); - assertEquals(1, op.getNumPanesPerWindow()); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000); - assertEquals(1500, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(500, op.getPaneSize()); - assertEquals(3, op.getNumPanesPerWindow()); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100); - assertEquals(1200, op.getWindowSize()); - assertEquals(1100, op.getWindowSlide()); - assertEquals(100, op.getPaneSize()); - assertEquals(12, op.getNumPanesPerWindow()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testWindowTriggerTimeAlignment() throws Exception { - - try { - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 1000 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 1000 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 500 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - assertEquals(0, op.getNextSlideTime() % 100); - assertEquals(0, op.getNextEvaluationTime() % 1100); - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testTumblingWindow() throws Exception { - try { - final int windowSize = 50; - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - currentTime = currentTime + 10; - testHarness.setProcessingTime(currentTime); - } - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(numElements, result.size()); - - Collections.sort(result); - for (int i = 0; i < numElements; i++) { - assertEquals(i, result.get(i).intValue()); - } - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testTumblingWindowWithProcessFunction() throws Exception { - try { - final int windowSize = 50; - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - currentTime = currentTime + 10; - testHarness.setProcessingTime(currentTime); - } - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(numElements, result.size()); - - Collections.sort(result); - for (int i = 0; i < numElements; i++) { - assertEquals(i, result.get(i).intValue()); - } - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testSlidingWindow() throws Exception { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - currentTime = currentTime + 10; - testHarness.setProcessingTime(currentTime); - } - - // get and verify the result - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - // if we kept this running, each element would be in the result three times (for each slide). - // we are closing the window before the final panes are through three times, so we may have less - // elements. - if (result.size() < numElements || result.size() > 3 * numElements) { - fail("Wrong number of results: " + result.size()); - } - - Collections.sort(result); - int lastNum = -1; - int lastCount = -1; - - for (int num : result) { - if (num == lastNum) { - lastCount++; - assertTrue(lastCount <= 3); - } - else { - lastNum = num; - lastCount = 1; - } - } - - testHarness.close(); - } - - @Test - public void testSlidingWindowWithProcessFunction() throws Exception { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - currentTime = currentTime + 10; - testHarness.setProcessingTime(currentTime); - } - - // get and verify the result - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - // if we kept this running, each element would be in the result three times (for each slide). - // we are closing the window before the final panes are through three times, so we may have less - // elements. - if (result.size() < numElements || result.size() > 3 * numElements) { - fail("Wrong number of results: " + result.size()); - } - - Collections.sort(result); - int lastNum = -1; - int lastCount = -1; - - for (int num : result) { - if (num == lastNum) { - lastCount++; - assertTrue(lastCount <= 3); - } - else { - lastNum = num; - lastCount = 1; - } - } - - testHarness.close(); - } - - @Test - public void testTumblingWindowSingleElements() throws Exception { - - try { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(0); - - testHarness.processElement(new StreamRecord<>(1)); - testHarness.processElement(new StreamRecord<>(2)); - - testHarness.setProcessingTime(50); - - testHarness.processElement(new StreamRecord<>(3)); - testHarness.processElement(new StreamRecord<>(4)); - testHarness.processElement(new StreamRecord<>(5)); - - testHarness.setProcessingTime(100); - - testHarness.processElement(new StreamRecord<>(6)); - - testHarness.setProcessingTime(200); - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(6, result.size()); - - Collections.sort(result); - assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), result); - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testTumblingWindowSingleElementsWithProcessFunction() throws Exception { - - try { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(0); - - testHarness.processElement(new StreamRecord<>(1)); - testHarness.processElement(new StreamRecord<>(2)); - - testHarness.setProcessingTime(50); - - testHarness.processElement(new StreamRecord<>(3)); - testHarness.processElement(new StreamRecord<>(4)); - testHarness.processElement(new StreamRecord<>(5)); - - testHarness.setProcessingTime(100); - - testHarness.processElement(new StreamRecord<>(6)); - - testHarness.setProcessingTime(200); - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(6, result.size()); - - Collections.sort(result); - assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), result); - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testSlidingWindowSingleElements() throws Exception { - try { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.setProcessingTime(0); - - testHarness.open(); - - testHarness.processElement(new StreamRecord<>(1)); - testHarness.processElement(new StreamRecord<>(2)); - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - assertEquals(6, result.size()); - - Collections.sort(result); - assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result); - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testSlidingWindowSingleElementsWithProcessFunction() throws Exception { - try { - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.setProcessingTime(0); - - testHarness.open(); - - testHarness.processElement(new StreamRecord<>(1)); - testHarness.processElement(new StreamRecord<>(2)); - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - - List result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - assertEquals(6, result.size()); - - Collections.sort(result); - assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result); - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowTumblingWithProcessFunction() { - try { - final int windowSize = 200; - - // tumbling window that triggers every 200 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - OneInputStreamOperatorTestHarness testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.open(); - - testHarness.setProcessingTime(0); - - // inject some elements - final int numElementsFirst = 700; - final int numElements = 1000; - for (int i = 0; i < numElementsFirst; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - // draw a snapshot and dispose the window - int beforeSnapShot = testHarness.getOutput().size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - assertTrue(afterSnapShot <= numElementsFirst); - - // inject some random elements, which should not show up in the state - for (int i = 0; i < 300; i++) { - testHarness.processElement(new StreamRecord<>(i + numElementsFirst)); - } - - testHarness.close(); - op.dispose(); - - // re-create the operator and restore the state - op = new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject some more elements - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.setProcessingTime(400); - - // get and verify the result - List finalResult = new ArrayList<>(); - finalResult.addAll(resultAtSnapshot); - List finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(finalPartialResult); - assertEquals(numElements, finalResult.size()); - - Collections.sort(finalResult); - for (int i = 0; i < numElements; i++) { - assertEquals(i, finalResult.get(i).intValue()); - } - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowTumbling() { - try { - final int windowSize = 200; - - // tumbling window that triggers every 200 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - OneInputStreamOperatorTestHarness testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.open(); - - testHarness.setProcessingTime(0); - - // inject some elements - final int numElementsFirst = 700; - final int numElements = 1000; - for (int i = 0; i < numElementsFirst; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - // draw a snapshot and dispose the window - int beforeSnapShot = testHarness.getOutput().size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - assertTrue(afterSnapShot <= numElementsFirst); - - // inject some random elements, which should not show up in the state - for (int i = 0; i < 300; i++) { - testHarness.processElement(new StreamRecord<>(i + numElementsFirst)); - } - - testHarness.close(); - op.dispose(); - - // re-create the operator and restore the state - op = new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSize); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject some more elements - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.setProcessingTime(400); - - // get and verify the result - List finalResult = new ArrayList<>(); - finalResult.addAll(resultAtSnapshot); - List finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(finalPartialResult); - assertEquals(numElements, finalResult.size()); - - Collections.sort(finalResult); - for (int i = 0; i < numElements; i++) { - assertEquals(i, finalResult.get(i).intValue()); - } - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowSlidingWithProcessFunction() { - try { - final int factor = 4; - final int windowSlide = 50; - final int windowSize = factor * windowSlide; - - // sliding window (200 msecs) every 50 msecs - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSlide); - - OneInputStreamOperatorTestHarness testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setProcessingTime(0); - - testHarness.setup(); - testHarness.open(); - - // inject some elements - final int numElements = 1000; - final int numElementsFirst = 700; - - for (int i = 0; i < numElementsFirst; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - // draw a snapshot - List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int beforeSnapShot = testHarness.getOutput().size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - - assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst); - - // inject the remaining elements - these should not influence the snapshot - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.close(); - - // re-create the operator and restore the state - op = new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityProcessFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSlide); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject again the remaining elements - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - testHarness.setProcessingTime(200); - testHarness.setProcessingTime(250); - testHarness.setProcessingTime(300); - testHarness.setProcessingTime(350); - - // get and verify the result - List finalResult = new ArrayList<>(resultAtSnapshot); - List finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(finalPartialResult); - assertEquals(factor * numElements, finalResult.size()); - - Collections.sort(finalResult); - for (int i = 0; i < factor * numElements; i++) { - assertEquals(i / factor, finalResult.get(i).intValue()); - } - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowSliding() { - try { - final int factor = 4; - final int windowSlide = 50; - final int windowSize = factor * windowSlide; - - // sliding window (200 msecs) every 50 msecs - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSlide); - - OneInputStreamOperatorTestHarness testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setProcessingTime(0); - - testHarness.setup(); - testHarness.open(); - - // inject some elements - final int numElements = 1000; - final int numElementsFirst = 700; - - for (int i = 0; i < numElementsFirst; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - // draw a snapshot - List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int beforeSnapShot = testHarness.getOutput().size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - - assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst); - - // inject the remaining elements - these should not influence the snapshot - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.close(); - - // re-create the operator and restore the state - op = new AccumulatingProcessingTimeWindowOperator<>( - validatingIdentityFunction, identitySelector, - IntSerializer.INSTANCE, IntSerializer.INSTANCE, - windowSize, windowSlide); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject again the remaining elements - for (int i = numElementsFirst; i < numElements; i++) { - testHarness.processElement(new StreamRecord<>(i)); - } - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - testHarness.setProcessingTime(200); - testHarness.setProcessingTime(250); - testHarness.setProcessingTime(300); - testHarness.setProcessingTime(350); - - // get and verify the result - List finalResult = new ArrayList<>(resultAtSnapshot); - List finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(finalPartialResult); - assertEquals(factor * numElements, finalResult.size()); - - Collections.sort(finalResult); - for (int i = 0; i < factor * numElements; i++) { - assertEquals(i / factor, finalResult.get(i).intValue()); - } - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testKeyValueStateInWindowFunction() { - try { - - StatefulFunction.globalCounts.clear(); - - // tumbling window that triggers every 20 milliseconds - AccumulatingProcessingTimeWindowOperator op = - new AccumulatingProcessingTimeWindowOperator<>( - new InternalIterableProcessWindowFunction<>(new StatefulFunction()), - identitySelector, - IntSerializer.INSTANCE, - IntSerializer.INSTANCE, - 50, - 50); - - OneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(0); - - testHarness.processElement(new StreamRecord<>(1)); - testHarness.processElement(new StreamRecord<>(2)); - - op.processElement(new StreamRecord<>(1)); - op.processElement(new StreamRecord<>(2)); - op.processElement(new StreamRecord<>(1)); - op.processElement(new StreamRecord<>(1)); - op.processElement(new StreamRecord<>(2)); - op.processElement(new StreamRecord<>(2)); - - testHarness.setProcessingTime(1000); - - List result = extractFromStreamRecords(testHarness.getOutput()); - assertEquals(8, result.size()); - - Collections.sort(result); - assertEquals(Arrays.asList(1, 1, 1, 1, 2, 2, 2, 2), result); - - assertEquals(4, StatefulFunction.globalCounts.get(1).intValue()); - assertEquals(4, StatefulFunction.globalCounts.get(2).intValue()); - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - // ------------------------------------------------------------------------ - - private void assertInvalidParameter(long windowSize, long windowSlide) { - try { - new AccumulatingProcessingTimeWindowOperator( - mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, - windowSize, windowSlide); - fail("This should fail with an IllegalArgumentException"); - } - catch (IllegalArgumentException e) { - // expected - } - catch (Exception e) { - fail("Wrong exception. Expected IllegalArgumentException but found " + e.getClass().getSimpleName()); - } - } - - // ------------------------------------------------------------------------ - - private static class StatefulFunction extends ProcessWindowFunction { - - // we use a concurrent map here even though there is no concurrency, to - // get "volatile" style access to entries - private static final Map globalCounts = new ConcurrentHashMap<>(); - - private ValueState state; - - @Override - public void open(Configuration parameters) { - assertNotNull(getRuntimeContext()); - state = getRuntimeContext().getState( - new ValueStateDescriptor<>("totalCount", Integer.class, 0)); - } - - @Override - public void process(Integer key, - Context context, - Iterable values, - Collector out) throws Exception { - for (Integer i : values) { - // we need to update this state before emitting elements. Else, the test's main - // thread will have received all output elements before the state is updated and - // the checks may fail - state.update(state.value() + 1); - globalCounts.put(key, state.value()); - - out.collect(i); - } - } - } - - // ------------------------------------------------------------------------ - - private static StreamTask createMockTask() { - Configuration configuration = new Configuration(); - configuration.setString(CoreOptions.STATE_BACKEND, "jobmanager"); - - StreamTask task = mock(StreamTask.class); - when(task.getAccumulatorMap()).thenReturn(new HashMap>()); - when(task.getName()).thenReturn("Test task name"); - when(task.getExecutionConfig()).thenReturn(new ExecutionConfig()); - - final TaskManagerRuntimeInfo mockTaskManagerRuntimeInfo = mock(TaskManagerRuntimeInfo.class); - when(mockTaskManagerRuntimeInfo.getConfiguration()).thenReturn(configuration); - - final Environment env = mock(Environment.class); - when(env.getTaskInfo()).thenReturn(new TaskInfo("Test task name", 1, 0, 1, 0)); - when(env.getUserClassLoader()).thenReturn(AggregatingAlignedProcessingTimeWindowOperatorTest.class.getClassLoader()); - when(env.getMetricGroup()).thenReturn(new UnregisteredTaskMetricsGroup()); - when(env.getTaskManagerInfo()).thenReturn(new TestingTaskManagerRuntimeInfo()); - - when(task.getEnvironment()).thenReturn(env); - return task; - } - - private static StreamTask createMockTaskWithTimer( - final ProcessingTimeService timerService) { - StreamTask mockTask = createMockTask(); - when(mockTask.getProcessingTimeService()).thenReturn(timerService); - return mockTask; - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private List extractFromStreamRecords(Iterable input) { - List result = new ArrayList<>(); - for (Object in : input) { - if (in instanceof StreamRecord) { - result.add((T) ((StreamRecord) in).getValue()); - } - } - return result; - } - - private static void shutdownTimerServiceAndWait(ProcessingTimeService timers) throws Exception { - timers.shutdownService(); - - while (!timers.isTerminated()) { - Thread.sleep(2); - } - } -} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java deleted file mode 100644 index 62f4f0baf4afe..0000000000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java +++ /dev/null @@ -1,863 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windowing; - -import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.functions.RichReduceFunction; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.java.ClosureCleaner; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.StreamTask; -import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; - -import org.junit.After; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; - -/** - * Tests for aligned {@link AggregatingProcessingTimeWindowOperator}. - */ -@SuppressWarnings("serial") -public class AggregatingAlignedProcessingTimeWindowOperatorTest { - - @SuppressWarnings("unchecked") - private final ReduceFunction mockFunction = mock(ReduceFunction.class); - - @SuppressWarnings("unchecked") - private final KeySelector mockKeySelector = mock(KeySelector.class); - - private final KeySelector, Integer> fieldOneSelector = - new KeySelector, Integer>() { - @Override - public Integer getKey(Tuple2 value) { - return value.f0; - } - }; - - private final ReduceFunction> sumFunction = new ReduceFunction>() { - @Override - public Tuple2 reduce(Tuple2 value1, Tuple2 value2) { - return new Tuple2<>(value1.f0, value1.f1 + value2.f1); - } - }; - - private final TypeSerializer> tupleSerializer = - new TupleTypeInfo>(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO) - .createSerializer(new ExecutionConfig()); - - private final Comparator> tupleComparator = new Comparator>() { - @Override - public int compare(Tuple2 o1, Tuple2 o2) { - int diff0 = o1.f0 - o2.f0; - int diff1 = o1.f1 - o2.f1; - return diff0 != 0 ? diff0 : diff1; - } - }; - - // ------------------------------------------------------------------------ - - public AggregatingAlignedProcessingTimeWindowOperatorTest() { - ClosureCleaner.clean(fieldOneSelector, false); - ClosureCleaner.clean(sumFunction, false); - } - - // ------------------------------------------------------------------------ - - @After - public void checkNoTriggerThreadsRunning() { - // make sure that all the threads we trigger are shut down - long deadline = System.currentTimeMillis() + 5000; - while (StreamTask.TRIGGER_THREAD_GROUP.activeCount() > 0 && System.currentTimeMillis() < deadline) { - try { - Thread.sleep(10); - } - catch (InterruptedException ignored) {} - } - - assertTrue("Not all trigger threads where properly shut down", - StreamTask.TRIGGER_THREAD_GROUP.activeCount() == 0); - } - - // ------------------------------------------------------------------------ - - @Test - public void testInvalidParameters() { - try { - assertInvalidParameter(-1L, -1L); - assertInvalidParameter(10000L, -1L); - assertInvalidParameter(-1L, 1000L); - assertInvalidParameter(1000L, 2000L); - - // actual internal slide is too low here: - assertInvalidParameter(1000L, 999L); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testWindowSizeAndSlide() { - try { - AggregatingProcessingTimeWindowOperator op; - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000); - assertEquals(5000, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(1000, op.getPaneSize()); - assertEquals(5, op.getNumPanesPerWindow()); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000); - assertEquals(1000, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(1000, op.getPaneSize()); - assertEquals(1, op.getNumPanesPerWindow()); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000); - assertEquals(1500, op.getWindowSize()); - assertEquals(1000, op.getWindowSlide()); - assertEquals(500, op.getPaneSize()); - assertEquals(3, op.getNumPanesPerWindow()); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100); - assertEquals(1200, op.getWindowSize()); - assertEquals(1100, op.getWindowSlide()); - assertEquals(100, op.getPaneSize()); - assertEquals(12, op.getNumPanesPerWindow()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testWindowTriggerTimeAlignment() throws Exception { - try { - - AggregatingProcessingTimeWindowOperator op = - new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000); - - KeyedOneInputStreamOperatorTestHarness testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 1000 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 1000 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 500 == 0); - assertTrue(op.getNextEvaluationTime() % 1000 == 0); - testHarness.close(); - - op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100); - - testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, mockKeySelector, BasicTypeInfo.STRING_TYPE_INFO); - testHarness.open(); - - assertTrue(op.getNextSlideTime() % 100 == 0); - assertTrue(op.getNextEvaluationTime() % 1100 == 0); - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testTumblingWindowUniqueElements() throws Exception { - - try { - final int windowSize = 50; - - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSize); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, fieldOneSelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - currentTime = currentTime + 10; - testHarness.setProcessingTime(currentTime); - } - - // get and verify the result - List> result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(numElements, result.size()); - - testHarness.close(); - - Collections.sort(result, tupleComparator); - for (int i = 0; i < numElements; i++) { - assertEquals(i, result.get(i).f0.intValue()); - assertEquals(i, result.get(i).f1.intValue()); - } - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testTumblingWindowDuplicateElements() throws Exception { - try { - final int windowSize = 50; - - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSize); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, fieldOneSelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.setProcessingTime(0); - testHarness.open(); - - final int numWindows = 10; - - long previousNextTime = 0; - int window = 1; - - long currentTime = 0; - - while (window <= numWindows) { - long nextTime = op.getNextEvaluationTime(); - int val = ((int) nextTime) ^ ((int) (nextTime >>> 32)); - - StreamRecord> next = new StreamRecord<>(new Tuple2<>(val, val)); - testHarness.processElement(next); - - if (nextTime != previousNextTime) { - window++; - previousNextTime = nextTime; - } - currentTime = currentTime + 1; - testHarness.setProcessingTime(currentTime); - } - - testHarness.setProcessingTime(currentTime + 100); - - List> result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - testHarness.close(); - - // we have ideally one element per window. we may have more, when we emitted a value into the - // successive window (corner case), so we can have twice the number of elements, in the worst case. - assertTrue(result.size() >= numWindows && result.size() <= 2 * numWindows); - - // deduplicate for more accurate checks - HashSet> set = new HashSet<>(result); - assertTrue(set.size() == 10); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testSlidingWindow() throws Exception { - try { - // tumbling window that triggers every 20 milliseconds - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - 150, 50); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, fieldOneSelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - final int numElements = 1000; - - long currentTime = 0; - - for (int i = 0; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - currentTime = currentTime + 1; - testHarness.setProcessingTime(currentTime); - } - - // get and verify the result - List> result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - - testHarness.close(); - - // every element can occur between one and three times - if (result.size() < numElements || result.size() > 3 * numElements) { - System.out.println(result); - fail("Wrong number of results: " + result.size()); - } - - Collections.sort(result, tupleComparator); - int lastNum = -1; - int lastCount = -1; - - for (Tuple2 val : result) { - assertEquals(val.f0, val.f1); - - if (val.f0 == lastNum) { - lastCount++; - assertTrue(lastCount <= 3); - } - else { - lastNum = val.f0; - lastCount = 1; - } - } - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testSlidingWindowSingleElements() throws Exception { - try { - // tumbling window that triggers every 20 milliseconds - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, 150, 50); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, fieldOneSelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(0); - - StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, 1)); - testHarness.processElement(next1); - - StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, 2)); - testHarness.processElement(next2); - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - - List> result = extractFromStreamRecords(testHarness.extractOutputStreamRecords()); - assertEquals(6, result.size()); - - Collections.sort(result, tupleComparator); - assertEquals(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(1, 1), - new Tuple2<>(1, 1), - new Tuple2<>(2, 2), - new Tuple2<>(2, 2), - new Tuple2<>(2, 2) - ), result); - - testHarness.close(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testPropagateExceptionsFromProcessElement() throws Exception { - - try { - ReduceFunction> failingFunction = new FailingFunction(100); - - // the operator has a window time that is so long that it will not fire in this test - final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000; - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - failingFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - hundredYears, hundredYears); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(op, fieldOneSelector, BasicTypeInfo.INT_TYPE_INFO); - - testHarness.open(); - - for (int i = 0; i < 100; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(1, 1)); - testHarness.processElement(next); - } - - try { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(1, 1)); - testHarness.processElement(next); - fail("This fail with an exception"); - } - catch (Exception e) { - assertTrue(e.getMessage().contains("Artificial Test Exception")); - } - - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowTumbling() { - try { - final int windowSize = 200; - - // tumbling window that triggers every 50 milliseconds - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSize); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setProcessingTime(0); - - testHarness.setup(); - testHarness.open(); - - // inject some elements - final int numElementsFirst = 700; - final int numElements = 1000; - - for (int i = 0; i < numElementsFirst; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - // draw a snapshot - List> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int beforeSnapShot = resultAtSnapshot.size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - - assertTrue(resultAtSnapshot.size() <= numElementsFirst); - - // inject some random elements, which should not show up in the state - for (int i = numElementsFirst; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - testHarness.close(); - op.dispose(); - - // re-create the operator and restore the state - op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSize); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject the remaining elements - for (int i = numElementsFirst; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - testHarness.setProcessingTime(200); - - // get and verify the result - List> finalResult = new ArrayList<>(resultAtSnapshot); - List> partialFinalResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(partialFinalResult); - assertEquals(numElements, finalResult.size()); - - Collections.sort(finalResult, tupleComparator); - for (int i = 0; i < numElements; i++) { - assertEquals(i, finalResult.get(i).f0.intValue()); - assertEquals(i, finalResult.get(i).f1.intValue()); - } - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void checkpointRestoreWithPendingWindowSliding() { - try { - final int factor = 4; - final int windowSlide = 50; - final int windowSize = factor * windowSlide; - - // sliding window (200 msecs) every 50 msecs - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSlide); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setProcessingTime(0); - - testHarness.setup(); - testHarness.open(); - - // inject some elements - final int numElements = 1000; - final int numElementsFirst = 700; - - for (int i = 0; i < numElementsFirst; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - // draw a snapshot - List> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); - int beforeSnapShot = resultAtSnapshot.size(); - StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis()); - int afterSnapShot = testHarness.getOutput().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - - assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst); - - // inject the remaining elements - these should not influence the snapshot - for (int i = numElementsFirst; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - testHarness.close(); - op.dispose(); - - // re-create the operator and restore the state - op = new AggregatingProcessingTimeWindowOperator<>( - sumFunction, fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, - windowSize, windowSlide); - - testHarness = new OneInputStreamOperatorTestHarness<>(op); - - testHarness.setup(); - testHarness.restore(state); - testHarness.open(); - - // inject again the remaining elements - for (int i = numElementsFirst; i < numElements; i++) { - StreamRecord> next = new StreamRecord<>(new Tuple2<>(i, i)); - testHarness.processElement(next); - } - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - testHarness.setProcessingTime(200); - testHarness.setProcessingTime(250); - testHarness.setProcessingTime(300); - testHarness.setProcessingTime(350); - testHarness.setProcessingTime(400); - - // get and verify the result - List> finalResult = new ArrayList<>(resultAtSnapshot); - List> partialFinalResult = extractFromStreamRecords(testHarness.getOutput()); - finalResult.addAll(partialFinalResult); - assertEquals(numElements * factor, finalResult.size()); - - Collections.sort(finalResult, tupleComparator); - for (int i = 0; i < factor * numElements; i++) { - assertEquals(i / factor, finalResult.get(i).f0.intValue()); - assertEquals(i / factor, finalResult.get(i).f1.intValue()); - } - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testKeyValueStateInWindowFunctionTumbling() { - try { - final long twoSeconds = 2000; - - StatefulFunction.globalCounts.clear(); - - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - new StatefulFunction(), fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, twoSeconds, twoSeconds); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - op, - fieldOneSelector, - BasicTypeInfo.INT_TYPE_INFO); - - testHarness.setProcessingTime(0); - testHarness.open(); - - // because the window interval is so large, everything should be in one window - // and aggregate into one value per key - - for (int i = 0; i < 10; i++) { - StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, i)); - testHarness.processElement(next1); - - StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, i)); - testHarness.processElement(next2); - } - - testHarness.setProcessingTime(1000); - - int count1 = StatefulFunction.globalCounts.get(1); - int count2 = StatefulFunction.globalCounts.get(2); - - assertTrue(count1 >= 2 && count1 <= 2 * 10); - assertEquals(count1, count2); - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testKeyValueStateInWindowFunctionSliding() { - try { - final int factor = 2; - final int windowSlide = 50; - final int windowSize = factor * windowSlide; - - StatefulFunction.globalCounts.clear(); - - AggregatingProcessingTimeWindowOperator> op = - new AggregatingProcessingTimeWindowOperator<>( - new StatefulFunction(), fieldOneSelector, - IntSerializer.INSTANCE, tupleSerializer, windowSize, windowSlide); - - KeyedOneInputStreamOperatorTestHarness, Tuple2> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - op, - fieldOneSelector, - BasicTypeInfo.INT_TYPE_INFO); - - testHarness.setProcessingTime(0); - - testHarness.open(); - - // because the window interval is so large, everything should be in one window - // and aggregate into one value per key - final int numElements = 100; - - // because we do not release the lock here, these elements - for (int i = 0; i < numElements; i++) { - - StreamRecord> next1 = new StreamRecord<>(new Tuple2<>(1, i)); - StreamRecord> next2 = new StreamRecord<>(new Tuple2<>(2, i)); - StreamRecord> next3 = new StreamRecord<>(new Tuple2<>(1, i)); - StreamRecord> next4 = new StreamRecord<>(new Tuple2<>(2, i)); - - testHarness.processElement(next1); - testHarness.processElement(next2); - testHarness.processElement(next3); - testHarness.processElement(next4); - } - - testHarness.setProcessingTime(50); - testHarness.setProcessingTime(100); - testHarness.setProcessingTime(150); - testHarness.setProcessingTime(200); - - int count1 = StatefulFunction.globalCounts.get(1); - int count2 = StatefulFunction.globalCounts.get(2); - - assertTrue(count1 >= 2 && count1 <= 2 * numElements); - assertEquals(count1, count2); - - testHarness.close(); - op.dispose(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - // ------------------------------------------------------------------------ - - private void assertInvalidParameter(long windowSize, long windowSlide) { - try { - new AggregatingProcessingTimeWindowOperator<>( - mockFunction, mockKeySelector, - StringSerializer.INSTANCE, StringSerializer.INSTANCE, - windowSize, windowSlide); - fail("This should fail with an IllegalArgumentException"); - } - catch (IllegalArgumentException e) { - // expected - } - catch (Exception e) { - fail("Wrong exception. Expected IllegalArgumentException but found " + e.getClass().getSimpleName()); - } - } - - // ------------------------------------------------------------------------ - - private static class FailingFunction implements ReduceFunction> { - - private final int failAfterElements; - - private int numElements; - - FailingFunction(int failAfterElements) { - this.failAfterElements = failAfterElements; - } - - @Override - public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { - numElements++; - - if (numElements >= failAfterElements) { - throw new Exception("Artificial Test Exception"); - } - - return new Tuple2<>(value1.f0, value1.f1 + value2.f1); - } - } - - // ------------------------------------------------------------------------ - - private static class StatefulFunction extends RichReduceFunction> { - - private static final Map globalCounts = new ConcurrentHashMap<>(); - - private ValueState state; - - @Override - public void open(Configuration parameters) { - assertNotNull(getRuntimeContext()); - - // start with one, so the final count is correct and we test that we do not - // initialize with 0 always by default - state = getRuntimeContext().getState(new ValueStateDescriptor<>("totalCount", Integer.class, 1)); - } - - @Override - public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { - state.update(state.value() + 1); - globalCounts.put(value1.f0, state.value()); - - return new Tuple2<>(value1.f0, value1.f1 + value2.f1); - } - } - - // ------------------------------------------------------------------------ - - @SuppressWarnings({"unchecked", "rawtypes"}) - private List extractFromStreamRecords(Iterable input) { - List result = new ArrayList<>(); - for (Object in : input) { - if (in instanceof StreamRecord) { - result.add((T) ((StreamRecord) in).getValue()); - } - } - return result; - } -} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java index a7c6f473a248b..f967a5b2b8d01 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java @@ -61,7 +61,6 @@ import org.apache.flink.util.Collector; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.util.concurrent.TimeUnit; @@ -310,31 +309,6 @@ public void testReduceProcessingTime() throws Exception { processElementAndEnsureOutput(winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1)); } - - /** - * Ignored because we currently don't have the fast processing-time window operator. - */ - @Test - @SuppressWarnings("rawtypes") - @Ignore - public void testReduceFastProcessingTime() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); - - DataStream> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2)); - - DataStream> window = source - .windowAll(SlidingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS))) - .reduce(new DummyReducer()); - - OneInputTransformation, Tuple2> transform = - (OneInputTransformation, Tuple2>) window.getTransformation(); - OneInputStreamOperator, Tuple2> operator = transform.getOperator(); - Assert.assertTrue(operator instanceof AggregatingProcessingTimeWindowOperator); - - processElementAndEnsureOutput(operator, null, BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1)); - } - @Test @SuppressWarnings("rawtypes") public void testReduceWithWindowFunctionEventTime() throws Exception { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ContinuousEventTimeTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ContinuousEventTimeTriggerTest.java index 9c14a9fdd2598..f0af9c2b5bf09 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ContinuousEventTimeTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ContinuousEventTimeTriggerTest.java @@ -25,7 +25,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import java.util.Collection; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CountTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CountTriggerTest.java index 38dd01d72057d..47fd9c228f5fc 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CountTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CountTriggerTest.java @@ -23,7 +23,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeSessionWindowsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeSessionWindowsTest.java index 23af8384b306f..5c4c989338f50 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeSessionWindowsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeSessionWindowsTest.java @@ -28,7 +28,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import org.mockito.Matchers; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeTriggerTest.java index 2bcc19284c68f..f54367b699a2a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/EventTimeTriggerTest.java @@ -23,7 +23,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/MergingWindowSetTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/MergingWindowSetTest.java index 0c45d0318133f..019facabf5975 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/MergingWindowSetTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/MergingWindowSetTest.java @@ -30,7 +30,8 @@ import org.apache.flink.streaming.api.windowing.triggers.Trigger; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import org.mockito.Matchers; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeSessionWindowsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeSessionWindowsTest.java index ceda3b90cde35..f49799cee99e3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeSessionWindowsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeSessionWindowsTest.java @@ -28,7 +28,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.util.TestLogger; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import org.mockito.Matchers; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeTriggerTest.java index 791eb424a6eda..7e78854583aed 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/ProcessingTimeTriggerTest.java @@ -23,7 +23,8 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Lists; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TimeWindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TimeWindowTranslationTest.java index dc0e21ca43d39..d525ba6d665f1 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TimeWindowTranslationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TimeWindowTranslationTest.java @@ -29,13 +29,10 @@ import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.WindowedStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.transformations.OneInputTransformation; -import org.apache.flink.streaming.api.windowing.assigners.SlidingAlignedProcessingTimeWindows; import org.apache.flink.streaming.api.windowing.assigners.SlidingEventTimeWindows; -import org.apache.flink.streaming.api.windowing.assigners.TumblingAlignedProcessingTimeWindows; import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows; import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.streaming.api.windowing.triggers.EventTimeTrigger; @@ -43,7 +40,6 @@ import org.apache.flink.util.Collector; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.util.concurrent.TimeUnit; @@ -96,60 +92,6 @@ public void apply(Tuple tuple, Assert.assertTrue(operator2 instanceof WindowOperator); } - /** - * These tests ensure that the fast aligned time windows operator is used if the - * conditions are right. - */ - @Test - public void testReduceAlignedTimeWindows() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); - - DataStream> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2)); - - DummyReducer reducer = new DummyReducer(); - - DataStream> window1 = source - .keyBy(0) - .window(SlidingAlignedProcessingTimeWindows.of(Time.of(1000, TimeUnit.MILLISECONDS), Time.of(100, TimeUnit.MILLISECONDS))) - .reduce(reducer); - - OneInputTransformation, Tuple2> transform1 = (OneInputTransformation, Tuple2>) window1.getTransformation(); - OneInputStreamOperator, Tuple2> operator1 = transform1.getOperator(); - Assert.assertTrue(operator1 instanceof AggregatingProcessingTimeWindowOperator); - } - - /** - * These tests ensure that the fast aligned time windows operator is used if the - * conditions are right. - */ - @Test - public void testApplyAlignedTimeWindows() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); - - DataStream> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2)); - - DataStream> window1 = source - .keyBy(0) - .window(TumblingAlignedProcessingTimeWindows.of(Time.of(1000, TimeUnit.MILLISECONDS))) - .apply(new WindowFunction, Tuple2, Tuple, TimeWindow>() { - private static final long serialVersionUID = 1L; - - @Override - public void apply(Tuple tuple, - TimeWindow window, - Iterable> values, - Collector> out) throws Exception { - - } - }); - - OneInputTransformation, Tuple2> transform1 = (OneInputTransformation, Tuple2>) window1.getTransformation(); - OneInputStreamOperator, Tuple2> operator1 = transform1.getOperator(); - Assert.assertTrue(operator1 instanceof AccumulatingProcessingTimeWindowOperator); - } - @Test @SuppressWarnings("rawtypes") public void testReduceEventTimeWindows() throws Exception { @@ -232,49 +174,6 @@ public void apply(Tuple tuple, Assert.assertTrue(winOperator1.getStateDescriptor() instanceof ListStateDescriptor); } - /** - * These tests ensure that the fast aligned time windows operator is used if the - * conditions are right. - * - *

TODO: update once the fast aligned time windows operator is in - */ - @Ignore - @Test - public void testNonParallelFastTimeWindows() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - - DataStream> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2)); - - DummyReducer reducer = new DummyReducer(); - - DataStream> window1 = source - .timeWindowAll(Time.of(1000, TimeUnit.MILLISECONDS), - Time.of(100, TimeUnit.MILLISECONDS)) - .reduce(reducer); - - OneInputTransformation, Tuple2> transform1 = (OneInputTransformation, Tuple2>) window1.getTransformation(); - OneInputStreamOperator, Tuple2> operator1 = transform1.getOperator(); - Assert.assertTrue(operator1 instanceof AggregatingProcessingTimeWindowOperator); - - DataStream> window2 = source - .timeWindowAll(Time.of(1000, TimeUnit.MILLISECONDS)) - .apply(new AllWindowFunction, Tuple2, TimeWindow>() { - private static final long serialVersionUID = 1L; - - @Override - public void apply( - TimeWindow window, - Iterable> values, - Collector> out) throws Exception { - - } - }); - - OneInputTransformation, Tuple2> transform2 = (OneInputTransformation, Tuple2>) window2.getTransformation(); - OneInputStreamOperator, Tuple2> operator2 = transform2.getOperator(); - Assert.assertTrue(operator2 instanceof AccumulatingProcessingTimeWindowOperator); - } - // ------------------------------------------------------------------------ // UDFs // ------------------------------------------------------------------------ diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorMigrationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorMigrationTest.java index 9f1906445ec6a..d7df479094044 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorMigrationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorMigrationTest.java @@ -29,7 +29,6 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.TypeInfoParser; import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.datastream.LegacyWindowOperatorType; import org.apache.flink.streaming.api.functions.windowing.PassThroughWindowFunction; import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction; import org.apache.flink.streaming.api.functions.windowing.WindowFunction; @@ -84,7 +83,7 @@ public class WindowOperatorMigrationTest { @Parameterized.Parameters(name = "Migration Savepoint: {0}") public static Collection parameters () { - return Arrays.asList(MigrationVersion.v1_1, MigrationVersion.v1_2, MigrationVersion.v1_3); + return Arrays.asList(MigrationVersion.v1_2, MigrationVersion.v1_3); } /** @@ -753,219 +752,6 @@ public void testRestoreApplyProcessingTimeWindows() throws Exception { testHarness.close(); } - /** - * Manually run this to write binary snapshot data. - */ - @Ignore - @Test - public void writeAggregatingAlignedProcessingTimeWindowsSnapshot() throws Exception { - TypeInformation> inputType = TypeInfoParser.parse("Tuple2"); - - AggregatingProcessingTimeWindowOperator> operator = - new AggregatingProcessingTimeWindowOperator<>( - new ReduceFunction>() { - private static final long serialVersionUID = -8913160567151867987L; - - @Override - public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { - return new Tuple2<>(value1.f0, value1.f1 + value2.f1); - } - }, - new TupleKeySelector(), - BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), - inputType.createSerializer(new ExecutionConfig()), - 3000, - 3000); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(operator, new TupleKeySelector(), BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(3); - - // timestamp is ignored in processing time - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), Long.MAX_VALUE)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 7000)); - - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - - // do a snapshot, close and restore again - OperatorStateHandles snapshot = testHarness.snapshot(0, 0); - OperatorSnapshotUtil.writeStateHandle( - snapshot, - "src/test/resources/win-op-migration-test-aggr-aligned-flink" + flinkGenerateSavepointVersion + "-snapshot"); - testHarness.close(); - } - - @Test - public void testRestoreAggregatingAlignedProcessingTimeWindows() throws Exception { - final int windowSize = 3; - - TypeInformation> inputType = TypeInfoParser.parse("Tuple2"); - - ReducingStateDescriptor> stateDesc = new ReducingStateDescriptor<>("window-contents", - new SumReducer(), - inputType.createSerializer(new ExecutionConfig())); - - WindowOperator, Tuple2, Tuple2, TimeWindow> operator = new WindowOperator<>( - TumblingProcessingTimeWindows.of(Time.of(windowSize, TimeUnit.SECONDS)), - new TimeWindow.Serializer(), - new TupleKeySelector(), - BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), - stateDesc, - new InternalSingleValueWindowFunction<>(new PassThroughWindowFunction>()), - ProcessingTimeTrigger.create(), - 0, - null /* late data output tag */, - LegacyWindowOperatorType.FAST_AGGREGATING); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(operator, new TupleKeySelector(), BasicTypeInfo.STRING_TYPE_INFO); - - ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); - - testHarness.setup(); - - MigrationTestUtil.restoreFromSnapshot( - testHarness, - OperatorSnapshotUtil.getResourceFilename( - "win-op-migration-test-aggr-aligned-flink" + testMigrateVersion + "-snapshot"), - testMigrateVersion); - - testHarness.open(); - - testHarness.setProcessingTime(5000); - - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key2", 3), 2999)); - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key1", 2), 2999)); - - TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator()); - - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - - testHarness.setProcessingTime(7000); - - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key1", 3), 5999)); - - TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator()); - - testHarness.close(); - } - - /** - * Manually run this to write binary snapshot data. - */ - @Ignore - @Test - public void writeAlignedProcessingTimeWindowsSnapshot() throws Exception { - TypeInformation> inputType = TypeInfoParser.parse("Tuple2"); - - AccumulatingProcessingTimeWindowOperator, Tuple2> operator = - new AccumulatingProcessingTimeWindowOperator<>( - new InternalIterableWindowFunction<>(new WindowFunction, Tuple2, String, TimeWindow>() { - - private static final long serialVersionUID = 6551516443265733803L; - - @Override - public void apply(String s, TimeWindow window, Iterable> input, Collector> out) throws Exception { - int sum = 0; - for (Tuple2 anInput : input) { - sum += anInput.f1; - } - out.collect(new Tuple2<>(s, sum)); - } - }), - new TupleKeySelector(), - BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), - inputType.createSerializer(new ExecutionConfig()), - 3000, - 3000); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(operator, new TupleKeySelector(), BasicTypeInfo.STRING_TYPE_INFO); - - testHarness.open(); - - testHarness.setProcessingTime(3); - - // timestamp is ignored in processing time - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), Long.MAX_VALUE)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 7000)); - - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - - // do a snapshot, close and restore again - OperatorStateHandles snapshot = testHarness.snapshot(0, 0); - OperatorSnapshotUtil.writeStateHandle( - snapshot, - "src/test/resources/win-op-migration-test-accum-aligned-flink" + flinkGenerateSavepointVersion + "-snapshot"); - testHarness.close(); - } - - @Test - public void testRestoreAccumulatingAlignedProcessingTimeWindows() throws Exception { - final int windowSize = 3; - - TypeInformation> inputType = TypeInfoParser.parse("Tuple2"); - - ReducingStateDescriptor> stateDesc = new ReducingStateDescriptor<>("window-contents", - new SumReducer(), - inputType.createSerializer(new ExecutionConfig())); - - WindowOperator, Tuple2, Tuple2, TimeWindow> operator = new WindowOperator<>( - TumblingProcessingTimeWindows.of(Time.of(windowSize, TimeUnit.SECONDS)), - new TimeWindow.Serializer(), - new TupleKeySelector(), - BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), - stateDesc, - new InternalSingleValueWindowFunction<>(new PassThroughWindowFunction>()), - ProcessingTimeTrigger.create(), - 0, - null /* late data output tag */, - LegacyWindowOperatorType.FAST_ACCUMULATING); - - OneInputStreamOperatorTestHarness, Tuple2> testHarness = - new KeyedOneInputStreamOperatorTestHarness<>(operator, new TupleKeySelector(), BasicTypeInfo.STRING_TYPE_INFO); - - ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); - - testHarness.setup(); - - MigrationTestUtil.restoreFromSnapshot( - testHarness, - OperatorSnapshotUtil.getResourceFilename( - "win-op-migration-test-accum-aligned-flink" + testMigrateVersion + "-snapshot"), - testMigrateVersion); - - testHarness.open(); - - testHarness.setProcessingTime(5000); - - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key2", 3), 2999)); - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key1", 2), 2999)); - - TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator()); - - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 1), 7000)); - - testHarness.setProcessingTime(7000); - - expectedOutput.add(new StreamRecord<>(new Tuple2<>("key1", 3), 5999)); - - TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator()); - - testHarness.close(); - } - private static class TupleKeySelector implements KeySelector, String> { private static final long serialVersionUID = 1L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java index 42c6c6f9aeb78..acdf45a635b49 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java @@ -68,8 +68,9 @@ import org.apache.flink.util.OutputTag; import org.apache.flink.util.TestLogger; -import com.google.common.base.Joiner; -import com.google.common.collect.Iterables; +import org.apache.flink.shaded.guava18.com.google.common.base.Joiner; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; + import org.junit.Assert; import org.junit.Test; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java index 8748ed4da3b56..821438e993ef5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java @@ -63,7 +63,6 @@ import org.apache.flink.util.Collector; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.util.concurrent.TimeUnit; @@ -336,32 +335,6 @@ public void testReduceProcessingTime() throws Exception { processElementAndEnsureOutput(winOperator, winOperator.getKeySelector(), BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1)); } - - /** - * Ignored because we currently don't have the fast processing-time window operator. - */ - @Test - @SuppressWarnings("rawtypes") - @Ignore - public void testReduceFastProcessingTime() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); - - DataStream> source = env.fromElements(Tuple2.of("hello", 1), Tuple2.of("hello", 2)); - - DataStream> window = source - .keyBy(new TupleKeySelector()) - .window(SlidingProcessingTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS))) - .reduce(new DummyReducer()); - - OneInputTransformation, Tuple2> transform = - (OneInputTransformation, Tuple2>) window.getTransformation(); - OneInputStreamOperator, Tuple2> operator = transform.getOperator(); - Assert.assertTrue(operator instanceof AggregatingProcessingTimeWindowOperator); - - processElementAndEnsureOutput(operator, null, BasicTypeInfo.STRING_TYPE_INFO, new Tuple2<>("hello", 1)); - } - @Test @SuppressWarnings("rawtypes") public void testReduceWithWindowFunctionEventTime() throws Exception { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java index 51328abbebc12..82642eab4dcc7 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; @@ -45,6 +46,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -93,6 +95,7 @@ public void testBlockingNonInterruptibleCheckpoint() throws Exception { Configuration taskConfig = new Configuration(); StreamConfig cfg = new StreamConfig(taskConfig); cfg.setStreamOperator(new TestOperator()); + cfg.setOperatorID(new OperatorID()); cfg.setStateBackend(new LockingStreamStateBackend()); Task task = createTask(taskConfig); @@ -154,6 +157,7 @@ private static Task createTask(Configuration taskConfig) throws IOException { mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + mock(BlobCache.class), new FallbackLibraryCacheManager(), new FileCache(new String[] { EnvironmentInformation.getTemporaryFileDirectory() }), new TestingTaskManagerRuntimeInfo(), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 25b504b916c16..f73499c13162b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -24,8 +24,11 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -40,11 +43,11 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.DefaultOperatorStateBackend; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; @@ -55,14 +58,12 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerActions; import org.apache.flink.runtime.util.EnvironmentInformation; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.TimeCharacteristic; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -73,7 +74,6 @@ import java.io.EOFException; import java.io.IOException; -import java.io.Serializable; import java.net.URL; import java.util.Collection; import java.util.Collections; @@ -104,12 +104,6 @@ public class InterruptSensitiveRestoreTest { private static final int OPERATOR_RAW = 1; private static final int KEYED_MANAGED = 2; private static final int KEYED_RAW = 3; - private static final int LEGACY = 4; - - @Test - public void testRestoreWithInterruptLegacy() throws Exception { - testRestoreWithInterrupt(LEGACY); - } @Test public void testRestoreWithInterruptOperatorManaged() throws Exception { @@ -143,10 +137,7 @@ private void testRestoreWithInterrupt(int mode) throws Exception { case KEYED_MANAGED: case KEYED_RAW: cfg.setStateKeySerializer(IntSerializer.INSTANCE); - cfg.setStreamOperator(new StreamSource<>(new TestSource())); - break; - case LEGACY: - cfg.setStreamOperator(new StreamSource<>(new TestSourceLegacy())); + cfg.setStreamOperator(new StreamSource<>(new TestSource(mode))); break; default: throw new IllegalArgumentException(); @@ -154,7 +145,7 @@ private void testRestoreWithInterrupt(int mode) throws Exception { StreamStateHandle lockingHandle = new InterruptLockingStateHandle(); - Task task = createTask(taskConfig, lockingHandle, mode); + Task task = createTask(cfg, taskConfig, lockingHandle, mode); // start the task and wait until it is in "restore" task.startTaskThread(); @@ -178,6 +169,7 @@ private void testRestoreWithInterrupt(int mode) throws Exception { // ------------------------------------------------------------------------ private static Task createTask( + StreamConfig streamConfig, Configuration taskConfig, StreamStateHandle state, int mode) throws IOException { @@ -186,11 +178,10 @@ private static Task createTask( when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); - ChainedStateHandle operatorState = null; - List keyedStateFromBackend = Collections.emptyList(); - List keyedStateFromStream = Collections.emptyList(); - List> operatorStateBackend = Collections.emptyList(); - List> operatorStateStream = Collections.emptyList(); + Collection keyedStateFromBackend = Collections.emptyList(); + Collection keyedStateFromStream = Collections.emptyList(); + Collection operatorStateBackend = Collections.emptyList(); + Collection operatorStateStream = Collections.emptyList(); Map operatorStateMetadata = new HashMap<>(1); OperatorStateHandle.StateMetaInfo metaInfo = @@ -203,14 +194,14 @@ private static Task createTask( Collections.singletonList(new OperatorStateHandle(operatorStateMetadata, state)); List keyedStateHandles = - Collections.singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, state)); + Collections.singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, state)); switch (mode) { case OPERATOR_MANAGED: - operatorStateBackend = Collections.singletonList(operatorStateHandles); + operatorStateBackend = operatorStateHandles; break; case OPERATOR_RAW: - operatorStateStream = Collections.singletonList(operatorStateHandles); + operatorStateStream = operatorStateHandles; break; case KEYED_MANAGED: keyedStateFromBackend = keyedStateHandles; @@ -218,20 +209,21 @@ private static Task createTask( case KEYED_RAW: keyedStateFromStream = keyedStateHandles; break; - case LEGACY: - operatorState = new ChainedStateHandle<>(Collections.singletonList(state)); - break; default: throw new IllegalArgumentException(); } - TaskStateHandles taskStateHandles = new TaskStateHandles( - operatorState, + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( operatorStateBackend, operatorStateStream, keyedStateFromBackend, keyedStateFromStream); + JobVertexID jobVertexID = new JobVertexID(); + OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID); + streamConfig.setOperatorID(operatorID); + TaskStateSnapshot stateSnapshot = new TaskStateSnapshot(); + stateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); JobInformation jobInformation = new JobInformation( new JobID(), "test job name", @@ -241,7 +233,7 @@ private static Task createTask( Collections.emptyList()); TaskInformation taskInformation = new TaskInformation( - new JobVertexID(), + jobVertexID, "test task name", 1, 1, @@ -258,7 +250,7 @@ private static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - taskStateHandles, + stateSnapshot, mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, @@ -266,6 +258,7 @@ private static Task createTask( mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + mock(BlobCache.class), new FallbackLibraryCacheManager(), new FileCache(new String[] { EnvironmentInformation.getTemporaryFileDirectory() }), new TestingTaskManagerRuntimeInfo(), @@ -273,7 +266,6 @@ private static Task createTask( mock(ResultPartitionConsumableNotifier.class), mock(PartitionProducerStateChecker.class), mock(Executor.class)); - } // ------------------------------------------------------------------------ @@ -293,11 +285,11 @@ public FSDataInputStream openInputStream() throws IOException { FSDataInputStream is = new FSDataInputStream() { @Override - public void seek(long desired) throws IOException { + public void seek(long desired) { } @Override - public long getPos() throws IOException { + public long getPos() { return 0; } @@ -349,32 +341,14 @@ public long getStateSize() { // ------------------------------------------------------------------------ - private static class TestSourceLegacy implements SourceFunction, Checkpointed { + private static class TestSource implements SourceFunction, CheckpointedFunction { private static final long serialVersionUID = 1L; + private final int testType; - @Override - public void run(SourceContext ctx) throws Exception { - fail("should never be called"); + public TestSource(int testType) { + this.testType = testType; } - @Override - public void cancel() {} - - @Override - public Serializable snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - fail("should never be called"); - return null; - } - - @Override - public void restoreState(Serializable state) throws Exception { - fail("should never be called"); - } - } - - private static class TestSource implements SourceFunction, CheckpointedFunction { - private static final long serialVersionUID = 1L; - @Override public void run(SourceContext ctx) throws Exception { fail("should never be called"); @@ -390,6 +364,8 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { @Override public void initializeState(FunctionInitializationContext context) throws Exception { + // raw keyed state is already read by timer service, all others to initialize the context...we only need to + // trigger this manually. ((StateInitializationContext) context).getRawOperatorStateInputs().iterator().next().getStream().read(); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index f7987a122589f..8d80d66caed93 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -28,40 +28,35 @@ import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamNode; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatus; import org.apache.flink.streaming.util.TestHarnessUtil; -import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; import org.junit.Assert; import org.junit.Test; -import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -69,7 +64,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; @@ -78,7 +72,6 @@ import scala.concurrent.duration.FiniteDuration; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -93,7 +86,7 @@ public class OneInputStreamTaskTest extends TestLogger { private static final ListStateDescriptor TEST_DESCRIPTOR = - new ListStateDescriptor<>("test", new IntSerializer()); + new ListStateDescriptor<>("test", new IntSerializer()); /** * This test verifies that open() and close() are correctly called. This test also verifies @@ -109,6 +102,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new TestOpenCloseMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); long initialTime = 0L; ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); @@ -128,8 +122,8 @@ public void testOpenCloseAndTimestamps() throws Exception { assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled); TestHarnessUtil.assertOutputEquals("Output was not correct.", - expectedOutput, - testHarness.getOutput()); + expectedOutput, + testHarness.getOutput()); } /** @@ -151,6 +145,7 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -172,8 +167,8 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { testHarness.waitForInputProcessing(); expectedOutput.add(new Watermark(initialTime)); TestHarnessUtil.assertOutputEquals("Output was not correct.", - expectedOutput, - testHarness.getOutput()); + expectedOutput, + testHarness.getOutput()); // contrary to checkpoint barriers these elements are not blocked by watermarks testHarness.processElement(new StreamRecord("Hello", initialTime)); @@ -213,7 +208,7 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { testHarness.processElement(new Watermark(initialTime + 6), 0, 0); testHarness.processElement(new Watermark(initialTime + 5), 1, 1); // this watermark should be advanced first testHarness.processElement(StreamStatus.IDLE, 1, 1); // once this is acknowledged, - // watermark (initial + 6) should be forwarded + // watermark (initial + 6) should be forwarded testHarness.waitForInputProcessing(); expectedOutput.add(new Watermark(initialTime + 5)); expectedOutput.add(new Watermark(initialTime + 6)); @@ -270,6 +265,7 @@ public void testWatermarksNotForwardedWithinChainWhenIdle() throws Exception { StreamConfig tailOperatorConfig = new StreamConfig(new Configuration()); headOperatorConfig.setStreamOperator(headOperator); + headOperatorConfig.setOperatorID(new OperatorID(42L, 42L)); headOperatorConfig.setChainStart(); headOperatorConfig.setChainIndex(0); headOperatorConfig.setChainedOutputs(Collections.singletonList(new StreamEdge( @@ -282,6 +278,7 @@ public void testWatermarksNotForwardedWithinChainWhenIdle() throws Exception { ))); watermarkOperatorConfig.setStreamOperator(watermarkOperator); + watermarkOperatorConfig.setOperatorID(new OperatorID(4711L, 42L)); watermarkOperatorConfig.setTypeSerializerIn1(StringSerializer.INSTANCE); watermarkOperatorConfig.setChainIndex(1); watermarkOperatorConfig.setChainedOutputs(Collections.singletonList(new StreamEdge( @@ -303,6 +300,7 @@ public void testWatermarksNotForwardedWithinChainWhenIdle() throws Exception { null)); tailOperatorConfig.setStreamOperator(tailOperator); + tailOperatorConfig.setOperatorID(new OperatorID(123L, 123L)); tailOperatorConfig.setTypeSerializerIn1(StringSerializer.INSTANCE); tailOperatorConfig.setBufferTimeout(0); tailOperatorConfig.setChainIndex(2); @@ -412,6 +410,7 @@ public void testCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -471,6 +470,7 @@ public void testOvertakingCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -543,13 +543,11 @@ public void testSnapshottingAndRestoring() throws Exception { long checkpointId = 1L; long checkpointTimestamp = 1L; - long recoveryTimestamp = 3L; - long seed = 2L; int numberChainedTasks = 11; StreamConfig streamConfig = testHarness.getStreamConfig(); - configureChainedTestingStreamOperator(streamConfig, numberChainedTasks, seed, recoveryTimestamp); + configureChainedTestingStreamOperator(streamConfig, numberChainedTasks); AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( testHarness.jobConfig, @@ -580,14 +578,19 @@ public void testSnapshottingAndRestoring() throws Exception { testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); final OneInputStreamTask restoredTask = new OneInputStreamTask(); - restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles())); - final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); + final OneInputStreamTaskTestHarness restoredTaskHarness = + new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig(); - configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp); + configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks); + + TaskStateSnapshot stateHandles = env.getCheckpointStateHandles(); + Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size()); + + restoredTask.setInitialState(stateHandles); TestingStreamOperator.numberRestoreCalls = 0; @@ -601,32 +604,31 @@ public void testSnapshottingAndRestoring() throws Exception { TestingStreamOperator.numberRestoreCalls = 0; } + //============================================================================================== // Utility functions and classes //============================================================================================== private void configureChainedTestingStreamOperator( StreamConfig streamConfig, - int numberChainedTasks, - long seed, - long recoveryTimestamp) { + int numberChainedTasks) { Preconditions.checkArgument(numberChainedTasks >= 1, "The operator chain must at least " + "contain one operator."); - Random random = new Random(seed); - - TestingStreamOperator previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); + TestingStreamOperator previousOperator = new TestingStreamOperator<>(); streamConfig.setStreamOperator(previousOperator); + streamConfig.setOperatorID(new OperatorID(0L, 0L)); // create the chain of operators Map chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1); List outputEdges = new ArrayList<>(numberChainedTasks - 1); for (int chainedIndex = 1; chainedIndex < numberChainedTasks; chainedIndex++) { - TestingStreamOperator chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); + TestingStreamOperator chainedOperator = new TestingStreamOperator<>(); StreamConfig chainedConfig = new StreamConfig(new Configuration()); chainedConfig.setStreamOperator(chainedOperator); + chainedConfig.setOperatorID(new OperatorID(0L, chainedIndex)); chainedTaskConfigs.put(chainedIndex, chainedConfig); StreamEdge outputEdge = new StreamEdge( @@ -673,7 +675,7 @@ public IN getKey(IN value) throws Exception { private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment { private volatile long checkpointId; - private volatile SubtaskState checkpointStateHandles; + private volatile TaskStateSnapshot checkpointStateHandles; private final OneShotLatch checkpointLatch = new OneShotLatch(); @@ -682,17 +684,17 @@ public long getCheckpointId() { } AcknowledgeStreamMockEnvironment( - Configuration jobConfig, Configuration taskConfig, - ExecutionConfig executionConfig, long memorySize, - MockInputSplitProvider inputSplitProvider, int bufferSize) { + Configuration jobConfig, Configuration taskConfig, + ExecutionConfig executionConfig, long memorySize, + MockInputSplitProvider inputSplitProvider, int bufferSize) { super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize); } @Override public void acknowledgeCheckpoint( - long checkpointId, - CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + long checkpointId, + CheckpointMetrics checkpointMetrics, + TaskStateSnapshot checkpointStateHandles) { this.checkpointId = checkpointId; this.checkpointStateHandles = checkpointStateHandles; @@ -703,25 +705,20 @@ public OneShotLatch getCheckpointLatch() { return checkpointLatch; } - public SubtaskState getCheckpointStateHandles() { + public TaskStateSnapshot getCheckpointStateHandles() { return checkpointStateHandles; } } private static class TestingStreamOperator - extends AbstractStreamOperator - implements OneInputStreamOperator, StreamCheckpointedOperator { + extends AbstractStreamOperator + implements OneInputStreamOperator { private static final long serialVersionUID = 774614855940397174L; public static int numberRestoreCalls = 0; public static int numberSnapshotCalls = 0; - private final long seed; - private final long recoveryTimestamp; - - private transient Random random; - @Override public void open() throws Exception { super.open(); @@ -747,7 +744,7 @@ public void open() throws Exception { @Override public void snapshotState(StateSnapshotContext context) throws Exception { ListState partitionableState = - getOperatorStateBackend().getListState(TEST_DESCRIPTOR); + getOperatorStateBackend().getListState(TEST_DESCRIPTOR); partitionableState.clear(); partitionableState.add(42); @@ -758,59 +755,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void initializeState(StateInitializationContext context) throws Exception { - - } - - TestingStreamOperator(long seed, long recoveryTimestamp) { - this.seed = seed; - this.recoveryTimestamp = recoveryTimestamp; - } - - @Override - public void processElement(StreamRecord element) throws Exception { - - } - - @Override - public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - if (random == null) { - random = new Random(seed); + if (context.isRestored()) { + ++numberRestoreCalls; } - - Serializable functionState = generateFunctionState(); - Integer operatorState = generateOperatorState(); - - InstantiationUtil.serializeObject(out, functionState); - InstantiationUtil.serializeObject(out, operatorState); } @Override - public void restoreState(FSDataInputStream in) throws Exception { - numberRestoreCalls++; - - if (random == null) { - random = new Random(seed); - } - - assertEquals(this.recoveryTimestamp, recoveryTimestamp); - - assertNotNull(in); - - ClassLoader cl = Thread.currentThread().getContextClassLoader(); - - Serializable functionState = InstantiationUtil.deserializeObject(in, cl); - Integer operatorState = InstantiationUtil.deserializeObject(in, cl); - - assertEquals(random.nextInt(), functionState); - assertEquals(random.nextInt(), (int) operatorState); - } - - private Serializable generateFunctionState() { - return random.nextInt(); - } + public void processElement(StreamRecord element) throws Exception { - private Integer generateOperatorState() { - return random.nextInt(); } } @@ -893,8 +845,8 @@ protected void handleWatermark(Watermark mark) { *

If it receives a watermark when it's not expecting one, it'll throw an exception and fail. */ private static class TriggerableFailOnWatermarkTestOperator - extends AbstractStreamOperator - implements OneInputStreamOperator { + extends AbstractStreamOperator + implements OneInputStreamOperator { private static final long serialVersionUID = 2048954179291813243L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java index 47a53500be822..b3b0a9f414e35 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -64,6 +65,7 @@ public void testCheckpointsTriggeredBySource() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource sourceOperator = new StreamSource<>(source); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); // this starts the source thread testHarness.invoke(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java index 27818bcafa36d..8867632a5c3ff 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -63,6 +64,7 @@ public void testOpenClose() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource sourceOperator = new StreamSource<>(new OpenCloseTestSource()); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); testHarness.waitForTaskCompletion(); @@ -106,6 +108,7 @@ public void testCheckpointing() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource, ?> sourceOperator = new StreamSource<>(new MockSource(numElements, sourceCheckpointDelay, sourceReadDelay)); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); // prepare the diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 5b995c67b8e94..231f59e97fb2a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -28,7 +28,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -333,7 +333,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java index 6e3c299f9ed9d..36bdc054b9340 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -91,6 +92,7 @@ public void testDeclineCallOnCancelBarrierOneInput() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap<>(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); StreamMockEnvironment environment = spy(testHarness.createEnvironment()); @@ -135,6 +137,7 @@ public void testDeclineCallOnCancelBarrierTwoInputs() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap op = new CoStreamMap<>(new UnionCoMap()); streamConfig.setStreamOperator(op); + streamConfig.setOperatorID(new OperatorID()); StreamMockEnvironment environment = spy(testHarness.createEnvironment()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java index 4f2135d96bc0f..79e9583a8be72 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointOptions; @@ -42,6 +43,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -107,6 +109,7 @@ public void testConcurrentAsyncCheckpointCannotFailFinishedStreamTask() throws E final AbstractStateBackend blockingStateBackend = new BlockingStateBackend(); streamConfig.setStreamOperator(noOpStreamOperator); + streamConfig.setOperatorID(new OperatorID()); streamConfig.setStateBackend(blockingStateBackend); final long checkpointId = 0L; @@ -151,6 +154,7 @@ public void testConcurrentAsyncCheckpointCannotFailFinishedStreamTask() throws E mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + mock(BlobCache.class), new FallbackLibraryCacheManager(), mock(FileCache.class), taskManagerRuntimeInfo, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index 923b912437a2c..9bb91ad7ba8a6 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -27,12 +27,15 @@ import org.apache.flink.configuration.CoreOptions; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.blob.BlobCache; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -49,14 +52,16 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.operators.testutils.MockEnvironment; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.KeyGroupRange; @@ -65,12 +70,12 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.taskmanager.TaskExecutionStateListener; import org.apache.flink.runtime.taskmanager.TaskManagerActions; +import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.DirectExecutorService; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.TimeCharacteristic; @@ -79,7 +84,6 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -110,12 +114,14 @@ import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import scala.concurrent.Await; import scala.concurrent.Future; @@ -128,6 +134,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyCollectionOf; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; @@ -136,7 +143,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; import static org.powermock.api.mockito.PowerMockito.whenNew; /** @@ -158,6 +164,7 @@ public class StreamTaskTest extends TestLogger { public void testEarlyCanceling() throws Exception { Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow(); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new SlowlyDeserializingOperator()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -203,6 +210,7 @@ public void testStateBackendLoadingAndClosing() throws Exception { taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName()); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction())); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -227,6 +235,7 @@ public void testStateBackendClosingOnFailure() throws Exception { taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName()); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction())); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -301,9 +310,9 @@ public void testFailingCheckpointStreamOperator() throws Exception { streamTask.setEnvironment(mockEnvironment); // mock the operators - StreamOperator streamOperator1 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); - StreamOperator streamOperator2 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); - StreamOperator streamOperator3 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + StreamOperator streamOperator1 = mock(StreamOperator.class); + StreamOperator streamOperator2 = mock(StreamOperator.class); + StreamOperator streamOperator3 = mock(StreamOperator.class); // mock the returned snapshots OperatorSnapshotResult operatorSnapshotResult1 = mock(OperatorSnapshotResult.class); @@ -315,14 +324,12 @@ public void testFailingCheckpointStreamOperator() throws Exception { when(streamOperator2.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(operatorSnapshotResult2); when(streamOperator3.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenThrow(testException); - // mock the returned legacy snapshots - StreamStateHandle streamStateHandle1 = mock(StreamStateHandle.class); - StreamStateHandle streamStateHandle2 = mock(StreamStateHandle.class); - StreamStateHandle streamStateHandle3 = mock(StreamStateHandle.class); - - when(streamOperator1.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle1); - when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2); - when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3); + OperatorID operatorID1 = new OperatorID(); + OperatorID operatorID2 = new OperatorID(); + OperatorID operatorID3 = new OperatorID(); + when(streamOperator1.getOperatorID()).thenReturn(operatorID1); + when(streamOperator2.getOperatorID()).thenReturn(operatorID2); + when(streamOperator3.getOperatorID()).thenReturn(operatorID3); // set up the task @@ -346,10 +353,6 @@ public void testFailingCheckpointStreamOperator() throws Exception { verify(operatorSnapshotResult1).cancel(); verify(operatorSnapshotResult2).cancel(); - - verify(streamStateHandle1).discardState(); - verify(streamStateHandle2).discardState(); - verify(streamStateHandle3).discardState(); } /** @@ -371,12 +374,12 @@ public void testFailingAsyncCheckpointRunnable() throws Exception { CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); streamTask.setEnvironment(mockEnvironment); - StreamOperator streamOperator1 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); - StreamOperator streamOperator2 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); - StreamOperator streamOperator3 = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); - - // mock the new state handles / futures + // mock the operators + StreamOperator streamOperator1 = mock(StreamOperator.class); + StreamOperator streamOperator2 = mock(StreamOperator.class); + StreamOperator streamOperator3 = mock(StreamOperator.class); + // mock the new state operator snapshots OperatorSnapshotResult operatorSnapshotResult1 = mock(OperatorSnapshotResult.class); OperatorSnapshotResult operatorSnapshotResult2 = mock(OperatorSnapshotResult.class); OperatorSnapshotResult operatorSnapshotResult3 = mock(OperatorSnapshotResult.class); @@ -390,14 +393,12 @@ public void testFailingAsyncCheckpointRunnable() throws Exception { when(streamOperator2.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(operatorSnapshotResult2); when(streamOperator3.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(operatorSnapshotResult3); - // mock the legacy state snapshot - StreamStateHandle streamStateHandle1 = mock(StreamStateHandle.class); - StreamStateHandle streamStateHandle2 = mock(StreamStateHandle.class); - StreamStateHandle streamStateHandle3 = mock(StreamStateHandle.class); - - when(streamOperator1.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle1); - when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2); - when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3); + OperatorID operatorID1 = new OperatorID(); + OperatorID operatorID2 = new OperatorID(); + OperatorID operatorID3 = new OperatorID(); + when(streamOperator1.getOperatorID()).thenReturn(operatorID1); + when(streamOperator2.getOperatorID()).thenReturn(operatorID2); + when(streamOperator3.getOperatorID()).thenReturn(operatorID3); StreamOperator[] streamOperators = {streamOperator1, streamOperator2, streamOperator3}; @@ -418,10 +419,6 @@ public void testFailingAsyncCheckpointRunnable() throws Exception { verify(operatorSnapshotResult1).cancel(); verify(operatorSnapshotResult2).cancel(); verify(operatorSnapshotResult3).cancel(); - - verify(streamStateHandle1).discardState(); - verify(streamStateHandle2).discardState(); - verify(streamStateHandle3).discardState(); } /** @@ -455,13 +452,13 @@ public Object answer(InvocationOnMock invocation) throws Throwable { return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class)); + }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); StreamTask> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); streamTask.setEnvironment(mockEnvironment); - StreamOperator streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + StreamOperator streamOperator = mock(StreamOperator.class); KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class); KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class); @@ -505,18 +502,19 @@ public Object answer(InvocationOnMock invocation) throws Throwable { acknowledgeCheckpointLatch.await(); - ArgumentCaptor subtaskStateCaptor = ArgumentCaptor.forClass(SubtaskState.class); + ArgumentCaptor subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class); // check that the checkpoint has been completed verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture()); - SubtaskState subtaskState = subtaskStateCaptor.getValue(); + TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue(); + OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); // check that the subtask state contains the expected state handles - assertEquals(managedKeyedStateHandle, subtaskState.getManagedKeyedState()); - assertEquals(rawKeyedStateHandle, subtaskState.getRawKeyedState()); - assertEquals(new ChainedStateHandle<>(Collections.singletonList(managedOperatorStateHandle)), subtaskState.getManagedOperatorState()); - assertEquals(new ChainedStateHandle<>(Collections.singletonList(rawOperatorStateHandle)), subtaskState.getRawOperatorState()); + assertEquals(Collections.singletonList(managedKeyedStateHandle), subtaskState.getManagedKeyedState()); + assertEquals(Collections.singletonList(rawKeyedStateHandle), subtaskState.getRawKeyedState()); + assertEquals(Collections.singletonList(managedOperatorStateHandle), subtaskState.getManagedOperatorState()); + assertEquals(Collections.singletonList(rawOperatorStateHandle), subtaskState.getRawOperatorState()); // check that the state handles have not been discarded verify(managedKeyedStateHandle, never()).discardState(); @@ -558,18 +556,24 @@ public void testAsyncCheckpointingConcurrentCloseBeforeAcknowledge() throws Exce Environment mockEnvironment = mock(Environment.class); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); - whenNew(SubtaskState.class).withAnyArguments().thenAnswer(new Answer() { - @Override - public SubtaskState answer(InvocationOnMock invocation) throws Throwable { + whenNew(OperatorSubtaskState.class). + withArguments( + anyCollectionOf(OperatorStateHandle.class), + anyCollectionOf(OperatorStateHandle.class), + anyCollectionOf(KeyedStateHandle.class), + anyCollectionOf(KeyedStateHandle.class)). + thenAnswer(new Answer() { + @Override + public OperatorSubtaskState answer(InvocationOnMock invocation) throws Throwable { createSubtask.trigger(); completeSubtask.await(); - - return new SubtaskState( - (ChainedStateHandle) invocation.getArguments()[0], - (ChainedStateHandle) invocation.getArguments()[1], - (ChainedStateHandle) invocation.getArguments()[2], - (KeyedStateHandle) invocation.getArguments()[3], - (KeyedStateHandle) invocation.getArguments()[4]); + Object[] arguments = invocation.getArguments(); + return new OperatorSubtaskState( + (OperatorStateHandle) arguments[0], + (OperatorStateHandle) arguments[1], + (KeyedStateHandle) arguments[2], + (KeyedStateHandle) arguments[3] + ); } }); @@ -577,7 +581,9 @@ public SubtaskState answer(InvocationOnMock invocation) throws Throwable { CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); streamTask.setEnvironment(mockEnvironment); - StreamOperator streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + final StreamOperator streamOperator = mock(StreamOperator.class); + final OperatorID operatorID = new OperatorID(); + when(streamOperator.getOperatorID()).thenReturn(operatorID); KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class); KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class); @@ -636,7 +642,7 @@ public SubtaskState answer(InvocationOnMock invocation) throws Throwable { } // check that the checkpoint has not been acknowledged - verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(SubtaskState.class)); + verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); // check that the state handles have been discarded verify(managedKeyedStateHandle).discardState(); @@ -676,7 +682,7 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { checkpointCompletedLatch.trigger(); return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class)); + }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); @@ -686,7 +692,10 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { // mock the operators StreamOperator statelessOperator = - mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + mock(StreamOperator.class); + + final OperatorID operatorID = new OperatorID(); + when(statelessOperator.getOperatorID()).thenReturn(operatorID); // mock the returned empty snapshot result (all state handles are null) OperatorSnapshotResult statelessOperatorSnapshotResult = new OperatorSnapshotResult(); @@ -713,10 +722,96 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { Assert.assertNull(checkpointResult.get(0)); } + /** + * Tests that the StreamTask first closes alls its operators before setting its + * state to not running (isRunning == false) + * + *

See FLINK-7430. + */ + @Test + public void testOperatorClosingBeforeStopRunning() throws Throwable { + Configuration taskConfiguration = new Configuration(); + StreamConfig streamConfig = new StreamConfig(taskConfiguration); + streamConfig.setStreamOperator(new BlockingCloseStreamOperator()); + streamConfig.setOperatorID(new OperatorID()); + + MockEnvironment mockEnvironment = new MockEnvironment( + "Test Task", + 32L * 1024L, + new MockInputSplitProvider(), + 1, + taskConfiguration, + new ExecutionConfig()); + StreamTask streamTask = new NoOpStreamTask<>(mockEnvironment); + final AtomicReference atomicThrowable = new AtomicReference<>(null); + + CompletableFuture invokeFuture = CompletableFuture.runAsync( + () -> { + try { + streamTask.invoke(); + } catch (Exception e) { + atomicThrowable.set(e); + } + }, + TestingUtils.defaultExecutor()); + + BlockingCloseStreamOperator.IN_CLOSE.await(); + + // check that the StreamTask is not yet in isRunning == false + assertTrue(streamTask.isRunning()); + + // let the operator finish its close operation + BlockingCloseStreamOperator.FINISH_CLOSE.trigger(); + + // wait until the invoke is complete + invokeFuture.get(); + + // now the StreamTask should no longer be running + assertFalse(streamTask.isRunning()); + + // check if an exception occurred + if (atomicThrowable.get() != null) { + throw atomicThrowable.get(); + } + } + // ------------------------------------------------------------------------ // Test Utilities // ------------------------------------------------------------------------ + private static class NoOpStreamTask> extends StreamTask { + + public NoOpStreamTask(Environment environment) { + setEnvironment(environment); + } + + @Override + protected void init() throws Exception {} + + @Override + protected void run() throws Exception {} + + @Override + protected void cleanup() throws Exception {} + + @Override + protected void cancelTask() throws Exception {} + } + + private static class BlockingCloseStreamOperator extends AbstractStreamOperator { + private static final long serialVersionUID = -9042150529568008847L; + + public static final OneShotLatch IN_CLOSE = new OneShotLatch(); + public static final OneShotLatch FINISH_CLOSE = new OneShotLatch(); + + @Override + public void close() throws Exception { + IN_CLOSE.trigger(); + FINISH_CLOSE.await(); + super.close(); + } + } + private static class TestingExecutionStateListener implements TaskExecutionStateListener { private ExecutionState executionState = null; @@ -763,6 +858,7 @@ public static Task createTask( StreamConfig taskConfig, Configuration taskManagerConfig) throws Exception { + BlobCache blobCache = mock(BlobCache.class); LibraryCacheManager libCache = mock(LibraryCacheManager.class); when(libCache.getClassLoader(any(JobID.class))).thenReturn(StreamTaskTest.class.getClassLoader()); @@ -803,7 +899,7 @@ public static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateHandles(), + new TaskStateSnapshot(), mock(MemoryManager.class), mock(IOManager.class), network, @@ -811,6 +907,7 @@ public static Task createTask( mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), + blobCache, libCache, mock(FileCache.class), new TestingTaskManagerRuntimeInfo(taskManagerConfig, new String[] {System.getProperty("java.io.tmpdir")}), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index a02fe4e7d9a56..19d48e195f2ef 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; @@ -142,6 +143,7 @@ public void setupOutputForSingletonOperatorChain() { streamConfig.setNumberOfOutputs(1); streamConfig.setTypeSerializerOut(outputSerializer); streamConfig.setVertexID(0); + streamConfig.setOperatorID(new OperatorID(4711L, 123L)); StreamOperator dummyOperator = new AbstractStreamOperator() { private static final long serialVersionUID = 1L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java index 66531ac51ce28..d785c0d7517f5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.co.RichCoMapFunction; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -64,6 +65,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new TestOpenCloseMapFunction()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); long initialTime = 0L; ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); @@ -110,6 +112,7 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -216,6 +219,7 @@ public void testCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -296,6 +300,7 @@ public void testOvertakingCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java index 47e8726874d33..9156f3413e529 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java @@ -25,26 +25,21 @@ import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.CloseableRegistry; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0Serializer; -import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState; -import org.apache.flink.migration.util.MigrationInstantiationUtil; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackend; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -52,7 +47,6 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperatorTest; import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; @@ -70,7 +64,6 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import java.io.FileInputStream; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -88,7 +81,7 @@ /** * Base class for {@code AbstractStreamOperator} test harnesses. */ -public class AbstractStreamOperatorTestHarness { +public class AbstractStreamOperatorTestHarness implements AutoCloseable { protected final StreamOperator operator; @@ -154,6 +147,7 @@ public AbstractStreamOperatorTestHarness( Configuration underlyingConfig = environment.getTaskConfiguration(); this.config = new StreamConfig(underlyingConfig); this.config.setCheckpointingEnabled(true); + this.config.setOperatorID(new OperatorID()); this.executionConfig = environment.getExecutionConfig(); this.closableRegistry = new CloseableRegistry(); this.checkpointLock = new Object(); @@ -305,38 +299,8 @@ public void setup(TypeSerializer outputSerializer) { setupCalled = true; } - public void initializeStateFromLegacyCheckpoint(String checkpointFilename) throws Exception { - - FileInputStream fin = new FileInputStream(checkpointFilename); - StreamTaskState state = MigrationInstantiationUtil.deserializeObject(fin, ClassLoader.getSystemClassLoader()); - fin.close(); - - if (!setupCalled) { - setup(); - } - - StreamStateHandle stateHandle = SavepointV0Serializer.convertOperatorAndFunctionState(state); - - List keyGroupStatesList = new ArrayList<>(); - if (state.getKvStates() != null) { - KeyGroupsStateHandle keyedStateHandle = SavepointV0Serializer.convertKeyedBackendState( - state.getKvStates(), - environment.getTaskInfo().getIndexOfThisSubtask(), - 0); - keyGroupStatesList.add(keyedStateHandle); - } - - // finally calling the initializeState() with the legacy operatorStateHandles - initializeState(new OperatorStateHandles(0, - stateHandle, - keyGroupStatesList, - Collections.emptyList(), - Collections.emptyList(), - Collections.emptyList())); - } - /** - * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}. + * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}. * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} * if it was not called before. * @@ -393,13 +357,11 @@ public void initializeState(OperatorStateHandles operatorStateHandles) throws Ex rawOperatorState, numSubtasks).get(subtaskIndex); - OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles( - 0, - operatorStateHandles.getLegacyOperatorState(), - localManagedKeyGroupState, - localRawKeyGroupState, - localManagedOperatorState, - localRawOperatorState); + OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState( + nullToEmptyCollection(localManagedOperatorState), + nullToEmptyCollection(localRawOperatorState), + nullToEmptyCollection(localManagedKeyGroupState), + nullToEmptyCollection(localRawKeyGroupState)); operator.initializeState(massagedOperatorStateHandles); } else { @@ -408,6 +370,10 @@ public void initializeState(OperatorStateHandles operatorStateHandles) throws Ex initializeCalled = true; } + private static Collection nullToEmptyCollection(Collection collection) { + return collection != null ? collection : Collections.emptyList(); + } + /** * Takes the different {@link OperatorStateHandles} created by calling {@link #snapshot(long, long)} * on different instances of {@link AbstractStreamOperatorTestHarness} (each one representing one subtask) @@ -467,7 +433,6 @@ public static OperatorStateHandles repackageState(OperatorStateHandles... handle return new OperatorStateHandles( 0, - null, mergedManagedKeyedState, mergedRawKeyedState, mergedManagedOperatorState, @@ -491,8 +456,6 @@ public void open() throws Exception { */ public OperatorStateHandles snapshot(long checkpointId, long timestamp) throws Exception { - CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(new JobID(), "test_op"); - OperatorSnapshotResult operatorStateResult = operator.snapshotState( checkpointId, timestamp, @@ -504,45 +467,14 @@ public OperatorStateHandles snapshot(long checkpointId, long timestamp) throws E OperatorStateHandle opManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateManagedFuture()); OperatorStateHandle opRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateRawFuture()); - // also snapshot legacy state, if any - StreamStateHandle legacyStateHandle = null; - - if (operator instanceof StreamCheckpointedOperator) { - - final CheckpointStreamFactory.CheckpointStateOutputStream outStream = - streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); - - ((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp); - legacyStateHandle = outStream.closeAndGetHandle(); - } - return new OperatorStateHandles( 0, - legacyStateHandle, keyedManaged != null ? Collections.singletonList(keyedManaged) : null, keyedRaw != null ? Collections.singletonList(keyedRaw) : null, opManaged != null ? Collections.singletonList(opManaged) : null, opRaw != null ? Collections.singletonList(opRaw) : null); } - /** - * Calls {@link StreamCheckpointedOperator#snapshotState(FSDataOutputStream, long, long)} if - * the operator implements this interface. - */ - @Deprecated - public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception { - - CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory( - new JobID(), - "test_op").createCheckpointStateOutputStream(checkpointId, timestamp); - if (operator instanceof StreamCheckpointedOperator) { - ((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp); - return outStream.closeAndGetHandle(); - } else { - throw new RuntimeException("Operator is not StreamCheckpointedOperator"); - } - } - /** * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyOfCompletedCheckpoint(long)} ()}. */ @@ -550,22 +482,6 @@ public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { operator.notifyOfCompletedCheckpoint(checkpointId); } - /** - * Calls {@link StreamCheckpointedOperator#restoreState(FSDataInputStream)} if - * the operator implements this interface. - */ - @Deprecated - @SuppressWarnings("deprecation") - public void restore(StreamStateHandle snapshot) throws Exception { - if (operator instanceof StreamCheckpointedOperator) { - try (FSDataInputStream in = snapshot.openInputStream()) { - ((StreamCheckpointedOperator) operator).restoreState(in); - } - } else { - throw new RuntimeException("Operator is not StreamCheckpointedOperator"); - } - } - /** * Calls close and dispose on the operator. */ diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java index 0d42d9f1e3e50..c2ec63a6474d0 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java @@ -23,33 +23,23 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; -import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.util.Migration; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.List; -import java.util.concurrent.RunnableFuture; import static org.mockito.Matchers.any; import static org.mockito.Mockito.anyInt; @@ -142,61 +132,6 @@ public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwa } } - /** - * - */ - @Override - public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception { - // simply use an in-memory handle - MemoryStateBackend backend = new MemoryStateBackend(); - - CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op"); - CheckpointStreamFactory.CheckpointStateOutputStream outStream = - streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); - - if (operator instanceof StreamCheckpointedOperator) { - ((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp); - } - - if (keyedStateBackend != null) { - RunnableFuture keyedSnapshotRunnable = keyedStateBackend.snapshot( - checkpointId, - timestamp, - streamFactory, - CheckpointOptions.forFullCheckpoint()); - if (!keyedSnapshotRunnable.isDone()) { - Thread runner = new Thread(keyedSnapshotRunnable); - runner.start(); - } - outStream.write(1); - ObjectOutputStream oos = new ObjectOutputStream(outStream); - oos.writeObject(keyedSnapshotRunnable.get()); - oos.flush(); - } else { - outStream.write(0); - } - return outStream.closeAndGetHandle(); - } - - /** - * - */ - @Override - public void restore(StreamStateHandle snapshot) throws Exception { - try (FSDataInputStream inStream = snapshot.openInputStream()) { - - if (operator instanceof StreamCheckpointedOperator) { - ((StreamCheckpointedOperator) operator).restoreState(inStream); - } - - byte keyedStatePresent = (byte) inStream.read(); - if (keyedStatePresent == 1) { - ObjectInputStream ois = new ObjectInputStream(inStream); - this.restoredKeyedState = Collections.singletonList((KeyedStateHandle) ois.readObject()); - } - } - } - private static boolean hasMigrationHandles(Collection allKeyGroupsHandles) { for (KeyedStateHandle handle : allKeyGroupsHandles) { if (handle instanceof Migration) { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorSnapshotUtil.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorSnapshotUtil.java index 7e32723cbe02d..33f32e9f6bf5a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorSnapshotUtil.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorSnapshotUtil.java @@ -21,7 +21,6 @@ import org.apache.flink.runtime.checkpoint.savepoint.SavepointV1Serializer; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import java.io.DataInputStream; @@ -53,7 +52,8 @@ public static void writeStateHandle(OperatorStateHandles state, String path) thr dos.writeInt(state.getOperatorChainIndex()); - SavepointV1Serializer.serializeStreamStateHandle(state.getLegacyOperatorState(), dos); + // still required for compatibility + SavepointV1Serializer.serializeStreamStateHandle(null, dos); Collection rawOperatorState = state.getRawOperatorState(); if (rawOperatorState != null) { @@ -108,7 +108,8 @@ public static OperatorStateHandles readStateHandle(String path) throws IOExcepti try (DataInputStream dis = new DataInputStream(in)) { int index = dis.readInt(); - StreamStateHandle legacyState = SavepointV1Serializer.deserializeStreamStateHandle(dis); + // still required for compatibility to consume the bytes. + SavepointV1Serializer.deserializeStreamStateHandle(dis); List rawOperatorState = null; int numRawOperatorStates = dis.readInt(); @@ -154,7 +155,12 @@ public static OperatorStateHandles readStateHandle(String path) throws IOExcepti } } - return new OperatorStateHandles(index, legacyState, managedKeyedState, rawKeyedState, managedOperatorState, rawOperatorState); + return new OperatorStateHandles( + index, + managedKeyedState, + rawKeyedState, + managedOperatorState, + rawOperatorState); } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TestHarnessUtil.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TestHarnessUtil.java index 64894484cb75b..807b68c80b574 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TestHarnessUtil.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TestHarnessUtil.java @@ -21,7 +21,8 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import com.google.common.collect.Iterables; +import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; + import org.junit.Assert; import java.util.ArrayList; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/migration/MigrationTestUtil.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/migration/MigrationTestUtil.java index f723b345fea65..1c95a047f7865 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/migration/MigrationTestUtil.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/migration/MigrationTestUtil.java @@ -29,22 +29,16 @@ public class MigrationTestUtil { /** * Restore from a snapshot taken with an older Flink version. * - * @param testHarness the test harness to restore the snapshot to. - * @param snapshotPath the absolute path to the snapshot. + * @param testHarness the test harness to restore the snapshot to. + * @param snapshotPath the absolute path to the snapshot. * @param snapshotFlinkVersion the Flink version of the snapshot. - * * @throws Exception */ public static void restoreFromSnapshot( - AbstractStreamOperatorTestHarness testHarness, - String snapshotPath, - MigrationVersion snapshotFlinkVersion) throws Exception { + AbstractStreamOperatorTestHarness testHarness, + String snapshotPath, + MigrationVersion snapshotFlinkVersion) throws Exception { - if (snapshotFlinkVersion == MigrationVersion.v1_1) { - // Flink 1.1 snapshots should be read using the legacy restore method - testHarness.initializeStateFromLegacyCheckpoint(snapshotPath); - } else { - testHarness.initializeState(OperatorSnapshotUtil.readStateHandle(snapshotPath)); - } + testHarness.initializeState(OperatorSnapshotUtil.readStateHandle(snapshotPath)); } } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AsyncDataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AsyncDataStream.scala index 67af484834b07..e91922a9e00b4 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AsyncDataStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AsyncDataStream.scala @@ -21,9 +21,9 @@ package org.apache.flink.streaming.api.scala import org.apache.flink.annotation.PublicEvolving import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.streaming.api.datastream.{AsyncDataStream => JavaAsyncDataStream} -import org.apache.flink.streaming.api.functions.async.collector.{AsyncCollector => JavaAsyncCollector} +import org.apache.flink.streaming.api.functions.async.{ResultFuture => JavaResultFuture} import org.apache.flink.streaming.api.functions.async.{AsyncFunction => JavaAsyncFunction} -import org.apache.flink.streaming.api.scala.async.{AsyncCollector, AsyncFunction, JavaAsyncCollectorWrapper} +import org.apache.flink.streaming.api.scala.async.{AsyncFunction, JavaResultFutureWrapper, ResultFuture} import org.apache.flink.util.Preconditions import scala.concurrent.duration.TimeUnit @@ -34,7 +34,7 @@ import scala.concurrent.duration.TimeUnit * Example: * {{{ * val input: DataStream[String] = ... - * val asyncFunction: (String, AsyncCollector[String]) => Unit = ... + * val asyncFunction: (String, ResultFuture[String]) => Unit = ... * * AsyncDataStream.orderedWait(input, asyncFunction, timeout, TimeUnit.MILLISECONDS, 100) * }}} @@ -68,8 +68,8 @@ object AsyncDataStream { : DataStream[OUT] = { val javaAsyncFunction = new JavaAsyncFunction[IN, OUT] { - override def asyncInvoke(input: IN, collector: JavaAsyncCollector[OUT]): Unit = { - asyncFunction.asyncInvoke(input, new JavaAsyncCollectorWrapper(collector)) + override def asyncInvoke(input: IN, resultFuture: JavaResultFuture[OUT]): Unit = { + asyncFunction.asyncInvoke(input, new JavaResultFutureWrapper(resultFuture)) } } @@ -126,7 +126,7 @@ object AsyncDataStream { timeout: Long, timeUnit: TimeUnit, capacity: Int) ( - asyncFunction: (IN, AsyncCollector[OUT]) => Unit) + asyncFunction: (IN, ResultFuture[OUT]) => Unit) : DataStream[OUT] = { Preconditions.checkNotNull(asyncFunction) @@ -134,9 +134,9 @@ object AsyncDataStream { val cleanAsyncFunction = input.executionEnvironment.scalaClean(asyncFunction) val func = new JavaAsyncFunction[IN, OUT] { - override def asyncInvoke(input: IN, collector: JavaAsyncCollector[OUT]): Unit = { + override def asyncInvoke(input: IN, resultFuture: JavaResultFuture[OUT]): Unit = { - cleanAsyncFunction(input, new JavaAsyncCollectorWrapper[OUT](collector)) + cleanAsyncFunction(input, new JavaResultFutureWrapper[OUT](resultFuture)) } } @@ -167,7 +167,7 @@ object AsyncDataStream { input: DataStream[IN], timeout: Long, timeUnit: TimeUnit) ( - asyncFunction: (IN, AsyncCollector[OUT]) => Unit) + asyncFunction: (IN, ResultFuture[OUT]) => Unit) : DataStream[OUT] = { unorderedWait(input, timeout, timeUnit, DEFAULT_QUEUE_CAPACITY)(asyncFunction) } @@ -195,8 +195,8 @@ object AsyncDataStream { : DataStream[OUT] = { val javaAsyncFunction = new JavaAsyncFunction[IN, OUT] { - override def asyncInvoke(input: IN, collector: JavaAsyncCollector[OUT]): Unit = { - asyncFunction.asyncInvoke(input, new JavaAsyncCollectorWrapper[OUT](collector)) + override def asyncInvoke(input: IN, resultFuture: JavaResultFuture[OUT]): Unit = { + asyncFunction.asyncInvoke(input, new JavaResultFutureWrapper[OUT](resultFuture)) } } @@ -251,7 +251,7 @@ object AsyncDataStream { timeout: Long, timeUnit: TimeUnit, capacity: Int) ( - asyncFunction: (IN, AsyncCollector[OUT]) => Unit) + asyncFunction: (IN, ResultFuture[OUT]) => Unit) : DataStream[OUT] = { Preconditions.checkNotNull(asyncFunction) @@ -259,8 +259,8 @@ object AsyncDataStream { val cleanAsyncFunction = input.executionEnvironment.scalaClean(asyncFunction) val func = new JavaAsyncFunction[IN, OUT] { - override def asyncInvoke(input: IN, collector: JavaAsyncCollector[OUT]): Unit = { - cleanAsyncFunction(input, new JavaAsyncCollectorWrapper[OUT](collector)) + override def asyncInvoke(input: IN, resultFuture: JavaResultFuture[OUT]): Unit = { + cleanAsyncFunction(input, new JavaResultFutureWrapper[OUT](resultFuture)) } } @@ -290,7 +290,7 @@ object AsyncDataStream { input: DataStream[IN], timeout: Long, timeUnit: TimeUnit) ( - asyncFunction: (IN, AsyncCollector[OUT]) => Unit) + asyncFunction: (IN, ResultFuture[OUT]) => Unit) : DataStream[OUT] = { orderedWait(input, timeout, timeUnit, DEFAULT_QUEUE_CAPACITY)(asyncFunction) diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncFunction.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncFunction.scala index 72e3702906801..aea6b57a3c0a1 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncFunction.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncFunction.scala @@ -24,13 +24,13 @@ import org.apache.flink.annotation.PublicEvolving * A function to trigger async I/O operations. * * For each asyncInvoke an async io operation can be triggered, and once it has been done, - * the result can be collected by calling AsyncCollector.collect. For each async operation, its + * the result can be collected by calling ResultFuture.complete. For each async operation, its * context is stored in the operator immediately after invoking asyncInvoke, avoiding blocking for * each stream input as long as the internal buffer is not full. * - * [[AsyncCollector]] can be passed into callbacks or futures to collect the result data. + * [[ResultFuture]] can be passed into callbacks or futures to collect the result data. * An error can also be propagate to the async IO operator by - * [[AsyncCollector.collect(Throwable)]]. + * [[ResultFuture.completeExceptionally(Throwable)]]. * * @tparam IN The type of the input element * @tparam OUT The type of the output elements @@ -42,7 +42,7 @@ trait AsyncFunction[IN, OUT] { * Trigger the async operation for each stream input * * @param input element coming from an upstream task - * @param collector to collect the result data + * @param resultFuture to be completed with the result data */ - def asyncInvoke(input: IN, collector: AsyncCollector[OUT]): Unit + def asyncInvoke(input: IN, resultFuture: ResultFuture[OUT]): Unit } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaAsyncCollectorWrapper.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaResultFutureWrapper.scala similarity index 62% rename from flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaAsyncCollectorWrapper.scala rename to flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaResultFutureWrapper.scala index 3c5e95a83912e..7680b8964bf73 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaAsyncCollectorWrapper.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/JavaResultFutureWrapper.scala @@ -19,25 +19,26 @@ package org.apache.flink.streaming.api.scala.async import org.apache.flink.annotation.Internal -import org.apache.flink.streaming.api.functions.async.collector.{AsyncCollector => JavaAsyncCollector} +import org.apache.flink.streaming.api.functions.async +import org.apache.flink.streaming.api.functions.async.ResultFuture import scala.collection.JavaConverters._ /** - * Internal wrapper class to map a Flink's Java API [[JavaAsyncCollector]] to a Scala - * [[AsyncCollector]]. + * Internal wrapper class to map a Flink's Java API [[ResultFuture]] to a Scala + * [[ResultFuture]]. * - * @param javaAsyncCollector to forward the calls to + * @param javaResultFuture to forward the calls to * @tparam OUT type of the output elements */ @Internal -class JavaAsyncCollectorWrapper[OUT](val javaAsyncCollector: JavaAsyncCollector[OUT]) - extends AsyncCollector[OUT] { - override def collect(result: Iterable[OUT]): Unit = { - javaAsyncCollector.collect(result.asJavaCollection) +class JavaResultFutureWrapper[OUT](val javaResultFuture: async.ResultFuture[OUT]) + extends ResultFuture[OUT] { + override def complete(result: Iterable[OUT]): Unit = { + javaResultFuture.complete(result.asJavaCollection) } - override def collect(throwable: Throwable): Unit = { - javaAsyncCollector.collect(throwable) + override def completeExceptionally(throwable: Throwable): Unit = { + javaResultFuture.completeExceptionally(throwable) } } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncCollector.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/ResultFuture.scala similarity index 77% rename from flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncCollector.scala rename to flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/ResultFuture.scala index a149c88e639cc..516e69337d9d6 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/AsyncCollector.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/async/ResultFuture.scala @@ -21,30 +21,30 @@ package org.apache.flink.streaming.api.scala.async import org.apache.flink.annotation.PublicEvolving /** - * The async collector collects data/errors from the user code while processing + * The result future collects data/errors from the user code while processing * asynchronous I/O operations. * * @tparam OUT type of the output element */ @PublicEvolving -trait AsyncCollector[OUT] { +trait ResultFuture[OUT] { /** - * Complete the async collector with a set of result elements. + * Complete the ResultFuture with a set of result elements. * * Note that it should be called for exactly one time in the user code. * Calling this function for multiple times will cause data lose. * - * Put all results in a [[Iterable]] and then issue AsyncCollector.collect(Iterable). + * Put all results in a [[Iterable]] and then issue ResultFuture.complete(Iterable). * * @param result to complete the async collector with */ - def collect(result: Iterable[OUT]) + def complete(result: Iterable[OUT]) /** - * Complete this async collector with an error. + * Complete this ResultFuture with an error. * * @param throwable to complete the async collector with */ - def collect(throwable: Throwable) + def completeExceptionally(throwable: Throwable) } diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/TimeWindowTranslationTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/TimeWindowTranslationTest.scala index 104400f1c192a..35a56d7019a7b 100644 --- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/TimeWindowTranslationTest.scala +++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/TimeWindowTranslationTest.scala @@ -25,15 +25,15 @@ import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.function.WindowFunction import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.streaming.api.windowing.assigners.{SlidingAlignedProcessingTimeWindows, SlidingEventTimeWindows, TumblingAlignedProcessingTimeWindows} +import org.apache.flink.streaming.api.windowing.assigners.SlidingEventTimeWindows import org.apache.flink.streaming.api.windowing.time.Time import org.apache.flink.streaming.api.windowing.triggers.EventTimeTrigger import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.streaming.runtime.operators.windowing.{AccumulatingProcessingTimeWindowOperator, AggregatingProcessingTimeWindowOperator, WindowOperator} +import org.apache.flink.streaming.runtime.operators.windowing.WindowOperator import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.util.Collector import org.junit.Assert._ -import org.junit.{Ignore, Test} +import org.junit.Test /** * These tests verify that the api calls on [[WindowedStream]] that use the "time" shortcut @@ -85,59 +85,6 @@ class TimeWindowTranslationTest extends StreamingMultipleProgramsTestBase { assertTrue(operator2.isInstanceOf[WindowOperator[_, _, _, _, _]]) } - /** - * These tests ensure that the fast aligned time windows operator is used if the - * conditions are right. - */ - @Test - def testReduceAlignedTimeWindows(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - - val source = env.fromElements(("hello", 1), ("hello", 2)) - - val window1 = source - .keyBy(0) - .window(SlidingAlignedProcessingTimeWindows.of(Time.seconds(1), Time.milliseconds(100))) - .reduce(new DummyReducer()) - - val transform1 = window1.javaStream.getTransformation - .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]] - - val operator1 = transform1.getOperator - - assertTrue(operator1.isInstanceOf[AggregatingProcessingTimeWindowOperator[_, _]]) - } - - /** - * These tests ensure that the fast aligned time windows operator is used if the - * conditions are right. - */ - @Test - def testApplyAlignedTimeWindows(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime) - - val source = env.fromElements(("hello", 1), ("hello", 2)) - - val window1 = source - .keyBy(0) - .window(TumblingAlignedProcessingTimeWindows.of(Time.minutes(1))) - .apply(new WindowFunction[(String, Int), (String, Int), Tuple, TimeWindow]() { - def apply( - key: Tuple, - window: TimeWindow, - values: Iterable[(String, Int)], - out: Collector[(String, Int)]) { } - }) - - val transform1 = window1.javaStream.getTransformation - .asInstanceOf[OneInputTransformation[(String, Int), (String, Int)]] - - val operator1 = transform1.getOperator - - assertTrue(operator1.isInstanceOf[AccumulatingProcessingTimeWindowOperator[_, _, _]]) - } - @Test def testReduceEventTimeWindows(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment diff --git a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailureHandler.java b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailureHandler.java index 0ce0b12559058..07d23412d0952 100644 --- a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailureHandler.java +++ b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailureHandler.java @@ -92,7 +92,14 @@ public void channelOpen(ChannelHandlerContext context, ChannelStateEvent event) final Channel sourceChannel = event.getChannel(); sourceChannel.setReadable(false); - if (blocked.get()) { + boolean isBlocked = blocked.get(); + LOG.debug("Attempt to open proxy channel from [{}] to [{}:{}] in state [blocked = {}]", + sourceChannel.getLocalAddress(), + remoteHost, + remotePort, + isBlocked); + + if (isBlocked) { sourceChannel.close(); return; } diff --git a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailuresProxy.java b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailuresProxy.java index 70300494a7e0e..5531811ca8c30 100644 --- a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailuresProxy.java +++ b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/networking/NetworkFailuresProxy.java @@ -53,8 +53,6 @@ public class NetworkFailuresProxy implements AutoCloseable { private final Set networkFailureHandlers = Collections.newSetFromMap(new ConcurrentHashMap<>()); public NetworkFailuresProxy(int localPort, String remoteHost, int remotePort) { - LOG.info("Proxying [*:{}] to [{}:{}]", localPort, remoteHost, remotePort); - // Configure the bootstrap. serverBootstrap = new ServerBootstrap( new NioServerSocketChannelFactory(executor, executor)); @@ -83,6 +81,7 @@ public ChannelPipeline getPipeline() throws Exception { }); channel = serverBootstrap.bind(new InetSocketAddress(localPort)); + LOG.info("Proxying [*:{}] to [{}:{}]", getLocalPort(), remoteHost, remotePort); } /** diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml index 85d90b3684edb..a0c68a81a75b0 100644 --- a/flink-tests/pom.xml +++ b/flink-tests/pom.xml @@ -193,6 +193,11 @@ under the License. test + + org.apache.flink + flink-shaded-guava + + org.scalatest scalatest_${scala.binary.version} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java index 22ed84723e552..4d5fa719c73ec 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java @@ -27,9 +27,12 @@ import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.HighAvailabilityOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils; import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointListener; @@ -48,21 +51,25 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; +import org.apache.curator.test.TestingServer; import org.junit.After; -import org.junit.AfterClass; import org.junit.Before; -import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.rules.TestName; +import java.io.File; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.flink.test.checkpointing.AbstractEventTimeWindowCheckpointingITCase.StateBackendEnum.ROCKSDB_INCREMENTAL_ZK; import static org.apache.flink.test.util.TestUtils.tryExecute; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -87,6 +94,8 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog private static TestStreamEnvironment env; + private static TestingServer zkServer; + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); @@ -101,11 +110,27 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog } enum StateBackendEnum { - MEM, FILE, ROCKSDB_FULLY_ASYNC, ROCKSDB_INCREMENTAL, MEM_ASYNC, FILE_ASYNC + MEM, FILE, ROCKSDB_FULLY_ASYNC, ROCKSDB_INCREMENTAL, ROCKSDB_INCREMENTAL_ZK, MEM_ASYNC, FILE_ASYNC } - @BeforeClass - public static void startTestCluster() { + @Before + public void startTestCluster() throws Exception { + + // print a message when starting a test method to avoid Travis' "Maven produced no + // output for xxx seconds." messages + System.out.println( + "Starting " + getClass().getCanonicalName() + "#" + name.getMethodName() + "."); + + // Testing HA Scenario / ZKCompletedCheckpointStore with incremental checkpoints + if (ROCKSDB_INCREMENTAL_ZK.equals(stateBackendEnum)) { + zkServer = new TestingServer(); + zkServer.start(); + } + + TemporaryFolder temporaryFolder = new TemporaryFolder(); + temporaryFolder.create(); + final File haDir = temporaryFolder.newFolder(); + Configuration config = new Configuration(); config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 2); config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, PARALLELISM / 2); @@ -113,28 +138,29 @@ public static void startTestCluster() { // the default network buffers size (10% of heap max =~ 150MB) seems to much for this test case config.setLong(TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MAX, 80L << 20); // 80 MB - cluster = new LocalFlinkMiniCluster(config, false); + if (zkServer != null) { + config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER"); + config.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, zkServer.getConnectString()); + config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, haDir.toURI().toString()); + } + + // purposefully delay in the executor to tease out races + final ScheduledExecutorService executor = Executors.newScheduledThreadPool(10); + HighAvailabilityServices haServices = HighAvailabilityServicesUtils.createAvailableOrEmbeddedServices( + config, + new Executor() { + @Override + public void execute(Runnable command) { + executor.schedule(command, 500, MILLISECONDS); + } + }); + + cluster = new LocalFlinkMiniCluster(config, haServices, false); cluster.start(); env = new TestStreamEnvironment(cluster, PARALLELISM); env.getConfig().setUseSnapshotCompression(true); - } - - @AfterClass - public static void stopTestCluster() { - if (cluster != null) { - cluster.stop(); - } - } - @Before - public void beforeTest() throws IOException { - // print a message when starting a test method to avoid Travis' "Maven produced no - // output for xxx seconds." messages - System.out.println( - "Starting " + getClass().getCanonicalName() + "#" + name.getMethodName() + "."); - - // init state back-end switch (stateBackendEnum) { case MEM: this.stateBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, false); @@ -159,7 +185,8 @@ public void beforeTest() throws IOException { this.stateBackend = rdb; break; } - case ROCKSDB_INCREMENTAL: { + case ROCKSDB_INCREMENTAL: + case ROCKSDB_INCREMENTAL_ZK: { String rocksDb = tempFolder.newFolder().getAbsolutePath(); String backups = tempFolder.newFolder().getAbsolutePath(); // we use the fs backend with small threshold here to test the behaviour with file @@ -173,16 +200,25 @@ public void beforeTest() throws IOException { this.stateBackend = rdb; break; } - + default: + throw new IllegalStateException("No backend selected."); } } - /** - * Prints a message when finishing a test method to avoid Travis' "Maven produced no output - * for xxx seconds." messages. - */ @After - public void afterTest() { + public void stopTestCluster() throws IOException { + if (cluster != null) { + cluster.stop(); + cluster = null; + } + + if (zkServer != null) { + zkServer.stop(); + zkServer = null; + } + + //Prints a message when finishing a test method to avoid Travis' "Maven produced no output + // for xxx seconds." messages. System.out.println( "Finished " + getClass().getCanonicalName() + "#" + name.getMethodName() + "."); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java new file mode 100644 index 0000000000000..394815f2ae61a --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing; + +/** + * Integration tests for incremental RocksDB backend. + */ +public class HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase extends AbstractEventTimeWindowCheckpointingITCase { + + public HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase() { + super(StateBackendEnum.ROCKSDB_INCREMENTAL_ZK); + } + + @Override + protected int numElementsPerKey() { + return 3000; + } + + @Override + protected int windowSize() { + return 1000; + } + + @Override + protected int windowSlide() { + return 100; + } + + @Override + protected int numKeys() { + return 100; + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java index cad669364fdbe..99fb6ef80f643 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java @@ -44,7 +44,6 @@ import org.apache.flink.runtime.testingUtils.TestingCluster; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -975,7 +974,7 @@ public void restoreState(List state) throws Exception { } } - private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction, CheckpointedRestoring { + private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction { private static final long serialVersionUID = -359715965103593462L; private static final int NUM_PARTITIONS = 7; @@ -1030,10 +1029,5 @@ public void initializeState(FunctionInitializationContext context) throws Except checkCorrectRestore[getRuntimeContext().getIndexOfThisSubtask()] = counter; } } - - @Override - public void restoreState(Integer state) throws Exception { - counterPartitions.add(state); - } } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java index a3d45dd48f9ad..1b7dafab6e41a 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java @@ -53,7 +53,7 @@ import org.apache.flink.runtime.messages.JobManagerMessages.DisposeSavepoint; import org.apache.flink.runtime.messages.JobManagerMessages.TriggerSavepoint; import org.apache.flink.runtime.messages.JobManagerMessages.TriggerSavepointSuccess; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; @@ -74,11 +74,13 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; +import org.apache.flink.shaded.guava18.com.google.common.collect.HashMultimap; +import org.apache.flink.shaded.guava18.com.google.common.collect.Multimap; + import akka.actor.ActorRef; import akka.actor.ActorSystem; import akka.testkit.JavaTestKit; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -349,10 +351,6 @@ protected void run() { OperatorSubtaskState subtaskState = operatorState.getState(tdd.getSubtaskIndex()); assertNotNull(subtaskState); - - errMsg = "Initial operator state mismatch."; - assertEquals(errMsg, subtaskState.getLegacyOperatorState(), - tdd.getTaskStateHandles().getLegacyOperatorState().get(chainIndexAndJobVertex.f0)); } } @@ -375,17 +373,18 @@ protected void run() { assertTrue(errMsg, resp.getClass() == getDisposeSavepointSuccess().getClass()); // - Verification START ------------------------------------------- - // The checkpoint files List checkpointFiles = new ArrayList<>(); for (OperatorState stateForTaskGroup : savepoint.getOperatorStates()) { for (OperatorSubtaskState subtaskState : stateForTaskGroup.getStates()) { - StreamStateHandle streamTaskState = subtaskState.getLegacyOperatorState(); + Collection streamTaskState = subtaskState.getManagedOperatorState(); - if (streamTaskState != null) { - FileStateHandle fileStateHandle = (FileStateHandle) streamTaskState; - checkpointFiles.add(new File(fileStateHandle.getFilePath().toUri())); + if (streamTaskState != null && !streamTaskState.isEmpty()) { + for (OperatorStateHandle osh : streamTaskState) { + FileStateHandle fileStateHandle = (FileStateHandle) osh.getDelegateStateHandle(); + checkpointFiles.add(new File(fileStateHandle.getFilePath().toUri())); + } } } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java index f19d6904f03cd..530e97310e3f9 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java @@ -33,7 +33,8 @@ import org.apache.flink.streaming.api.operators.StreamGroupedFold; import org.apache.flink.streaming.api.operators.StreamGroupedReduce; -import com.google.common.collect.EvictingQueue; +import org.apache.flink.shaded.guava18.com.google.common.collect.EvictingQueue; + import org.junit.Assert; import java.util.Collections; diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java index 21be7ba84993b..eccc7e906d629 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/SavepointMigrationTestBase.java @@ -25,6 +25,7 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointSerializers; import org.apache.flink.runtime.client.JobListeningContext; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.jobgraph.JobGraph; @@ -39,6 +40,7 @@ import org.apache.commons.io.FileUtils; import org.junit.After; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; @@ -64,6 +66,11 @@ */ public class SavepointMigrationTestBase extends TestBaseUtils { + @BeforeClass + public static void before() { + SavepointSerializers.setFailWhenLegacyStateDetected(false); + } + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom11MigrationITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom11MigrationITCase.java deleted file mode 100644 index da6e035741671..0000000000000 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom11MigrationITCase.java +++ /dev/null @@ -1,562 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.checkpointing.utils; - -import org.apache.flink.api.common.accumulators.IntCounter; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeHint; -import org.apache.flink.api.common.typeutils.base.LongSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.apache.flink.streaming.api.TimeCharacteristic; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; -import org.apache.flink.streaming.api.functions.source.RichSourceFunction; -import org.apache.flink.streaming.api.functions.source.SourceFunction; -import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.TimestampedCollector; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.util.Collector; - -import org.junit.Ignore; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * This verifies that we can restore a complete job from a Flink 1.1 savepoint. - * - *

The test pipeline contains both "Checkpointed" state and keyed user state. - */ -public class StatefulJobSavepointFrom11MigrationITCase extends SavepointMigrationTestBase { - private static final int NUM_SOURCE_ELEMENTS = 4; - private static final String EXPECTED_ELEMENTS_ACCUMULATOR = "NUM_EXPECTED_ELEMENTS"; - private static final String SUCCESSFUL_CHECK_ACCUMULATOR = "SUCCESSFUL_CHECKS"; - - /** - * This has to be manually executed to create the savepoint on Flink 1.1. - */ - @Test - @Ignore - public void testCreateSavepointOnFlink11() throws Exception { - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); - // we only test memory state backend yet - env.setStateBackend(new MemoryStateBackend()); - env.enableCheckpointing(500); - env.setParallelism(4); - env.setMaxParallelism(4); - - // create source - env - .addSource(new LegacyCheckpointedSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") - .flatMap(new LegacyCheckpointedFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") - .keyBy(0) - .flatMap(new LegacyCheckpointedFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") - .keyBy(0) - .flatMap(new KeyedStateSettingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") - .keyBy(0) - .transform( - "custom_operator", - new TypeHint>() {}.getTypeInfo(), - new CheckpointedUdfOperator(new LegacyCheckpointedFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") - .addSink(new AccumulatorCountingSink>(EXPECTED_ELEMENTS_ACCUMULATOR)); - - executeAndSavepoint( - env, - "src/test/resources/stateful-udf-migration-itcase-flink1.1-savepoint", - new Tuple2<>(EXPECTED_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS)); - } - - /** - * This has to be manually executed to create the savepoint on Flink 1.1. - */ - @Test - @Ignore - public void testCreateSavepointOnFlink11WithRocksDB() throws Exception { - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); - RocksDBStateBackend rocksBackend = - new RocksDBStateBackend(new MemoryStateBackend()); -// rocksBackend.enableFullyAsyncSnapshots(); - env.setStateBackend(rocksBackend); - env.enableCheckpointing(500); - env.setParallelism(4); - env.setMaxParallelism(4); - - // create source - env - .addSource(new LegacyCheckpointedSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") - .flatMap(new LegacyCheckpointedFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") - .keyBy(0) - .flatMap(new LegacyCheckpointedFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") - .keyBy(0) - .flatMap(new KeyedStateSettingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") - .keyBy(0) - .transform( - "custom_operator", - new TypeHint>() {}.getTypeInfo(), - new CheckpointedUdfOperator(new LegacyCheckpointedFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") - .addSink(new AccumulatorCountingSink>(EXPECTED_ELEMENTS_ACCUMULATOR)); - - executeAndSavepoint( - env, - "src/test/resources/stateful-udf-migration-itcase-flink1.1-rocksdb-savepoint", - new Tuple2<>(EXPECTED_ELEMENTS_ACCUMULATOR, NUM_SOURCE_ELEMENTS)); - } - - @Test - public void testSavepointRestoreFromFlink11() throws Exception { - - final int expectedSuccessfulChecks = 21; - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); - // we only test memory state backend yet - env.setStateBackend(new MemoryStateBackend()); - env.enableCheckpointing(500); - env.setParallelism(4); - env.setMaxParallelism(4); - - // create source - env - .addSource(new RestoringCheckingSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") - .flatMap(new RestoringCheckingFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") - .keyBy(0) - .flatMap(new RestoringCheckingFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") - .keyBy(0) - .flatMap(new KeyedStateCheckingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") - .keyBy(0) - .transform( - "custom_operator", - new TypeHint>() {}.getTypeInfo(), - new RestoringCheckingUdfOperator(new RestoringCheckingFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") - .addSink(new AccumulatorCountingSink>(EXPECTED_ELEMENTS_ACCUMULATOR)); - - restoreAndExecute( - env, - getResourceFilename("stateful-udf-migration-itcase-flink1.1-savepoint"), - new Tuple2<>(SUCCESSFUL_CHECK_ACCUMULATOR, expectedSuccessfulChecks)); - } - - @Test - public void testSavepointRestoreFromFlink11FromRocksDB() throws Exception { - - final int expectedSuccessfulChecks = 21; - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); - // we only test memory state backend yet - env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend())); - env.enableCheckpointing(500); - env.setParallelism(4); - env.setMaxParallelism(4); - - // create source - env - .addSource(new RestoringCheckingSource(NUM_SOURCE_ELEMENTS)).setMaxParallelism(1).uid("LegacyCheckpointedSource") - .flatMap(new RestoringCheckingFlatMap()).startNewChain().uid("LegacyCheckpointedFlatMap") - .keyBy(0) - .flatMap(new RestoringCheckingFlatMapWithKeyedState()).startNewChain().uid("LegacyCheckpointedFlatMapWithKeyedState") - .keyBy(0) - .flatMap(new KeyedStateCheckingFlatMap()).startNewChain().uid("KeyedStateSettingFlatMap") - .keyBy(0) - .transform( - "custom_operator", - new TypeHint>() {}.getTypeInfo(), - new RestoringCheckingUdfOperator(new RestoringCheckingFlatMapWithKeyedState())).uid("LegacyCheckpointedOperator") - .addSink(new AccumulatorCountingSink>(EXPECTED_ELEMENTS_ACCUMULATOR)); - - restoreAndExecute( - env, - getResourceFilename("stateful-udf-migration-itcase-flink1.1-rocksdb-savepoint"), - new Tuple2<>(SUCCESSFUL_CHECK_ACCUMULATOR, expectedSuccessfulChecks)); - } - - private static class LegacyCheckpointedSource - implements SourceFunction>, Checkpointed { - - public static String checkpointedString = "Here be dragons!"; - - private static final long serialVersionUID = 1L; - - private volatile boolean isRunning = true; - - private final int numElements; - - public LegacyCheckpointedSource(int numElements) { - this.numElements = numElements; - } - - @Override - public void run(SourceContext> ctx) throws Exception { - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - ctx.collect(new Tuple2<>(i, i)); - } - } - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - - @Override - public void restoreState(String state) throws Exception { - assertEquals(checkpointedString, state); - } - - @Override - public String snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedString; - } - } - - private static class RestoringCheckingSource - extends RichSourceFunction> - implements CheckpointedRestoring { - - private static final long serialVersionUID = 1L; - - private volatile boolean isRunning = true; - - private final int numElements; - - private String restoredState; - - public RestoringCheckingSource(int numElements) { - this.numElements = numElements; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); - } - - @Override - public void run(SourceContext> ctx) throws Exception { - assertEquals(LegacyCheckpointedSource.checkpointedString, restoredState); - getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); - - synchronized (ctx.getCheckpointLock()) { - for (long i = 0; i < numElements; i++) { - ctx.collect(new Tuple2<>(i, i)); - } - } - - while (isRunning) { - Thread.sleep(20); - } - } - - @Override - public void cancel() { - isRunning = false; - } - - @Override - public void restoreState(String state) throws Exception { - restoredState = state; - } - } - - private static class LegacyCheckpointedFlatMap extends RichFlatMapFunction, Tuple2> - implements Checkpointed> { - - private static final long serialVersionUID = 1L; - - public static Tuple2 checkpointedTuple = - new Tuple2<>("hello", 42L); - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - } - - @Override - public void restoreState(Tuple2 state) throws Exception { - } - - @Override - public Tuple2 snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedTuple; - } - } - - private static class RestoringCheckingFlatMap extends RichFlatMapFunction, Tuple2> - implements CheckpointedRestoring> { - - private static final long serialVersionUID = 1L; - - private transient Tuple2 restoredState; - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); - } - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - - assertEquals(LegacyCheckpointedFlatMap.checkpointedTuple, restoredState); - getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); - - } - - @Override - public void restoreState(Tuple2 state) throws Exception { - restoredState = state; - } - } - - private static class LegacyCheckpointedFlatMapWithKeyedState - extends RichFlatMapFunction, Tuple2> - implements Checkpointed> { - - private static final long serialVersionUID = 1L; - - public static Tuple2 checkpointedTuple = - new Tuple2<>("hello", 42L); - - private final ValueStateDescriptor stateDescriptor = - new ValueStateDescriptor("state-name", LongSerializer.INSTANCE); - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - - getRuntimeContext().getState(stateDescriptor).update(value.f1); - } - - @Override - public void restoreState(Tuple2 state) throws Exception { - } - - @Override - public Tuple2 snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedTuple; - } - } - - private static class RestoringCheckingFlatMapWithKeyedState extends RichFlatMapFunction, Tuple2> - implements CheckpointedRestoring> { - - private static final long serialVersionUID = 1L; - - private transient Tuple2 restoredState; - - private final ValueStateDescriptor stateDescriptor = - new ValueStateDescriptor("state-name", LongSerializer.INSTANCE); - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); - } - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - - ValueState state = getRuntimeContext().getState(stateDescriptor); - if (state == null) { - throw new RuntimeException("Missing key value state for " + value); - } - - assertEquals(value.f1, state.value()); - assertEquals(LegacyCheckpointedFlatMap.checkpointedTuple, restoredState); - getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); - } - - @Override - public void restoreState(Tuple2 state) throws Exception { - restoredState = state; - } - } - - private static class KeyedStateSettingFlatMap extends RichFlatMapFunction, Tuple2> { - - private static final long serialVersionUID = 1L; - - private final ValueStateDescriptor stateDescriptor = - new ValueStateDescriptor("state-name", LongSerializer.INSTANCE); - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - - getRuntimeContext().getState(stateDescriptor).update(value.f1); - } - } - - private static class KeyedStateCheckingFlatMap extends RichFlatMapFunction, Tuple2> { - - private static final long serialVersionUID = 1L; - - private final ValueStateDescriptor stateDescriptor = - new ValueStateDescriptor("state-name", LongSerializer.INSTANCE); - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR, new IntCounter()); - } - - @Override - public void flatMap(Tuple2 value, Collector> out) throws Exception { - out.collect(value); - - ValueState state = getRuntimeContext().getState(stateDescriptor); - if (state == null) { - throw new RuntimeException("Missing key value state for " + value); - } - - assertEquals(value.f1, state.value()); - getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); - } - } - - private static class CheckpointedUdfOperator - extends AbstractUdfStreamOperator, FlatMapFunction, Tuple2>> - implements OneInputStreamOperator, Tuple2> { - private static final long serialVersionUID = 1L; - - private static final String CHECKPOINTED_STRING = "Oh my, that's nice!"; - - public CheckpointedUdfOperator(FlatMapFunction, Tuple2> userFunction) { - super(userFunction); - } - - @Override - public void processElement(StreamRecord> element) throws Exception { - output.collect(element); - } - - @Override - public void processWatermark(Watermark mark) throws Exception { - output.emitWatermark(mark); - } - - // Flink 1.1 -// @Override -// public StreamTaskState snapshotOperatorState( -// long checkpointId, long timestamp) throws Exception { -// StreamTaskState result = super.snapshotOperatorState(checkpointId, timestamp); -// -// AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView( -// checkpointId, -// timestamp); -// -// out.writeUTF(checkpointedString); -// -// result.setOperatorState(out.closeAndGetHandle()); -// -// return result; -// } - } - - private static class RestoringCheckingUdfOperator - extends AbstractUdfStreamOperator, FlatMapFunction, Tuple2>> - implements OneInputStreamOperator, Tuple2> { - private static final long serialVersionUID = 1L; - - private String restoredState; - - public RestoringCheckingUdfOperator(FlatMapFunction, Tuple2> userFunction) { - super(userFunction); - } - - @Override - public void open() throws Exception { - super.open(); - } - - @Override - public void processElement(StreamRecord> element) throws Exception { - userFunction.flatMap(element.getValue(), new TimestampedCollector<>(output)); - - assertEquals(CheckpointedUdfOperator.CHECKPOINTED_STRING, restoredState); - getRuntimeContext().getAccumulator(SUCCESSFUL_CHECK_ACCUMULATOR).add(1); - } - - @Override - public void processWatermark(Watermark mark) throws Exception { - output.emitWatermark(mark); - } - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); - - DataInputViewStreamWrapper streamWrapper = new DataInputViewStreamWrapper(in); - - restoredState = streamWrapper.readUTF(); - } - } - - private static class AccumulatorCountingSink extends RichSinkFunction { - private static final long serialVersionUID = 1L; - - private final String accumulatorName; - - int count = 0; - - public AccumulatorCountingSink(String accumulatorName) { - this.accumulatorName = accumulatorName; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - getRuntimeContext().addAccumulator(accumulatorName, new IntCounter()); - } - - @Override - public void invoke(T value) throws Exception { - count++; - getRuntimeContext().getAccumulator(accumulatorName).add(1); - } - } -} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom12MigrationITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom12MigrationITCase.java index 7dd1144827167..6859c2d243b6b 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom12MigrationITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/StatefulJobSavepointFrom12MigrationITCase.java @@ -29,14 +29,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; -import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; @@ -60,10 +54,13 @@ /** * This verifies that we can restore a complete job from a Flink 1.2 savepoint. * - *

The test pipeline contains both "Checkpointed" state and keyed user state. + *

The test for checkpointed (legacy state) was removed from this test for Flink 1.4 because compatibility with + * Flink 1.1 is removed. The legacy state in the binary savepoints is ignored by the tests now. * *

The tests will time out if they don't see the required number of successful checks within * a time limit. + * + * */ public class StatefulJobSavepointFrom12MigrationITCase extends SavepointMigrationTestBase { private static final int NUM_SOURCE_ELEMENTS = 4; @@ -247,7 +244,7 @@ protected String getRocksDBSavepointPath() { } private static class LegacyCheckpointedSource - implements SourceFunction>, Checkpointed { + implements SourceFunction> { public static String checkpointedString = "Here be dragons!"; @@ -283,21 +280,10 @@ public void run(SourceContext> ctx) throws Exception { public void cancel() { isRunning = false; } - - @Override - public void restoreState(String state) throws Exception { - assertEquals(checkpointedString, state); - } - - @Override - public String snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedString; - } } private static class CheckingRestoringSource - extends RichSourceFunction> - implements CheckpointedRestoring { + extends RichSourceFunction> { private static final long serialVersionUID = 1L; @@ -322,7 +308,6 @@ public void open(Configuration parameters) throws Exception { @Override public void run(SourceContext> ctx) throws Exception { - assertEquals(LegacyCheckpointedSource.checkpointedString, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); // immediately trigger any set timers @@ -343,15 +328,9 @@ public void run(SourceContext> ctx) throws Exception { public void cancel() { isRunning = false; } - - @Override - public void restoreState(String state) throws Exception { - restoredState = state; - } } - private static class LegacyCheckpointedFlatMap extends RichFlatMapFunction, Tuple2> - implements Checkpointed> { + private static class LegacyCheckpointedFlatMap extends RichFlatMapFunction, Tuple2> { private static final long serialVersionUID = 1L; @@ -362,19 +341,9 @@ private static class LegacyCheckpointedFlatMap extends RichFlatMapFunction value, Collector> out) throws Exception { out.collect(value); } - - @Override - public void restoreState(Tuple2 state) throws Exception { - } - - @Override - public Tuple2 snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedTuple; - } } - private static class CheckingRestoringFlatMap extends RichFlatMapFunction, Tuple2> - implements CheckpointedRestoring> { + private static class CheckingRestoringFlatMap extends RichFlatMapFunction, Tuple2> { private static final long serialVersionUID = 1L; @@ -393,20 +362,14 @@ public void open(Configuration parameters) throws Exception { public void flatMap(Tuple2 value, Collector> out) throws Exception { out.collect(value); - assertEquals(LegacyCheckpointedFlatMap.checkpointedTuple, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); } - @Override - public void restoreState(Tuple2 state) throws Exception { - restoredState = state; - } } private static class LegacyCheckpointedFlatMapWithKeyedState - extends RichFlatMapFunction, Tuple2> - implements Checkpointed> { + extends RichFlatMapFunction, Tuple2> { private static final long serialVersionUID = 1L; @@ -424,19 +387,10 @@ public void flatMap(Tuple2 value, Collector> out) assertEquals(value.f1, getRuntimeContext().getState(stateDescriptor).value()); } - - @Override - public void restoreState(Tuple2 state) throws Exception { - } - - @Override - public Tuple2 snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return checkpointedTuple; - } } - private static class CheckingRestoringFlatMapWithKeyedState extends RichFlatMapFunction, Tuple2> - implements CheckpointedRestoring> { + private static class CheckingRestoringFlatMapWithKeyedState + extends RichFlatMapFunction, Tuple2> { private static final long serialVersionUID = 1L; @@ -464,18 +418,12 @@ public void flatMap(Tuple2 value, Collector> out) } assertEquals(value.f1, state.value()); - assertEquals(LegacyCheckpointedFlatMap.checkpointedTuple, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); } - - @Override - public void restoreState(Tuple2 state) throws Exception { - restoredState = state; - } } - private static class CheckingRestoringFlatMapWithKeyedStateInOperator extends RichFlatMapFunction, Tuple2> - implements CheckpointedRestoring> { + private static class CheckingRestoringFlatMapWithKeyedStateInOperator + extends RichFlatMapFunction, Tuple2> { private static final long serialVersionUID = 1L; @@ -503,14 +451,8 @@ public void flatMap(Tuple2 value, Collector> out) } assertEquals(value.f1, state.value()); - assertEquals(LegacyCheckpointedFlatMap.checkpointedTuple, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); } - - @Override - public void restoreState(Tuple2 state) throws Exception { - restoredState = state; - } } private static class KeyedStateSettingFlatMap extends RichFlatMapFunction, Tuple2> { @@ -578,17 +520,6 @@ public void processElement(StreamRecord> element) throws Exce public void processWatermark(Watermark mark) throws Exception { output.emitWatermark(mark); } - - @Override - public void snapshotState( - FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { - super.snapshotState(out, checkpointId, timestamp); - - DataOutputViewStreamWrapper streamWrapper = new DataOutputViewStreamWrapper(out); - - streamWrapper.writeUTF(CHECKPOINTED_STRING); - streamWrapper.flush(); - } } private static class CheckingRestoringUdfOperator @@ -615,8 +546,6 @@ public void open() throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { userFunction.flatMap(element.getValue(), new TimestampedCollector<>(output)); - - assertEquals(CheckpointedUdfOperator.CHECKPOINTED_STRING, restoredState); getRuntimeContext().getAccumulator(SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR).add(1); } @@ -624,15 +553,6 @@ public void processElement(StreamRecord> element) throws Exce public void processWatermark(Watermark mark) throws Exception { output.emitWatermark(mark); } - - @Override - public void restoreState(FSDataInputStream in) throws Exception { - super.restoreState(in); - - DataInputViewStreamWrapper streamWrapper = new DataInputViewStreamWrapper(in); - - restoredState = streamWrapper.readUTF(); - } } private static class TimelyStatefulOperator diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/LegacyCheckpointedStreamingProgram.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/LegacyCheckpointedStreamingProgram.java deleted file mode 100644 index 1431d9605b840..0000000000000 --- a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/LegacyCheckpointedStreamingProgram.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.classloading.jar; - -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.restartstrategy.RestartStrategies; -import org.apache.flink.runtime.state.CheckpointListener; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.sink.SinkFunction; -import org.apache.flink.streaming.api.functions.source.SourceFunction; - -/** - * This test is the same as the {@link CheckpointedStreamingProgram} but using the - * old and deprecated {@link Checkpointed} interface. It stays here in order to - * guarantee that although deprecated, the old Checkpointed interface is still supported. - * This is necessary to not break user code. - * */ -public class LegacyCheckpointedStreamingProgram { - - private static final int CHECKPOINT_INTERVALL = 100; - - public static void main(String[] args) throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.getConfig().disableSysoutLogging(); - env.enableCheckpointing(CHECKPOINT_INTERVALL); - env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 10000)); - env.disableOperatorChaining(); - - DataStream text = env.addSource(new SimpleStringGenerator()); - text.map(new StatefulMapper()).addSink(new NoOpSink()); - env.setParallelism(1); - env.execute("Checkpointed Streaming Program"); - } - - // with Checkpointing - private static class SimpleStringGenerator implements SourceFunction, Checkpointed { - - private static final long serialVersionUID = 3700033137820808611L; - - public boolean running = true; - - @Override - public void run(SourceContext ctx) throws Exception { - while (running) { - Thread.sleep(1); - ctx.collect("someString"); - } - } - - @Override - public void cancel() { - running = false; - } - - @Override - public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return null; - } - - @Override - public void restoreState(Integer state) { - - } - } - - private static class StatefulMapper implements MapFunction, Checkpointed, CheckpointListener { - - private static final long serialVersionUID = 2703630582894634440L; - - private String someState; - private boolean atLeastOneSnapshotComplete = false; - private boolean restored = false; - - @Override - public StatefulMapper snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { - return this; - } - - @Override - public void restoreState(StatefulMapper state) { - restored = true; - this.someState = state.someState; - this.atLeastOneSnapshotComplete = state.atLeastOneSnapshotComplete; - } - - @Override - public String map(String value) throws Exception { - if (!atLeastOneSnapshotComplete) { - // throttle consumption by the checkpoint interval until we have one snapshot. - Thread.sleep(CHECKPOINT_INTERVALL); - } - if (atLeastOneSnapshotComplete && !restored) { - throw new RuntimeException("Intended failure, to trigger restore"); - } - if (restored) { - throw new SuccessException(); - //throw new RuntimeException("All good"); - } - someState = value; // update our state - return value; - } - - @Override - public void notifyCheckpointComplete(long checkpointId) throws Exception { - atLeastOneSnapshotComplete = true; - } - } - // -------------------------------------------------------------------------------------------- - - /** - * We intentionally use a user specified failure exception. - */ - private static class SuccessException extends Exception { - - private static final long serialVersionUID = 7073311460437532086L; - } - - private static class NoOpSink implements SinkFunction { - private static final long serialVersionUID = 2381410324190818620L; - - @Override - public void invoke(String value) throws Exception { - } - } -} diff --git a/flink-tests/src/test/java/org/apache/flink/test/runtime/entrypoint/StreamingNoop.java b/flink-tests/src/test/java/org/apache/flink/test/runtime/entrypoint/StreamingNoop.java new file mode 100644 index 0000000000000..cd88ae10deff1 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/runtime/entrypoint/StreamingNoop.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.runtime.entrypoint; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.utils.ParameterTool; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.DiscardingSink; +import org.apache.flink.streaming.api.functions.source.FileMonitoringFunction; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.ObjectOutputStream; + +/** + * A program to generate a job graph for entrypoint testing purposes. + * + *

The dataflow is a simple streaming program that continuously monitors a (non-existent) directory. + * Note that the job graph doesn't depend on any user code; it uses in-built Flink classes only. + * + *

Program arguments: + * --output [graph file] (default: 'job.graph') + */ +public class StreamingNoop { + public static void main(String[] args) throws Exception { + ParameterTool params = ParameterTool.fromArgs(args); + + // define the dataflow + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(2); + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(10, 1000)); + env.readFileStream("input/", 60000, FileMonitoringFunction.WatchType.ONLY_NEW_FILES) + .addSink(new DiscardingSink()); + + // generate a job graph + final JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + File jobGraphFile = new File(params.get("output", "job.graph")); + try (FileOutputStream output = new FileOutputStream(jobGraphFile); + ObjectOutputStream obOutput = new ObjectOutputStream(output)){ + obOutput.writeObject(jobGraph); + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/AbstractOperatorRestoreTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/AbstractOperatorRestoreTestBase.java index 3d78242ac8993..00d0b2c130e48 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/AbstractOperatorRestoreTestBase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/state/operator/restore/AbstractOperatorRestoreTestBase.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointSerializers; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils; @@ -88,6 +89,11 @@ public abstract class AbstractOperatorRestoreTestBase extends TestLogger { private static final FiniteDuration timeout = new FiniteDuration(30L, TimeUnit.SECONDS); + @BeforeClass + public static void beforeClass() { + SavepointSerializers.setFailWhenLegacyStateDetected(false); + } + @BeforeClass public static void setupCluster() throws Exception { final Configuration configuration = new Configuration(); diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/api/StreamingOperatorsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/api/StreamingOperatorsITCase.java index 32a04fa38b0bb..8a910d96dd063 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/api/StreamingOperatorsITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/api/StreamingOperatorsITCase.java @@ -30,12 +30,11 @@ import org.apache.flink.streaming.api.datastream.SplitStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; -import org.apache.flink.streaming.api.functions.async.collector.AsyncCollector; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; - import org.apache.flink.util.Collector; import org.apache.flink.util.MathUtils; @@ -243,11 +242,11 @@ public void close() throws Exception { @Override public void asyncInvoke(final Tuple2 input, - final AsyncCollector collector) throws Exception { + final ResultFuture resultFuture) throws Exception { executorService.submit(new Runnable() { @Override public void run() { - collector.collect(Collections.singletonList(input.f0 + input.f0)); + resultFuture.complete(Collections.singletonList(input.f0 + input.f0)); } }); } diff --git a/flink-yarn-tests/src/test/scala/org/apache/flink/yarn/TestingYarnJobManager.scala b/flink-yarn-tests/src/test/scala/org/apache/flink/yarn/TestingYarnJobManager.scala index b539961d8b638..bd72d6d5b721e 100644 --- a/flink-yarn-tests/src/test/scala/org/apache/flink/yarn/TestingYarnJobManager.scala +++ b/flink-yarn-tests/src/test/scala/org/apache/flink/yarn/TestingYarnJobManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{Executor, ScheduledExecutorService} import akka.actor.ActorRef import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager import org.apache.flink.runtime.executiongraph.restart.RestartStrategyFactory @@ -58,6 +59,7 @@ class TestingYarnJobManager( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: Scheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -73,6 +75,7 @@ class TestingYarnJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java b/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java index 55dc47f272dbe..0eb2cc5c4cf23 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java @@ -253,6 +253,7 @@ private void isReadyForDeployment(ClusterSpecification clusterSpecification) thr // The number of cores can be configured in the config. // If not configured, it is set to the number of task slots int numYarnVcores = conf.getInt(YarnConfiguration.NM_VCORES, YarnConfiguration.DEFAULT_NM_VCORES); + numYarnVcores = numYarnVcores <= 0 ? YarnConfiguration.DEFAULT_NM_VCORES : numYarnVcores; int configuredVcores = flinkConfiguration.getInteger(YarnConfigOptions.VCORES, clusterSpecification.getSlotsPerTaskManager()); // don't configure more than the maximum configured number of vcores if (configuredVcores > numYarnVcores) { @@ -991,7 +992,7 @@ private static List uploadAndRegisterFiles( for (File shipFile : shipFiles) { LocalResource shipResources = Records.newRecord(LocalResource.class); - Path shipLocalPath = new Path("file://" + shipFile.getAbsolutePath()); + Path shipLocalPath = new Path("file:///" + shipFile.getAbsolutePath()); Path remotePath = Utils.setupLocalResource(fs, appId, shipLocalPath, shipResources, fs.getHomeDirectory()); @@ -1351,7 +1352,13 @@ protected ContainerLaunchContext setupApplicationMasterContainer( ContainerLaunchContext amContainer = Records.newRecord(ContainerLaunchContext.class); final Map startCommandValues = new HashMap<>(); - startCommandValues.put("java", "$JAVA_HOME/bin/java"); + if (System.getProperty("os.name").toLowerCase().startsWith("windows")){ + startCommandValues.put("java", "%JAVA_HOME%/bin/java"); + } + else { + startCommandValues.put("java", "$JAVA_HOME/bin/java"); + } + startCommandValues.put("jvmmem", "-Xmx" + Utils.calculateHeapSize(jobManagerMemoryMb, flinkConfiguration) + "m"); diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnApplicationMasterRunner.java b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnApplicationMasterRunner.java index 913090136925d..e71644e953674 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnApplicationMasterRunner.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnApplicationMasterRunner.java @@ -62,6 +62,8 @@ import org.slf4j.LoggerFactory; import java.io.File; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -92,7 +94,7 @@ public class YarnApplicationMasterRunner { private static final FiniteDuration TASKMANAGER_REGISTRATION_TIMEOUT = new FiniteDuration(5, TimeUnit.MINUTES); /** The process environment variables. */ - private static final Map ENV = System.getenv(); + private static final Map ENV = getSystemEnv(); /** The exit code returned if the initialization of the application master failed. */ private static final int INIT_ERROR_EXIT_CODE = 31; @@ -100,6 +102,25 @@ public class YarnApplicationMasterRunner { /** The exit code returned if the process exits because a critical actor died. */ private static final int ACTOR_DIED_EXIT_CODE = 32; + /** + * Add this private static method to convert the hostname to lowercase. + */ + private static Map getSystemEnv(){ + final Map origSysEnv = System.getenv(); + final Map modifiedEnv = new HashMap<>(); + + for (Map.Entry entry : origSysEnv.entrySet()){ + modifiedEnv.put(entry.getKey(), entry.getValue()); + } + + String hostName = modifiedEnv.get(Environment.NM_HOST.key()); + if (hostName != null){ + modifiedEnv.put(Environment.NM_HOST.key(), hostName.toLowerCase()); + } + + return Collections.unmodifiableMap(modifiedEnv); + } + // ------------------------------------------------------------------------ // Program entry point // ------------------------------------------------------------------------ diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnClusterDescriptorV2.java b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnClusterDescriptorV2.java index 00b73a863ce1f..3e58da5ae5890 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnClusterDescriptorV2.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnClusterDescriptorV2.java @@ -18,9 +18,7 @@ package org.apache.flink.yarn; -import org.apache.flink.client.deployment.ClusterSpecification; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.yarn.entrypoint.YarnJobClusterEntrypoint; import org.apache.flink.yarn.entrypoint.YarnSessionClusterEntrypoint; @@ -45,9 +43,4 @@ protected String getYarnSessionClusterEntrypoint() { protected String getYarnJobClusterEntrypoint() { return YarnJobClusterEntrypoint.class.getName(); } - - @Override - public YarnClusterClient deployJobCluster(ClusterSpecification clusterSpecification, JobGraph jobGraph) { - throw new UnsupportedOperationException("Cannot yet deploy a per-job yarn cluster."); - } } diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnResourceManager.java b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnResourceManager.java index fb1a1c3d1d1e1..dd12fefaedc00 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/YarnResourceManager.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/YarnResourceManager.java @@ -28,7 +28,6 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile; import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; -import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.resourcemanager.JobLeaderIdService; import org.apache.flink.runtime.resourcemanager.ResourceManager; @@ -228,7 +227,7 @@ public void startNewWorker(ResourceProfile resourceProfile) { } @Override - public void stopWorker(InstanceID instanceId) { + public void stopWorker(ResourceID resourceID) { // TODO: Implement to stop the worker } @@ -294,7 +293,7 @@ public void onNodesUpdated(List list) { @Override public void onError(Throwable error) { - onFatalErrorAsync(error); + onFatalError(error); } //Utility methods diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnIntraNonHaMasterServices.java b/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnIntraNonHaMasterServices.java index 75f8c0a240384..86db1c42c4855 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnIntraNonHaMasterServices.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnIntraNonHaMasterServices.java @@ -71,6 +71,9 @@ public class YarnIntraNonHaMasterServices extends AbstractYarnNonHaServices { /** The embedded leader election service used by JobManagers to find the resource manager. */ private final SingleLeaderElectionService resourceManagerLeaderElectionService; + /** The embedded leader election service for the dispatcher. */ + private final SingleLeaderElectionService dispatcherLeaderElectionService; + // ------------------------------------------------------------------------ /** @@ -100,6 +103,7 @@ public YarnIntraNonHaMasterServices( try { this.dispatcher = Executors.newSingleThreadExecutor(new ServicesThreadFactory()); this.resourceManagerLeaderElectionService = new SingleLeaderElectionService(dispatcher, DEFAULT_LEADER_ID); + this.dispatcherLeaderElectionService = new SingleLeaderElectionService(dispatcher, DEFAULT_LEADER_ID); // all good! successful = true; @@ -129,6 +133,17 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { } } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + enter(); + + try { + return dispatcherLeaderElectionService.createLeaderRetrievalService(); + } finally { + exit(); + } + } + @Override public LeaderElectionService getResourceManagerLeaderElectionService() { enter(); @@ -140,6 +155,16 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { } } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + enter(); + try { + return dispatcherLeaderElectionService; + } finally { + exit(); + } + } + @Override public LeaderElectionService getJobManagerLeaderElectionService(JobID jobID) { enter(); diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnPreConfiguredMasterNonHaServices.java b/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnPreConfiguredMasterNonHaServices.java index 6686a52fa0603..c1466d21da692 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnPreConfiguredMasterNonHaServices.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/highavailability/YarnPreConfiguredMasterNonHaServices.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; +import org.apache.flink.runtime.dispatcher.Dispatcher; import org.apache.flink.runtime.highavailability.HighAvailabilityServices; import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils; import org.apache.flink.runtime.leaderelection.LeaderElectionService; @@ -66,6 +67,9 @@ public class YarnPreConfiguredMasterNonHaServices extends AbstractYarnNonHaServi /** The RPC URL under which the single ResourceManager can be reached while available. */ private final String resourceManagerRpcUrl; + /** The RPC URL under which the single Dispatcher can be reached while available. */ + private final String dispatcherRpcUrl; + // ------------------------------------------------------------------------ /** @@ -116,6 +120,13 @@ public YarnPreConfiguredMasterNonHaServices( addressResolution, config); + this.dispatcherRpcUrl = AkkaRpcServiceUtils.getRpcUrl( + rmHost, + rmPort, + Dispatcher.DISPATCHER_NAME, + addressResolution, + config); + // all well! successful = true; } @@ -144,6 +155,17 @@ public LeaderRetrievalService getResourceManagerLeaderRetriever() { } } + @Override + public LeaderRetrievalService getDispatcherLeaderRetriever() { + enter(); + + try { + return new StandaloneLeaderRetrievalService(dispatcherRpcUrl, DEFAULT_LEADER_ID); + } finally { + exit(); + } + } + @Override public LeaderElectionService getResourceManagerLeaderElectionService() { enter(); @@ -155,6 +177,16 @@ public LeaderElectionService getResourceManagerLeaderElectionService() { } } + @Override + public LeaderElectionService getDispatcherLeaderElectionService() { + enter(); + try { + throw new UnsupportedOperationException("Not supported on the TaskManager side"); + } finally { + exit(); + } + } + @Override public LeaderElectionService getJobManagerLeaderElectionService(JobID jobID) { enter(); diff --git a/flink-yarn/src/main/scala/org/apache/flink/yarn/YarnJobManager.scala b/flink-yarn/src/main/scala/org/apache/flink/yarn/YarnJobManager.scala index a2d166854fb1c..b8dacee3b90b2 100644 --- a/flink-yarn/src/main/scala/org/apache/flink/yarn/YarnJobManager.scala +++ b/flink-yarn/src/main/scala/org/apache/flink/yarn/YarnJobManager.scala @@ -24,6 +24,7 @@ import java.util.concurrent.{Executor, ScheduledExecutorService, TimeUnit} import akka.actor.ActorRef import org.apache.flink.configuration.{Configuration => FlinkConfiguration} import org.apache.flink.core.fs.Path +import org.apache.flink.runtime.blob.BlobServer import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory import org.apache.flink.runtime.clusterframework.ContaineredJobManager import org.apache.flink.runtime.clusterframework.messages.StopCluster @@ -49,7 +50,8 @@ import scala.language.postfixOps * @param instanceManager Instance manager to manage the registered * [[org.apache.flink.runtime.taskmanager.TaskManager]] * @param scheduler Scheduler to schedule Flink jobs - * @param libraryCacheManager Manager to manage uploaded jar files + * @param blobServer BLOB store for file uploads + * @param libraryCacheManager manages uploaded jar files and class paths * @param archive Archive for finished Flink jobs * @param restartStrategyFactory Restart strategy to be used in case of a job recovery * @param timeout Timeout for futures @@ -61,6 +63,7 @@ class YarnJobManager( ioExecutor: Executor, instanceManager: InstanceManager, scheduler: FlinkScheduler, + blobServer: BlobServer, libraryCacheManager: BlobLibraryCacheManager, archive: ActorRef, restartStrategyFactory: RestartStrategyFactory, @@ -76,6 +79,7 @@ class YarnJobManager( ioExecutor, instanceManager, scheduler, + blobServer, libraryCacheManager, archive, restartStrategyFactory, diff --git a/pom.xml b/pom.xml index 6ed08fdac8374..948fb78ec1aa0 100644 --- a/pom.xml +++ b/pom.xml @@ -84,8 +84,9 @@ under the License. UTF-8 UTF-8 - 2.4.1 - + never-match-me + 2.8.0 1C @@ -220,10 +221,15 @@ under the License. netty here. It will overwrite Hadoop's guava dependency (even though we handle it separatly in the flink-shaded-hadoop module). - We can use all guava versions everywhere by adding it directly as a dependency to each project. --> + + org.apache.flink + flink-shaded-guava + 18.0-1.0 + + com.google.code.findbugs @@ -1259,19 +1265,8 @@ under the License. shading, the root pom would have to be Scala suffixed and thereby all other modules. --> org.apache.flink:force-shading - com.google.guava:* - - - com.google - org.apache.flink.shaded.com.google - - com.google.protobuf.** - com.google.inject.** - - - diff --git a/test-infra/end-to-end-test/common.sh b/test-infra/end-to-end-test/common.sh index dea80fafd482a..26e1522b162b8 100644 --- a/test-infra/end-to-end-test/common.sh +++ b/test-infra/end-to-end-test/common.sh @@ -70,6 +70,7 @@ function stop_cluster { | grep -v "NoAvailableBrokersException" \ | grep -v "Async Kafka commit failed" \ | grep -v "DisconnectException" \ + | grep -v "AskTimeoutException" \ | grep -iq "error"; then echo "Found error in log files:" cat $FLINK_DIR/log/* @@ -80,6 +81,7 @@ function stop_cluster { | grep -v "NoAvailableBrokersException" \ | grep -v "Async Kafka commit failed" \ | grep -v "DisconnectException" \ + | grep -v "AskTimeoutException" \ | grep -iq "exception"; then echo "Found exception in log files:" cat $FLINK_DIR/log/* diff --git a/tools/maven/checkstyle.xml b/tools/maven/checkstyle.xml index 3f78054d4471a..514453ee5f403 100644 --- a/tools/maven/checkstyle.xml +++ b/tools/maven/checkstyle.xml @@ -211,7 +211,7 @@ This file is based on the checkstyle file of Apache Beam. - + diff --git a/tools/maven/scalastyle-config.xml b/tools/maven/scalastyle-config.xml index 0f7f6bbcb0484..848b2afc1ba0f 100644 --- a/tools/maven/scalastyle-config.xml +++ b/tools/maven/scalastyle-config.xml @@ -86,7 +86,7 @@ - + diff --git a/tools/maven/suppressions.xml b/tools/maven/suppressions.xml index 8a80341289dca..b19435eff4cc4 100644 --- a/tools/maven/suppressions.xml +++ b/tools/maven/suppressions.xml @@ -27,4 +27,13 @@ under the License. + + + + +