diff --git a/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/DefaultRewriteLifecycleListener.java b/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/DefaultRewriteLifecycleListener.java index fad96ca18..50a25384a 100644 --- a/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/DefaultRewriteLifecycleListener.java +++ b/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/DefaultRewriteLifecycleListener.java @@ -1,11 +1,14 @@ package org.ocpsoft.rewrite.servlet.impl; +import javax.servlet.ServletRequest; + import org.ocpsoft.rewrite.event.Rewrite; import org.ocpsoft.rewrite.servlet.http.event.HttpServletRewrite; import org.ocpsoft.rewrite.servlet.spi.RewriteLifecycleListener; public class DefaultRewriteLifecycleListener implements RewriteLifecycleListener { + private static final String REQUEST_NESTING_KEY = DefaultRewriteLifecycleListener.class + "_request_nesting"; @Override public boolean handles(Rewrite payload) @@ -20,17 +23,38 @@ public int priority() } @Override - public void afterInboundLifecycle(HttpServletRewrite event) + public void beforeInboundRewrite(HttpServletRewrite event) { - HttpRewriteWrappedResponse.getInstance(event.getRequest()).flushBufferedStreams(); + incrementRequestNesting(event); } @Override - public void beforeInboundLifecycle(HttpServletRewrite event) - {} + public void afterInboundLifecycle(HttpServletRewrite event) + { + decrementRequestNesting(event); + if (getRequestNesting(event.getRequest()) == 0) + HttpRewriteWrappedResponse.getInstance(event.getRequest()).flushBufferedStreams(); + } + + private void decrementRequestNesting(HttpServletRewrite event) + { + if (getRequestNesting(event.getRequest()) > 0) + event.getRequest().setAttribute(REQUEST_NESTING_KEY, getRequestNesting(event.getRequest()) - 1); + } + + private void incrementRequestNesting(HttpServletRewrite event) + { + event.getRequest().setAttribute(REQUEST_NESTING_KEY, getRequestNesting(event.getRequest()) + 1); + } + + public static int getRequestNesting(ServletRequest event) + { + Integer nesting = (Integer) event.getAttribute(REQUEST_NESTING_KEY); + return nesting == null ? 0 : nesting; + } @Override - public void beforeInboundRewrite(HttpServletRewrite event) + public void beforeInboundLifecycle(HttpServletRewrite event) {} @Override diff --git a/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/HttpRewriteWrappedResponse.java b/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/HttpRewriteWrappedResponse.java index 26ed5c13c..4c6d82b1c 100644 --- a/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/HttpRewriteWrappedResponse.java +++ b/impl-servlet/src/main/java/org/ocpsoft/rewrite/servlet/impl/HttpRewriteWrappedResponse.java @@ -77,17 +77,33 @@ public HttpRewriteWrappedResponse(final HttpServletRequest request, final HttpSe private PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(stream, Charset.forName(getCharacterEncoding())), true); private List bufferedStages = new ArrayList(); + private boolean buffersLocked = false; public boolean isBufferingActive() { return !bufferedStages.isEmpty(); } - public void addBufferStage(OutputBuffer stage) + public void addBufferStage(OutputBuffer stage) throws IllegalStateException { + if (areBuffersLocked()) + { + throw new IllegalStateException( + "Cannot add output buffers to Response once request processing has been passed to the application."); + } this.bufferedStages.add(stage); } + private boolean areBuffersLocked() + { + return buffersLocked; + } + + private void lockBuffers() + { + this.buffersLocked = true; + } + public void flushBufferedStreams() { if (isBufferingActive()) @@ -135,6 +151,7 @@ public PrintWriter getWriter() return printWriter; else try { + lockBuffers(); return super.getWriter(); } catch (IOException e) { @@ -149,6 +166,7 @@ public ServletOutputStream getOutputStream() return new ByteArrayServletOutputStream(stream); else try { + lockBuffers(); return super.getOutputStream(); } catch (IOException e) { @@ -159,6 +177,7 @@ public ServletOutputStream getOutputStream() @Override public void setContentLength(int contentLength) { + lockBuffers(); /* * Prevent content-length being set as the page might be modified. */ @@ -172,7 +191,10 @@ public void flushBuffer() throws IOException if (isBufferingActive()) stream.flush(); else + { + lockBuffers(); super.flushBuffer(); + } } private class ByteArrayServletOutputStream extends ServletOutputStream @@ -280,211 +302,185 @@ private void rewrite(final HttpOutboundServletRewrite event) @Override public void addCookie(Cookie cookie) { - // TODO Auto-generated method stub + lockBuffers(); super.addCookie(cookie); } @Override public boolean containsHeader(String name) { - // TODO Auto-generated method stub return super.containsHeader(name); } @Override public void sendError(int sc, String msg) throws IOException { - // TODO Auto-generated method stub + lockBuffers(); super.sendError(sc, msg); } @Override public void sendError(int sc) throws IOException { - // TODO Auto-generated method stub + lockBuffers(); super.sendError(sc); } @Override public void sendRedirect(String location) throws IOException { - // TODO Auto-generated method stub + lockBuffers(); super.sendRedirect(location); } @Override public void setDateHeader(String name, long date) { - // TODO Auto-generated method stub + lockBuffers(); super.setDateHeader(name, date); } @Override public void addDateHeader(String name, long date) { - // TODO Auto-generated method stub + lockBuffers(); super.addDateHeader(name, date); } @Override public void setHeader(String name, String value) { - // TODO Auto-generated method stub + lockBuffers(); super.setHeader(name, value); } @Override public void addHeader(String name, String value) { - // TODO Auto-generated method stub + lockBuffers(); super.addHeader(name, value); } @Override public void setIntHeader(String name, int value) { - // TODO Auto-generated method stub + lockBuffers(); super.setIntHeader(name, value); } @Override public void addIntHeader(String name, int value) { - // TODO Auto-generated method stub + lockBuffers(); super.addIntHeader(name, value); } @Override public void setStatus(int sc) { - // TODO Auto-generated method stub + lockBuffers(); super.setStatus(sc); } @Override + @SuppressWarnings("deprecation") public void setStatus(int sc, String sm) { - // TODO Auto-generated method stub + lockBuffers(); super.setStatus(sc, sm); } @Override public int getStatus() { - // TODO Auto-generated method stub return super.getStatus(); } @Override public String getHeader(String name) { - // TODO Auto-generated method stub return super.getHeader(name); } @Override public Collection getHeaders(String name) { - // TODO Auto-generated method stub return super.getHeaders(name); } @Override public Collection getHeaderNames() { - // TODO Auto-generated method stub return super.getHeaderNames(); } @Override public ServletResponse getResponse() { - // TODO Auto-generated method stub return super.getResponse(); } @Override public void setResponse(ServletResponse response) { - // TODO Auto-generated method stub super.setResponse(response); } @Override public void setCharacterEncoding(String charset) { - // TODO Auto-generated method stub + lockBuffers(); super.setCharacterEncoding(charset); } @Override public String getCharacterEncoding() { - // TODO Auto-generated method stub return super.getCharacterEncoding(); } @Override public void setBufferSize(int size) { - // TODO Auto-generated method stub super.setBufferSize(size); } @Override public int getBufferSize() { - // TODO Auto-generated method stub return super.getBufferSize(); } @Override public boolean isCommitted() { - // TODO Auto-generated method stub return super.isCommitted(); } @Override public void reset() { - // TODO Auto-generated method stub + stream.reset(); super.reset(); } @Override public void resetBuffer() { - // TODO Auto-generated method stub + stream.reset(); super.resetBuffer(); } @Override public void setLocale(Locale loc) { - // TODO Auto-generated method stub super.setLocale(loc); } @Override public Locale getLocale() { - // TODO Auto-generated method stub return super.getLocale(); } - @Override - public boolean isWrapperFor(ServletResponse wrapped) - { - // TODO Auto-generated method stub - return super.isWrapperFor(wrapped); - } - - @Override - public boolean isWrapperFor(Class wrappedType) - { - // TODO Auto-generated method stub - return super.isWrapperFor(wrappedType); - } - } diff --git a/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseConfigurationProvider.java b/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseConfigurationProvider.java index f57aa8915..77aa0b707 100644 --- a/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseConfigurationProvider.java +++ b/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseConfigurationProvider.java @@ -15,6 +15,8 @@ */ package org.ocpsoft.rewrite.servlet.wrapper; +import java.io.IOException; + import javax.servlet.ServletContext; import org.ocpsoft.rewrite.config.Configuration; @@ -24,6 +26,7 @@ import org.ocpsoft.rewrite.servlet.config.HttpOperation; import org.ocpsoft.rewrite.servlet.config.Path; import org.ocpsoft.rewrite.servlet.config.Response; +import org.ocpsoft.rewrite.servlet.config.SendStatus.SendError; import org.ocpsoft.rewrite.servlet.config.rule.Join; import org.ocpsoft.rewrite.servlet.http.event.HttpServletRewrite; import org.ocpsoft.rewrite.servlet.impl.HttpRewriteWrappedResponse; @@ -57,8 +60,8 @@ public Configuration getConfiguration(final ServletContext context) /* * Test unbuffered. Use a Join to perform a forward so we know buffering would have been activated. */ - .addRule(Join.path("/unbuffered").to("/other.html")) - .defineRule().when(Path.matches("/other.html")) + .addRule(Join.path("/unbuffered").to("/unbuffered.html")) + .defineRule().when(Path.matches("/unbuffered.html")) .perform(new HttpOperation() { @Override public void performHttp(HttpServletRewrite event, EvaluationContext context) @@ -72,6 +75,39 @@ public void performHttp(HttpServletRewrite event, EvaluationContext context) Response.setCode(201).perform(event, context); } } + }) + + /* + * Test buffer failure constraints. + */ + .addRule(Join.path("/bufferforward").to("/forward.html")) + .defineRule().when(Path.matches("/forward.html")) + .perform(new HttpOperation() { + @Override + public void performHttp(HttpServletRewrite event, EvaluationContext context) + { + Response.withOutputBufferedBy(new BufferedResponseToLowercase1()).perform(event, context); + Response.setCode(202).perform(event, context); + } + }) + + .defineRule().when(Path.matches("/bufferfail")) + .perform(new HttpOperation() { + @Override + public void performHttp(HttpServletRewrite event, EvaluationContext context) + { + try { + event.getResponse().getOutputStream(); // cause buffers to lock + Response.withOutputBufferedBy(new BufferedResponseToLowercase1()).perform(event, context); + } + catch (IllegalStateException e) { + SendError.code(503).perform(event, context); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + } }); return config; diff --git a/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseTest.java b/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseTest.java index cdf5e0ba8..d62f4d95e 100644 --- a/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseTest.java +++ b/impl-servlet/src/test/java/org/ocpsoft/rewrite/servlet/wrapper/BufferedResponseTest.java @@ -42,7 +42,8 @@ public static WebArchive getDeployment() .getDeployment() .addPackages(true, ServletRoot.class.getPackage()) .addAsWebResource(new StringAsset("UPPERCASE"), "index.html") - .addAsWebResource(new StringAsset("UPPERCASE"), "other.html") + .addAsWebResource(new StringAsset("UPPERCASE"), "unbuffered.html") + .addAsWebResource(new StringAsset("UPPERCASE"), "forward.html") .addAsServiceProvider(ConfigurationProvider.class, BufferedResponseConfigurationProvider.class); return deployment; } @@ -58,8 +59,23 @@ public void testResponseBufferingAppliesAllBuffers() throws Exception @Test public void testResponseBufferingOnlyAppliesWhenBuffersRegistered() throws Exception { - HttpAction action = get("/other.html"); + HttpAction action = get("/unbuffered"); Assert.assertEquals(201, action.getStatusCode()); Assert.assertEquals("UPPERCASE", action.getResponseContent()); } + + @Test + public void testResponseBufferingAcceptedAfterForward() throws Exception + { + HttpAction action = get("/bufferforward"); + Assert.assertEquals(202, action.getStatusCode()); + Assert.assertEquals("uppercase", action.getResponseContent()); + } + + @Test + public void testResponseBufferingRejectedAfterStreamAccessed() throws Exception + { + HttpAction action = get("/bufferfail"); + Assert.assertEquals(503, action.getStatusCode()); + } }