Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ public Connection createConnection(URI remoteURI, String username, String passwo
}

public Connection createConnection(URI remoteURI, String username, String password, String clientId, boolean syncPublish) throws JMSException {
ConnectionFactory factory = createConnectionFactory(remoteURI, username, password, syncPublish);
return createConnection(remoteURI, username, password, clientId, syncPublish, null);
}

public Connection createConnection(URI remoteURI, String username, String password, String clientId, boolean syncPublish,
String sslProtocol) throws JMSException {
ConnectionFactory factory = createConnectionFactory(remoteURI, username, password, syncPublish, sslProtocol);

Connection connection = factory.createConnection();
connection.setExceptionListener(new ExceptionListener() {
Expand Down Expand Up @@ -166,7 +171,12 @@ private TopicConnectionFactory createTopicConnectionFactory(
}

private ConnectionFactory createConnectionFactory(
URI remoteURI, String username, String password, boolean syncPublish) {
URI remoteURI, String username, String password, boolean syncPublish) {
return createConnectionFactory(remoteURI, username, password, syncPublish, null);
}

private ConnectionFactory createConnectionFactory(
URI remoteURI, String username, String password, boolean syncPublish, String sslProtocol) {

String clientScheme;
boolean useSSL = false;
Expand Down Expand Up @@ -204,6 +214,9 @@ private ConnectionFactory createConnectionFactory(

if (useSSL) {
amqpURI += "?transport.verifyHost=false";
if (sslProtocol != null) {
amqpURI += "&transport.enabledProtocols=" + sslProtocol;
}
}

LOG.debug("In createConnectionFactory using URI: {}", amqpURI);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
*/
package org.apache.activemq.transport.amqp;

import jakarta.jms.Connection;
import jakarta.jms.JMSException;
import java.net.URI;

import java.net.URISyntaxException;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -32,6 +36,13 @@ protected URI getBrokerURI() {
return amqpNioPlusSslURI;
}

protected Connection createConnection(String clientId, boolean syncPublish) throws JMSException {
Connection connection = JMSClientContext.INSTANCE.createConnection(getBrokerURI(), "admin", "password", clientId, syncPublish,
enabledProtocols);
connection.start();
return connection;
}

@Override
protected boolean isUseTcpConnector() {
return false;
Expand All @@ -46,4 +57,15 @@ protected boolean isUseNioPlusSslConnector() {
protected String getTargetConnectorName() {
return "amqp+nio+ssl";
}

@Test(timeout=30000)
public void testSslHandshakeRenegotiationTlsv12() throws Exception {
testSslHandshakeRenegotiation("TLSv1.2");
}

@Test(timeout=30000)
public void testSslHandshakeRenegotiationTlsv13() throws Exception {
testSslHandshakeRenegotiation("TLSv1.3");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,29 @@
*/
package org.apache.activemq.transport.amqp;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import io.netty.channel.Channel;
import io.netty.handler.ssl.SslHandler;
import jakarta.jms.Message;
import jakarta.jms.MessageConsumer;
import jakarta.jms.MessageProducer;
import jakarta.jms.Queue;
import jakarta.jms.Session;
import jakarta.jms.TextMessage;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.URI;

import org.apache.activemq.broker.TransportConnector;
import org.apache.activemq.transport.amqp.joram.ActiveMQAdmin;
import org.apache.activemq.util.NioSslTestUtil;
import org.apache.qpid.jms.JmsConnection;
import org.apache.qpid.jms.provider.amqp.AmqpProvider;
import org.apache.qpid.jms.transports.netty.NettyTcpTransport;
import org.objectweb.jtests.jms.framework.TestConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -27,6 +48,16 @@
public class JMSClientSslTest extends JMSClientTest {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientSslTest.class);

protected String enabledProtocols = null;


@Override
public void setUp() throws Exception {
enabledProtocols = null;
super.setUp();
}


@Override
protected URI getBrokerURI() {
return amqpSslURI;
Expand All @@ -46,4 +77,65 @@ protected boolean isUseSslConnector() {
protected String getTargetConnectorName() {
return "amqp+ssl";
}

protected void testSslHandshakeRenegotiation(String protocol) throws Exception {
enabledProtocols = protocol;

ActiveMQAdmin.enableJMSFrameTracing();

connection = createConnection();

JmsConnection jmsCon = (JmsConnection) connection;
NettyTcpTransport transport = getNettyTransport(jmsCon);
Channel channel = getNettyChannel(transport);
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
assertEquals(protocol, sslHandler.engine().getSession().getProtocol());

// trigger handshakes
for (int i = 0; i < 10; i++) {
sslHandler.engine().beginHandshake();
}

// give some time for the handshake updates
Thread.sleep(100);

// check status advances if NIOSSL, then continue
// below to verify transports are not stuck
checkHandshakeStatusAdvances(((InetSocketAddress)channel.localAddress()).getPort());

// Make sure messages still work
Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
Queue queue = session.createQueue(getDestinationName());
MessageProducer p = session.createProducer(queue);

TextMessage message = session.createTextMessage();
message.setText("hello");
p.send(message);

MessageConsumer consumer = session.createConsumer(queue);
Message msg = consumer.receive(100);
assertNotNull(msg);
assertTrue(msg instanceof TextMessage);

}

// This only applies to NIO SSL
protected void checkHandshakeStatusAdvances(int localPort) throws Exception {
TransportConnector connector = brokerService.getTransportConnectorByScheme(getBrokerURI().getScheme());
NioSslTestUtil.checkHandshakeStatusAdvances(connector, localPort);
}

private NettyTcpTransport getNettyTransport(JmsConnection jmsCon) throws Exception {
Field providerField = JmsConnection.class.getDeclaredField("provider");
providerField.setAccessible(true);
AmqpProvider provider = (AmqpProvider) providerField.get(jmsCon);
return (NettyTcpTransport) provider.getTransport();
}

private Channel getNettyChannel(NettyTcpTransport transport) throws Exception {
Field channelField = NettyTcpTransport.class.getDeclaredField("channel");
channelField.setAccessible(true);
return (Channel) channelField.get(transport);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
*/
package org.apache.activemq.transport.amqp.auto;

import jakarta.jms.Connection;
import jakarta.jms.JMSException;
import java.net.URI;

import org.apache.activemq.transport.amqp.JMSClientContext;
import org.apache.activemq.transport.amqp.JMSClientSslTest;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -33,6 +37,13 @@ protected URI getBrokerURI() {
return autoNioPlusSslURI;
}

protected Connection createConnection(String clientId, boolean syncPublish) throws JMSException {
Connection connection = JMSClientContext.INSTANCE.createConnection(getBrokerURI(), "admin", "password", clientId, syncPublish,
enabledProtocols);
connection.start();
return connection;
}

@Override
protected boolean isUseTcpConnector() {
return false;
Expand All @@ -47,4 +58,15 @@ protected boolean isUseAutoNioPlusSslConnector() {
protected String getTargetConnectorName() {
return "auto+nio+ssl";
}

@Test(timeout=30000)
public void testSslHandshakeRenegotiationTlsv12() throws Exception {
testSslHandshakeRenegotiation("TLSv1.2");
}

@Test(timeout=30000)
public void testSslHandshakeRenegotiationTlsv13() throws Exception {
testSslHandshakeRenegotiation("TLSv1.3");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,24 @@ public void serviceRead() {
if (!plain.hasRemaining()) {
int readCount = secureRead(plain);

if (readCount == 0) {
/*
* 1) If data is read, continue below to the processCommand() call
* and handle processing the data in the buffer. This takes priority
* and some handshake status updates (like NEED_WRAP) can be handled
* concurrently with application data (like TLSv1.3 key updates)
* when the broker sends data to a client.
*
* 2) If no data is read, it's possible that the connection is waiting
* for us to process a handshake update (either KeyUpdate for
* TLS1.3 or renegotiation for TLSv1.2) so we need to check and process
* any handshake updates. If the handshake status was updated,
* we want to continue and loop again to recheck if we can now read new
* application data into the buffer after processing the updates.
*
* 3) If no data is read, and no handshake update is needed, then we
* are finished and can break.
*/
if (readCount == 0 && !handleHandshakeUpdate()) {
break;
}

Expand All @@ -184,7 +201,11 @@ public void serviceRead() {
receiveCounter.addAndGet(readCount);
}

if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
// Try and process commands if there is any data in plain if status is OK
// Handshake renegotiation can happen concurrently with application data reads
// so it's possible to have read data that needs processing even if the
// handshake status indicates NEED_UNWRAP
if (status == SSLEngineResult.Status.OK && plain.hasRemaining()) {
processCommand(plain);
//we have received enough bytes to detect the protocol
if (receiveCounter.get() >= 8) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.util;

import static org.junit.Assert.assertTrue;

import java.lang.reflect.Field;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLSocket;
import org.apache.activemq.broker.TransportConnection;
import org.apache.activemq.broker.TransportConnector;
import org.apache.activemq.transport.nio.NIOSSLTransport;

public class NioSslTestUtil {

public static void checkHandshakeStatusAdvances(TransportConnector connector, SSLSocket socket) throws Exception {
checkHandshakeStatusAdvances(connector, socket.getLocalPort(), 10000, 10);
}

public static void checkHandshakeStatusAdvances(TransportConnector connector, int localPort) throws Exception {
checkHandshakeStatusAdvances(connector, localPort, 10000, 10);
}

public static void checkHandshakeStatusAdvances(TransportConnector connector, int localPort,
long duration, long sleepMillis) throws Exception {

TransportConnection con = connector.getConnections().stream()
.filter(tc -> tc.getRemoteAddress().contains(
Integer.toString(localPort))).findFirst().orElseThrow();

Field field = NIOSSLTransport.class.getDeclaredField("handshakeStatus");
field.setAccessible(true);
NIOSSLTransport t = con.getTransport().narrow(NIOSSLTransport.class);
// If this is the NIOSSLTransport then verify we exit NEED_WRAP and NEED_TASK
if (t != null) {
assertTrue(Wait.waitFor(() -> {
SSLEngineResult.HandshakeStatus status = (SSLEngineResult.HandshakeStatus) field.get(t);
return status != HandshakeStatus.NEED_WRAP && status != HandshakeStatus.NEED_TASK;
}, duration, sleepMillis));
}
}
}
Loading