Skip to content

Commit

Permalink
[FLINK-4392] [rpc] Make RPC Service thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanEwen committed Dec 23, 2016
1 parent 0bb7b57 commit 4b7ab28
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 25 deletions.
Expand Up @@ -19,11 +19,12 @@
package org.apache.flink.runtime.rpc.akka;

import akka.actor.ActorRef;
import org.apache.flink.runtime.rpc.RpcGateway;

/**
* Interface for Akka based rpc gateways
*/
interface AkkaGateway {
interface AkkaGateway extends RpcGateway {

ActorRef getRpcServer();
}
Expand Up @@ -28,47 +28,61 @@
import akka.dispatch.Mapper;
import akka.pattern.AskableActorSelection;
import akka.util.Timeout;

import org.apache.flink.runtime.akka.AkkaUtils;
import org.apache.flink.runtime.rpc.MainThreadExecutor;
import org.apache.flink.runtime.rpc.RpcGateway;
import org.apache.flink.runtime.rpc.RpcEndpoint;
import org.apache.flink.runtime.rpc.RpcService;
import org.apache.flink.util.Preconditions;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import scala.concurrent.Future;

import javax.annotation.concurrent.ThreadSafe;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/**
* Akka based {@link RpcService} implementation. The rpc service starts an Akka actor to receive
* rpcs from a {@link RpcGateway}.
* Akka based {@link RpcService} implementation. The RPC service starts an Akka actor to receive
* RPC invocations from a {@link RpcGateway}.
*/
@ThreadSafe
public class AkkaRpcService implements RpcService {

private static final Logger LOG = LoggerFactory.getLogger(AkkaRpcService.class);

private final Object lock = new Object();

private final ActorSystem actorSystem;
private final Timeout timeout;
private final Collection<ActorRef> actors = new HashSet<>(4);
private final Set<ActorRef> actors = new HashSet<>(4);

private volatile boolean stopped;

public AkkaRpcService(final ActorSystem actorSystem, final Timeout timeout) {
this.actorSystem = Preconditions.checkNotNull(actorSystem, "actor system");
this.timeout = Preconditions.checkNotNull(timeout, "timeout");
this.actorSystem = checkNotNull(actorSystem, "actor system");
this.timeout = checkNotNull(timeout, "timeout");
}

// this method does not mutate state and is thus thread-safe
@Override
public <C extends RpcGateway> Future<C> connect(final String address, final Class<C> clazz) {
LOG.info("Try to connect to remote rpc server with address {}. Returning a {} gateway.", address, clazz.getName());
checkState(!stopped, "RpcService is stopped");

final ActorSelection actorSel = actorSystem.actorSelection(address);
LOG.debug("Try to connect to remote RPC endpoint with address {}. Returning a {} gateway.",
address, clazz.getName());

final ActorSelection actorSel = actorSystem.actorSelection(address);
final AskableActorSelection asker = new AskableActorSelection(actorSel);

final Future<Object> identify = asker.ask(new Identify(42), timeout);

return identify.map(new Mapper<Object, C>(){
@Override
public C apply(Object obj) {
Expand All @@ -89,56 +103,86 @@ public C apply(Object obj) {

@Override
public <C extends RpcGateway, S extends RpcEndpoint<C>> C startServer(S rpcEndpoint) {
Preconditions.checkNotNull(rpcEndpoint, "rpc endpoint");

LOG.info("Start Akka rpc actor to handle rpcs for {}.", rpcEndpoint.getClass().getName());
checkNotNull(rpcEndpoint, "rpc endpoint");

Props akkaRpcActorProps = Props.create(AkkaRpcActor.class, rpcEndpoint);
ActorRef actorRef;

synchronized (lock) {
checkState(!stopped, "RpcService is stopped");
actorRef = actorSystem.actorOf(akkaRpcActorProps);
actors.add(actorRef);
}

ActorRef actorRef = actorSystem.actorOf(akkaRpcActorProps);
actors.add(actorRef);
LOG.info("Starting RPC endpoint for {} at {} .", rpcEndpoint.getClass().getName(), actorRef.path());

InvocationHandler akkaInvocationHandler = new AkkaInvocationHandler(actorRef, timeout);

// Rather than using the System ClassLoader directly, we derive the ClassLoader
// from this class . That works better in cases where Flink runs embedded and all Flink
// code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader
ClassLoader classLoader = getClass().getClassLoader();

@SuppressWarnings("unchecked")
C self = (C) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
classLoader,
new Class<?>[]{rpcEndpoint.getSelfGatewayType(), MainThreadExecutor.class, AkkaGateway.class},
akkaInvocationHandler);

return self;
}

@Override
public <C extends RpcGateway> void stopServer(C selfGateway) {
public void stopServer(RpcGateway selfGateway) {
if (selfGateway instanceof AkkaGateway) {
AkkaGateway akkaClient = (AkkaGateway) selfGateway;

if (actors.contains(akkaClient.getRpcServer())) {
ActorRef selfActorRef = akkaClient.getRpcServer();

LOG.info("Stop Akka rpc actor {}.", selfActorRef.path());
boolean fromThisService;
synchronized (lock) {
if (stopped) {
return;
} else {
fromThisService = actors.remove(akkaClient.getRpcServer());
}
}

if (fromThisService) {
ActorRef selfActorRef = akkaClient.getRpcServer();
LOG.info("Stopping RPC endpoint {}.", selfActorRef.path());
selfActorRef.tell(PoisonPill.getInstance(), ActorRef.noSender());
} else {
LOG.debug("RPC endpoint {} already stopped or from different RPC service");
}
}
}

@Override
public void stopService() {
LOG.info("Stop Akka rpc service.");
actorSystem.shutdown();
LOG.info("Stopping Akka RPC service.");

synchronized (lock) {
if (stopped) {
return;
}

stopped = true;
actorSystem.shutdown();
actors.clear();
}

actorSystem.awaitTermination();
}

@Override
public <C extends RpcGateway> String getAddress(C selfGateway) {
checkState(!stopped, "RpcService is stopped");

if (selfGateway instanceof AkkaGateway) {
ActorRef actorRef = ((AkkaGateway) selfGateway).getRpcServer();
return AkkaUtils.getAkkaURL(actorSystem, actorRef);
} else {
String className = AkkaGateway.class.getName();
throw new RuntimeException("Cannot get address for non " + className + '.');
throw new IllegalArgumentException("Cannot get address for non " + className + '.');
}
}
}

0 comments on commit 4b7ab28

Please sign in to comment.