From 7ab4f848377b43ad7c37b666124ae2af9943e6a1 Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Mon, 5 Sep 2016 12:13:29 +0200 Subject: [PATCH] [FLINK-4580] [rpc] Verify that the rpc endpoint supports the rpc gateway at connect time When calling RpcService.connect it is checked that the rpc endpoint supports the specified rpc gateway. If not, then a RpcConnectionException is thrown. The verification is implemented as an additional message following after the Identify message. The reason for this is that the ActorSystem won't wait for the Identify message to time out after it has determined that the specified actor does not exist. For user-level messages this seems to be not the case and, thus, we would have to wait for the timeout. --- .../flink/runtime/rpc/akka/AkkaRpcActor.java | 15 ++++++- .../runtime/rpc/akka/AkkaRpcService.java | 18 ++++++-- .../rpc/akka/messages/VerifyRpcGateway.java | 41 +++++++++++++++++++ .../runtime/rpc/akka/AkkaRpcActorTest.java | 26 ++++++++++++ 4 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/VerifyRpcGateway.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java index 2373be9414ed2..17decf878ca5e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java @@ -32,6 +32,8 @@ import org.apache.flink.runtime.rpc.akka.messages.RpcInvocation; import org.apache.flink.runtime.rpc.akka.messages.RunAsync; +import org.apache.flink.runtime.rpc.akka.messages.VerifyRpcGateway; +import org.apache.flink.runtime.rpc.exceptions.RpcConnectionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -103,7 +105,18 @@ public void apply(Object message) throws Exception { private void handleMessage(Object message) { mainThreadValidator.enterMainThread(); try { - if (message instanceof RunAsync) { + if (message instanceof VerifyRpcGateway) { + VerifyRpcGateway verifyRpcGateway = (VerifyRpcGateway) message; + + if (verifyRpcGateway.getClazz().isAssignableFrom(rpcEndpoint.getSelfGatewayType())) { + getSender().tell(new Status.Success(getSelf()), getSelf()); + } else { + getSender().tell(new Status.Failure( + new RpcConnectionException("The provided rpc gateway " + verifyRpcGateway.getClazz() + + " is not supported by the rpc endpoint " + rpcEndpoint + '.')), + getSelf()); + } + } else if (message instanceof RunAsync) { handleRunAsync((RunAsync) message); } else if (message instanceof CallAsync) { handleCallAsync((CallAsync) message); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java index 060a1ef499cb0..fa99799f8ef48 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcService.java @@ -27,6 +27,7 @@ import akka.actor.Props; import akka.dispatch.Mapper; import akka.pattern.AskableActorSelection; +import akka.pattern.Patterns; import akka.util.Timeout; import org.apache.flink.runtime.akka.AkkaUtils; @@ -35,6 +36,7 @@ import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.StartStoppable; +import org.apache.flink.runtime.rpc.akka.messages.VerifyRpcGateway; import org.apache.flink.runtime.rpc.exceptions.RpcConnectionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +44,7 @@ import scala.concurrent.ExecutionContext; import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; +import scala.reflect.ClassTag$; import javax.annotation.concurrent.ThreadSafe; import java.lang.reflect.InvocationHandler; @@ -98,9 +101,10 @@ public Future connect(final String address, final Clas final AskableActorSelection asker = new AskableActorSelection(actorSel); final Future identify = asker.ask(new Identify(42), timeout); - return identify.map(new Mapper(){ + + final Future rpcGatewayVerification = identify.flatMap(new Mapper>(){ @Override - public C checkedApply(Object obj) throws Exception { + public Future checkedApply(Object obj) throws Exception { ActorIdentity actorIdentity = (ActorIdentity) obj; @@ -109,6 +113,15 @@ public C checkedApply(Object obj) throws Exception { } else { ActorRef actorRef = actorIdentity.getRef(); + return Patterns.ask(actorRef, new VerifyRpcGateway(clazz), timeout).mapTo(ClassTag$.MODULE$.apply(ActorRef.class)); + } + } + }, actorSystem.dispatcher()); + + return rpcGatewayVerification.map(new Mapper() { + @Override + public C checkedApply(ActorRef actorRef) throws Exception { + final String address = AkkaUtils.getAkkaURL(actorSystem, actorRef); InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(address, actorRef, timeout, maximumFramesize); @@ -125,7 +138,6 @@ public C checkedApply(Object obj) throws Exception { akkaInvocationHandler); return proxy; - } } }, actorSystem.dispatcher()); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/VerifyRpcGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/VerifyRpcGateway.java new file mode 100644 index 0000000000000..4d72305d4e609 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/messages/VerifyRpcGateway.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.akka.messages; + +import org.apache.flink.runtime.rpc.RpcGateway; + +import java.io.Serializable; + +/** + * Message to verify that the {@link org.apache.flink.runtime.rpc.RpcEndpoint} supports the + * specified {@link RpcGateway}. + */ +public class VerifyRpcGateway implements Serializable { + private static final long serialVersionUID = 2253701464989577738L; + + private final Class clazz; + + public VerifyRpcGateway(Class clazz) { + this.clazz = clazz; + } + + public Class getClazz() { + return clazz; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java index a6ceb9104a1ef..a3edee2a9df99 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java @@ -69,6 +69,8 @@ public void testAddressResolution() throws Exception { Future futureRpcGateway = akkaRpcService.connect(rpcEndpoint.getAddress(), DummyRpcGateway.class); + rpcEndpoint.start(); + DummyRpcGateway rpcGateway = Await.result(futureRpcGateway, timeout.duration()); assertEquals(rpcEndpoint.getAddress(), rpcGateway.getAddress()); @@ -90,6 +92,28 @@ public void testFailingAddressResolution() throws Exception { } } + /** + * Tests that connect fails if one tries to associate with a rpc endpoint which does not support + * the specified rpc gateway. + */ + @Test + public void testFailingRpcGatewayVerification() throws Exception { + DummyRpcEndpoint rpcEndpoint = new DummyRpcEndpoint(akkaRpcService); + + rpcEndpoint.start(); + + Future futureRpcGateway = akkaRpcService.connect(rpcEndpoint.getAddress(), FalseRpcGateway.class); + + try { + FalseRpcGateway gateway = Await.result(futureRpcGateway, timeout.duration()); + + fail("The rpc connection should have failed, because the rpc endpoint does not " + + "support the specified rpc gateway."); + } catch (RpcConnectionException exception) { + // we're expecting this exception here + } + } + /** * Tests that the {@link AkkaRpcActor} stashes messages until the corresponding * {@link RpcEndpoint} has been started. @@ -122,6 +146,8 @@ private interface DummyRpcGateway extends RpcGateway { Future foobar(); } + private interface FalseRpcGateway extends RpcGateway {} + private static class DummyRpcEndpoint extends RpcEndpoint { private volatile int _foobar = 42;