From 34f1e16249e3f16b0720d63b87fe6f5210dfffae Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Fri, 12 Aug 2016 10:32:30 +0200 Subject: [PATCH] [FLINK-4383] [rpc] Eagerly serialize remote rpc invocation messages This PR introduces an eager serialization for remote rpc invocation messages. That way it is possible to check whether the message is serializable and whether it exceeds the maximum allowed akka frame size. If either of these constraints is violated, a proper exception is thrown instead of simply swallowing the exception as Akka does it. Address PR comments --- .../flink/runtime/rpc/akka/AkkaGateway.java | 2 +- .../rpc/akka/AkkaInvocationHandler.java | 83 +++++-- .../flink/runtime/rpc/akka/AkkaRpcActor.java | 26 ++- .../runtime/rpc/akka/AkkaRpcService.java | 20 +- .../rpc/akka/messages/LocalRpcInvocation.java | 54 +++++ .../akka/messages/RemoteRpcInvocation.java | 206 +++++++++++++++++ .../rpc/akka/messages/RpcInvocation.java | 106 +++------ .../runtime/rpc/akka/AkkaRpcServiceTest.java | 2 +- .../rpc/akka/MessageSerializationTest.java | 210 ++++++++++++++++++ 9 files changed, 597 insertions(+), 112 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java index ec3091c839c05..f6125dc0a99da 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaGateway.java @@ -26,5 +26,5 @@ */ interface AkkaGateway extends RpcGateway { - ActorRef getRpcServer(); + ActorRef getRpcEndpoint(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java index 580b161bd041b..297104b4beac3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java @@ -25,13 +25,17 @@ import org.apache.flink.runtime.rpc.MainThreadExecutor; import org.apache.flink.runtime.rpc.RpcTimeout; import org.apache.flink.runtime.rpc.akka.messages.CallAsync; +import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation; +import org.apache.flink.runtime.rpc.akka.messages.RemoteRpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.RunAsync; import org.apache.flink.util.Preconditions; +import org.apache.log4j.Logger; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; +import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; @@ -42,19 +46,28 @@ import static org.apache.flink.util.Preconditions.checkArgument; /** - * Invocation handler to be used with a {@link AkkaRpcActor}. The invocation handler wraps the - * rpc in a {@link RpcInvocation} message and then sends it to the {@link AkkaRpcActor} where it is + * Invocation handler to be used with an {@link AkkaRpcActor}. The invocation handler wraps the + * rpc in a {@link LocalRpcInvocation} message and then sends it to the {@link AkkaRpcActor} where it is * executed. */ class AkkaInvocationHandler implements InvocationHandler, AkkaGateway, MainThreadExecutor { - private final ActorRef rpcServer; + private static final Logger LOG = Logger.getLogger(AkkaInvocationHandler.class); + + private final ActorRef rpcEndpoint; + + // whether the actor ref is local and thus no message serialization is needed + private final boolean isLocal; // default timeout for asks private final Timeout timeout; - AkkaInvocationHandler(ActorRef rpcServer, Timeout timeout) { - this.rpcServer = Preconditions.checkNotNull(rpcServer); + private final long maximumFramesize; + + AkkaInvocationHandler(ActorRef rpcEndpoint, Timeout timeout, long maximumFramesize) { + this.rpcEndpoint = Preconditions.checkNotNull(rpcEndpoint); + this.isLocal = this.rpcEndpoint.path().address().hasLocalScope(); this.timeout = Preconditions.checkNotNull(timeout); + this.maximumFramesize = maximumFramesize; } @Override @@ -76,23 +89,43 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl parameterAnnotations, args); - RpcInvocation rpcInvocation = new RpcInvocation( - methodName, - filteredArguments.f0, - filteredArguments.f1); + RpcInvocation rpcInvocation; + + if (isLocal) { + rpcInvocation = new LocalRpcInvocation( + methodName, + filteredArguments.f0, + filteredArguments.f1); + } else { + try { + RemoteRpcInvocation remoteRpcInvocation = new RemoteRpcInvocation( + methodName, + filteredArguments.f0, + filteredArguments.f1); + + if (remoteRpcInvocation.getSize() > maximumFramesize) { + throw new IOException("The rpc invocation size exceeds the maximum akka framesize."); + } else { + rpcInvocation = remoteRpcInvocation; + } + } catch (IOException e) { + LOG.warn("Could not create remote rpc invocation message. Failing rpc invocation because...", e); + throw e; + } + } Class returnType = method.getReturnType(); if (returnType.equals(Void.TYPE)) { - rpcServer.tell(rpcInvocation, ActorRef.noSender()); + rpcEndpoint.tell(rpcInvocation, ActorRef.noSender()); result = null; } else if (returnType.equals(Future.class)) { // execute an asynchronous call - result = Patterns.ask(rpcServer, rpcInvocation, futureTimeout); + result = Patterns.ask(rpcEndpoint, rpcInvocation, futureTimeout); } else { // execute a synchronous call - Future futureResult = Patterns.ask(rpcServer, rpcInvocation, futureTimeout); + Future futureResult = Patterns.ask(rpcEndpoint, rpcInvocation, futureTimeout); FiniteDuration duration = timeout.duration(); result = Await.result(futureResult, duration); @@ -103,8 +136,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl } @Override - public ActorRef getRpcServer() { - return rpcServer; + public ActorRef getRpcEndpoint() { + return rpcEndpoint; } @Override @@ -117,19 +150,25 @@ public void scheduleRunAsync(Runnable runnable, long delay) { checkNotNull(runnable, "runnable"); checkArgument(delay >= 0, "delay must be zero or greater"); - // Unfortunately I couldn't find a way to allow only local communication. Therefore, the - // runnable field is transient transient - rpcServer.tell(new RunAsync(runnable, delay), ActorRef.noSender()); + if (isLocal) { + rpcEndpoint.tell(new RunAsync(runnable, delay), ActorRef.noSender()); + } else { + throw new RuntimeException("Trying to send a Runnable to a remote actor at " + + rpcEndpoint.path() + ". This is not supported."); + } } @Override public Future callAsync(Callable callable, Timeout callTimeout) { - // Unfortunately I couldn't find a way to allow only local communication. Therefore, the - // callable field is declared transient - @SuppressWarnings("unchecked") - Future result = (Future) Patterns.ask(rpcServer, new CallAsync(callable), callTimeout); + if(isLocal) { + @SuppressWarnings("unchecked") + Future result = (Future) Patterns.ask(rpcEndpoint, new CallAsync(callable), callTimeout); - return result; + return result; + } else { + throw new RuntimeException("Trying to send a Callable to a remote actor at " + + rpcEndpoint.path() + ". This is not supported."); + } } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java index 5e0a7da000931..dfcbcc3924374 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcGateway; import org.apache.flink.runtime.rpc.akka.messages.CallAsync; +import org.apache.flink.runtime.rpc.akka.messages.LocalRpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.RunAsync; @@ -35,6 +36,7 @@ import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; +import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; @@ -42,10 +44,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull; /** - * Akka rpc actor which receives {@link RpcInvocation}, {@link RunAsync} and {@link CallAsync} + * Akka rpc actor which receives {@link LocalRpcInvocation}, {@link RunAsync} and {@link CallAsync} * messages. *

- * The {@link RpcInvocation} designates a rpc and is dispatched to the given {@link RpcEndpoint} + * The {@link LocalRpcInvocation} designates a rpc and is dispatched to the given {@link RpcEndpoint} * instance. *

* The {@link RunAsync} and {@link CallAsync} messages contain executable code which is executed @@ -95,15 +97,12 @@ public void onReceive(final Object message) { * @param rpcInvocation Rpc invocation message */ private void handleRpcInvocation(RpcInvocation rpcInvocation) { - Method rpcMethod = null; - try { - rpcMethod = lookupRpcMethod(rpcInvocation.getMethodName(), rpcInvocation.getParameterTypes()); - } catch (final NoSuchMethodException e) { - LOG.error("Could not find rpc method for rpc invocation: {}.", rpcInvocation, e); - } + String methodName = rpcInvocation.getMethodName(); + Class[] parameterTypes = rpcInvocation.getParameterTypes(); + + Method rpcMethod = lookupRpcMethod(methodName, parameterTypes); - if (rpcMethod != null) { if (rpcMethod.getReturnType().equals(Void.TYPE)) { // No return value to send back try { @@ -127,6 +126,12 @@ private void handleRpcInvocation(RpcInvocation rpcInvocation) { getSender().tell(new Status.Failure(e), getSelf()); } } + } catch(ClassNotFoundException e) { + LOG.error("Could not load method arguments.", e); + } catch (IOException e) { + LOG.error("Could not deserialize rpc invocation message.", e); + } catch (final NoSuchMethodException e) { + LOG.error("Could not find rpc method for rpc invocation: {}.", rpcInvocation, e); } } @@ -195,7 +200,8 @@ else if (runAsync.getDelay() == 0) { * @param methodName Name of the method * @param parameterTypes Parameter types of the method * @return Method of the rpc endpoint - * @throws NoSuchMethodException + * @throws NoSuchMethodException Thrown if the method with the given name and parameter types + * cannot be found at the rpc endpoint */ private Method lookupRpcMethod(final String methodName, final Class[] parameterTypes) throws NoSuchMethodException { return rpcEndpoint.getClass().getMethod(methodName, parameterTypes); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java index db40f10e10f39..b963c53a93a87 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java @@ -58,17 +58,27 @@ public class AkkaRpcService implements RpcService { private static final Logger LOG = LoggerFactory.getLogger(AkkaRpcService.class); + static final String MAXIMUM_FRAME_SIZE_PATH = "akka.remote.netty.tcp.maximum-frame-size"; + private final Object lock = new Object(); private final ActorSystem actorSystem; private final Timeout timeout; private final Set actors = new HashSet<>(4); + private final long maximumFramesize; private volatile boolean stopped; public AkkaRpcService(final ActorSystem actorSystem, final Timeout timeout) { this.actorSystem = checkNotNull(actorSystem, "actor system"); this.timeout = checkNotNull(timeout, "timeout"); + + if (actorSystem.settings().config().hasPath(MAXIMUM_FRAME_SIZE_PATH)) { + maximumFramesize = actorSystem.settings().config().getBytes(MAXIMUM_FRAME_SIZE_PATH); + } else { + // only local communication + maximumFramesize = Long.MAX_VALUE; + } } // this method does not mutate state and is thus thread-safe @@ -88,7 +98,7 @@ public Future connect(final String address, final Clas public C apply(Object obj) { ActorRef actorRef = ((ActorIdentity) obj).getRef(); - InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout); + InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout, maximumFramesize); @SuppressWarnings("unchecked") C proxy = (C) Proxy.newProxyInstance( @@ -116,7 +126,7 @@ public > C startServer(S rpcEndpo LOG.info("Starting RPC endpoint for {} at {} .", rpcEndpoint.getClass().getName(), actorRef.path()); - InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout); + InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout, maximumFramesize); // Rather than using the System ClassLoader directly, we derive the ClassLoader // from this class . That works better in cases where Flink runs embedded and all Flink @@ -142,12 +152,12 @@ public void stopServer(RpcGateway selfGateway) { if (stopped) { return; } else { - fromThisService = actors.remove(akkaClient.getRpcServer()); + fromThisService = actors.remove(akkaClient.getRpcEndpoint()); } } if (fromThisService) { - ActorRef selfActorRef = akkaClient.getRpcServer(); + ActorRef selfActorRef = akkaClient.getRpcEndpoint(); LOG.info("Stopping RPC endpoint {}.", selfActorRef.path()); selfActorRef.tell(PoisonPill.getInstance(), ActorRef.noSender()); } else { @@ -178,7 +188,7 @@ public String getAddress(RpcGateway selfGateway) { checkState(!stopped, "RpcService is stopped"); if (selfGateway instanceof AkkaGateway) { - ActorRef actorRef = ((AkkaGateway) selfGateway).getRpcServer(); + ActorRef actorRef = ((AkkaGateway) selfGateway).getRpcEndpoint(); return AkkaUtils.getAkkaURL(actorSystem, actorRef); } else { String className = AkkaGateway.class.getName(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java new file mode 100644 index 0000000000000..97c10d71bf141 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/LocalRpcInvocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka.messages; + +import org.apache.flink.util.Preconditions; + +/** + * Local rpc invocation message containing the remote procedure name, its parameter types and the + * corresponding call arguments. This message will only be sent if the communication is local and, + * thus, the message does not have to be serialized. + */ +public final class LocalRpcInvocation implements RpcInvocation { + + private final String methodName; + private final Class[] parameterTypes; + private final Object[] args; + + public LocalRpcInvocation(String methodName, Class[] parameterTypes, Object[] args) { + this.methodName = Preconditions.checkNotNull(methodName); + this.parameterTypes = Preconditions.checkNotNull(parameterTypes); + this.args = args; + } + + @Override + public String getMethodName() { + return methodName; + } + + @Override + public Class[] getParameterTypes() { + return parameterTypes; + } + + @Override + public Object[] getArgs() { + return args; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java new file mode 100644 index 0000000000000..bc26a29715c91 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RemoteRpcInvocation.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka.messages; + +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.SerializedValue; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * Remote rpc invocation message which is used when the actor communication is remote and, thus, the + * message has to be serialized. + *

+ * In order to fail fast and report an appropriate error message to the user, the method name, the + * parameter types and the arguments are eagerly serialized. In case the the invocation call + * contains a non-serializable object, then an {@link IOException} is thrown. + */ +public class RemoteRpcInvocation implements RpcInvocation, Serializable { + private static final long serialVersionUID = 6179354390913843809L; + + // Serialized invocation data + private SerializedValue serializedMethodInvocation; + + // Transient field which is lazily initialized upon first access to the invocation data + private transient RemoteRpcInvocation.MethodInvocation methodInvocation; + + public RemoteRpcInvocation( + final String methodName, + final Class[] parameterTypes, + final Object[] args) throws IOException { + + serializedMethodInvocation = new SerializedValue<>(new RemoteRpcInvocation.MethodInvocation(methodName, parameterTypes, args)); + methodInvocation = null; + } + + @Override + public String getMethodName() throws IOException, ClassNotFoundException { + deserializeMethodInvocation(); + + return methodInvocation.getMethodName(); + } + + @Override + public Class[] getParameterTypes() throws IOException, ClassNotFoundException { + deserializeMethodInvocation(); + + return methodInvocation.getParameterTypes(); + } + + @Override + public Object[] getArgs() throws IOException, ClassNotFoundException { + deserializeMethodInvocation(); + + return methodInvocation.getArgs(); + } + + /** + * Size (#bytes of the serialized data) of the rpc invocation message. + * + * @return Size of the remote rpc invocation message + */ + public long getSize() { + return serializedMethodInvocation.getByteArray().length; + } + + private void deserializeMethodInvocation() throws IOException, ClassNotFoundException { + if (methodInvocation == null) { + methodInvocation = serializedMethodInvocation.deserializeValue(ClassLoader.getSystemClassLoader()); + } + } + + // ------------------------------------------------------------------- + // Serialization methods + // ------------------------------------------------------------------- + + private void writeObject(ObjectOutputStream oos) throws IOException { + oos.writeObject(serializedMethodInvocation); + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + serializedMethodInvocation = (SerializedValue) ois.readObject(); + methodInvocation = null; + } + + // ------------------------------------------------------------------- + // Utility classes + // ------------------------------------------------------------------- + + /** + * Wrapper class for the method invocation information + */ + private static final class MethodInvocation implements Serializable { + private static final long serialVersionUID = 9187962608946082519L; + + private String methodName; + private Class[] parameterTypes; + private Object[] args; + + private MethodInvocation(final String methodName, final Class[] parameterTypes, final Object[] args) { + this.methodName = methodName; + this.parameterTypes = Preconditions.checkNotNull(parameterTypes); + this.args = args; + } + + String getMethodName() { + return methodName; + } + + Class[] getParameterTypes() { + return parameterTypes; + } + + Object[] getArgs() { + return args; + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + oos.writeUTF(methodName); + + oos.writeInt(parameterTypes.length); + + for (Class parameterType : parameterTypes) { + oos.writeObject(parameterType); + } + + if (args != null) { + oos.writeBoolean(true); + + for (int i = 0; i < args.length; i++) { + try { + oos.writeObject(args[i]); + } catch (IOException e) { + throw new IOException("Could not serialize " + i + "th argument of method " + + methodName + ". This indicates that the argument type " + + args.getClass().getName() + " is not serializable. Arguments have to " + + "be serializable for remote rpc calls.", e); + } + } + } else { + oos.writeBoolean(false); + } + } + + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + methodName = ois.readUTF(); + + int length = ois.readInt(); + + parameterTypes = new Class[length]; + + for (int i = 0; i < length; i++) { + try { + parameterTypes[i] = (Class) ois.readObject(); + } catch (IOException e) { + throw new IOException("Could not deserialize " + i + "th parameter type of method " + + methodName + '.', e); + } catch (ClassNotFoundException e) { + throw new ClassNotFoundException("Could not deserialize " + i + "th " + + "parameter type of method " + methodName + ". This indicates that the parameter " + + "type is not part of the system class loader.", e); + } + } + + boolean hasArgs = ois.readBoolean(); + + if (hasArgs) { + args = new Object[length]; + + for (int i = 0; i < length; i++) { + try { + args[i] = ois.readObject(); + } catch (IOException e) { + throw new IOException("Could not deserialize " + i + "th argument of method " + + methodName + '.', e); + } catch (ClassNotFoundException e) { + throw new ClassNotFoundException("Could not deserialize " + i + "th " + + "argument of method " + methodName + ". This indicates that the argument " + + "type is not part of the system class loader.", e); + } + } + } else { + args = null; + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java index 5d52ef1c0b298..b174c99a4d37c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/RpcInvocation.java @@ -18,81 +18,41 @@ package org.apache.flink.runtime.rpc.akka.messages; -import org.apache.flink.util.Preconditions; - import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; /** - * Rpc invocation message containing the remote procedure name, its parameter types and the - * corresponding call arguments. + * Interface for rpc invocation messages. The interface allows to request all necessary information + * to lookup a method and call it with the corresponding arguments. */ -public final class RpcInvocation implements Serializable { - private static final long serialVersionUID = -7058254033460536037L; - - private final String methodName; - private final Class[] parameterTypes; - private transient Object[] args; - - public RpcInvocation(String methodName, Class[] parameterTypes, Object[] args) { - this.methodName = Preconditions.checkNotNull(methodName); - this.parameterTypes = Preconditions.checkNotNull(parameterTypes); - this.args = args; - } - - public String getMethodName() { - return methodName; - } - - public Class[] getParameterTypes() { - return parameterTypes; - } - - public Object[] getArgs() { - return args; - } - - private void writeObject(ObjectOutputStream oos) throws IOException { - oos.defaultWriteObject(); - - if (args != null) { - // write has args true - oos.writeBoolean(true); - - for (int i = 0; i < args.length; i++) { - try { - oos.writeObject(args[i]); - } catch (IOException e) { - Class argClass = args[i].getClass(); - - throw new IOException("Could not write " + i + "th argument of method " + - methodName + ". The argument type is " + argClass + ". " + - "Make sure that this type is serializable.", e); - } - } - } else { - // write has args false - oos.writeBoolean(false); - } - } - - private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { - ois.defaultReadObject(); - - boolean hasArgs = ois.readBoolean(); - - if (hasArgs) { - int numberArguments = parameterTypes.length; - - args = new Object[numberArguments]; - - for (int i = 0; i < numberArguments; i++) { - args[i] = ois.readObject(); - } - } else { - args = null; - } - } +public interface RpcInvocation { + + /** + * Returns the method's name. + * + * @return Method name + * @throws IOException if the rpc invocation message is a remote message and could not be deserialized + * @throws ClassNotFoundException if the rpc invocation message is a remote message and contains + * serialized classes which cannot be found on the receiving side + */ + String getMethodName() throws IOException, ClassNotFoundException; + + /** + * Returns the method's parameter types + * + * @return Method's parameter types + * @throws IOException if the rpc invocation message is a remote message and could not be deserialized + * @throws ClassNotFoundException if the rpc invocation message is a remote message and contains + * serialized classes which cannot be found on the receiving side + */ + Class[] getParameterTypes() throws IOException, ClassNotFoundException; + + /** + * Returns the arguments of the remote procedure call + * + * @return Arguments of the remote procedure call + * @throws IOException if the rpc invocation message is a remote message and could not be deserialized + * @throws ClassNotFoundException if the rpc invocation message is a remote message and contains + * serialized classes which cannot be found on the receiving side + */ + Object[] getArgs() throws IOException, ClassNotFoundException; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java index 5e37e10ff0dc5..f26b40b8a997e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcServiceTest.java @@ -64,7 +64,7 @@ public void testJobMasterResourceManagerRegistration() throws Exception { AkkaGateway akkaClient = (AkkaGateway) rm; - jobMaster.registerAtResourceManager(AkkaUtils.getAkkaURL(actorSystem, akkaClient.getRpcServer())); + jobMaster.registerAtResourceManager(AkkaUtils.getAkkaURL(actorSystem, akkaClient.getRpcEndpoint())); // wait for successful registration FiniteDuration timeout = new FiniteDuration(200, TimeUnit.SECONDS); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java new file mode 100644 index 0000000000000..ca8179c63f5e1 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/MessageSerializationTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka; + +import akka.actor.ActorSystem; +import akka.util.Timeout; +import com.typesafe.config.Config; +import com.typesafe.config.ConfigValueFactory; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.rpc.RpcEndpoint; +import org.apache.flink.runtime.rpc.RpcGateway; +import org.apache.flink.runtime.rpc.RpcMethod; +import org.apache.flink.runtime.rpc.RpcService; +import org.apache.flink.util.TestLogger; +import org.hamcrest.core.Is; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; + +import java.io.IOException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +/** + * Tests that akka rpc invocation messages are properly serialized and errors reported + */ +public class MessageSerializationTest extends TestLogger { + private static ActorSystem actorSystem1; + private static ActorSystem actorSystem2; + private static AkkaRpcService akkaRpcService1; + private static AkkaRpcService akkaRpcService2; + + private static final FiniteDuration timeout = new FiniteDuration(10L, TimeUnit.SECONDS); + private static final int maxFrameSize = 32000; + + @BeforeClass + public static void setup() { + Config akkaConfig = AkkaUtils.getDefaultAkkaConfig(); + Config modifiedAkkaConfig = akkaConfig.withValue(AkkaRpcService.MAXIMUM_FRAME_SIZE_PATH, ConfigValueFactory.fromAnyRef(maxFrameSize + "b")); + + actorSystem1 = AkkaUtils.createActorSystem(modifiedAkkaConfig); + actorSystem2 = AkkaUtils.createActorSystem(modifiedAkkaConfig); + + akkaRpcService1 = new AkkaRpcService(actorSystem1, new Timeout(timeout)); + akkaRpcService2 = new AkkaRpcService(actorSystem2, new Timeout(timeout)); + } + + @AfterClass + public static void teardown() { + akkaRpcService1.stopService(); + akkaRpcService2.stopService(); + + actorSystem1.shutdown(); + actorSystem2.shutdown(); + + actorSystem1.awaitTermination(); + actorSystem2.awaitTermination(); + } + + /** + * Tests that a local rpc call with a non serializable argument can be executed. + */ + @Test + public void testNonSerializableLocalMessageTransfer() throws InterruptedException, IOException { + LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue<>(); + TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, linkedBlockingQueue); + + TestGateway testGateway = testEndpoint.getSelf(); + + NonSerializableObject expected = new NonSerializableObject(42); + + testGateway.foobar(expected); + + assertThat(linkedBlockingQueue.take(), Is.is(expected)); + } + + /** + * Tests that a remote rpc call with a non-serializable argument fails with an + * {@link IOException} (or an {@link java.lang.reflect.UndeclaredThrowableException} if the + * the method declaration does not include the {@link IOException} as throwable). + */ + @Test(expected = IOException.class) + public void testNonSerializableRemoteMessageTransfer() throws Exception { + LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue<>(); + + TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, linkedBlockingQueue); + + String address = testEndpoint.getAddress(); + + Future remoteGatewayFuture = akkaRpcService2.connect(address, TestGateway.class); + + TestGateway remoteGateway = Await.result(remoteGatewayFuture, timeout); + + remoteGateway.foobar(new Object()); + + fail("Should have failed because Object is not serializable."); + } + + /** + * Tests that a remote rpc call with a serializable argument can be successfully executed. + */ + @Test + public void testSerializableRemoteMessageTransfer() throws Exception { + LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue<>(); + + TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, linkedBlockingQueue); + + String address = testEndpoint.getAddress(); + + Future remoteGatewayFuture = akkaRpcService2.connect(address, TestGateway.class); + + TestGateway remoteGateway = Await.result(remoteGatewayFuture, timeout); + + int expected = 42; + + remoteGateway.foobar(expected); + + assertThat(linkedBlockingQueue.take(), Is.is(expected)); + } + + /** + * Tests that a message which exceeds the maximum frame size is detected and a corresponding + * exception is thrown. + */ + @Test(expected = IOException.class) + public void testMaximumFramesizeRemoteMessageTransfer() throws Exception { + LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue<>(); + + TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService1, linkedBlockingQueue); + + String address = testEndpoint.getAddress(); + + Future remoteGatewayFuture = akkaRpcService2.connect(address, TestGateway.class); + + TestGateway remoteGateway = Await.result(remoteGatewayFuture, timeout); + + int bufferSize = maxFrameSize + 1; + byte[] buffer = new byte[bufferSize]; + + remoteGateway.foobar(buffer); + + fail("Should have failed due to exceeding the maximum framesize."); + } + + private interface TestGateway extends RpcGateway { + void foobar(Object object) throws IOException, InterruptedException; + } + + private static class TestEndpoint extends RpcEndpoint { + + private final LinkedBlockingQueue queue; + + protected TestEndpoint(RpcService rpcService, LinkedBlockingQueue queue) { + super(rpcService); + this.queue = queue; + } + + @RpcMethod + public void foobar(Object object) throws InterruptedException { + queue.put(object); + } + } + + private static class NonSerializableObject { + private final Object object = new Object(); + private final int value; + + NonSerializableObject(int value) { + this.value = value; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof NonSerializableObject) { + NonSerializableObject nonSerializableObject = (NonSerializableObject) obj; + + return value == nonSerializableObject.value; + } else { + return false; + } + } + + @Override + public int hashCode() { + return value * 41; + } + } +}