diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpClientTransport.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpClientTransport.java index 2480daae225..3549664eaf5 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpClientTransport.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpClientTransport.java @@ -16,8 +16,8 @@ */ package org.apache.activemq.transport.http; -import java.io.DataInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.InterruptedIOException; import java.net.URI; import java.security.cert.X509Certificate; @@ -26,7 +26,7 @@ import org.apache.activemq.command.ShutdownInfo; import org.apache.activemq.transport.FutureResponse; -import org.apache.activemq.transport.util.TextWireFormat; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; import org.apache.activemq.util.ByteArrayOutputStream; import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.IdGenerator; @@ -90,8 +90,8 @@ public class HttpClientTransport extends HttpTransportSupport { protected boolean canSendCompressed = false; private int minSendAsCompressedSize = 0; - public HttpClientTransport(TextWireFormat wireFormat, URI remoteUrl) { - super(wireFormat, remoteUrl); + public HttpClientTransport(final HttpTransportMarshaller marshaller, URI remoteUrl) { + super(marshaller, remoteUrl); } public FutureResponse asyncRequest(Object command) throws IOException { @@ -106,8 +106,7 @@ public void oneway(Object command) throws IOException { } HttpPost httpMethod = new HttpPost(getRemoteUrl().toString()); configureMethod(httpMethod); - String data = getTextWireFormat().marshalText(command); - byte[] bytes = data.getBytes("UTF-8"); + byte[] bytes = asBytes(command); if (useCompression && canSendCompressed && bytes.length > minSendAsCompressedSize) { ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); GZIPOutputStream stream = new GZIPOutputStream(bytesOut); @@ -147,17 +146,24 @@ public void oneway(Object command) throws IOException { } } + private byte[] asBytes(final Object command) throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + getMarshaller().marshal(command, outputStream); + return outputStream.toByteArray(); + } + @Override public Object request(Object command) throws IOException { return null; } - private DataInputStream createDataInputStream(HttpResponse answer) throws IOException { - Header encoding = answer.getEntity().getContentEncoding(); - if (encoding != null && "gzip".equalsIgnoreCase(encoding.getValue())) { - return new DataInputStream(new GZIPInputStream(answer.getEntity().getContent())); + private InputStream createInputStream(final HttpResponse answer) throws IOException { + final InputStream inputStream = answer.getEntity().getContent(); + final Header encoding = answer.getEntity().getContentEncoding(); + if (encoding == null || !"gzip".equalsIgnoreCase(encoding.getValue())) { + return inputStream; } else { - return new DataInputStream(answer.getEntity().getContent()); + return new GZIPInputStream(inputStream); } } @@ -195,8 +201,8 @@ public void run() { } } else { receiveCounter++; - DataInputStream stream = createDataInputStream(answer); - Object command = getTextWireFormat().unmarshal(stream); + final InputStream stream = createInputStream(answer); + final Object command = getMarshaller().unmarshal(stream); if (command == null) { LOG.debug("Received null command from url: " + remoteUrl); } else { @@ -415,6 +421,6 @@ public void setPeerCertificates(X509Certificate[] certificates) { @Override public WireFormat getWireFormat() { - return getTextWireFormat(); + return getMarshaller().getWireFormat(); } } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportFactory.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportFactory.java index 02ecf771807..22f067c71f3 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportFactory.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportFactory.java @@ -27,8 +27,10 @@ import org.apache.activemq.transport.TransportFactory; import org.apache.activemq.transport.TransportLoggerFactory; import org.apache.activemq.transport.TransportServer; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; +import org.apache.activemq.transport.http.marshallers.HttpWireFormatMarshaller; +import org.apache.activemq.transport.http.marshallers.TextWireFormatMarshallers; import org.apache.activemq.transport.util.TextWireFormat; -import org.apache.activemq.transport.xstream.XStreamWireFormat; import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.IntrospectionSupport; import org.apache.activemq.util.URISupport; @@ -39,6 +41,16 @@ public class HttpTransportFactory extends TransportFactory { private static final Logger LOG = LoggerFactory.getLogger(HttpTransportFactory.class); + private static final String WIRE_FORMAT_XSTREAM = "xstream"; + private final String defaultWireFormatType; + + public HttpTransportFactory() { + defaultWireFormatType = WIRE_FORMAT_XSTREAM; + } + + public HttpTransportFactory(final String defaultWireFormatType) { + this.defaultWireFormatType = defaultWireFormatType; + } @Override public TransportServer doBind(URI location) throws IOException { @@ -57,22 +69,18 @@ public TransportServer doBind(URI location) throws IOException { } } - protected TextWireFormat asTextWireFormat(WireFormat wireFormat) { - if (wireFormat instanceof TextWireFormat) { - return (TextWireFormat)wireFormat; - } - LOG.trace("Not created with a TextWireFormat: {}", wireFormat); - return new XStreamWireFormat(); + protected WireFormat processWireFormat(final WireFormat wireFormat) { + return wireFormat; } @Override protected String getDefaultWireFormatType() { - return "xstream"; + return defaultWireFormatType; } @Override protected Transport createTransport(URI location, WireFormat wf) throws IOException { - TextWireFormat textWireFormat = asTextWireFormat(wf); + final WireFormat wireFormat = processWireFormat(wf); // need to remove options from uri URI uri; try { @@ -82,7 +90,14 @@ protected Transport createTransport(URI location, WireFormat wf) throws IOExcept cause.initCause(e); throw cause; } - return new HttpClientTransport(textWireFormat, uri); + return new HttpClientTransport(createMarshaller(wireFormat), uri); + } + + protected HttpTransportMarshaller createMarshaller(final WireFormat wireFormat) + { + return wireFormat instanceof TextWireFormat ? + TextWireFormatMarshallers.newTransportMarshaller((TextWireFormat)wireFormat) : + new HttpWireFormatMarshaller(wireFormat); } @Override diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportSupport.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportSupport.java index d01ce25d2b9..97ee904a824 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportSupport.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTransportSupport.java @@ -19,7 +19,7 @@ import java.net.URI; import org.apache.activemq.transport.TransportThreadSupport; -import org.apache.activemq.transport.util.TextWireFormat; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; /** * A useful base class for HTTP Transport implementations. @@ -27,15 +27,15 @@ * */ public abstract class HttpTransportSupport extends TransportThreadSupport { - private TextWireFormat textWireFormat; + private HttpTransportMarshaller marshaller; private URI remoteUrl; private String proxyHost; private int proxyPort = 8080; private String proxyUser; private String proxyPassword; - public HttpTransportSupport(TextWireFormat textWireFormat, URI remoteUrl) { - this.textWireFormat = textWireFormat; + public HttpTransportSupport(final HttpTransportMarshaller marshaller, final URI remoteUrl) { + this.marshaller = marshaller; this.remoteUrl = remoteUrl; } @@ -53,12 +53,8 @@ public URI getRemoteUrl() { return remoteUrl; } - public TextWireFormat getTextWireFormat() { - return textWireFormat; - } - - public void setTextWireFormat(TextWireFormat textWireFormat) { - this.textWireFormat = textWireFormat; + public HttpTransportMarshaller getMarshaller() { + return marshaller; } public String getProxyHost() { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTunnelServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTunnelServlet.java index e6dc7c9d0e6..2143acbe0b4 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTunnelServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/HttpTunnelServlet.java @@ -17,10 +17,8 @@ package org.apache.activemq.transport.http; import java.io.BufferedReader; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; -import java.io.InputStreamReader; import java.util.HashMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -38,10 +36,14 @@ import org.apache.activemq.command.WireFormatInfo; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportAcceptListener; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; +import org.apache.activemq.transport.http.marshallers.HttpWireFormatMarshaller; +import org.apache.activemq.transport.http.marshallers.TextWireFormatMarshallers; import org.apache.activemq.transport.util.TextWireFormat; import org.apache.activemq.transport.xstream.XStreamWireFormat; import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.ServiceListener; +import org.apache.activemq.wireformat.WireFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,7 +58,7 @@ public class HttpTunnelServlet extends HttpServlet { private TransportAcceptListener listener; private HttpTransportFactory transportFactory; - private TextWireFormat wireFormat; + private HttpTransportMarshaller marshaller; private ConcurrentMap clients = new ConcurrentHashMap(); private final long requestTimeout = 30000L; private HashMap transportOptions; @@ -74,10 +76,15 @@ public void init() throws ServletException { throw new ServletException("No such attribute 'transportFactory' available in the ServletContext"); } transportOptions = (HashMap)getServletContext().getAttribute("transportOptions"); - wireFormat = (TextWireFormat)getServletContext().getAttribute("wireFormat"); + WireFormat wireFormat = (WireFormat) getServletContext().getAttribute("wireFormat"); if (wireFormat == null) { wireFormat = createWireFormat(); } + if (wireFormat instanceof TextWireFormat) { + marshaller = TextWireFormatMarshallers.newServletMarshaller(wireFormat); + } else { + marshaller = new HttpWireFormatMarshaller(wireFormat); + } } @Override @@ -104,8 +111,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t packet = (Command)transportChannel.getQueue().poll(requestTimeout, TimeUnit.MILLISECONDS); - DataOutputStream stream = new DataOutputStream(response.getOutputStream()); - wireFormat.marshal(packet, stream); + marshaller.marshal(packet, response.getOutputStream()); count++; } catch (InterruptedException ignore) { } @@ -124,8 +130,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) stream = new GZIPInputStream(stream); } - // Read the command directly from the reader, assuming UTF8 encoding - Command command = (Command) wireFormat.unmarshalText(new InputStreamReader(stream, "UTF-8")); + final Command command = (Command)marshaller.unmarshal(stream); if (command instanceof WireFormatInfo) { WireFormatInfo info = (WireFormatInfo) command; diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTextWireFormatMarshaller.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTextWireFormatMarshaller.java new file mode 100644 index 00000000000..a3dd25bb24e --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTextWireFormatMarshaller.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.marshallers; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.apache.activemq.transport.util.TextWireFormat; +import org.apache.activemq.wireformat.WireFormat; + +/** + * A {@link HttpTransportMarshaller} implementation using a {@link TextWireFormat} and UTF8 encoding. + */ +public class HttpTextWireFormatMarshaller implements HttpTransportMarshaller +{ + private static final Charset CHARSET = StandardCharsets.UTF_8; + private final TextWireFormat wireFormat; + + public HttpTextWireFormatMarshaller(final TextWireFormat wireFormat) { + this.wireFormat = wireFormat; + } + + @Override + public void marshal(final Object command, final OutputStream outputStream) throws IOException { + final String s = wireFormat.marshalText(command); + outputStream.write(s.getBytes(CHARSET)); + } + + @Override + public Object unmarshal(final InputStream stream) throws IOException { + return wireFormat.unmarshalText(new InputStreamReader(stream, CHARSET)); + } + + @Override + public WireFormat getWireFormat() { + return wireFormat; + } +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTransportMarshaller.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTransportMarshaller.java new file mode 100644 index 00000000000..68c154efc0d --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpTransportMarshaller.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.marshallers; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import org.apache.activemq.wireformat.WireFormat; + +/** + * A generic interface for marshallers used for HTTP communication. + */ +public interface HttpTransportMarshaller +{ + /** + * The implementations of this method should be able to marshall the supplied object into the output stream. + * + * @param command the object to marshall + * @param outputStream output stream for the serialised form. + * @throws IOException + */ + void marshal(final Object command, final OutputStream outputStream) throws IOException; + + /** + * The implementations of this method handle unmarshalling of objects from a wire format into Java objects. + * + * @param stream the stream with the serialised form of an object + * @return the deserialised object + * @throws IOException + */ + Object unmarshal(final InputStream stream) throws IOException; + + /** + * + * @return the wire format used by this marshaller + */ + WireFormat getWireFormat(); +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpWireFormatMarshaller.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpWireFormatMarshaller.java new file mode 100644 index 00000000000..660faadfafb --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/HttpWireFormatMarshaller.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.marshallers; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import org.apache.activemq.wireformat.WireFormat; + +public class HttpWireFormatMarshaller implements HttpTransportMarshaller +{ + private final WireFormat wireFormat; + + public HttpWireFormatMarshaller(final WireFormat wireFormat) { + this.wireFormat = wireFormat; + } + + @Override + public void marshal(final Object command, final OutputStream outputStream) throws IOException { + final DataOutputStream out = new DataOutputStream(outputStream); + wireFormat.marshal(command, out); + out.flush(); + } + + @Override + public Object unmarshal(final InputStream stream) throws IOException { + return wireFormat.unmarshal(new DataInputStream(stream)); + } + + @Override + public WireFormat getWireFormat() { + return wireFormat; + } +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/TextWireFormatMarshallers.java b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/TextWireFormatMarshallers.java new file mode 100644 index 00000000000..f7796671002 --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/http/marshallers/TextWireFormatMarshallers.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.marshallers; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.apache.activemq.transport.util.TextWireFormat; +import org.apache.activemq.wireformat.WireFormat; + +/** + * A factory for marshallers {@link HttpTransportMarshaller} that maintain compatibility with the original + * ActiveMQ code that used {@link TextWireFormat#marshalText(Object)} and {@link TextWireFormat#marshal(Object)} depending + * on the context. + * All text handling is done using UTF-8. + */ +public class TextWireFormatMarshallers { + private static final Charset CHARSET = StandardCharsets.UTF_8; + + /** + * The returned marshaller uses {@link TextWireFormat#marshal(Object)} and {@link TextWireFormat#unmarshalText(Reader)}. + */ + public static HttpTransportMarshaller newServletMarshaller(final WireFormat wireFormat) { + return new MarshalPlainUnmarshalTextMarshaller((TextWireFormat)wireFormat); + } + + /** + * The returned marshaller uses {@link TextWireFormat#marshalText(Object)} and {@link TextWireFormat#unmarshal(DataInput)} + */ + public static HttpTransportMarshaller newTransportMarshaller(final TextWireFormat textWireFormat) { + return new MarshalTextUnmarshalPlainMarshaller(textWireFormat); + } + + private static class MarshalTextUnmarshalPlainMarshaller implements HttpTransportMarshaller { + private final TextWireFormat wireFormat; + + private MarshalTextUnmarshalPlainMarshaller(final TextWireFormat wireFormat) { + this.wireFormat = wireFormat; + } + + @Override + public void marshal(final Object command, final OutputStream outputStream) throws IOException { + final String s = wireFormat.marshalText(command); + outputStream.write(s.getBytes(CHARSET)); + } + + @Override + public Object unmarshal(final InputStream stream) throws IOException { + return wireFormat.unmarshal(new DataInputStream(stream)); + } + + @Override + public WireFormat getWireFormat() { + return wireFormat; + } + } + + private static class MarshalPlainUnmarshalTextMarshaller implements HttpTransportMarshaller { + private final TextWireFormat wireFormat; + + private MarshalPlainUnmarshalTextMarshaller(final TextWireFormat wireFormat) { + this.wireFormat = wireFormat; + } + + @Override + public void marshal(final Object command, final OutputStream outputStream) throws IOException { + wireFormat.marshal(command, new DataOutputStream(outputStream)); + } + + @Override + public Object unmarshal(final InputStream stream) throws IOException { + return wireFormat.unmarshalText(new InputStreamReader(stream, CHARSET)); + } + + @Override + public WireFormat getWireFormat() { + return wireFormat; + } + } +} \ No newline at end of file diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsClientTransport.java b/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsClientTransport.java index 2e432fcaf14..dbedf6f3027 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsClientTransport.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsClientTransport.java @@ -22,7 +22,7 @@ import org.apache.activemq.broker.SslContext; import org.apache.activemq.transport.http.HttpClientTransport; -import org.apache.activemq.transport.util.TextWireFormat; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; import org.apache.activemq.util.IOExceptionSupport; import org.apache.http.conn.ClientConnectionManager; import org.apache.http.conn.scheme.Scheme; @@ -32,8 +32,8 @@ public class HttpsClientTransport extends HttpClientTransport { - public HttpsClientTransport(TextWireFormat wireFormat, URI remoteUrl) { - super(wireFormat, remoteUrl); + public HttpsClientTransport(final HttpTransportMarshaller marshaller, URI remoteUrl) { + super(marshaller, remoteUrl); } @Override diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsTransportFactory.java b/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsTransportFactory.java index bf382aa74f3..0f64145979f 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsTransportFactory.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/https/HttpsTransportFactory.java @@ -37,6 +37,12 @@ */ public class HttpsTransportFactory extends HttpTransportFactory { + public HttpsTransportFactory() {} + + public HttpsTransportFactory(final String defaultWireFormatType) { + super(defaultWireFormatType); + } + public TransportServer doBind(String brokerId, URI location) throws IOException { return doBind(location); } @@ -67,6 +73,6 @@ protected Transport createTransport(URI location, WireFormat wf) throws Malforme cause.initCause(e); throw cause; } - return new HttpsClientTransport(asTextWireFormat(wf), uri); + return new HttpsClientTransport(createMarshaller(wf), uri); } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpClientTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpClientTransportTest.java new file mode 100644 index 00000000000..49281b9fda1 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpClientTransportTest.java @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http; + +import java.io.ByteArrayInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.activemq.command.ConsumerInfo; +import org.apache.activemq.transport.http.marshallers.TextWireFormatMarshallers; +import org.apache.activemq.transport.xstream.XStreamWireFormat; +import org.apache.commons.io.IOUtils; +import org.apache.commons.io.output.ByteArrayOutputStream; +import org.apache.http.HttpResponse; +import org.apache.http.ProtocolVersion; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.InputStreamEntity; +import org.apache.http.message.BasicHttpResponse; +import org.apache.http.message.BasicStatusLine; +import org.apache.http.params.BasicHttpParams; +import org.hamcrest.CoreMatchers; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +public class HttpClientTransportTest { + + @Rule + public final MockitoRule rule = MockitoJUnit.rule(); + + @Mock + private HttpClient sendHttpClient; + + @Mock + private HttpClient receiveHttpClient; + + @Test + public void testPreservesAsymmetricalMarshalling() throws Exception { + final AtomicReference unmarshalledCommand = new AtomicReference<>(); + + final HttpClientTransport httpClientTransport = new HttpClientTransport(TextWireFormatMarshallers.newTransportMarshaller(new XStreamWireFormat()), URI.create("http://localhost")) { + @Override + public HttpClient getSendHttpClient() { + return sendHttpClient; + } + + @Override + public HttpClient getReceiveHttpClient() { + return receiveHttpClient; + } + + @Override + public void doConsume(final Object command) { + unmarshalledCommand.set(command); + try { + stop(); + } catch (Exception e) { + } + } + }; + + final AtomicReference marshalledCommand = new AtomicReference<>(); + + { + when(sendHttpClient.getParams()).thenReturn(new BasicHttpParams()); + when(sendHttpClient.execute(Mockito.any())).thenAnswer(new Answer() { + @Override + public HttpResponse answer(final InvocationOnMock invocation) throws Throwable { + final HttpPost method = invocation.getArgumentAt(0, HttpPost.class); + final String entityBody = IOUtils.toString(method.getEntity().getContent()); + marshalledCommand.set(entityBody); + return newHttpOkResponse(); + } + }); + + httpClientTransport.oneway(new ConsumerInfo()); + assertThat(marshalledCommand.get(), CoreMatchers.startsWith("<")); + } + + { + final BasicHttpResponse httpOkResponse = newHttpOkResponse(); + httpOkResponse.setEntity(new InputStreamEntity(new ByteArrayInputStream(toMarshalledMessage(marshalledCommand)))); + when(receiveHttpClient.execute(Mockito.any())).thenReturn(httpOkResponse); + httpClientTransport.run(); + assertThat(unmarshalledCommand.get(), CoreMatchers.instanceOf(ConsumerInfo.class)); + } + } + + private byte[] toMarshalledMessage(AtomicReference marshalledCommand) throws IOException { + final byte[] textBytes = marshalledCommand.get().getBytes(StandardCharsets.UTF_8); + + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dataOutputStream = new DataOutputStream(baos); + dataOutputStream.writeInt(textBytes.length); + dataOutputStream.write(textBytes); + dataOutputStream.flush(); + return baos.toByteArray(); + } + + private BasicHttpResponse newHttpOkResponse() { + return new BasicHttpResponse(new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), HttpURLConnection.HTTP_OK, "OK")); + } +} \ No newline at end of file diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpOpenWireSendAndReceiveTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpOpenWireSendAndReceiveTest.java new file mode 100755 index 00000000000..11ee097925e --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpOpenWireSendAndReceiveTest.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.http; + +import org.apache.activemq.ActiveMQConnectionFactory; +import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.test.JmsTopicSendReceiveWithTwoConnectionsTest; +import org.apache.activemq.transport.TransportFactory; +import org.apache.activemq.transport.http.openwire.AssertingTransportFactory; +import org.apache.activemq.transport.http.openwire.CustomHttpTransportFactory; +import org.apache.activemq.transport.http.openwire.SpyMarshaller; +import org.apache.activemq.transport.http.marshallers.HttpWireFormatMarshaller; + +import java.util.LinkedList; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.junit.Assert.assertThat; + +/** + * + */ +public class HttpOpenWireSendAndReceiveTest extends JmsTopicSendReceiveWithTwoConnectionsTest { + private static final String CUSTOM_HTTP_PROTOCOL = "http"; + private static final String WIRE_FORMAT_OPENWIRE = "default"; + private final AssertingTransportFactory clientTransportFactory = new AssertingTransportFactory(WIRE_FORMAT_OPENWIRE, HttpWireFormatMarshaller.class); + + protected BrokerService broker; + + @Override + protected void setUp() throws Exception { + if (broker == null) { + broker = createBroker(); + broker.start(); + } + super.setUp(); + WaitForJettyListener.waitForJettySocketToAccept(getBrokerURL()); + } + + @Override + protected void tearDown() throws Exception { + super.tearDown(); + if (broker != null) { + broker.stop(); + } + } + + @Override + public void testSendReceive() throws Exception + { + super.testSendReceive(); + final LinkedList usedMarshallers = clientTransportFactory.getSpyMarshallers(); + assertThat(usedMarshallers.size(), equalTo(2)); + final SpyMarshaller marshaller1 = usedMarshallers.pop(); + final SpyMarshaller marshaller2 = usedMarshallers.pop(); + + assertThat(marshaller1.getMarshallCallsCnt(), not(equalTo(0))); + assertThat(marshaller1.getUnmarshallCallsCnt(), not(equalTo(0))); + + assertThat(marshaller1.getMarshallCallsCnt(), equalTo(marshaller2.getMarshallCallsCnt())); + assertThat(marshaller1.getUnmarshallCallsCnt(), equalTo(marshaller2.getUnmarshallCallsCnt())); + } + + protected String getBrokerURL() { + return "http://localhost:8161"; + } + + protected BrokerService createBroker() throws Exception { + final BrokerService broker = new BrokerService(); + broker.setPersistent(false); + addConnector(broker, getBrokerURL()); + return broker; + } + + @Override + protected ActiveMQConnectionFactory createConnectionFactory() { + TransportFactory.registerTransportFactory(CUSTOM_HTTP_PROTOCOL, clientTransportFactory); + return new ActiveMQConnectionFactory(getBrokerURL()); + } + + private static void addConnector(final BrokerService brokerService, final String brokerURL) throws Exception { + TransportFactory.registerTransportFactory(CUSTOM_HTTP_PROTOCOL, new CustomHttpTransportFactory()); + + brokerService.addConnector(brokerURL); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpTunnelServletTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpTunnelServletTest.java new file mode 100644 index 00000000000..42d2913f122 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/HttpTunnelServletTest.java @@ -0,0 +1,156 @@ +package org.apache.activemq.transport.http; + +import java.io.IOException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import javax.servlet.ReadListener; +import javax.servlet.ServletConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.activemq.command.ConsumerInfo; +import org.apache.activemq.transport.TransportAcceptListener; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import static org.hamcrest.CoreMatchers.not; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HttpTunnelServletTest { + + @Rule + public final MockitoRule rule = MockitoJUnit.rule(); + + @Mock + private HttpServletRequest request; + @Mock + private HttpServletResponse response; + + private final MockServletOutputStream servletOutputStream = new MockServletOutputStream(); + + @Before + public void setup() throws IOException { + when(response.getOutputStream()).thenReturn(servletOutputStream); + } + + @Test + public void testPreservesAsymmetricalMarshalling() throws Exception { + final AtomicReference commandRef = new AtomicReference<>(); + + final BlockingQueueTransport transportChannel = newTransportChannel(commandRef); + + final HttpTunnelServlet httpTunnelServlet = newServlet(transportChannel); + + httpTunnelServlet.doGet(request, response); //marshall + + final String wireFormatMessage = servletOutputStream.getContent(); + assertThat(wireFormatMessage, not(CoreMatchers.startsWith("<"))); + + final String message = toTextMessage(wireFormatMessage); + when(request.getInputStream()).thenReturn(new MockServletInputStream(message)); + httpTunnelServlet.doPost(request, response); //unmarshallText + assertThat(commandRef.get(), CoreMatchers.instanceOf(ConsumerInfo.class)); + } + + private HttpTunnelServlet newServlet(final BlockingQueueTransport transportChannel) throws ServletException { + final HttpTunnelServlet httpTunnelServlet = new HttpTunnelServlet() { + @Override + protected BlockingQueueTransport getTransportChannel(final HttpServletRequest request, final HttpServletResponse response) { + return transportChannel; + } + }; + final ServletConfig servletConfig = mock(ServletConfig.class); + + final ServletContext servletContext = mockServletContext(); + when(servletConfig.getServletContext()).thenReturn(servletContext); + httpTunnelServlet.init(servletConfig); + return httpTunnelServlet; + } + + private BlockingQueueTransport newTransportChannel(final AtomicReference commandRef) { + final BlockingQueueTransport transportChannel = new BlockingQueueTransport(new ArrayBlockingQueue<>(10)) { + @Override + public void doConsume(final Object command) { + commandRef.set(command); + } + }; + transportChannel.getQueue().offer(new ConsumerInfo()); + return transportChannel; + } + + private ServletContext mockServletContext() { + final ServletContext servletContext = mock(ServletContext.class); + final TransportAcceptListener acceptListener = mock(TransportAcceptListener.class); + when(servletContext.getAttribute(eq("acceptListener"))).thenReturn(acceptListener); + when(servletContext.getAttribute(eq("transportFactory"))).thenReturn(new HttpTransportFactory()); + return servletContext; + } + + private String toTextMessage(final String message) { + return message.substring(message.indexOf('<')); + } + + private static class MockServletOutputStream extends ServletOutputStream { + private final StringBuilder sb = new StringBuilder(); + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setWriteListener(final WriteListener writeListener) { + } + + @Override + public void write(final int b) throws IOException { + sb.append((char)b); + } + + public String getContent() { + final String s = sb.toString(); + sb.setLength(0); + return s; + } + } + + private class MockServletInputStream extends ServletInputStream { + private final String string; + private int pos; + + private MockServletInputStream(final String message) { + string = message; + } + + @Override + public boolean isFinished() { + return pos==string.length(); + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setReadListener(final ReadListener readListener) { + } + + @Override + public int read() throws IOException { + return isFinished() ? -1 : string.charAt(pos++); + } + } +} \ No newline at end of file diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/AssertingTransportFactory.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/AssertingTransportFactory.java new file mode 100644 index 00000000000..dc1a691422c --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/AssertingTransportFactory.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.openwire; + +import org.apache.activemq.transport.http.HttpTransportFactory; +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; +import org.apache.activemq.wireformat.WireFormat; +import org.hamcrest.CoreMatchers; + +import java.util.LinkedList; + +import static org.junit.Assert.assertThat; + +/** + * Ensures that all transports created by this factory are of the expected type. + */ +public class AssertingTransportFactory extends HttpTransportFactory { + private final Class expectedMarshallerType; + private final LinkedList spyMarshallers = new LinkedList<>(); + + public AssertingTransportFactory(final String wireFormat, final Class expectedMarshallerType) { + super(wireFormat); + this.expectedMarshallerType = expectedMarshallerType; + } + + @Override + protected HttpTransportMarshaller createMarshaller(final WireFormat wireFormat) + { + final HttpTransportMarshaller marshaller = super.createMarshaller(wireFormat); + assertThat("Unexpected marshaller used", marshaller, CoreMatchers.instanceOf(expectedMarshallerType)); + final SpyMarshaller spyMarshaller = new SpyMarshaller(marshaller); + spyMarshallers.add(spyMarshaller); + return spyMarshaller; + } + + public LinkedList getSpyMarshallers() { + return spyMarshallers; + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportFactory.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportFactory.java new file mode 100644 index 00000000000..16258514e22 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportFactory.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.openwire; + +import org.apache.activemq.transport.TransportServer; +import org.apache.activemq.transport.http.HttpTransportFactory; +import org.apache.activemq.transport.http.HttpTransportServer; +import org.apache.activemq.util.IOExceptionSupport; +import org.apache.activemq.util.IntrospectionSupport; +import org.apache.activemq.util.URISupport; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; + +public class CustomHttpTransportFactory extends HttpTransportFactory +{ + @Override + public TransportServer doBind(final URI location) throws IOException { + try { + final Map options = new HashMap(URISupport.parseParameters(location)); + final HttpTransportServer result = new CustomHttpTransportServer(location, this); + final Map transportOptions = IntrospectionSupport.extractProperties(options, "transport."); + result.setTransportOption(transportOptions); + return result; + } catch (final URISyntaxException e) { + throw IOExceptionSupport.create(e); + } + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportServer.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportServer.java new file mode 100644 index 00000000000..3f27994e554 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/CustomHttpTransportServer.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.openwire; + +import org.apache.activemq.openwire.OpenWireFormatFactory; +import org.apache.activemq.transport.http.HttpTransportFactory; +import org.apache.activemq.transport.http.HttpTransportServer; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.component.AbstractLifeCycle.AbstractLifeCycleListener; +import org.eclipse.jetty.util.component.LifeCycle; + +import java.net.URI; + +public class CustomHttpTransportServer extends HttpTransportServer { + private final HttpTransportFactory transportFactory; + + public CustomHttpTransportServer(final URI location, final CustomHttpTransportFactory transportFactory) { + super(location, transportFactory); + this.transportFactory = transportFactory; + } + + @Override + protected void createServer() { + super.createServer(); + + server.addLifeCycleListener(new AbstractLifeCycleListener() + { + @Override + public void lifeCycleStarting(final LifeCycle event) + { + setupServletContext((ServletContextHandler)server.getHandler()); + } + }); + } + + private void setupServletContext(final ServletContextHandler handler) { + ServletContextAttributes.setAcceptListener(handler, getAcceptListener()); + ServletContextAttributes.setTransportOptions(handler, transportOptions); + ServletContextAttributes.setTransportFactory(handler, transportFactory); + ServletContextAttributes.setWireFormat(handler, new OpenWireFormatFactory().createWireFormat()); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/ServletContextAttributes.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/ServletContextAttributes.java new file mode 100644 index 00000000000..22b09368659 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/ServletContextAttributes.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.openwire; + +import org.apache.activemq.transport.TransportAcceptListener; +import org.apache.activemq.transport.http.HttpTransportFactory; +import org.apache.activemq.wireformat.WireFormat; +import org.eclipse.jetty.servlet.ServletContextHandler; + +import java.util.Map; + +public final class ServletContextAttributes { + + private ServletContextAttributes() {} + + public static void setWireFormat(final ServletContextHandler servletContext, final WireFormat wireFormat) { + servletContext.setAttribute("wireFormat", wireFormat); + } + + public static void setTransportFactory(final ServletContextHandler servletContext, final HttpTransportFactory transportFactory) { + servletContext.setAttribute("transportFactory", transportFactory); + } + + public static void setTransportOptions(final ServletContextHandler servletContext, final Map transportOptions) { + servletContext.setAttribute("transportOptions", transportOptions); + } + + public static void setAcceptListener(final ServletContextHandler servletContext, final TransportAcceptListener acceptListener) { + servletContext.setAttribute("acceptListener", acceptListener); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/SpyMarshaller.java b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/SpyMarshaller.java new file mode 100644 index 00000000000..ae9e97825c4 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/http/openwire/SpyMarshaller.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.http.openwire; + +import org.apache.activemq.transport.http.marshallers.HttpTransportMarshaller; +import org.apache.activemq.wireformat.WireFormat; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicInteger; + +public class SpyMarshaller implements HttpTransportMarshaller { + private final HttpTransportMarshaller marshaller; + + private final AtomicInteger marshallCalls = new AtomicInteger(); + private final AtomicInteger unmarshallCalls = new AtomicInteger(); + + public SpyMarshaller(final HttpTransportMarshaller marshaller) { + this.marshaller = marshaller; + } + + @Override + public void marshal(final Object command, final OutputStream outputStream) throws IOException { + marshallCalls.incrementAndGet(); + marshaller.marshal(command, outputStream); + } + + @Override + public Object unmarshal(final InputStream stream) throws IOException { + unmarshallCalls.incrementAndGet(); + return marshaller.unmarshal(stream); + } + + public int getMarshallCallsCnt() + { + return marshallCalls.get(); + } + + public int getUnmarshallCallsCnt() + { + return unmarshallCalls.get(); + } + + @Override + public WireFormat getWireFormat() { + return marshaller.getWireFormat(); + } +}