diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java index 32856da18b50..324b8ee70ed3 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java @@ -87,16 +87,20 @@ private void tryComplete(ChannelHandlerContext ctx) { return; } - ChannelPipeline p = ctx.pipeline(); - saslRpcClient.setupSaslHandler(p, HANDLER_NAME); - p.remove(SaslChallengeDecoder.class); - p.remove(this); + saslRpcClient.setupSaslHandler(ctx.pipeline(), HANDLER_NAME); + removeHandlers(ctx); setCryptoAESOption(); saslPromise.setSuccess(true); } + private void removeHandlers(ChannelHandlerContext ctx) { + ChannelPipeline p = ctx.pipeline(); + p.remove(SaslChallengeDecoder.class); + p.remove(this); + } + private void setCryptoAESOption() { boolean saslEncryptionEnabled = SaslUtil.QualityOfProtection.PRIVACY.getSaslQop() .equalsIgnoreCase(saslRpcClient.getSaslQOP()); @@ -142,6 +146,9 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep } else { saslPromise.tryFailure(new FallbackDisallowedException()); } + // When we switch to simple auth, we should also remove SaslChallengeDecoder and + // NettyHBaseSaslRpcClientHandler. + removeHandlers(ctx); return; } LOG.trace("Reading input token size={} for processing by initSASLContext", len); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java index 387318888a00..cb7a173625e1 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java @@ -89,9 +89,11 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop); ChannelPipeline p = ctx.pipeline(); if (useWrap) { - p.addBefore(DECODER_NAME, null, new SaslWrapHandler(saslServer::wrap)).addLast( - new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), - new SaslUnwrapHandler(saslServer::unwrap)); + p.addBefore(DECODER_NAME, null, new SaslWrapHandler(saslServer::wrap)) + .addBefore(NettyRpcServerResponseEncoder.NAME, null, + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addBefore(NettyRpcServerResponseEncoder.NAME, null, + new SaslUnwrapHandler(saslServer::unwrap)); } conn.setupHandler(); p.remove(this); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java index f3ead471fe61..bd024c90c3f4 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java @@ -132,7 +132,11 @@ protected void initChannel(Channel ch) throws Exception { initSSL(pipeline, conf.getBoolean(HBASE_SERVER_NETTY_TLS_SUPPORTPLAINTEXT, true)); } pipeline.addLast(NettyRpcServerPreambleHandler.DECODER_NAME, preambleDecoder) - .addLast(createNettyRpcServerPreambleHandler()); + .addLast(createNettyRpcServerPreambleHandler()) + // We need NettyRpcServerResponseEncoder here because NettyRpcServerPreambleHandler may + // send RpcResponse to client. + .addLast(NettyRpcServerResponseEncoder.NAME, + new NettyRpcServerResponseEncoder(metrics)); } }); try { diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java index ca25dea17fe2..8269bbc60d88 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java @@ -58,8 +58,9 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep LengthFieldBasedFrameDecoder decoder = new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4); decoder.setSingleDecode(true); - p.addLast(NettyHBaseSaslRpcServerHandler.DECODER_NAME, decoder); - p.addLast(new NettyHBaseSaslRpcServerHandler(rpcServer, conn)); + p.addBefore(NettyRpcServerResponseEncoder.NAME, NettyHBaseSaslRpcServerHandler.DECODER_NAME, + decoder).addBefore(NettyRpcServerResponseEncoder.NAME, null, + new NettyHBaseSaslRpcServerHandler(rpcServer, conn)); } else { conn.setupHandler(); } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerResponseEncoder.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerResponseEncoder.java index 30f8dba236a5..d3e338ffce02 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerResponseEncoder.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerResponseEncoder.java @@ -31,6 +31,8 @@ @InterfaceAudience.Private class NettyRpcServerResponseEncoder extends ChannelOutboundHandlerAdapter { + static final String NAME = "NettyRpcServerResponseEncoder"; + private final MetricsHBaseServer metrics; NettyRpcServerResponseEncoder(MetricsHBaseServer metrics) { diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java index 54c105802c55..f52357539dec 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java @@ -71,9 +71,10 @@ class NettyServerRpcConnection extends ServerRpcConnection { void setupHandler() { channel.pipeline() - .addLast("frameDecoder", new NettyRpcFrameDecoder(rpcServer.maxRequestSize, this)) - .addLast("decoder", new NettyRpcServerRequestDecoder(rpcServer.metrics, this)) - .addLast("encoder", new NettyRpcServerResponseEncoder(rpcServer.metrics)); + .addBefore(NettyRpcServerResponseEncoder.NAME, "frameDecoder", + new NettyRpcFrameDecoder(rpcServer.maxRequestSize, this)) + .addBefore(NettyRpcServerResponseEncoder.NAME, "decoder", + new NettyRpcServerRequestDecoder(rpcServer.metrics, this)); } void process(ByteBuf buf) throws IOException, InterruptedException { @@ -115,6 +116,6 @@ public NettyServerCall createCall(int id, final BlockingService service, @Override protected void doRespond(RpcResponse resp) { - channel.writeAndFlush(resp); + NettyFutureUtils.safeWriteAndFlush(channel, resp); } } diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcSkipInitialSaslHandshake.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcSkipInitialSaslHandshake.java new file mode 100644 index 000000000000..c47cceeb76a8 --- /dev/null +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcSkipInitialSaslHandshake.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import static org.apache.hadoop.hbase.ipc.TestProtobufRpcServiceImpl.SERVICE; +import static org.apache.hadoop.hbase.ipc.TestProtobufRpcServiceImpl.newBlockingStub; +import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getKeytabFileForTesting; +import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getPrincipalForTesting; +import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.loginKerberosPrincipal; +import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.setSecuredConfiguration; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.security.HBaseKerberosUtils; +import org.apache.hadoop.hbase.security.SecurityInfo; +import org.apache.hadoop.hbase.security.User; +import org.apache.hadoop.hbase.testclassification.MediumTests; +import org.apache.hadoop.hbase.testclassification.RPCTests; +import org.apache.hadoop.minikdc.MiniKdc; +import org.apache.hadoop.security.UserGroupInformation; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.mockito.Mockito; + +import org.apache.hbase.thirdparty.com.google.common.collect.Lists; +import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; +import org.apache.hbase.thirdparty.io.netty.channel.Channel; +import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; + +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos; +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto.BlockingInterface; + +@Category({ RPCTests.class, MediumTests.class }) +public class TestRpcSkipInitialSaslHandshake { + + @ClassRule + public static final HBaseClassTestRule CLASS_RULE = + HBaseClassTestRule.forClass(TestRpcSkipInitialSaslHandshake.class); + + protected static final HBaseTestingUtility TEST_UTIL = new HBaseTestingUtility(); + + protected static final File KEYTAB_FILE = + new File(TEST_UTIL.getDataTestDir("keytab").toUri().getPath()); + + protected static MiniKdc KDC; + protected static String HOST = "localhost"; + protected static String PRINCIPAL; + + protected String krbKeytab; + protected String krbPrincipal; + protected UserGroupInformation ugi; + protected Configuration clientConf; + protected Configuration serverConf; + + protected static void initKDCAndConf() throws Exception { + KDC = TEST_UTIL.setupMiniKdc(KEYTAB_FILE); + PRINCIPAL = "hbase/" + HOST; + KDC.createPrincipal(KEYTAB_FILE, PRINCIPAL); + HBaseKerberosUtils.setPrincipalForTesting(PRINCIPAL + "@" + KDC.getRealm()); + // set a smaller timeout and retry to speed up tests + TEST_UTIL.getConfiguration().setInt(RpcClient.SOCKET_TIMEOUT_READ, 2000000000); + TEST_UTIL.getConfiguration().setInt("hbase.security.relogin.maxretries", 1); + } + + protected static void stopKDC() throws InterruptedException { + if (KDC != null) { + KDC.stop(); + } + } + + protected final void setUpPrincipalAndConf() throws Exception { + krbKeytab = getKeytabFileForTesting(); + krbPrincipal = getPrincipalForTesting(); + ugi = loginKerberosPrincipal(krbKeytab, krbPrincipal); + clientConf = new Configuration(TEST_UTIL.getConfiguration()); + setSecuredConfiguration(clientConf); + clientConf.setBoolean(RpcClient.IPC_CLIENT_FALLBACK_TO_SIMPLE_AUTH_ALLOWED_KEY, true); + serverConf = new Configuration(TEST_UTIL.getConfiguration()); + } + + @BeforeClass + public static void setUp() throws Exception { + initKDCAndConf(); + } + + @AfterClass + public static void tearDown() throws Exception { + stopKDC(); + TEST_UTIL.cleanupTestDir(); + } + + @Before + public void setUpTest() throws Exception { + setUpPrincipalAndConf(); + } + + /** + * This test is for HBASE-27923,which NettyRpcServer may hange if it should skip initial sasl + * handshake. + */ + @Test + public void test() throws Exception { + SecurityInfo securityInfoMock = Mockito.mock(SecurityInfo.class); + Mockito.when(securityInfoMock.getServerPrincipal()) + .thenReturn(HBaseKerberosUtils.KRB_PRINCIPAL); + SecurityInfo.addInfo("TestProtobufRpcProto", securityInfoMock); + + final AtomicBoolean useSaslRef = new AtomicBoolean(false); + NettyRpcServer rpcServer = new NettyRpcServer(null, getClass().getSimpleName(), + Lists.newArrayList(new RpcServer.BlockingServiceAndInterface(SERVICE, null)), + new InetSocketAddress(HOST, 0), serverConf, new FifoRpcScheduler(serverConf, 1), true) { + + @Override + protected NettyRpcServerPreambleHandler createNettyRpcServerPreambleHandler() { + return new NettyRpcServerPreambleHandler(this) { + private NettyServerRpcConnection conn; + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + super.channelRead0(ctx, msg); + useSaslRef.set(conn.useSasl); + + } + + @Override + protected NettyServerRpcConnection createNettyServerRpcConnection(Channel channel) { + conn = super.createNettyServerRpcConnection(channel); + return conn; + } + }; + } + }; + + rpcServer.start(); + try (NettyRpcClient rpcClient = + new NettyRpcClient(clientConf, HConstants.DEFAULT_CLUSTER_ID.toString(), null, null)) { + BlockingInterface stub = newBlockingStub(rpcClient, rpcServer.getListenerAddress(), + User.create(UserGroupInformation.getCurrentUser())); + + String response = + stub.echo(null, TestProtos.EchoRequestProto.newBuilder().setMessage("test").build()) + .getMessage(); + assertTrue("test".equals(response)); + assertFalse(useSaslRef.get()); + + } finally { + rpcServer.stop(); + } + } +}