Skip to content

Commit

Permalink
[SPARK-17714][CORE][TEST-MAVEN][TEST-HADOOP2.6] Avoid using ExecutorC…
Browse files Browse the repository at this point in the history
…lassLoader to load Netty generated classes

## What changes were proposed in this pull request?

Netty's `MessageToMessageEncoder` uses [Javassist](https://github.com/netty/netty/blob/91a0bdc17a8298437d6de08a8958d753799bd4a6/common/src/main/java/io/netty/util/internal/JavassistTypeParameterMatcherGenerator.java#L62) to generate a matcher class and the implementation calls `Class.forName` to check if this class is already generated. If `MessageEncoder` or `MessageDecoder` is created in `ExecutorClassLoader.findClass`, it will cause `ClassCircularityError`. This is because loading this Netty generated class will call `ExecutorClassLoader.findClass` to search this class, and `ExecutorClassLoader` will try to use RPC to load it and cause to load the non-exist matcher class again. JVM will report `ClassCircularityError` to prevent such infinite recursion.

##### Why it only happens in Maven builds

It's because Maven and SBT have different class loader tree. The Maven build will set a URLClassLoader as the current context class loader to run the tests and expose this issue. The class loader tree is as following:

```
bootstrap class loader ------ ... ----- REPL class loader ---- ExecutorClassLoader
|
|
URLClasssLoader
```

The SBT build uses the bootstrap class loader directly and `ReplSuite.test("propagation of local properties")` is the first test in ReplSuite, which happens to load `io/netty/util/internal/__matchers__/org/apache/spark/network/protocol/MessageMatcher` into the bootstrap class loader (Note: in maven build, it's loaded into URLClasssLoader so it cannot be found in ExecutorClassLoader). This issue can be reproduced in SBT as well. Here are the produce steps:
- Enable `hadoop.caller.context.enabled`.
- Replace `Class.forName` with `Utils.classForName` in `object CallerContext`.
- Ignore `ReplSuite.test("propagation of local properties")`.
- Run `ReplSuite` using SBT.

This PR just creates a singleton MessageEncoder and MessageDecoder and makes sure they are created before switching to ExecutorClassLoader. TransportContext will be created when creating RpcEnv and that happens before creating ExecutorClassLoader.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16859 from zsxwing/SPARK-17714.
  • Loading branch information
zsxwing committed Feb 13, 2017
1 parent 3dbff9b commit 905fdf0
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 27 deletions.
Expand Up @@ -62,8 +62,20 @@ public class TransportContext {
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;

private final MessageEncoder encoder;
private final MessageDecoder decoder;
/**
* Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
* before switching the current context class loader to ExecutorClassLoader.
*
* Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
* implementation calls "Class.forName" to check if this calls is already generated. If the
* following two objects are created in "ExecutorClassLoader.findClass", it will cause
* "ClassCircularityError". This is because loading this Netty generated class will call
* "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
*/
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;

public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this(conf, rpcHandler, false);
Expand All @@ -75,8 +87,6 @@ public TransportContext(
boolean closeIdleConnections) {
this.conf = conf;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
this.closeIdleConnections = closeIdleConnections;
}

Expand Down Expand Up @@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline(
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
Expand Down
Expand Up @@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {

private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);

public static final MessageDecoder INSTANCE = new MessageDecoder();

private MessageDecoder() {}

@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
Message.Type msgType = Message.Type.decode(in);
Expand Down
Expand Up @@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {

private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);

public static final MessageEncoder INSTANCE = new MessageEncoder();

private MessageEncoder() {}

/***
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
Expand Down
Expand Up @@ -18,15 +18,14 @@
package org.apache.spark.network.server;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ResponseMessage;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
Expand All @@ -48,7 +47,7 @@
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
* timeout if the client is continuously sending but getting no responses, for simplicity.
*/
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);

private final TransportClient client;
Expand Down Expand Up @@ -114,11 +113,13 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

@Override
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
} else {
} else if (request instanceof ResponseMessage) {
responseHandler.handle((ResponseMessage) request);
} else {
ctx.fireChannelRead(request);
}
}

Expand Down
Expand Up @@ -49,11 +49,11 @@
public class ProtocolSuite {
private void testServerToClient(Message msg) {
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
MessageEncoder.INSTANCE);
serverChannel.writeOutbound(msg);

EmbeddedChannel clientChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder());
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);

while (!serverChannel.outboundMessages().isEmpty()) {
clientChannel.writeInbound(serverChannel.readOutbound());
Expand All @@ -65,11 +65,11 @@ private void testServerToClient(Message msg) {

private void testClientToServer(Message msg) {
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
MessageEncoder.INSTANCE);
clientChannel.writeOutbound(msg);

EmbeddedChannel serverChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder());
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);

while (!clientChannel.outboundMessages().isEmpty()) {
serverChannel.writeInbound(clientChannel.readOutbound());
Expand Down
16 changes: 4 additions & 12 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Expand Up @@ -2608,12 +2608,8 @@ private[util] object CallerContext extends Logging {
val callerContextSupported: Boolean = {
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
// master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
Class.forName("org.apache.hadoop.ipc.CallerContext")
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
Utils.classForName("org.apache.hadoop.ipc.CallerContext")
Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
true
} catch {
case _: ClassNotFoundException =>
Expand Down Expand Up @@ -2688,12 +2684,8 @@ private[spark] class CallerContext(
def setCurrentContext(): Unit = {
if (CallerContext.callerContextSupported) {
try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
// master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
val hdfsContext = builder.getMethod("build").invoke(builderInst)
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
Expand Down

0 comments on commit 905fdf0

Please sign in to comment.