From a4ac78be86cea58455c2441710a6bbe19d0a38e1 Mon Sep 17 00:00:00 2001 From: "Christopher L. Shannon (cshannon)" Date: Mon, 18 May 2015 15:01:44 +0000 Subject: [PATCH] Adding maxFrameSize to the Stomp Protocol to be consistent with Openwire. This applies to https://issues.apache.org/jira/browse/AMQ-5776 --- .../activemq/transport/stomp/StompCodec.java | 15 +- .../transport/stomp/StompTransportFilter.java | 8 + .../transport/stomp/StompWireFormat.java | 56 ++- .../stomp/StompMaxFrameSizeTest.java | 362 ++++++++++++++++++ 4 files changed, 422 insertions(+), 19 deletions(-) create mode 100644 activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxFrameSizeTest.java diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java index 989b1d8b440..3581d3b1882 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import org.apache.activemq.transport.tcp.TcpTransport; import org.apache.activemq.util.ByteArrayOutputStream; @@ -33,6 +34,7 @@ public class StompCodec { TcpTransport transport; StompWireFormat wireFormat; + AtomicLong frameSize = new AtomicLong(); ByteArrayOutputStream currentCommand = new ByteArrayOutputStream(); boolean processedHeaders = false; String action; @@ -71,12 +73,14 @@ public void parse(ByteArrayInputStream input, int readSize) throws Exception { // end of headers section, parse action and header if (b == '\n' && (previousByte == '\n' || currentCommand.endsWith(crlfcrlf))) { DataByteArrayInputStream data = new DataByteArrayInputStream(currentCommand.toByteArray()); - action = wireFormat.parseAction(data); - headers = wireFormat.parseHeaders(data); + try { + action = wireFormat.parseAction(data, frameSize); + headers = wireFormat.parseHeaders(data, frameSize); + String contentLengthHeader = headers.get(Stomp.Headers.CONTENT_LENGTH); if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLengthHeader != null) { - contentLength = wireFormat.parseContentLength(contentLengthHeader); + contentLength = wireFormat.parseContentLength(contentLengthHeader, frameSize); } else { contentLength = -1; } @@ -100,6 +104,10 @@ public void parse(ByteArrayInputStream input, int readSize) throws Exception { transport.doConsume(new StompFrameError(new ProtocolException("The maximum data length was exceeded", true))); return; } + if (frameSize.incrementAndGet() > wireFormat.getMaxFrameSize()) { + transport.doConsume(new StompFrameError(new ProtocolException("The maximum frame size was exceeded", true))); + return; + } } } else { // read desired content length @@ -123,6 +131,7 @@ protected void processCommand() throws Exception { awaitingCommandStart = true; currentCommand.reset(); contentLength = -1; + frameSize.set(0); } public static String detectVersion(Map headers) throws ProtocolException { diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java index 9cf003ed275..87774dbc7bc 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java @@ -168,4 +168,12 @@ public void setMaxDataLength(int maxDataLength) { public int getMaxDataLength() { return wireFormat.getMaxDataLength(); } + + public void setMaxFrameSize(int maxFrameSize) { + wireFormat.setMaxFrameSize(maxFrameSize); + } + + public long getMaxFrameSize() { + return wireFormat.getMaxFrameSize(); + } } diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java index 1a95443131a..25ba91b3649 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java @@ -25,6 +25,7 @@ import java.io.PushbackInputStream; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import org.apache.activemq.util.ByteArrayInputStream; import org.apache.activemq.util.ByteArrayOutputStream; @@ -44,10 +45,15 @@ public class StompWireFormat implements WireFormat { private static final int MAX_HEADER_LENGTH = 1024 * 10; private static final int MAX_HEADERS = 1000; private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100; + public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE; private int version = 1; private int maxDataLength = MAX_DATA_LENGTH; + private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE; private String stompVersion = Stomp.DEFAULT_VERSION; + + //The current frame size as it is unmarshalled from the stream + private AtomicLong frameSize = new AtomicLong(); @Override public ByteSequence marshal(Object command) throws IOException { @@ -98,12 +104,12 @@ public void marshal(Object command, DataOutput os) throws IOException { public Object unmarshal(DataInput in) throws IOException { try { - + // parse action - String action = parseAction(in); + String action = parseAction(in, frameSize); // Parse the headers - HashMap headers = parseHeaders(in); + HashMap headers = parseHeaders(in, frameSize); // Read in the data part. byte[] data = NO_DATA; @@ -111,7 +117,7 @@ public Object unmarshal(DataInput in) throws IOException { if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) { // Bless the client, he's telling us how much data to read in. - int length = parseContentLength(contentLength); + int length = parseContentLength(contentLength, frameSize); data = new byte[length]; in.readFully(data); @@ -125,14 +131,17 @@ public Object unmarshal(DataInput in) throws IOException { // We don't know how much to read.. data ends when we hit a 0 byte b; ByteArrayOutputStream baos = null; - while ((b = in.readByte()) != 0) { - + while ((b = in.readByte()) != 0) { if (baos == null) { baos = new ByteArrayOutputStream(); } else if (baos.size() > getMaxDataLength()) { throw new ProtocolException("The maximum data length was exceeded", true); + } else { + if (frameSize.incrementAndGet() > getMaxFrameSize()) { + throw new ProtocolException("The maximum frame size was exceeded", true); + } } - + baos.write(b); } @@ -146,6 +155,8 @@ public Object unmarshal(DataInput in) throws IOException { } catch (ProtocolException e) { return new StompFrameError(e); + } finally { + frameSize.set(0); } } @@ -178,9 +189,9 @@ private ByteSequence readHeaderLine(DataInput in, int maxLength, String errorMes return line; } - protected String parseAction(DataInput in) throws IOException { + protected String parseAction(DataInput in, AtomicLong frameSize) throws IOException { String action = null; - + // skip white space to next real action line while (true) { action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded"); @@ -193,19 +204,20 @@ protected String parseAction(DataInput in) throws IOException { } } } - + frameSize.addAndGet(action.length()); return action; } - protected HashMap parseHeaders(DataInput in) throws IOException { - HashMap headers = new HashMap(25); + protected HashMap parseHeaders(DataInput in, AtomicLong frameSize) throws IOException { + HashMap headers = new HashMap(25); while (true) { ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded"); if (line != null && line.length > 1) { - + if (headers.size() > MAX_HEADERS) { throw new ProtocolException("The maximum number of headers was exceeded", true); } + frameSize.addAndGet(line.length); try { @@ -245,8 +257,8 @@ protected HashMap parseHeaders(DataInput in) throws IOException } return headers; } - - protected int parseContentLength(String contentLength) throws ProtocolException { + + protected int parseContentLength(String contentLength, AtomicLong frameSize) throws ProtocolException { int length; try { length = Integer.parseInt(contentLength.trim()); @@ -257,6 +269,10 @@ protected int parseContentLength(String contentLength) throws ProtocolException if (length > getMaxDataLength()) { throw new ProtocolException("The maximum data length was exceeded", true); } + + if (frameSize.addAndGet(length) > getMaxFrameSize()) { + throw new ProtocolException("The maximum frame size was exceeded", true); + } return length; } @@ -325,7 +341,7 @@ private String decodeHeader(InputStream header) throws IOException { return new String(decoded.toByteArray(), "UTF-8"); } - + @Override public int getVersion() { return version; @@ -351,4 +367,12 @@ public void setMaxDataLength(int maxDataLength) { public int getMaxDataLength() { return maxDataLength; } + + public long getMaxFrameSize() { + return maxFrameSize; + } + + public void setMaxFrameSize(long maxFrameSize) { + this.maxFrameSize = maxFrameSize; + } } diff --git a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxFrameSizeTest.java b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxFrameSizeTest.java new file mode 100644 index 00000000000..ec367757769 --- /dev/null +++ b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxFrameSizeTest.java @@ -0,0 +1,362 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.stomp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.net.Socket; +import java.util.Arrays; +import java.util.Collection; + +import javax.net.SocketFactory; +import javax.net.ssl.SSLSocketFactory; + +import org.apache.activemq.broker.TransportConnector; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class StompMaxFrameSizeTest extends StompTestSupport { + + enum TestType {FRAME_MAX_GREATER_THAN_HEADER_MAX, FRAME_MAX_LESS_THAN_HEADER_MAX, FRAME_MAX_LESS_THAN_ACTION_MAX}; + + //set max data size higher than max frame size so that max frame size gets tested + private static final int MAX_DATA_SIZE = 100 * 1024; + private StompConnection connection; + private TestType testType; + private int maxFrameSize; + + /** + * This defines the different possible max header sizes for this test. + */ + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + //The maximum size exceeds the default max header size of 10 * 1024 + {TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX, 64 * 1024}, + //The maximum size is less than the default max header size of 10 * 1024 + {TestType.FRAME_MAX_LESS_THAN_HEADER_MAX, 5 * 1024}, + //The maximum size is less than the default max action size of 1024 + {TestType.FRAME_MAX_LESS_THAN_ACTION_MAX, 512} + }); + } + + public StompMaxFrameSizeTest(TestType testType, int maxFrameSize) { + this.testType = testType; + this.maxFrameSize = maxFrameSize; + } + + @Override + public void setUp() throws Exception { + System.setProperty("javax.net.ssl.trustStore", "src/test/resources/client.keystore"); + System.setProperty("javax.net.ssl.trustStorePassword", "password"); + System.setProperty("javax.net.ssl.trustStoreType", "jks"); + System.setProperty("javax.net.ssl.keyStore", "src/test/resources/server.keystore"); + System.setProperty("javax.net.ssl.keyStorePassword", "password"); + System.setProperty("javax.net.ssl.keyStoreType", "jks"); + super.setUp(); + } + + @Override + public void tearDown() throws Exception { + if (connection != null) { + try { + connection.close(); + } catch (Throwable ex) {} + } + super.tearDown(); + } + + @Override + protected void addStompConnector() throws Exception { + TransportConnector connector = null; + + connector = brokerService.addConnector("stomp+ssl://0.0.0.0:"+ sslPort + + "?transport.maxDataLength=" + MAX_DATA_SIZE + "&transport.maxFrameSize=" + maxFrameSize); + sslPort = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp://0.0.0.0:" + port + + "?transport.maxDataLength=" + MAX_DATA_SIZE + "&transport.maxFrameSize=" + maxFrameSize); + port = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp+nio://0.0.0.0:" + nioPort + + "?transport.maxDataLength=" + MAX_DATA_SIZE + "&transport.maxFrameSize=" + maxFrameSize); + nioPort = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp+nio+ssl://0.0.0.0:" + nioSslPort + + "?transport.maxDataLength=" + MAX_DATA_SIZE + "&transport.maxFrameSize=" + maxFrameSize); + nioSslPort = connector.getConnectUri().getPort(); + } + + /** + * These tests should cause a Stomp error because the body size is greater than the + * max allowed frame size + */ + + @Test(timeout = 60000) + public void testOversizedBodyOnPlainSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(port, false, maxFrameSize + 100); + } + + @Test(timeout = 60000) + public void testOversizedBodyOnNioSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(nioPort, false, maxFrameSize + 100); + } + + @Test(timeout = 60000) + public void testOversizedBodyOnSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(sslPort, true, maxFrameSize + 100); + } + + @Test(timeout = 60000) + public void testOversizedBodyOnNioSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(nioSslPort, true, maxFrameSize + 100); + } + + + /** + * These tests should cause a Stomp error because even though the body size is less than max frame size, + * the action and headers plus data size should cause a max frame size failure + */ + @Test(timeout = 60000) + public void testOversizedTotalFrameOnPlainSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(port, false, maxFrameSize - 50); + } + + @Test(timeout = 60000) + public void testOversizedTotalFrameOnNioSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(nioPort, false, maxFrameSize - 50); + } + + @Test(timeout = 60000) + public void testOversizedTotalFrameOnSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(sslPort, true, maxFrameSize - 50); + } + + @Test(timeout = 60000) + public void testOversizedTotalFrameOnNioSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doOversizedTestMessage(nioSslPort, true, maxFrameSize - 50); + } + + + /** + * These tests will test a successful Stomp message when the total size is than max frame size + */ + @Test(timeout = 60000) + public void testUndersizedTotalFrameOnPlainSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doUndersizedTestMessage(port, false); + } + + @Test(timeout = 60000) + public void testUndersizedTotalFrameOnNioSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doUndersizedTestMessage(nioPort, false); + } + + @Test(timeout = 60000) + public void testUndersizedTotalFrameOnSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doUndersizedTestMessage(sslPort, true); + } + + @Test(timeout = 60000) + public void testUndersizedTotalFrameOnNioSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_GREATER_THAN_HEADER_MAX); + doUndersizedTestMessage(nioSslPort, true); + } + + /** + * These tests test that a Stomp error occurs if the action size exceeds maxFrameSize + * when the maxFrameSize length is less than the default max action length + */ + + @Test(timeout = 60000) + public void testOversizedActionOnPlainSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_ACTION_MAX); + doTestOversizedAction(port, false); + } + + @Test(timeout = 60000) + public void testOversizedActionOnNioSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_ACTION_MAX); + doTestOversizedAction(nioPort, false); + } + + @Test(timeout = 60000) + public void testOversizedActionOnSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_ACTION_MAX); + doTestOversizedAction(sslPort, true); + } + + @Test(timeout = 60000) + public void testOversizedActionOnNioSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_ACTION_MAX); + doTestOversizedAction(nioSslPort, true); + } + + + /** + * These tests will test that a Stomp error occurs if the header size exceeds maxFrameSize + * when the maxFrameSize length is less than the default max header length + */ + @Test(timeout = 60000) + public void testOversizedHeadersOnPlainSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_HEADER_MAX); + doTestOversizedHeaders(port, false); + } + + @Test(timeout = 60000) + public void testOversizedHeadersOnNioSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_HEADER_MAX); + doTestOversizedHeaders(nioPort, false); + } + + @Test(timeout = 60000) + public void testOversizedHeadersOnSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_HEADER_MAX); + doTestOversizedHeaders(sslPort, true); + } + + @Test(timeout = 60000) + public void testOversizedHeadersOnNioSslSocket() throws Exception { + Assume.assumeTrue(testType == TestType.FRAME_MAX_LESS_THAN_HEADER_MAX); + doTestOversizedHeaders(nioSslPort, true); + } + + + protected void doTestOversizedAction(int port, boolean useSsl) throws Exception { + initializeStomp(port, useSsl); + + char[] actionArray = new char[maxFrameSize + 100]; + Arrays.fill(actionArray, 'A'); + String action = new String(actionArray); + + String frame = action + "\n" + "destination:/queue/" + getQueueName() + "\n\n" + "body" + Stomp.NULL; + stompConnection.sendFrame(frame); + + StompFrame received = stompConnection.receive(500000); + assertNotNull(received); + assertEquals("ERROR", received.getAction()); + assertTrue(received.getBody().contains("maximum frame size")); + } + + protected void doTestOversizedHeaders(int port, boolean useSsl) throws Exception { + initializeStomp(port, useSsl); + + StringBuilder headers = new StringBuilder(maxFrameSize + 100); + int i = 0; + while (headers.length() < maxFrameSize + 1) { + headers.append("key" + i++ + ":value\n"); + } + + String frame = "SEND\n" + headers.toString() + "\n" + "destination:/queue/" + getQueueName() + + headers.toString() + "\n\n" + "body" + Stomp.NULL; + stompConnection.sendFrame(frame); + + StompFrame received = stompConnection.receive(5000); + assertNotNull(received); + assertEquals("ERROR", received.getAction()); + assertTrue(received.getBody().contains("maximum frame size")); + } + + protected void doOversizedTestMessage(int port, boolean useSsl, int dataSize) throws Exception { + initializeStomp(port, useSsl); + + int size = dataSize + 100; + char[] bigBodyArray = new char[size]; + Arrays.fill(bigBodyArray, 'a'); + String bigBody = new String(bigBodyArray); + + String frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL; + + stompConnection.sendFrame(frame); + + StompFrame received = stompConnection.receive(5000); + assertNotNull(received); + assertEquals("ERROR", received.getAction()); + assertTrue(received.getBody().contains("maximum frame size")); + } + + protected void doUndersizedTestMessage(int port, boolean useSsl) throws Exception { + initializeStomp(port, useSsl); + + int size = 100; + char[] bigBodyArray = new char[size]; + Arrays.fill(bigBodyArray, 'a'); + String bigBody = new String(bigBodyArray); + + String frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL; + + stompConnection.sendFrame(frame); + + StompFrame received = stompConnection.receive(); + assertNotNull(received); + assertEquals("MESSAGE", received.getAction()); + assertEquals(bigBody, received.getBody()); + } + + protected StompConnection stompConnect(int port, boolean ssl) throws Exception { + if (stompConnection == null) { + stompConnection = new StompConnection(); + } + + Socket socket = null; + if (ssl) { + socket = createSslSocket(port); + } else { + socket = createSocket(port); + } + + stompConnection.open(socket); + + return stompConnection; + } + + protected void initializeStomp(int port, boolean useSsl) throws Exception{ + stompConnect(port, useSsl); + + String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + + frame = stompConnection.receiveFrame(); + assertTrue(frame.startsWith("CONNECTED")); + + frame = "SUBSCRIBE\n" + "destination:/queue/" + getQueueName() + "\n" + "ack:auto\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + } + + protected Socket createSocket(int port) throws IOException { + return new Socket("127.0.0.1", port); + } + + protected Socket createSslSocket(int port) throws IOException { + SocketFactory factory = SSLSocketFactory.getDefault(); + return factory.createSocket("127.0.0.1", port); + } +}