Skip to content

Commit

Permalink
RATIS-1542. Support TlsConf in netty streaming. (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
szetszwo committed Mar 16, 2022
1 parent 26582bf commit da5d868
Show file tree
Hide file tree
Showing 24 changed files with 1,130 additions and 78 deletions.
Expand Up @@ -118,8 +118,8 @@ public RaftClient build() {
}
}
return ClientImplUtils.newRaftClient(clientId, group, leaderId, primaryDataStreamServer,
Objects.requireNonNull(clientRpc, "The 'clientRpc' field is not initialized."),
properties, retryPolicy);
Objects.requireNonNull(clientRpc, "The 'clientRpc' field is not initialized."), retryPolicy,
properties, parameters);
}

/** Set {@link RaftClient} ID. */
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.apache.ratis.client.DataStreamClientRpc;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.RaftClientRpc;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftGroupId;
Expand All @@ -31,11 +32,12 @@

/** Client utilities for internal use. */
public interface ClientImplUtils {
@SuppressWarnings("checkstyle:ParameterNumber")
static RaftClient newRaftClient(ClientId clientId, RaftGroup group,
RaftPeerId leaderId, RaftPeer primaryDataStreamServer, RaftClientRpc clientRpc, RaftProperties properties,
RetryPolicy retryPolicy) {
return new RaftClientImpl(clientId, group, leaderId, primaryDataStreamServer, clientRpc, properties,
retryPolicy);
RaftPeerId leaderId, RaftPeer primaryDataStreamServer, RaftClientRpc clientRpc, RetryPolicy retryPolicy,
RaftProperties properties, Parameters parameters) {
return new RaftClientImpl(clientId, group, leaderId, primaryDataStreamServer, clientRpc, retryPolicy,
properties, parameters);
}

static DataStreamClient newDataStreamClient(ClientId clientId, RaftGroupId groupId, RaftPeer primaryDataStreamServer,
Expand Down
Expand Up @@ -24,6 +24,7 @@
import org.apache.ratis.client.api.LeaderElectionManagementApi;
import org.apache.ratis.client.api.SnapshotManagementApi;
import org.apache.ratis.client.retry.ClientRetryEvent;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry;
import org.apache.ratis.protocol.ClientId;
Expand Down Expand Up @@ -146,8 +147,9 @@ void set(Collection<RaftPeer> newPeers) {
private final ConcurrentMap<RaftPeerId, LeaderElectionManagementApi>
leaderElectionManagement = new ConcurrentHashMap<>();

@SuppressWarnings("checkstyle:ParameterNumber")
RaftClientImpl(ClientId clientId, RaftGroup group, RaftPeerId leaderId, RaftPeer primaryDataStreamServer,
RaftClientRpc clientRpc, RaftProperties properties, RetryPolicy retryPolicy) {
RaftClientRpc clientRpc, RetryPolicy retryPolicy, RaftProperties properties, Parameters parameters) {
this.clientId = clientId;
this.peers.set(group.getPeers());
this.groupId = group.getGroupId();
Expand All @@ -173,6 +175,7 @@ void set(Collection<RaftPeer> newPeers) {
.setRaftGroupId(groupId)
.setDataStreamServer(primaryDataStreamServer)
.setProperties(properties)
.setParameters(parameters)
.build());
this.adminApi = JavaUtils.memoize(() -> new AdminImpl(this));
}
Expand Down
13 changes: 13 additions & 0 deletions ratis-common/src/main/java/org/apache/ratis/conf/ConfUtils.java
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.ratis.conf;

import org.apache.ratis.security.TlsConf;
import org.apache.ratis.thirdparty.com.google.common.base.Objects;
import org.apache.ratis.util.NetUtils;
import org.apache.ratis.util.SizeInBytes;
Expand All @@ -35,6 +36,7 @@
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;

public interface ConfUtils {
Logger LOG = LoggerFactory.getLogger(ConfUtils.class);
Expand Down Expand Up @@ -198,6 +200,12 @@ static TimeDuration getTimeDuration(
return value;
}

static TlsConf getTlsConf(
Function<String, TlsConf> tlsConfGetter,
String key, Consumer<String> logger) {
return get((k, d) -> tlsConfGetter.apply(k), key, null, logger);
}

@SafeVarargs
static <T> T get(BiFunction<String, T, T> getter,
String key, T defaultValue, Consumer<String> logger, BiConsumer<String, T>... assertions) {
Expand Down Expand Up @@ -271,6 +279,11 @@ static void setTimeDuration(
set(timeDurationSetter, key, value, assertions);
}

static void setTlsConf(
BiConsumer<String, TlsConf> tlsConfSetter, String key, TlsConf value) {
set(tlsConfSetter, key, value);
}

@SafeVarargs
static <T> void set(
BiConsumer<String, T> setter, String key, T value,
Expand Down
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.security;

import org.apache.ratis.security.TlsConf.Builder;
import org.apache.ratis.security.TlsConf.CertificatesConf;
import org.apache.ratis.security.TlsConf.PrivateKeyConf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.net.URL;
import java.util.Optional;

public interface SecurityTestUtils {
Logger LOG = LoggerFactory.getLogger(SecurityTestUtils.class);

ClassLoader CLASS_LOADER = SecurityTestUtils.class.getClassLoader();

static File getResource(String name) {
final File file = Optional.ofNullable(CLASS_LOADER.getResource(name))
.map(URL::getFile)
.map(File::new)
.orElse(null);
LOG.info("Getting resource {}: {}", name, file);
return file;
}

static TlsConf newServerTlsConfig(boolean mutualAuthn) {
LOG.info("newServerTlsConfig: mutualAuthn? {}", mutualAuthn);
return new Builder()
.setName("server")
.setPrivateKey(new PrivateKeyConf(getResource("ssl/server.pem")))
.setKeyCertificates(new CertificatesConf(getResource("ssl/server.crt")))
.setTrustCertificates(new CertificatesConf(getResource("ssl/client.crt")))
.setMutualTls(mutualAuthn)
.build();
}

static TlsConf newClientTlsConfig(boolean mutualAuthn) {
LOG.info("newClientTlsConfig: mutualAuthn? {}", mutualAuthn);
return new Builder()
.setName("client")
.setPrivateKey(new PrivateKeyConf(getResource("ssl/client.pem")))
.setKeyCertificates(new CertificatesConf(getResource("ssl/client.crt")))
.setTrustCertificates(new CertificatesConf(getResource("ssl/ca.crt")))
.setMutualTls(mutualAuthn)
.build();
}
}
Expand Up @@ -41,7 +41,7 @@ public class MiniRaftClusterWithGrpc extends MiniRaftCluster.RpcBase {
@Override
public MiniRaftClusterWithGrpc newCluster(String[] ids, RaftProperties prop) {
RaftConfigKeys.Rpc.setType(prop, SupportedRpcType.GRPC);
return new MiniRaftClusterWithGrpc(ids, prop);
return new MiniRaftClusterWithGrpc(ids, prop, null);
}
};

Expand All @@ -55,8 +55,8 @@ default Factory<MiniRaftClusterWithGrpc> getFactory() {
public static final DelayLocalExecutionInjection sendServerRequestInjection =
new DelayLocalExecutionInjection(GrpcService.GRPC_SEND_SERVER_REQUEST);

protected MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties) {
super(ids, properties, null);
protected MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties, Parameters parameters) {
super(ids, properties, parameters);
}

@Override
Expand All @@ -66,7 +66,7 @@ protected Parameters setPropertiesAndInitParameters(RaftPeerId id, RaftGroup gro
GrpcConfigKeys.Client.setPort(properties, NetUtils.createSocketAddr(address).getPort()));
Optional.ofNullable(getAddress(id, group, RaftPeer::getAdminAddress)).ifPresent(address ->
GrpcConfigKeys.Admin.setPort(properties, NetUtils.createSocketAddr(address).getPort()));
return null;
return parameters;
}

@Override
Expand Down
Expand Up @@ -17,7 +17,10 @@
*/
package org.apache.ratis.netty;

import org.apache.ratis.conf.ConfUtils;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.security.TlsConf;
import org.apache.ratis.thirdparty.io.netty.util.NettyRuntime;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
Expand Down Expand Up @@ -52,7 +55,7 @@ static void setPort(RaftProperties properties, int port) {
}

interface DataStream {
Logger LOG = LoggerFactory.getLogger(Server.class);
Logger LOG = LoggerFactory.getLogger(DataStream.class);
static Consumer<String> getDefaultLog() {
return LOG::info;
}
Expand All @@ -71,44 +74,62 @@ static void setPort(RaftProperties properties, int port) {
setInt(properties::setInt, PORT_KEY, port);
}

String CLIENT_WORKER_GROUP_SIZE_KEY = PREFIX + ".client.worker-group.size";
int CLIENT_WORKER_GROUP_SIZE_DEFAULT = Math.max(1, NettyRuntime.availableProcessors() * 2);

static int clientWorkerGroupSize(RaftProperties properties) {
return getInt(properties::getInt, CLIENT_WORKER_GROUP_SIZE_KEY,
CLIENT_WORKER_GROUP_SIZE_DEFAULT, getDefaultLog(), requireMin(1), requireMax(65536));
}

static void setClientWorkerGroupSize(RaftProperties properties, int clientWorkerGroupSize) {
setInt(properties::setInt, CLIENT_WORKER_GROUP_SIZE_KEY, clientWorkerGroupSize);
}

String CLIENT_WORKER_GROUP_SHARE_KEY = PREFIX + ".client.worker-group.share";
boolean CLIENT_WORKER_GROUP_SHARE_DEFAULT = false;
interface Client {
String PREFIX = DataStream.PREFIX + ".client";

static boolean clientWorkerGroupShare(RaftProperties properties) {
return getBoolean(properties::getBoolean, CLIENT_WORKER_GROUP_SHARE_KEY,
CLIENT_WORKER_GROUP_SHARE_DEFAULT, getDefaultLog());
}

static void setClientWorkerGroupShare(RaftProperties properties, boolean clientWorkerGroupShare) {
setBoolean(properties::setBoolean, CLIENT_WORKER_GROUP_SHARE_KEY, clientWorkerGroupShare);
}
String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
Class<TlsConf> TLS_CONF_CLASS = TlsConf.class;
static TlsConf tlsConf(Parameters parameters) {
return getTlsConf(key -> parameters.get(key, TLS_CONF_CLASS), TLS_CONF_PARAMETER, getDefaultLog());
}
static void setTlsConf(Parameters parameters, TlsConf conf) {
LOG.info("setTlsConf " + conf);
ConfUtils.setTlsConf((key, value) -> parameters.put(key, value, TLS_CONF_CLASS), TLS_CONF_PARAMETER, conf);
}

String CLIENT_REPLY_QUEUE_GRACE_PERIOD_KEY = PREFIX + ".client.reply.queue.grace-period";
TimeDuration CLIENT_REPLY_QUEUE_GRACE_PERIOD_DEFAULT = TimeDuration.ONE_SECOND;
String WORKER_GROUP_SIZE_KEY = PREFIX + ".worker-group.size";
int WORKER_GROUP_SIZE_DEFAULT = Math.max(1, NettyRuntime.availableProcessors() * 2);
static int workerGroupSize(RaftProperties properties) {
return getInt(properties::getInt, WORKER_GROUP_SIZE_KEY,
WORKER_GROUP_SIZE_DEFAULT, getDefaultLog(), requireMin(1), requireMax(65536));
}
static void setWorkerGroupSize(RaftProperties properties, int clientWorkerGroupSize) {
setInt(properties::setInt, WORKER_GROUP_SIZE_KEY, clientWorkerGroupSize);
}

static TimeDuration clientReplyQueueGracePeriod(RaftProperties properties) {
return getTimeDuration(properties.getTimeDuration(CLIENT_REPLY_QUEUE_GRACE_PERIOD_DEFAULT.getUnit()),
CLIENT_REPLY_QUEUE_GRACE_PERIOD_KEY, CLIENT_REPLY_QUEUE_GRACE_PERIOD_DEFAULT, getDefaultLog());
}
String WORKER_GROUP_SHARE_KEY = PREFIX + ".worker-group.share";
boolean WORKER_GROUP_SHARE_DEFAULT = false;
static boolean workerGroupShare(RaftProperties properties) {
return getBoolean(properties::getBoolean, WORKER_GROUP_SHARE_KEY,
WORKER_GROUP_SHARE_DEFAULT, getDefaultLog());
}
static void setWorkerGroupShare(RaftProperties properties, boolean clientWorkerGroupShare) {
setBoolean(properties::setBoolean, WORKER_GROUP_SHARE_KEY, clientWorkerGroupShare);
}

static void setClientReplyQueueGracePeriod(RaftProperties properties, TimeDuration timeoutDuration) {
setTimeDuration(properties::setTimeDuration, CLIENT_REPLY_QUEUE_GRACE_PERIOD_KEY, timeoutDuration);
String REPLY_QUEUE_GRACE_PERIOD_KEY = PREFIX + ".reply.queue.grace-period";
TimeDuration REPLY_QUEUE_GRACE_PERIOD_DEFAULT = TimeDuration.ONE_SECOND;
static TimeDuration replyQueueGracePeriod(RaftProperties properties) {
return getTimeDuration(properties.getTimeDuration(REPLY_QUEUE_GRACE_PERIOD_DEFAULT.getUnit()),
REPLY_QUEUE_GRACE_PERIOD_KEY, REPLY_QUEUE_GRACE_PERIOD_DEFAULT, getDefaultLog());
}
static void setReplyQueueGracePeriod(RaftProperties properties, TimeDuration timeoutDuration) {
setTimeDuration(properties::setTimeDuration, REPLY_QUEUE_GRACE_PERIOD_KEY, timeoutDuration);
}
}

interface Server {
String PREFIX = NettyConfigKeys.PREFIX + ".server";
String PREFIX = DataStream.PREFIX + ".server";

String TLS_CONF_PARAMETER = PREFIX + ".tls.conf";
Class<TlsConf> TLS_CONF_CLASS = TlsConf.class;
static TlsConf tlsConf(Parameters parameters) {
return getTlsConf(key -> parameters.get(key, TLS_CONF_CLASS), TLS_CONF_PARAMETER, getDefaultLog());
}
static void setTlsConf(Parameters parameters, TlsConf conf) {
LOG.info("setTlsConf " + conf);
ConfUtils.setTlsConf((key, value) -> parameters.put(key, value, TLS_CONF_CLASS), TLS_CONF_PARAMETER, conf);
}

String USE_EPOLL_KEY = PREFIX + ".use-epoll";
boolean USE_EPOLL_DEFAULT = false;
Expand Down
Expand Up @@ -17,29 +17,25 @@
*/
package org.apache.ratis.netty;

import org.apache.ratis.client.DataStreamClientRpc;
import org.apache.ratis.client.DataStreamClientFactory;
import org.apache.ratis.client.DataStreamClientRpc;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.SupportedDataStreamType;
import org.apache.ratis.netty.client.NettyClientStreamRpc;
import org.apache.ratis.netty.server.NettyServerStreamRpc;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.security.TlsConf;
import org.apache.ratis.server.DataStreamServerRpc;
import org.apache.ratis.server.DataStreamServerFactory;
import org.apache.ratis.server.DataStreamServerRpc;
import org.apache.ratis.server.RaftServer;

import java.util.Optional;

public class NettyDataStreamFactory implements DataStreamServerFactory, DataStreamClientFactory {
private final TlsConf tlsConf;
private final Parameters parameters;

public NettyDataStreamFactory(Parameters parameters) {
//TODO: RATIS-1542: get TlsConf from parameters and add tls tests
this((TlsConf) null);
}

private NettyDataStreamFactory(TlsConf tlsConf) {
this.tlsConf = tlsConf;
this.parameters = Optional.ofNullable(parameters).orElseGet(Parameters::new);
}

@Override
Expand All @@ -49,11 +45,11 @@ public SupportedDataStreamType getDataStreamType() {

@Override
public DataStreamClientRpc newDataStreamClientRpc(RaftPeer server, RaftProperties properties) {
return new NettyClientStreamRpc(server, tlsConf, properties);
return new NettyClientStreamRpc(server, NettyConfigKeys.DataStream.Client.tlsConf(parameters), properties);
}

@Override
public DataStreamServerRpc newDataStreamServerRpc(RaftServer server) {
return new NettyServerStreamRpc(server, tlsConf);
return new NettyServerStreamRpc(server, NettyConfigKeys.DataStream.Server.tlsConf(parameters));
}
}
13 changes: 8 additions & 5 deletions ratis-netty/src/main/java/org/apache/ratis/netty/NettyUtils.java
Expand Up @@ -96,7 +96,7 @@ static SslContextBuilder initSslContextBuilderForServer(TlsConf tlsConf) {
}

static SslContext buildSslContextForServer(TlsConf tlsConf) {
return buildSslContext(tlsConf, true, NettyUtils::initSslContextBuilderForServer);
return buildSslContext("server", tlsConf, NettyUtils::initSslContextBuilderForServer);
}

static SslContextBuilder initSslContextBuilderForClient(TlsConf tlsConf) {
Expand All @@ -109,18 +109,21 @@ static SslContextBuilder initSslContextBuilderForClient(TlsConf tlsConf) {
}

static SslContext buildSslContextForClient(TlsConf tlsConf) {
return buildSslContext(tlsConf, false, NettyUtils::initSslContextBuilderForClient);
return buildSslContext("client", tlsConf, NettyUtils::initSslContextBuilderForClient);
}

static SslContext buildSslContext(TlsConf tlsConf, boolean isServer, Function<TlsConf, SslContextBuilder> builder) {
static SslContext buildSslContext(String name, TlsConf tlsConf, Function<TlsConf, SslContextBuilder> builder) {
if (tlsConf == null) {
return null;
}
final SslContext sslContext;
try {
return builder.apply(tlsConf).build();
sslContext = builder.apply(tlsConf).build();
} catch (Exception e) {
final String message = "Failed to build a " + (isServer ? "server" : "client") + " SslContext from " + tlsConf;
final String message = "Failed to buildSslContext for " + name + " from " + tlsConf;
throw new IllegalArgumentException(message, e);
}
LOG.debug("buildSslContext for {} from {} returns {}", name, tlsConf, sslContext.getClass().getName());
return sslContext;
}
}

0 comments on commit da5d868

Please sign in to comment.