From 09f7af98f48aeb91984eac0173a8d1dedbd37ac5 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 15:18:11 +0800 Subject: [PATCH 01/11] add ObjectChecker --- .../flink/api/scala/ObjectChecker.scala | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala new file mode 100644 index 0000000000000..109aa4677c358 --- /dev/null +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala + +import org.apache.flink.annotation.Internal +import org.apache.flink.api.common.InvalidProgramException + +/** + * Scala Object checker tries to verify if a class is implemented by + * Scala Object + */ +@Internal +object ObjectChecker { + def isSingleton[A](a: A)(implicit ev: A <:< Singleton = null) = + Option(ev).isDefined + + def assertScalaSingleton[A](a: A) = { + if (isSingleton(a)) { + val msg = "User defined function implemented by class " + a.getClass.getName + + " might be implemented by a Scala Object,it is forbidden by Flink since concurrent modification risks." + throw new InvalidProgramException(msg) + } + } +} From 34b4fa5216b3a4939ce3016182f7591fad0f50da Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 15:21:30 +0800 Subject: [PATCH 02/11] add if check scala singleton in Execution config --- .../flink/api/common/ExecutionConfig.java | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java index 3cde5e76dcce7..b2168d41903ee 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java @@ -91,6 +91,8 @@ public class ExecutionConfig implements Serializable, Archiveable Date: Mon, 14 Nov 2016 15:53:24 +0800 Subject: [PATCH 03/11] add check for stream related class --- .../api/scala/AllWindowedStream.scala | 34 ++++++++----- .../api/scala/CoGroupedStreams.scala | 20 +++++--- .../api/scala/ConnectedStreams.scala | 24 ++++++--- .../streaming/api/scala/DataStream.scala | 27 ++++++---- .../streaming/api/scala/JoinedStreams.scala | 22 +++++--- .../streaming/api/scala/KeyedStream.scala | 10 ++-- .../scala/StreamExecutionEnvironment.scala | 51 +++++++++++-------- .../streaming/api/scala/WindowedStream.scala | 33 +++++++----- 8 files changed, 138 insertions(+), 83 deletions(-) diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala index 83104e8318f8f..8b381c9fdda5c 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala @@ -111,7 +111,7 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { * @return The data stream that is the result of applying the reduce function to the window. */ def reduce(function: ReduceFunction[T]): DataStream[T] = { - asScalaStream(javaStream.reduce(clean(function))) + asScalaStream(javaStream.reduce(clean(check(function)))) } /** @@ -133,7 +133,7 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { if (function == null) { throw new NullPointerException("Reduce function must not be null.") } - val cleanFun = clean(function) + val cleanFun = clean(check(function)) val reducer = new ScalaReduceFunction[T](cleanFun) reduce(reducer) @@ -227,7 +227,7 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { if (function == null) { throw new NullPointerException("Fold function must not be null.") } - val cleanFun = clean(function) + val cleanFun = clean(check(function)) val folder = new ScalaFoldFunction[T,R](cleanFun) fold(initialValue, folder) @@ -312,7 +312,7 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { def apply[R: TypeInformation]( function: AllWindowFunction[T, R, W]): DataStream[R] = { - val cleanedFunction = clean(function) + val cleanedFunction = clean(check(function)) val javaFunction = new ScalaAllWindowFunctionWrapper[T, R, W](cleanedFunction) asScalaStream(javaStream.apply(javaFunction, implicitly[TypeInformation[R]])) @@ -332,7 +332,7 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { def apply[R: TypeInformation]( function: (W, Iterable[T], Collector[R]) => Unit): DataStream[R] = { - val cleanedFunction = clean(function) + val cleanedFunction = clean(check(function)) val applyFunction = new ScalaAllWindowFunction[T, R, W](cleanedFunction) asScalaStream(javaStream.apply(applyFunction, implicitly[TypeInformation[R]])) @@ -355,8 +355,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { preAggregator: ReduceFunction[T], windowFunction: AllWindowFunction[T, R, W]): DataStream[R] = { - val cleanedReducer = clean(preAggregator) - val cleanedWindowFunction = clean(windowFunction) + val cleanedReducer = clean(check(preAggregator)) + val cleanedWindowFunction = clean(check(windowFunction)) val applyFunction = new ScalaAllWindowFunctionWrapper[T, R, W](cleanedWindowFunction) @@ -388,8 +388,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { throw new NullPointerException("WindowApply function must not be null.") } - val cleanReducer = clean(preAggregator) - val cleanWindowFunction = clean(windowFunction) + val cleanReducer = clean(check(preAggregator)) + val cleanWindowFunction = clean(check(windowFunction)) val reducer = new ScalaReduceFunction[T](cleanReducer) val applyFunction = new ScalaAllWindowFunction[T, R, W](cleanWindowFunction) @@ -417,8 +417,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { preAggregator: FoldFunction[T, R], windowFunction: AllWindowFunction[R, R, W]): DataStream[R] = { - val cleanFolder = clean(preAggregator) - val cleanWindowFunction = clean(windowFunction) + val cleanFolder = clean(check(preAggregator)) + val cleanWindowFunction = clean(check(windowFunction)) val applyFunction = new ScalaAllWindowFunctionWrapper[R, R, W](cleanWindowFunction) @@ -455,9 +455,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { throw new NullPointerException("WindowApply function must not be null.") } - val cleanFolder = clean(preAggregator) - val cleanWindowFunction = clean(windowFunction) - + val cleanFolder = clean(check(preAggregator)) + val cleanWindowFunction = clean(check(windowFunction)) val folder = new ScalaFoldFunction[T, R](cleanFolder) val applyFunction = new ScalaAllWindowFunction[R, R, W](cleanWindowFunction) @@ -568,6 +567,13 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f) } + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaCheck(f) + } + /** * Gets the output type. */ diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CoGroupedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CoGroupedStreams.scala index 52c53d57a51af..f2f7085bc2671 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CoGroupedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/CoGroupedStreams.scala @@ -62,7 +62,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { * Specifies a [[KeySelector]] for elements from the first input. */ def where[KEY: TypeInformation](keySelector: T1 => KEY): Where[KEY] = { - val cleanFun = clean(keySelector) + val cleanFun = clean(check(keySelector)) val keyType = implicitly[TypeInformation[KEY]] val javaSelector = new KeySelector[T1, KEY] with ResultTypeQueryable[KEY] { def getKey(in: T1) = cleanFun(in) @@ -85,7 +85,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { * Specifies a [[KeySelector]] for elements from the second input. */ def equalTo(keySelector: T2 => KEY): EqualTo = { - val cleanFun = clean(keySelector) + val cleanFun = clean(check(keySelector)) val localKeyType = keyType val javaSelector = new KeySelector[T2, KEY] with ResultTypeQueryable[KEY] { def getKey(in: T2) = cleanFun(in) @@ -112,7 +112,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { throw new UnsupportedOperationException( "You first need to specify KeySelectors for both inputs using where() and equalTo().") } - new WithWindow[W](clean(assigner), null, null) + new WithWindow[W](clean(check(assigner)), null, null) } /** @@ -159,7 +159,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { require(fun != null, "CoGroup function must not be null.") val coGrouper = new CoGroupFunction[T1, T2, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def coGroup( left: java.lang.Iterable[T1], right: java.lang.Iterable[T2], out: Collector[O]) = { @@ -178,7 +178,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { require(fun != null, "CoGroup function must not be null.") val coGrouper = new CoGroupFunction[T1, T2, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def coGroup( left: java.lang.Iterable[T1], right: java.lang.Iterable[T2], out: Collector[O]) = { @@ -202,7 +202,7 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { .window(windowAssigner) .trigger(trigger) .evictor(evictor) - .apply(clean(function), implicitly[TypeInformation[T]])) + .apply(clean(check(function)), implicitly[TypeInformation[T]])) } } @@ -216,4 +216,12 @@ class CoGroupedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { private[flink] def clean[F <: AnyRef](f: F): F = { new StreamExecutionEnvironment(input1.javaStream.getExecutionEnvironment).scalaClean(f) } + + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(input1.javaStream.getExecutionEnvironment).scalaCheck(f) + } + } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala index a7325a4dc60ae..d4bfc4a86eb9e 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala @@ -65,8 +65,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { if (fun1 == null || fun2 == null) { throw new NullPointerException("Map function must not be null.") } - val cleanFun1 = clean(fun1) - val cleanFun2 = clean(fun2) + val cleanFun1 = clean(check(fun1)) + val cleanFun2 = clean(check(fun2)) val comapper = new CoMapFunction[IN1, IN2, R] { def map1(in1: IN1): R = cleanFun1(in1) def map2(in2: IN2): R = cleanFun2(in2) @@ -176,8 +176,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { if (fun1 == null || fun2 == null) { throw new NullPointerException("FlatMap functions must not be null.") } - val cleanFun1 = clean(fun1) - val cleanFun2 = clean(fun2) + val cleanFun1 = clean(check(fun1)) + val cleanFun2 = clean(check(fun2)) val flatMapper = new CoFlatMapFunction[IN1, IN2, R] { def flatMap1(value: IN1, out: Collector[R]): Unit = cleanFun1(value, out) def flatMap2(value: IN2, out: Collector[R]): Unit = cleanFun2(value, out) @@ -203,8 +203,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { if (fun1 == null || fun2 == null) { throw new NullPointerException("FlatMap functions must not be null.") } - val cleanFun1 = clean(fun1) - val cleanFun2 = clean(fun2) + val cleanFun1 = clean(check(fun1)) + val cleanFun2 = clean(check(fun2)) val flatMapper = new CoFlatMapFunction[IN1, IN2, R] { def flatMap1(value: IN1, out: Collector[R]) = { cleanFun1(value) foreach out.collect } @@ -285,8 +285,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { val keyType1 = implicitly[TypeInformation[K1]] val keyType2 = implicitly[TypeInformation[K2]] - val cleanFun1 = clean(fun1) - val cleanFun2 = clean(fun2) + val cleanFun1 = clean(check(fun1)) + val cleanFun2 = clean(check(fun2)) val keyExtractor1 = new KeySelectorWithType[IN1, K1](cleanFun1, keyType1) val keyExtractor2 = new KeySelectorWithType[IN2, K2](cleanFun2, keyType2) @@ -302,6 +302,14 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f) } + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaCheck(f) + } + + @PublicEvolving def transform[R: TypeInformation]( functionName: String, diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala index dbc91bd90ba0d..a49d88c70bf40 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala @@ -313,7 +313,7 @@ class DataStream[T](stream: JavaStream[T]) { */ def keyBy[K: TypeInformation](fun: T => K): KeyedStream[T, K] = { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val keyType: TypeInformation[K] = implicitly[TypeInformation[K]] val keyExtractor = new KeySelector[T, K] with ResultTypeQueryable[K] { @@ -356,7 +356,7 @@ class DataStream[T](stream: JavaStream[T]) { : DataStream[T] = { val keyType = implicitly[TypeInformation[K]] - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val keyExtractor = new KeySelector[T, K] with ResultTypeQueryable[K] { def getKey(in: T) = cleanFun(in) @@ -492,7 +492,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("Map function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val mapper = new MapFunction[T, R] { def map(in: T): R = cleanFun(in) } @@ -533,7 +533,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("FlatMap function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val flatMapper = new FlatMapFunction[T, R] { def flatMap(in: T, out: Collector[R]) { cleanFun(in, out) } } @@ -548,7 +548,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("FlatMap function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val flatMapper = new FlatMapFunction[T, R] { def flatMap(in: T, out: Collector[R]) { cleanFun(in) foreach out.collect } } @@ -572,7 +572,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("Filter function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val filterFun = new FilterFunction[T] { def filter(in: T) = cleanFun(in) } @@ -677,7 +677,7 @@ class DataStream[T](stream: JavaStream[T]) { */ @deprecated def assignTimestamps(extractor: TimestampExtractor[T]): DataStream[T] = { - asScalaStream(stream.assignTimestamps(clean(extractor))) + asScalaStream(stream.assignTimestamps(clean(check(extractor)))) } /** @@ -757,7 +757,7 @@ class DataStream[T](stream: JavaStream[T]) { */ @PublicEvolving def assignAscendingTimestamps(extractor: T => Long): DataStream[T] = { - val cleanExtractor = clean(extractor) + val cleanExtractor = clean(check(extractor)) val extractorFunction = new AscendingTimestampExtractor[T] { def extractAscendingTimestamp(element: T): Long = { cleanExtractor(element) @@ -782,7 +782,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("OutputSelector must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val selector = new OutputSelector[T] { def select(in: T): java.lang.Iterable[String] = { cleanFun(in).toIterable.asJava @@ -955,7 +955,7 @@ class DataStream[T](stream: JavaStream[T]) { if (fun == null) { throw new NullPointerException("Sink function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val sinkFunction = new SinkFunction[T] { def invoke(in: T) = cleanFun(in) } @@ -970,6 +970,13 @@ class DataStream[T](stream: JavaStream[T]) { new StreamExecutionEnvironment(stream.getExecutionEnvironment).scalaClean(f) } + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaCheck(f) + } + /** * Transforms the [[DataStream]] by using a custom [[OneInputStreamOperator]]. * diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/JoinedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/JoinedStreams.scala index 93b5cc885b4ee..6ccb02c746bfa 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/JoinedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/JoinedStreams.scala @@ -60,7 +60,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { * Specifies a [[KeySelector]] for elements from the first input. */ def where[KEY: TypeInformation](keySelector: T1 => KEY): Where[KEY] = { - val cleanFun = clean(keySelector) + val cleanFun = clean(check(keySelector)) val keyType = implicitly[TypeInformation[KEY]] val javaSelector = new KeySelector[T1, KEY] with ResultTypeQueryable[KEY] { def getKey(in: T1) = cleanFun(in) @@ -83,7 +83,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { * Specifies a [[KeySelector]] for elements from the second input. */ def equalTo(keySelector: T2 => KEY): EqualTo = { - val cleanFun = clean(keySelector) + val cleanFun = clean(check(keySelector)) val localKeyType = keyType val javaSelector = new KeySelector[T2, KEY] with ResultTypeQueryable[KEY] { def getKey(in: T2) = cleanFun(in) @@ -110,7 +110,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { "You first need to specify KeySelectors for both inputs using where() and equalTo().") } - new WithWindow[W](clean(assigner), null, null) + new WithWindow[W](clean(check(assigner)), null, null) } /** @@ -153,7 +153,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { require(fun != null, "Join function must not be null.") val joiner = new FlatJoinFunction[T1, T2, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def join(left: T1, right: T2, out: Collector[O]) = { out.collect(cleanFun(left, right)) } @@ -169,7 +169,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { require(fun != null, "Join function must not be null.") val joiner = new FlatJoinFunction[T1, T2, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def join(left: T1, right: T2, out: Collector[O]) = { cleanFun(left, right, out) } @@ -191,7 +191,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { .window(windowAssigner) .trigger(trigger) .evictor(evictor) - .apply(clean(function), implicitly[TypeInformation[T]])) + .apply(clean(check(function)), implicitly[TypeInformation[T]])) } /** @@ -208,7 +208,7 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { .window(windowAssigner) .trigger(trigger) .evictor(evictor) - .apply(clean(function), implicitly[TypeInformation[T]])) + .apply(clean(check(function)), implicitly[TypeInformation[T]])) } } } @@ -221,4 +221,12 @@ class JoinedStreams[T1, T2](input1: DataStream[T1], input2: DataStream[T2]) { private[flink] def clean[F <: AnyRef](f: F): F = { new StreamExecutionEnvironment(input1.javaStream.getExecutionEnvironment).scalaClean(f) } + + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(input1.javaStream.getExecutionEnvironment).scalaCheck(f) + } + } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala index f2999b394fca4..26afd7e41faff 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala @@ -176,7 +176,7 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] if (fun == null) { throw new NullPointerException("Reduce function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val reducer = new ReduceFunction[T] { def reduce(v1: T, v2: T) : T = { cleanFun(v1, v2) } } @@ -209,7 +209,7 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] if (fun == null) { throw new NullPointerException("Fold function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val folder = new FoldFunction[T,R] { def fold(acc: R, v: T) = { cleanFun(acc, v) @@ -390,7 +390,7 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] throw new NullPointerException("Filter function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val stateTypeInfo: TypeInformation[S] = implicitly[TypeInformation[S]] val serializer: TypeSerializer[S] = stateTypeInfo.createSerializer(getExecutionConfig) @@ -419,7 +419,7 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] throw new NullPointerException("Map function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val stateTypeInfo: TypeInformation[S] = implicitly[TypeInformation[S]] val serializer: TypeSerializer[S] = stateTypeInfo.createSerializer(getExecutionConfig) @@ -448,7 +448,7 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] throw new NullPointerException("Flatmap function must not be null.") } - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) val stateTypeInfo: TypeInformation[S] = implicitly[TypeInformation[S]] val serializer: TypeSerializer[S] = stateTypeInfo.createSerializer(getExecutionConfig) diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala index 432e8ac6cd330..3ab00bd2dda3c 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala @@ -24,7 +24,7 @@ import org.apache.flink.api.common.io.{FileInputFormat, FilePathFilter, InputFor import org.apache.flink.api.common.restartstrategy.RestartStrategies.RestartStrategyConfiguration import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer -import org.apache.flink.api.scala.ClosureCleaner +import org.apache.flink.api.scala.{ClosureCleaner, ObjectChecker} import org.apache.flink.runtime.state.AbstractStateBackend import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaEnv} import org.apache.flink.streaming.api.functions.source._ @@ -119,13 +119,13 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { // ------------------------------------------------------------------------ // Checkpointing Settings // ------------------------------------------------------------------------ - + /** * Gets the checkpoint config, which defines values like checkpoint interval, delay between * checkpoints, etc. */ def getCheckpointConfig = javaEnv.getCheckpointConfig() - + /** * Enables checkpointing for the streaming job. The distributed state of the streaming * dataflow will be periodically snapshotted. In case of a failure, the streaming @@ -165,12 +165,12 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { * * NOTE: Checkpointing iterative streaming dataflows in not properly supported at * the moment. For that reason, iterative jobs will not be started if used - * with enabled checkpointing. To override this mechanism, use the + * with enabled checkpointing. To override this mechanism, use the * [[enableCheckpointing(long, CheckpointingMode, boolean)]] method. * - * @param interval + * @param interval * Time interval between state checkpoints in milliseconds. - * @param mode + * @param mode * The checkpointing mode, selecting between "exactly once" and "at least once" guarantees. */ def enableCheckpointing(interval : Long, @@ -190,10 +190,10 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { * * NOTE: Checkpointing iterative streaming dataflows in not properly supported at * the moment. For that reason, iterative jobs will not be started if used - * with enabled checkpointing. To override this mechanism, use the + * with enabled checkpointing. To override this mechanism, use the * [[enableCheckpointing(long, CheckpointingMode, boolean)]] method. * - * @param interval + * @param interval * Time interval between state checkpoints in milliseconds. */ def enableCheckpointing(interval : Long) : StreamExecutionEnvironment = { @@ -214,7 +214,7 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { javaEnv.enableCheckpointing() this } - + def getCheckpointingMode = javaEnv.getCheckpointingMode() /** @@ -222,11 +222,11 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { * It defines in what form the key/value state, accessible from operations on * [[KeyedStream]] is maintained (heap, managed memory, externally), and where state * snapshots/checkpoints are stored, both for the key/value state, and for checkpointed - * functions (implementing the interface + * functions (implementing the interface * [[org.apache.flink.streaming.api.checkpoint.Checkpointed]]. * *

The [[org.apache.flink.runtime.state.memory.MemoryStateBackend]] for example - * maintains the state in heap memory, as objects. It is lightweight without extra + * maintains the state in heap memory, as objects. It is lightweight without extra * dependencies, but can checkpoint only small states (some counters). * *

In contrast, the [[org.apache.flink.runtime.state.filesystem.FsStateBackend]] @@ -582,17 +582,17 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { /** * Create a DataStream using a user defined source function for arbitrary - * source functionality. By default sources have a parallelism of 1. - * To enable parallel execution, the user defined source should implement - * ParallelSourceFunction or extend RichParallelSourceFunction. - * In these cases the resulting source will have the parallelism of the environment. + * source functionality. By default sources have a parallelism of 1. + * To enable parallel execution, the user defined source should implement + * ParallelSourceFunction or extend RichParallelSourceFunction. + * In these cases the resulting source will have the parallelism of the environment. * To change this afterwards call DataStreamSource.setParallelism(int) * */ def addSource[T: TypeInformation](function: SourceFunction[T]): DataStream[T] = { require(function != null, "Function must not be null.") - - val cleanFun = scalaClean(function) + + val cleanFun = scalaClean(scalaCheck(function)) val typeInfo = implicitly[TypeInformation[T]] asScalaStream(javaEnv.addSource(cleanFun).returns(typeInfo)) } @@ -604,7 +604,7 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { def addSource[T: TypeInformation](function: SourceContext[T] => Unit): DataStream[T] = { require(function != null, "Function must not be null.") val sourceFunction = new SourceFunction[T] { - val cleanFun = scalaClean(function) + val cleanFun = scalaClean(scalaCheck(function)) override def run(ctx: SourceContext[T]) { cleanFun(ctx) } @@ -617,7 +617,7 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { * Triggers the program execution. The environment will execute all parts of * the program that have resulted in a "sink" operation. Sink operations are * for example printing results or forwarding them to a message queue. - * + * * The program execution will be logged and displayed with a generated * default name. */ @@ -627,7 +627,7 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { * Triggers the program execution. The environment will execute all parts of * the program that have resulted in a "sink" operation. Sink operations are * for example printing results or forwarding them to a message queue. - * + * * The program execution will be logged and displayed with the provided name. */ def execute(jobName: String) = javaEnv.execute(jobName) @@ -668,6 +668,17 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) { } f } + + /** + * Check if the user defined function is a legal function. + * Currently, check if the function was implemented by a Scala Singleton Object. + */ + private[flink] def scalaCheck[F <: AnyRef](f: F): F = { + if (getConfig.isScalaObjectFunctionForbidden) { + ObjectChecker.assertScalaSingleton(f) + } + f + } } object StreamExecutionEnvironment { diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala index 76d9cdab0e829..49331b29e59fe 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala @@ -114,7 +114,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { * @return The data stream that is the result of applying the reduce function to the window. */ def reduce(function: ReduceFunction[T]): DataStream[T] = { - asScalaStream(javaStream.reduce(clean(function))) + asScalaStream(javaStream.reduce(clean(check(function)))) } /** @@ -136,7 +136,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { if (function == null) { throw new NullPointerException("Reduce function must not be null.") } - val cleanFun = clean(function) + val cleanFun = clean(check(function)) val reducer = new ScalaReduceFunction[T](cleanFun) reduce(reducer) } @@ -228,7 +228,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { if (function == null) { throw new NullPointerException("Fold function must not be null.") } - val cleanFun = clean(function) + val cleanFun = clean(check(function)) val folder = new ScalaFoldFunction[T, R](cleanFun) fold(initialValue, folder) } @@ -311,7 +311,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { def apply[R: TypeInformation]( function: WindowFunction[T, R, K, W]): DataStream[R] = { - val cleanFunction = clean(function) + val cleanFunction = clean(check(function)) val applyFunction = new ScalaWindowFunctionWrapper[T, R, K, W](cleanFunction) asScalaStream(javaStream.apply(applyFunction, implicitly[TypeInformation[R]])) } @@ -333,7 +333,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { throw new NullPointerException("WindowApply function must not be null.") } - val cleanedFunction = clean(function) + val cleanedFunction = clean(check(function)) val applyFunction = new ScalaWindowFunction[T, R, K, W](cleanedFunction) asScalaStream(javaStream.apply(applyFunction, implicitly[TypeInformation[R]])) @@ -356,8 +356,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { preAggregator: ReduceFunction[T], function: WindowFunction[T, R, K, W]): DataStream[R] = { - val cleanedPreAggregator = clean(preAggregator) - val cleanedWindowFunction = clean(function) + val cleanedPreAggregator = clean(check(preAggregator)) + val cleanedWindowFunction = clean(check(function)) val applyFunction = new ScalaWindowFunctionWrapper[T, R, K, W](cleanedWindowFunction) @@ -389,8 +389,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { throw new NullPointerException("WindowApply function must not be null.") } - val cleanReducer = clean(preAggregator) - val cleanWindowFunction = clean(windowFunction) + val cleanReducer = clean(check(preAggregator)) + val cleanWindowFunction = clean(check(windowFunction)) val reducer = new ScalaReduceFunction[T](cleanReducer) val applyFunction = new ScalaWindowFunction[T, R, K, W](cleanWindowFunction) @@ -417,8 +417,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { foldFunction: FoldFunction[T, R], function: WindowFunction[R, R, K, W]): DataStream[R] = { - val cleanedFunction = clean(function) - val cleanedFoldFunction = clean(foldFunction) + val cleanedFunction = clean(check(function)) + val cleanedFoldFunction = clean(check(foldFunction)) val applyFunction = new ScalaWindowFunctionWrapper[R, R, K, W](cleanedFunction) @@ -454,8 +454,8 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { throw new NullPointerException("WindowApply function must not be null.") } - val cleanFolder = clean(foldFunction) - val cleanWindowFunction = clean(windowFunction) + val cleanFolder = clean(check(foldFunction)) + val cleanWindowFunction = clean(check(windowFunction)) val folder = new ScalaFoldFunction[T, R](cleanFolder) val applyFunction = new ScalaWindowFunction[R, R, K, W](cleanWindowFunction) @@ -567,6 +567,13 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f) } + /** + * Check if the user defined function is a legal function. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaCheck(f) + } + /** * Gets the output type. */ From 340d35d3fd6a3da245b6aa31e818cb656c26af38 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 16:39:43 +0800 Subject: [PATCH 04/11] add check for DataSet related class --- .../scala/org/apache/flink/ml/package.scala | 6 +-- .../flink/api/scala/CoGroupDataSet.scala | 4 +- .../apache/flink/api/scala/CrossDataSet.scala | 2 +- .../org/apache/flink/api/scala/DataSet.scala | 43 ++++++++++++------- .../flink/api/scala/GroupedDataSet.scala | 8 ++-- .../apache/flink/api/scala/joinDataSet.scala | 4 +- .../scala/unfinishedKeyPairOperation.scala | 4 +- 7 files changed, 41 insertions(+), 30 deletions(-) diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala index 554e155201045..9c4001752a9a3 100644 --- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala @@ -54,20 +54,20 @@ package object ml { broadcastVariable: DataSet[B])( fun: (T, B) => O) : DataSet[O] = { - dataSet.map(new BroadcastSingleElementMapper[T, B, O](dataSet.clean(fun))) + dataSet.map(new BroadcastSingleElementMapper[T, B, O](dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } def filterWithBcVariable[B, O](broadcastVariable: DataSet[B])(fun: (T, B) => Boolean) : DataSet[T] = { - dataSet.filter(new BroadcastSingleElementFilter[T, B](dataSet.clean(fun))) + dataSet.filter(new BroadcastSingleElementFilter[T, B](dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } def mapWithBcVariableIteration[B, O: TypeInformation: ClassTag]( broadcastVariable: DataSet[B])(fun: (T, B, Int) => O) : DataSet[O] = { - dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O](dataSet.clean(fun))) + dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O](dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala index aa6b47b964923..b985d85c16268 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala @@ -84,7 +84,7 @@ class CoGroupDataSet[L, R]( fun: (Iterator[L], Iterator[R]) => O): DataSet[O] = { require(fun != null, "CoGroup function must not be null.") val coGrouper = new CoGroupFunction[L, R, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def coGroup(left: java.lang.Iterable[L], right: java.lang.Iterable[R], out: Collector[O]) = { out.collect(cleanFun(left.iterator().asScala, right.iterator().asScala)) } @@ -114,7 +114,7 @@ class CoGroupDataSet[L, R]( fun: (Iterator[L], Iterator[R], Collector[O]) => Unit): DataSet[O] = { require(fun != null, "CoGroup function must not be null.") val coGrouper = new CoGroupFunction[L, R, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def coGroup(left: java.lang.Iterable[L], right: java.lang.Iterable[R], out: Collector[O]) = { cleanFun(left.iterator.asScala, right.iterator.asScala, out) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala index 325aa27cf3f84..39c65a31b14a9 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala @@ -58,7 +58,7 @@ class CrossDataSet[L, R]( def apply[O: TypeInformation: ClassTag](fun: (L, R) => O): DataSet[O] = { require(fun != null, "Cross function must not be null.") val crosser = new CrossFunction[L, R, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def cross(left: L, right: R): O = { cleanFun(left, right) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 4e7be042901ff..19bd40cc19abb 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -128,6 +128,17 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { f } + /** + * Check if the user defined function is a legal function. + * Currently, check if the function was implemented by a Scala Singleton Object. + */ + private[flink] def check[F <: AnyRef](f: F): F = { + if (set.getExecutionEnvironment.getConfig.isScalaObjectFunctionForbidden) { + ObjectChecker.assertScalaSingleton(f) + } + f + } + // -------------------------------------------------------------------------------------------- // General methods // -------------------------------------------------------------------------------------------- @@ -294,7 +305,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("Map function must not be null.") } val mapper = new MapFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def map(in: T): R = cleanFun(in) } wrap(new MapOperator[T, R](javaSet, @@ -336,7 +347,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("MapPartition function must not be null.") } val partitionMapper = new MapPartitionFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def mapPartition(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala, out) } @@ -361,7 +372,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("MapPartition function must not be null.") } val partitionMapper = new MapPartitionFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def mapPartition(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala) foreach out.collect } @@ -395,7 +406,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("FlatMap function must not be null.") } val flatMapper = new FlatMapFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def flatMap(in: T, out: Collector[R]) { cleanFun(in, out) } } wrap(new FlatMapOperator[T, R](javaSet, @@ -413,7 +424,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("FlatMap function must not be null.") } val flatMapper = new FlatMapFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def flatMap(in: T, out: Collector[R]) { cleanFun(in) foreach out.collect } } wrap(new FlatMapOperator[T, R](javaSet, @@ -440,7 +451,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("Filter function must not be null.") } val filter = new FilterFunction[T] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def filter(in: T) = cleanFun(in) } wrap(new FilterOperator[T](javaSet, filter, getCallLocationName())) @@ -581,7 +592,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("Reduce function must not be null.") } val reducer = new ReduceFunction[T] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def reduce(v1: T, v2: T) = { cleanFun(v1, v2) } } wrap(new ReduceOperator[T](javaSet, reducer, getCallLocationName())) @@ -613,7 +624,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("GroupReduce function must not be null.") } val reducer = new GroupReduceFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def reduce(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala, out) } @@ -632,7 +643,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("GroupReduce function must not be null.") } val reducer = new GroupReduceFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def reduce(in: java.lang.Iterable[T], out: Collector[R]) { out.collect(cleanFun(in.iterator().asScala)) } @@ -688,7 +699,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { throw new NullPointerException("Combine function must not be null.") } val combiner = new GroupCombineFunction[T, R] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def combine(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala, out) } @@ -778,7 +789,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def distinct[K: TypeInformation](fun: T => K): DataSet[T] = { val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } wrap(new DistinctOperator[T]( @@ -853,7 +864,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { def groupBy[K: TypeInformation](fun: T => K): GroupedDataSet[T] = { val keyType = implicitly[TypeInformation[K]] val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } new GroupedDataSet[T](this, @@ -1394,7 +1405,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def partitionByHash[K: TypeInformation](fun: T => K): DataSet[T] = { val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } val op = new PartitionOperator[T]( @@ -1450,7 +1461,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def partitionByRange[K: TypeInformation](fun: T => K): DataSet[T] = { val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } val op = new PartitionOperator[T]( @@ -1512,7 +1523,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], fun: T => K) : DataSet[T] = { val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } @@ -1578,7 +1589,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={ val keyExtractor = new KeySelector[T, K] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def getKey(in: T) = cleanFun(in) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala index 72608b374b071..f0628762f2311 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala @@ -307,7 +307,7 @@ class GroupedDataSet[T: ClassTag]( strategy: CombineHint): DataSet[T] = { require(fun != null, "Reduce function must not be null.") val reducer = new ReduceFunction[T] { - val cleanFun = set.clean(fun) + val cleanFun = set.clean(set.check(fun)) def reduce(v1: T, v2: T) = { cleanFun(v1, v2) } @@ -350,7 +350,7 @@ class GroupedDataSet[T: ClassTag]( fun: (Iterator[T]) => R): DataSet[R] = { require(fun != null, "Group reduce function must not be null.") val reducer = new GroupReduceFunction[T, R] { - val cleanFun = set.clean(fun) + val cleanFun = set.clean(set.check(fun)) def reduce(in: java.lang.Iterable[T], out: Collector[R]) { out.collect(cleanFun(in.iterator().asScala)) } @@ -369,7 +369,7 @@ class GroupedDataSet[T: ClassTag]( fun: (Iterator[T], Collector[R]) => Unit): DataSet[R] = { require(fun != null, "Group reduce function must not be null.") val reducer = new GroupReduceFunction[T, R] { - val cleanFun = set.clean(fun) + val cleanFun = set.clean(set.check(fun)) def reduce(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala, out) } @@ -437,7 +437,7 @@ class GroupedDataSet[T: ClassTag]( fun: (Iterator[T], Collector[R]) => Unit): DataSet[R] = { require(fun != null, "GroupCombine function must not be null.") val combiner = new GroupCombineFunction[T, R] { - val cleanFun = set.clean(fun) + val cleanFun = set.clean(set.check(fun)) def combine(in: java.lang.Iterable[T], out: Collector[R]) { cleanFun(in.iterator().asScala, out) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala index 82435059ad91f..c039d0f782d45 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala @@ -75,7 +75,7 @@ class JoinDataSet[L, R]( def apply[O: TypeInformation: ClassTag](fun: (L, R) => O): DataSet[O] = { require(fun != null, "Join function must not be null.") val joiner = new FlatJoinFunction[L, R, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def join(left: L, right: R, out: Collector[O]) = { out.collect(cleanFun(left, right)) } @@ -106,7 +106,7 @@ class JoinDataSet[L, R]( def apply[O: TypeInformation: ClassTag](fun: (L, R, Collector[O]) => Unit): DataSet[O] = { require(fun != null, "Join function must not be null.") val joiner = new FlatJoinFunction[L, R, O] { - val cleanFun = clean(fun) + val cleanFun = clean(check(fun)) def join(left: L, right: R, out: Collector[O]) = { cleanFun(left, right, out) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala index 462007567d768..ca380e071aa0d 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala @@ -84,7 +84,7 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( def where[K: TypeInformation](fun: (L) => K) = { val keyType = implicitly[TypeInformation[K]] val keyExtractor = new KeySelector[L, K] { - val cleanFun = leftInput.clean(fun) + val cleanFun = leftInput.clean(leftInput.check(fun)) def getKey(in: L) = cleanFun(in) } val leftKey = new Keys.SelectorFunctionKeys[L, K](keyExtractor, leftInput.getType, keyType) @@ -133,7 +133,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( def equalTo[K: TypeInformation](fun: (R) => K): O = { val keyType = implicitly[TypeInformation[K]] val keyExtractor = new KeySelector[R, K] { - val cleanFun = unfinished.leftInput.clean(fun) + val cleanFun = unfinished.leftInput.clean(unfinished.leftInput.check(fun)) def getKey(in: R) = cleanFun(in) } val rightKey = new Keys.SelectorFunctionKeys[R, K]( From 226d6b28fdc67877395d022fcc64c101a712dda2 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 18:56:35 +0800 Subject: [PATCH 05/11] add test for ScalaObjectChecker --- .../flink/api/scala/ObjectChecker.scala | 18 ++++- .../api/scala/ScalaObjectCheckerTest.scala | 70 +++++++++++++++++++ .../scala/ScalaObjectcheckerStreamTest.scala | 52 ++++++++++++++ 3 files changed, 137 insertions(+), 3 deletions(-) create mode 100644 flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaObjectCheckerTest.scala create mode 100644 flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/ScalaObjectcheckerStreamTest.scala diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala index 109aa4677c358..2b6bdd2538be6 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala @@ -27,10 +27,22 @@ import org.apache.flink.api.common.InvalidProgramException */ @Internal object ObjectChecker { - def isSingleton[A](a: A)(implicit ev: A <:< Singleton = null) = - Option(ev).isDefined + def isSingleton[A](a: A): Boolean = { + val cls = a.getClass + val clsName = cls.getName + if (clsName.length > 0) { + val lastChar = clsName.charAt(clsName.length() - 1); + if (lastChar == '$') { + true + } else { + false + } + } else { + false + } + } - def assertScalaSingleton[A](a: A) = { + def assertScalaSingleton[A](a: A)(implicit ev: A <:< Singleton = null) = { if (isSingleton(a)) { val msg = "User defined function implemented by class " + a.getClass.getName + " might be implemented by a Scala Object,it is forbidden by Flink since concurrent modification risks." diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaObjectCheckerTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaObjectCheckerTest.scala new file mode 100644 index 0000000000000..a53e29fb19f26 --- /dev/null +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaObjectCheckerTest.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala + +import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.scala.ScalaObjectCheckerTest.{AScalaObject, RichMapClass, RichMapObject} +import org.junit.Test + + +class ScalaObjectCheckerTest { + + + @Test(expected = classOf[InvalidProgramException]) + def testAssertScalaForbidScalaObjectFunction(): Unit = { + ObjectChecker.assertScalaSingleton(AScalaObject) + } + + @Test + def testAssertScalaForbidScalaObjectFunction2(): Unit = { + class AScalaClass + ObjectChecker.assertScalaSingleton(new AScalaClass) + } + + @Test(expected = classOf[InvalidProgramException]) + def testEnvForObject(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val src = env.fromElements(1, 2, 3, 4) + + src.map(RichMapObject) + } + + @Test + def testEnvForClass(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val src = env.fromCollection(Seq(1, 2, 3)) + + src.map(new RichMapClass) + } +} + +object ScalaObjectCheckerTest { + + object AScalaObject + + object RichMapObject extends RichMapFunction[Int, Int] { + override def map(value: Int): Int = value * 2 + } + + class RichMapClass extends RichMapFunction[Int, Int] { + override def map(value: Int): Int = value * 2 + } + +} diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/ScalaObjectcheckerStreamTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/ScalaObjectcheckerStreamTest.scala new file mode 100644 index 0000000000000..9413355036fcf --- /dev/null +++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/ScalaObjectcheckerStreamTest.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.streaming.api.scala + +import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.common.functions.{MapFunction, RichMapFunction} +import org.apache.flink.streaming.api.scala.ScalaObjectCheckStreamTest.{RichMapClass, RichMapObject} +import org.junit.Test + +class ScalaObjectCheckStreamTest { + + @Test(expected = classOf[InvalidProgramException]) + def testStreamEnvForObject(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val src = env.fromElements(1, 2, 3, 4) + src.map(RichMapObject) + } + + + @Test + def testStreamEnvForClass(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val src = env.fromElements(1, 2, 3) + src.map(new RichMapClass) + } +} + +object ScalaObjectCheckStreamTest { + + object RichMapObject extends RichMapFunction[Int, Int] { + override def map(value: Int): Int = value * 2 + } + + class RichMapClass extends MapFunction[Int, Int] { + override def map(value: Int): Int = value * 2 + } +} From 41ed6d87e8114ee022e73902b95b7a75e3d237c8 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 20:00:26 +0800 Subject: [PATCH 06/11] add check for more dataset functions --- .../org/apache/flink/api/scala/CoGroupDataSet.scala | 3 ++- .../org/apache/flink/api/scala/CrossDataSet.scala | 1 + .../scala/org/apache/flink/api/scala/DataSet.scala | 10 +++++----- .../org/apache/flink/api/scala/GroupedDataSet.scala | 4 ++++ .../scala/org/apache/flink/api/scala/joinDataSet.scala | 5 +++-- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala index b985d85c16268..cd3d3a6f3a386 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/CoGroupDataSet.scala @@ -149,7 +149,7 @@ class CoGroupDataSet[L, R]( rightInput.javaSet, leftKeys, rightKeys, - coGrouper, + check(coGrouper), implicitly[TypeInformation[O]], buildGroupSortList(leftInput.getType, groupSortKeyPositionsFirst, groupSortOrdersFirst), buildGroupSortList(rightInput.getType, groupSortKeyPositionsSecond, groupSortOrdersSecond), @@ -164,6 +164,7 @@ class CoGroupDataSet[L, R]( // ---------------------------------------------------------------------------------------------- def withPartitioner[K : TypeInformation](partitioner : Partitioner[K]) : CoGroupDataSet[L, R] = { + check(partitioner) if (partitioner != null) { val typeInfo : TypeInformation[K] = implicitly[TypeInformation[K]] diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala index 39c65a31b14a9..cf03f33601852 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/CrossDataSet.scala @@ -83,6 +83,7 @@ class CrossDataSet[L, R]( */ def apply[O: TypeInformation: ClassTag](crosser: CrossFunction[L, R, O]): DataSet[O] = { require(crosser != null, "Cross function must not be null.") + check(crosser) val crossOperator = new CrossOperator[L, R, O]( leftInput.javaSet, rightInput.javaSet, diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 19bd40cc19abb..a4289fe534ebb 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -293,7 +293,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } wrap(new MapOperator[T, R](javaSet, implicitly[TypeInformation[R]], - mapper, + check(mapper), getCallLocationName())) } @@ -329,7 +329,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } wrap(new MapPartitionOperator[T, R](javaSet, implicitly[TypeInformation[R]], - partitionMapper, + check(partitionMapper), getCallLocationName())) } @@ -393,7 +393,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } wrap(new FlatMapOperator[T, R](javaSet, implicitly[TypeInformation[R]], - flatMapper, + check(flatMapper), getCallLocationName())) } @@ -609,7 +609,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } wrap(new GroupReduceOperator[T, R](javaSet, implicitly[TypeInformation[R]], - reducer, + check(reducer), getCallLocationName())) } @@ -675,7 +675,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } wrap(new GroupCombineOperator[T, R](javaSet, implicitly[TypeInformation[R]], - combiner, + check(combiner), getCallLocationName())) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala index f0628762f2311..8037e4afc1c3a 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala @@ -320,6 +320,7 @@ class GroupedDataSet[T: ClassTag]( * using an associative reduce function. */ def reduce(reducer: ReduceFunction[T]): DataSet[T] = { + set.check(reducer) reduce(getCallLocationName(), reducer, CombineHint.OPTIMIZER_CHOOSES) } @@ -330,6 +331,7 @@ class GroupedDataSet[T: ClassTag]( */ @PublicEvolving def reduce(reducer: ReduceFunction[T], strategy: CombineHint): DataSet[T] = { + set.check(reducer) reduce(getCallLocationName(), reducer, strategy) } @@ -337,6 +339,7 @@ class GroupedDataSet[T: ClassTag]( reducer: ReduceFunction[T], strategy: CombineHint): DataSet[T] = { require(reducer != null, "Reduce function must not be null.") + set.check(reducer) wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer, callLocationName). setCombineHint(strategy)) } @@ -386,6 +389,7 @@ class GroupedDataSet[T: ClassTag]( */ def reduceGroup[R: TypeInformation: ClassTag](reducer: GroupReduceFunction[T, R]): DataSet[R] = { require(reducer != null, "GroupReduce function must not be null.") + set.check(reducer) wrap( new GroupReduceOperator[T, R](maybeCreateSortedGrouping(), implicitly[TypeInformation[R]], reducer, getCallLocationName())) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala index c039d0f782d45..3a75cab693336 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala @@ -139,7 +139,7 @@ class JoinDataSet[L, R]( */ def apply[O: TypeInformation: ClassTag](joiner: FlatJoinFunction[L, R, O]): DataSet[O] = { require(joiner != null, "Join function must not be null.") - + check(joiner) val joinOperator = new EquiJoin[L, R, O]( leftInput.javaSet, rightInput.javaSet, @@ -167,7 +167,7 @@ class JoinDataSet[L, R]( */ def apply[O: TypeInformation: ClassTag](fun: JoinFunction[L, R, O]): DataSet[O] = { require(fun != null, "Join function must not be null.") - + check(fun) val generatedFunction: FlatJoinFunction[L, R, O] = new WrappingFlatJoinFunction[L, R, O](fun) val joinOperator = new EquiJoin[L, R, O]( @@ -193,6 +193,7 @@ class JoinDataSet[L, R]( // ---------------------------------------------------------------------------------------------- def withPartitioner[K : TypeInformation](partitioner : Partitioner[K]) : JoinDataSet[L, R] = { + check(partitioner) if (partitioner != null) { val typeInfo : TypeInformation[K] = implicitly[TypeInformation[K]] From 764aaef51295002d532d7678afd057dcb9d9cfe9 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 20:11:57 +0800 Subject: [PATCH 07/11] add check for more stream functions --- .../flink/streaming/api/scala/AllWindowedStream.scala | 2 ++ .../flink/streaming/api/scala/ConnectedStreams.scala | 10 +++++++--- .../apache/flink/streaming/api/scala/DataStream.scala | 7 +++++++ .../flink/streaming/api/scala/WindowedStream.scala | 2 +- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala index 8b381c9fdda5c..58d755fddd6fe 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/AllWindowedStream.scala @@ -211,6 +211,8 @@ class AllWindowedStream[T, W <: Window](javaStream: JavaAllWStream[T, W]) { throw new NullPointerException("Fold function must not be null.") } + check(function) + val resultType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(javaStream.fold(initialValue, function, resultType)) } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala index d4bfc4a86eb9e..f6f12fa1b48a7 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala @@ -96,7 +96,9 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { throw new NullPointerException("Map function must not be null.") } - val outType : TypeInformation[R] = implicitly[TypeInformation[R]] + check(coMapper) + + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(javaStream.map(coMapper).returns(outType).asInstanceOf[JavaStream[R]]) } @@ -125,6 +127,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { throw new NullPointerException("CoProcessFunction function must not be null.") } + check(coFlatMapper) + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(javaStream.process(coProcessFunction, outType)) @@ -153,8 +157,8 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { if (coFlatMapper == null) { throw new NullPointerException("FlatMap function must not be null.") } - - val outType : TypeInformation[R] = implicitly[TypeInformation[R]] + check(coFlatMapper) + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(javaStream.flatMap(coFlatMapper).returns(outType).asInstanceOf[JavaStream[R]]) } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala index a49d88c70bf40..1f411c8e4156d 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala @@ -508,6 +508,8 @@ class DataStream[T](stream: JavaStream[T]) { throw new NullPointerException("Map function must not be null.") } + check(mapper) + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(stream.map(mapper).returns(outType).asInstanceOf[JavaStream[R]]) } @@ -521,6 +523,8 @@ class DataStream[T](stream: JavaStream[T]) { throw new NullPointerException("FlatMap function must not be null.") } + check(flatMapper) + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(stream.flatMap(flatMapper).returns(outType).asInstanceOf[JavaStream[R]]) } @@ -562,6 +566,9 @@ class DataStream[T](stream: JavaStream[T]) { if (filter == null) { throw new NullPointerException("Filter function must not be null.") } + + check(filter) + asScalaStream(stream.filter(filter)) } diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala index 49331b29e59fe..5d7162d45dc2d 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/WindowedStream.scala @@ -210,7 +210,7 @@ class WindowedStream[T, K, W <: Window](javaStream: JavaWStream[T, K, W]) { if (function == null) { throw new NullPointerException("Fold function must not be null.") } - + check(function) val resultType : TypeInformation[R] = implicitly[TypeInformation[R]] asScalaStream(javaStream.fold(initialValue, function, resultType)) From d27a87ac033ad041ad44ea3f590337f8d1f869b2 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 20:39:57 +0800 Subject: [PATCH 08/11] fix error file=/home/travis/build/Renkai/flink/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala message=File line length exceeds 100 characters line=48 --- .../main/scala/org/apache/flink/api/scala/ObjectChecker.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala index 2b6bdd2538be6..e4494f7e2b16e 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala @@ -45,7 +45,8 @@ object ObjectChecker { def assertScalaSingleton[A](a: A)(implicit ev: A <:< Singleton = null) = { if (isSingleton(a)) { val msg = "User defined function implemented by class " + a.getClass.getName + - " might be implemented by a Scala Object,it is forbidden by Flink since concurrent modification risks." + " might be implemented by a Scala Object," + + "it is forbidden by Flink since concurrent modification risks." throw new InvalidProgramException(msg) } } From 77d39b9608d41ea8d65204a993e55a749ff7c4f0 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 14 Nov 2016 21:28:29 +0800 Subject: [PATCH 09/11] insert new line for scala style --- .../src/main/scala/org/apache/flink/ml/package.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala index 9c4001752a9a3..26e6c3096c2f4 100644 --- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/package.scala @@ -54,20 +54,23 @@ package object ml { broadcastVariable: DataSet[B])( fun: (T, B) => O) : DataSet[O] = { - dataSet.map(new BroadcastSingleElementMapper[T, B, O](dataSet.clean(dataSet.check(fun)))) + dataSet.map(new BroadcastSingleElementMapper[T, B, O]( + dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } def filterWithBcVariable[B, O](broadcastVariable: DataSet[B])(fun: (T, B) => Boolean) : DataSet[T] = { - dataSet.filter(new BroadcastSingleElementFilter[T, B](dataSet.clean(dataSet.check(fun)))) + dataSet.filter(new BroadcastSingleElementFilter[T, B]( + dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } def mapWithBcVariableIteration[B, O: TypeInformation: ClassTag]( broadcastVariable: DataSet[B])(fun: (T, B, Int) => O) : DataSet[O] = { - dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O](dataSet.clean(dataSet.check(fun)))) + dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O]( + dataSet.clean(dataSet.check(fun)))) .withBroadcastSet(broadcastVariable, "broadcastVariable") } } From 52c8acf9820900d97162a8ba445256604bccc67d Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 5 Dec 2016 18:09:46 +0800 Subject: [PATCH 10/11] modify function isSingleton --- .../apache/flink/api/scala/ObjectChecker.scala | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala index e4494f7e2b16e..835b908952873 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ObjectChecker.scala @@ -27,22 +27,11 @@ import org.apache.flink.api.common.InvalidProgramException */ @Internal object ObjectChecker { - def isSingleton[A](a: A): Boolean = { - val cls = a.getClass - val clsName = cls.getName - if (clsName.length > 0) { - val lastChar = clsName.charAt(clsName.length() - 1); - if (lastChar == '$') { - true - } else { - false - } - } else { - false - } + def isSingleton(obj: Any): Boolean = { + obj.getClass.getFields.map(_.getName) contains "MODULE$" } - def assertScalaSingleton[A](a: A)(implicit ev: A <:< Singleton = null) = { + def assertScalaSingleton(a: Any) = { if (isSingleton(a)) { val msg = "User defined function implemented by class " + a.getClass.getName + " might be implemented by a Scala Object," + From bb646784faa0f72392f47967cb6df52cf1baeca5 Mon Sep 17 00:00:00 2001 From: renkai Date: Mon, 5 Dec 2016 20:42:29 +0800 Subject: [PATCH 11/11] fix for rebase --- .../org/apache/flink/streaming/api/scala/ConnectedStreams.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala index f6f12fa1b48a7..b424258d4e98b 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala @@ -127,7 +127,7 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { throw new NullPointerException("CoProcessFunction function must not be null.") } - check(coFlatMapper) + check(coProcessFunction) val outType : TypeInformation[R] = implicitly[TypeInformation[R]]