Skip to content

Commit

Permalink
[ISSUE #6957] Support Proxy Protocol for gRPC and Remoting Server (#6958
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dingshuangxi888 authored Jul 4, 2023
1 parent 9554282 commit 00fc42b
Show file tree
Hide file tree
Showing 15 changed files with 563 additions and 59 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ maven_install(
"software.amazon.awssdk:s3:2.20.29",
"com.fasterxml.jackson.core:jackson-databind:2.13.4.2",
"com.adobe.testing:s3mock-junit4:2.11.0",
"io.github.aliyunmq:rocketmq-grpc-netty-codec-haproxy:1.0.0",
],
fetch_sources = True,
repositories = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.rocketmq.common.constant;

public class HAProxyConstants {

public static final String PROXY_PROTOCOL_PREFIX = "proxy_protocol_";
public static final String PROXY_PROTOCOL_ADDR = PROXY_PROTOCOL_PREFIX + "addr";
public static final String PROXY_PROTOCOL_PORT = PROXY_PROTOCOL_PREFIX + "port";
public static final String PROXY_PROTOCOL_SERVER_ADDR = PROXY_PROTOCOL_PREFIX + "server_addr";
public static final String PROXY_PROTOCOL_SERVER_PORT = PROXY_PROTOCOL_PREFIX + "server_port";
public static final String PROXY_PROTOCOL_TLV_PREFIX = PROXY_PROTOCOL_PREFIX + "tlv_0x";
}
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.github.aliyunmq</groupId>
<artifactId>rocketmq-grpc-netty-codec-haproxy</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.conversantmedia</groupId>
<artifactId>disruptor</artifactId>
Expand Down
2 changes: 2 additions & 0 deletions proxy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ java_library(
"@maven//:io_grpc_grpc_services",
"@maven//:io_grpc_grpc_stub",
"@maven//:io_netty_netty_all",
"@maven//:io_github_aliyunmq_rocketmq_grpc_netty_codec_haproxy",
"@maven//:io_openmessaging_storage_dledger",
"@maven//:io_opentelemetry_opentelemetry_api",
"@maven//:io_opentelemetry_opentelemetry_exporter_otlp",
Expand Down Expand Up @@ -94,6 +95,7 @@ java_library(
"@maven//:io_grpc_grpc_netty_shaded",
"@maven//:io_grpc_grpc_stub",
"@maven//:io_netty_netty_all",
"@maven//:io_github_aliyunmq_rocketmq_grpc_netty_codec_haproxy",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:io_opentelemetry_opentelemetry_exporter_otlp",
"@maven//:io_opentelemetry_opentelemetry_exporter_prometheus",
Expand Down
4 changes: 4 additions & 0 deletions proxy/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
<dependency>
<groupId>io.github.aliyunmq</groupId>
<artifactId>rocketmq-grpc-netty-codec-haproxy</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static GrpcServerBuilder newBuilder(ThreadPoolExecutor executor, int port
protected GrpcServerBuilder(ThreadPoolExecutor executor, int port) {
serverBuilder = NettyServerBuilder.forPort(port);

serverBuilder.protocolNegotiator(new OptionalSSLProtocolNegotiator());
serverBuilder.protocolNegotiator(new ProxyAndTlsProtocolNegotiator());

// build server
int bossLoopNum = ConfigurationManager.getProxyConfig().getGrpcBossLoopNum();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,61 @@
*/
package org.apache.rocketmq.proxy.grpc;

import io.grpc.Attributes;
import io.grpc.netty.shaded.io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.shaded.io.grpc.netty.ProtocolNegotiationEvent;
import io.grpc.netty.shaded.io.netty.buffer.ByteBuf;
import io.grpc.netty.shaded.io.netty.channel.ChannelHandler;
import io.grpc.netty.shaded.io.netty.channel.ChannelHandlerContext;
import io.grpc.netty.shaded.io.netty.channel.ChannelInboundHandlerAdapter;
import io.grpc.netty.shaded.io.netty.handler.codec.ByteToMessageDecoder;
import io.grpc.netty.shaded.io.netty.handler.codec.ProtocolDetectionResult;
import io.grpc.netty.shaded.io.netty.handler.codec.ProtocolDetectionState;
import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyMessage;
import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslHandler;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.SelfSignedCertificate;
import io.grpc.netty.shaded.io.netty.util.AsciiString;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import io.grpc.netty.shaded.io.netty.util.CharsetUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.constant.HAProxyConstants;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
import org.apache.rocketmq.proxy.config.ConfigurationManager;
import org.apache.rocketmq.proxy.config.ProxyConfig;
import org.apache.rocketmq.proxy.grpc.constant.AttributeKeys;
import org.apache.rocketmq.remoting.common.TlsMode;
import org.apache.rocketmq.remoting.netty.TlsSystemConfig;

public class OptionalSSLProtocolNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;

public class ProxyAndTlsProtocolNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
protected static final Logger log = LoggerFactory.getLogger(LoggerName.PROXY_LOGGER_NAME);

private static final String HA_PROXY_DECODER = "HAProxyDecoder";
private static final String HA_PROXY_HANDLER = "HAProxyHandler";
private static final String TLS_MODE_HANDLER = "TlsModeHandler";
/**
* the length of the ssl record header (in bytes)
*/
private static final int SSL_RECORD_HEADER_LENGTH = 5;

private static SslContext sslContext;

public OptionalSSLProtocolNegotiator() {
public ProxyAndTlsProtocolNegotiator() {
sslContext = loadSslContext();
}

Expand All @@ -64,11 +81,12 @@ public AsciiString scheme() {

@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return new PortUnificationServerHandler(grpcHandler);
return new ProxyAndTlsProtocolHandler(grpcHandler);
}

@Override
public void close() {}
public void close() {
}

private static SslContext loadSslContext() {
try {
Expand All @@ -85,8 +103,8 @@ private static SslContext loadSslContext() {
String tlsCertPath = ConfigurationManager.getProxyConfig().getTlsCertPath();
try (InputStream serverKeyInputStream = Files.newInputStream(
Paths.get(tlsKeyPath));
InputStream serverCertificateStream = Files.newInputStream(
Paths.get(tlsCertPath))) {
InputStream serverCertificateStream = Files.newInputStream(
Paths.get(tlsCertPath))) {
SslContext res = GrpcSslContexts.forServer(serverCertificateStream,
serverKeyInputStream)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
Expand All @@ -102,21 +120,103 @@ private static SslContext loadSslContext() {
}
}

public static class PortUnificationServerHandler extends ByteToMessageDecoder {
private static class ProxyAndTlsProtocolHandler extends ByteToMessageDecoder {

private final GrpcHttp2ConnectionHandler grpcHandler;

public ProxyAndTlsProtocolHandler(GrpcHttp2ConnectionHandler grpcHandler) {
this.grpcHandler = grpcHandler;
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
try {
ProtocolDetectionResult<HAProxyProtocolVersion> ha = HAProxyMessageDecoder.detectProtocol(
in);
if (ha.state() == ProtocolDetectionState.NEEDS_MORE_DATA) {
return;
}
if (ha.state() == ProtocolDetectionState.DETECTED) {
ctx.pipeline().addAfter(ctx.name(), HA_PROXY_DECODER, new HAProxyMessageDecoder())
.addAfter(HA_PROXY_DECODER, HA_PROXY_HANDLER, new HAProxyMessageHandler())
.addAfter(HA_PROXY_HANDLER, TLS_MODE_HANDLER, new TlsModeHandler(grpcHandler));
} else {
ctx.pipeline().addAfter(ctx.name(), TLS_MODE_HANDLER, new TlsModeHandler(grpcHandler));
}

ctx.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
ctx.pipeline().remove(this);
} catch (Exception e) {
log.error("process proxy protocol negotiator failed.", e);
throw e;
}
}
}

private static class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {

private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HAProxyMessage) {
replaceEventWithMessage((HAProxyMessage) msg);
ctx.fireUserEventTriggered(pne);
} else {
super.channelRead(ctx, msg);
}
ctx.pipeline().remove(this);
}

/**
* The definition of key refers to the implementation of nginx
* <a href="https://nginx.org/en/docs/http/ngx_http_core_module.html#var_proxy_protocol_addr">ngx_http_core_module</a>
*
* @param msg
*/
private void replaceEventWithMessage(HAProxyMessage msg) {
Attributes.Builder builder = InternalProtocolNegotiationEvent.getAttributes(pne).toBuilder();
if (StringUtils.isNotBlank(msg.sourceAddress())) {
builder.set(AttributeKeys.PROXY_PROTOCOL_ADDR, msg.sourceAddress());
}
if (msg.sourcePort() > 0) {
builder.set(AttributeKeys.PROXY_PROTOCOL_PORT, String.valueOf(msg.sourcePort()));
}
if (StringUtils.isNotBlank(msg.destinationAddress())) {
builder.set(AttributeKeys.PROXY_PROTOCOL_SERVER_ADDR, msg.destinationAddress());
}
if (msg.destinationPort() > 0) {
builder.set(AttributeKeys.PROXY_PROTOCOL_SERVER_PORT, String.valueOf(msg.destinationPort()));
}
if (CollectionUtils.isNotEmpty(msg.tlvs())) {
msg.tlvs().forEach(tlv -> {
Attributes.Key<String> key = AttributeKeys.valueOf(
HAProxyConstants.PROXY_PROTOCOL_TLV_PREFIX + String.format("%02x", tlv.typeByteValue()));
String value = StringUtils.trim(tlv.content().toString(CharsetUtil.UTF_8));
builder.set(key, value);
});
}
pne = InternalProtocolNegotiationEvent
.withAttributes(InternalProtocolNegotiationEvent.getDefault(), builder.build());
}
}

private static class TlsModeHandler extends ByteToMessageDecoder {

private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();

private final ChannelHandler ssl;
private final ChannelHandler plaintext;

public PortUnificationServerHandler(GrpcHttp2ConnectionHandler grpcHandler) {
public TlsModeHandler(GrpcHttp2ConnectionHandler grpcHandler) {
this.ssl = InternalProtocolNegotiators.serverTls(sslContext)
.newHandler(grpcHandler);
this.plaintext = InternalProtocolNegotiators.serverPlaintext()
.newHandler(grpcHandler);
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
throws Exception {
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
try {
TlsMode tlsMode = TlsSystemConfig.tlsMode;
if (TlsMode.ENFORCING.equals(tlsMode)) {
Expand All @@ -134,12 +234,21 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
ctx.pipeline().addAfter(ctx.name(), null, this.plaintext);
}
}
ctx.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
ctx.fireUserEventTriggered(pne);
ctx.pipeline().remove(this);
} catch (Exception e) {
log.error("process ssl protocol negotiator failed.", e);
throw e;
}
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
pne = (ProtocolNegotiationEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.proxy.grpc.constant;

import io.grpc.Attributes;
import org.apache.rocketmq.common.constant.HAProxyConstants;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class AttributeKeys {

public static final Attributes.Key<String> PROXY_PROTOCOL_ADDR =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_ADDR);

public static final Attributes.Key<String> PROXY_PROTOCOL_PORT =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_PORT);

public static final Attributes.Key<String> PROXY_PROTOCOL_SERVER_ADDR =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_SERVER_ADDR);

public static final Attributes.Key<String> PROXY_PROTOCOL_SERVER_PORT =
Attributes.Key.create(HAProxyConstants.PROXY_PROTOCOL_SERVER_PORT);

private static final Map<String, Attributes.Key<String>> ATTRIBUTES_KEY_MAP = new ConcurrentHashMap<>();

public static Attributes.Key<String> valueOf(String name) {
return ATTRIBUTES_KEY_MAP.computeIfAbsent(name, key -> Attributes.Key.create(name));
}
}
Loading

0 comments on commit 00fc42b

Please sign in to comment.