From 86ccc43940861703c2be96a5f35384407522125a Mon Sep 17 00:00:00 2001 From: Mark Thomas Date: Thu, 18 Feb 2021 16:41:57 +0000 Subject: [PATCH] Ensure ReadListener.onError() is fired if client drops the connection --- .../coyote/http11/Http11InputBuffer.java | 34 ++-- .../catalina/core/TestAsyncContextImpl.java | 170 +++++++++++++++- .../nonblocking/TestNonBlockingAPI.java | 192 ++++++++++++++++++ webapps/docs/changelog.xml | 6 + 4 files changed, 388 insertions(+), 14 deletions(-) diff --git a/java/org/apache/coyote/http11/Http11InputBuffer.java b/java/org/apache/coyote/http11/Http11InputBuffer.java index c12df8aff54..e3ace892372 100644 --- a/java/org/apache/coyote/http11/Http11InputBuffer.java +++ b/java/org/apache/coyote/http11/Http11InputBuffer.java @@ -761,11 +761,13 @@ void init(SocketWrapperBase socketWrapper) { private boolean fill(boolean block) throws IOException { if (log.isDebugEnabled()) { - log.debug("Before fill(): [" + parsingHeader + + log.debug("Before fill(): parsingHeader: [" + parsingHeader + "], parsingRequestLine: [" + parsingRequestLine + "], parsingRequestLinePhase: [" + parsingRequestLinePhase + "], parsingRequestLineStart: [" + parsingRequestLineStart + - "], byteBuffer.position() [" + byteBuffer.position() + "]"); + "], byteBuffer.position(): [" + byteBuffer.position() + + "], byteBuffer.limit(): [" + byteBuffer.limit() + + "], end: [" + end + "]"); } if (parsingHeader) { @@ -780,19 +782,25 @@ private boolean fill(boolean block) throws IOException { byteBuffer.limit(end).position(end); } - byteBuffer.mark(); - if (byteBuffer.position() < byteBuffer.limit()) { - byteBuffer.position(byteBuffer.limit()); - } - byteBuffer.limit(byteBuffer.capacity()); - SocketWrapperBase socketWrapper = this.wrapper; int nRead = -1; - if (socketWrapper != null) { - nRead = socketWrapper.read(block, byteBuffer); - } else { - throw new CloseNowException(sm.getString("iib.eof.error")); + byteBuffer.mark(); + try { + if (byteBuffer.position() < byteBuffer.limit()) { + byteBuffer.position(byteBuffer.limit()); + } + byteBuffer.limit(byteBuffer.capacity()); + SocketWrapperBase socketWrapper = this.wrapper; + if (socketWrapper != null) { + nRead = socketWrapper.read(block, byteBuffer); + } else { + throw new CloseNowException(sm.getString("iib.eof.error")); + } + } finally { + // Ensure that the buffer limit and position are returned to a + // consistent "ready for read" state if an error occurs during in + // the above code block. + byteBuffer.limit(byteBuffer.position()).reset(); } - byteBuffer.limit(byteBuffer.position()).reset(); if (log.isDebugEnabled()) { log.debug("Received [" diff --git a/test/org/apache/catalina/core/TestAsyncContextImpl.java b/test/org/apache/catalina/core/TestAsyncContextImpl.java index c8607e79586..e242917302e 100644 --- a/test/org/apache/catalina/core/TestAsyncContextImpl.java +++ b/test/org/apache/catalina/core/TestAsyncContextImpl.java @@ -17,6 +17,7 @@ package org.apache.catalina.core; import java.io.IOException; +import java.io.InputStream; import java.io.PrintWriter; import java.net.URI; import java.net.URISyntaxException; @@ -866,7 +867,7 @@ public void run() { } } - private static class TrackingListener implements AsyncListener { + public static class TrackingListener implements AsyncListener { private final boolean completeOnError; private final boolean completeOnTimeout; @@ -3016,4 +3017,171 @@ public void run() { } } + + /* + * Tests an error on an async thread when the client closes the connection + * before fully writing the request body. + * + * Required sequence is: + * - enter Servlet's service() method + * - startAsync() + * - start async thread + * - read partial body + * - close client connection + * - read on async thread -> I/O error + * - exit Servlet's service() method + * + * This test makes extensive use of instance fields in the Servlet that + * would normally be considered very poor practice. It is only safe in this + * test as the Servlet only processes a single request. + */ + @Test + public void testCanceledPost() throws Exception { + CountDownLatch partialReadLatch = new CountDownLatch(1); + CountDownLatch clientCloseLatch = new CountDownLatch(1); + CountDownLatch threadCompleteLatch = new CountDownLatch(1); + + AtomicBoolean testFailed = new AtomicBoolean(true); + + // Setup Tomcat instance + Tomcat tomcat = getTomcatInstance(); + + // No file system docBase required + Context ctx = tomcat.addContext("", null); + + PostServlet postServlet = new PostServlet(partialReadLatch, clientCloseLatch, threadCompleteLatch, testFailed); + Wrapper wrapper = Tomcat.addServlet(ctx, "postServlet", postServlet); + wrapper.setAsyncSupported(true); + ctx.addServletMappingDecoded("/*", "postServlet"); + + tomcat.start(); + + PostClient client = new PostClient(); + client.setPort(getPort()); + client.setRequest(new String[] { "POST / HTTP/1.1" + SimpleHttpClient.CRLF + + "Host: localhost:" + SimpleHttpClient.CRLF + + "Content-Length: 100" + SimpleHttpClient.CRLF + + SimpleHttpClient.CRLF + + "This is 16 bytes" + }); + client.connect(); + client.sendRequest(); + + // Wait server to read partial request body + partialReadLatch.await(); + + client.disconnect(); + + clientCloseLatch.countDown(); + + threadCompleteLatch.await(); + + Assert.assertFalse(testFailed.get()); + } + + + private static final class PostClient extends SimpleHttpClient { + + @Override + public boolean isResponseBodyOK() { + return true; + } + } + + + private static final class PostServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + private final transient CountDownLatch partialReadLatch; + private final transient CountDownLatch clientCloseLatch; + private final transient CountDownLatch threadCompleteLatch; + private final AtomicBoolean testFailed; + + public PostServlet(CountDownLatch doPostLatch, CountDownLatch clientCloseLatch, + CountDownLatch threadCompleteLatch, AtomicBoolean testFailed) { + this.partialReadLatch = doPostLatch; + this.clientCloseLatch = clientCloseLatch; + this.threadCompleteLatch = threadCompleteLatch; + this.testFailed = testFailed; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + + AsyncContext ac = req.startAsync(); + Thread t = new PostServletThread(ac, partialReadLatch, clientCloseLatch, threadCompleteLatch, testFailed); + t.start(); + + try { + threadCompleteLatch.await(); + } catch (InterruptedException e) { + // Ignore + } + } + } + + + private static final class PostServletThread extends Thread { + + private final AsyncContext ac; + private final CountDownLatch partialReadLatch; + private final CountDownLatch clientCloseLatch; + private final CountDownLatch threadCompleteLatch; + private final AtomicBoolean testFailed; + + public PostServletThread(AsyncContext ac, CountDownLatch partialReadLatch, CountDownLatch clientCloseLatch, + CountDownLatch threadCompleteLatch, AtomicBoolean testFailed) { + this.ac = ac; + this.partialReadLatch = partialReadLatch; + this.clientCloseLatch = clientCloseLatch; + this.threadCompleteLatch = threadCompleteLatch; + this.testFailed = testFailed; + } + + @Override + public void run() { + try { + int bytesRead = 0; + byte[] buffer = new byte[32]; + InputStream is = null; + + try { + is = ac.getRequest().getInputStream(); + + // Read the partial request body + while (bytesRead < 16) { + int read = is.read(buffer); + if (read == -1) { + // Error condition + return; + } + bytesRead += read; + } + } catch (IOException ioe) { + // Error condition + return; + } finally { + partialReadLatch.countDown(); + } + + // Wait for client to close connection + clientCloseLatch.await(); + + // Read again + try { + is.read(); + } catch (IOException e) { + e.printStackTrace(); + // Required. Clear the error marker. + testFailed.set(false); + } + } catch (InterruptedException e) { + // Ignore + } finally { + threadCompleteLatch.countDown(); + } + } + } } diff --git a/test/org/apache/catalina/nonblocking/TestNonBlockingAPI.java b/test/org/apache/catalina/nonblocking/TestNonBlockingAPI.java index 99329bdfb78..0578b5f4bf4 100644 --- a/test/org/apache/catalina/nonblocking/TestNonBlockingAPI.java +++ b/test/org/apache/catalina/nonblocking/TestNonBlockingAPI.java @@ -32,7 +32,10 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.LogManager; import javax.net.SocketFactory; @@ -46,6 +49,7 @@ import jakarta.servlet.ServletOutputStream; import jakarta.servlet.WriteListener; import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -54,7 +58,9 @@ import org.junit.Test; import org.apache.catalina.Context; +import org.apache.catalina.Wrapper; import org.apache.catalina.startup.BytesStreamer; +import org.apache.catalina.startup.SimpleHttpClient; import org.apache.catalina.startup.TesterServlet; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; @@ -1114,4 +1120,190 @@ public void onError(Throwable t) { } } + + + /* + * Tests an error on an non-blocking read when the client closes the + * connection before fully writing the request body. + * + * Required sequence is: + * - enter Servlet's service() method + * - startAsync() + * - configure non-blocking read + * - read partial body + * - close client connection + * - error is triggered + * - exit Servlet's service() method + * + * This test makes extensive use of instance fields in the Servlet that + * would normally be considered very poor practice. It is only safe in this + * test as the Servlet only processes a single request. + */ + @Test + public void testCanceledPost() throws Exception { + + LogManager.getLogManager().getLogger("org.apache.coyote").setLevel(Level.ALL); + LogManager.getLogManager().getLogger("org.apache.tomcat.util.net").setLevel(Level.ALL); + + CountDownLatch partialReadLatch = new CountDownLatch(1); + CountDownLatch completeLatch = new CountDownLatch(1); + + AtomicBoolean testFailed = new AtomicBoolean(true); + + // Setup Tomcat instance + Tomcat tomcat = getTomcatInstance(); + + // No file system docBase required + Context ctx = tomcat.addContext("", null); + + PostServlet postServlet = new PostServlet(partialReadLatch, completeLatch, testFailed); + Wrapper wrapper = Tomcat.addServlet(ctx, "postServlet", postServlet); + wrapper.setAsyncSupported(true); + ctx.addServletMappingDecoded("/*", "postServlet"); + + tomcat.start(); + + PostClient client = new PostClient(); + client.setPort(getPort()); + client.setRequest(new String[] { "POST / HTTP/1.1" + SimpleHttpClient.CRLF + + "Host: localhost:" + SimpleHttpClient.CRLF + + "Content-Length: 100" + SimpleHttpClient.CRLF + + SimpleHttpClient.CRLF + + "This is 16 bytes" + }); + client.connect(); + client.sendRequest(); + + // Wait server to read partial request body + partialReadLatch.await(); + + client.disconnect(); + + completeLatch.await(); + + Assert.assertFalse(testFailed.get()); + } + + + private static final class PostClient extends SimpleHttpClient { + + @Override + public boolean isResponseBodyOK() { + return true; + } + } + + + private static final class PostServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + private final transient CountDownLatch partialReadLatch; + private final transient CountDownLatch completeLatch; + private final AtomicBoolean testFailed; + + public PostServlet(CountDownLatch doPostLatch, CountDownLatch completeLatch, AtomicBoolean testFailed) { + this.partialReadLatch = doPostLatch; + this.completeLatch = completeLatch; + this.testFailed = testFailed; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + + AsyncContext ac = req.startAsync(); + ac.setTimeout(-1); + CanceledPostAsyncListener asyncListener = new CanceledPostAsyncListener(completeLatch); + ac.addListener(asyncListener); + + CanceledPostReadListener readListener = new CanceledPostReadListener(ac, partialReadLatch, testFailed); + req.getInputStream().setReadListener(readListener); + } + } + + + private static final class CanceledPostAsyncListener implements AsyncListener { + + private final transient CountDownLatch completeLatch; + + public CanceledPostAsyncListener(CountDownLatch completeLatch) { + this.completeLatch = completeLatch; + } + + @Override + public void onComplete(AsyncEvent event) throws IOException { + System.out.println("complete"); + completeLatch.countDown(); + } + + @Override + public void onTimeout(AsyncEvent event) throws IOException { + System.out.println("onTimeout"); + } + + @Override + public void onError(AsyncEvent event) throws IOException { + System.out.println("onError-async"); + } + + @Override + public void onStartAsync(AsyncEvent event) throws IOException { + System.out.println("onStartAsync"); + } + } + + private static final class CanceledPostReadListener implements ReadListener { + + private final AsyncContext ac; + private final CountDownLatch partialReadLatch; + private final AtomicBoolean testFailed; + private int totalRead = 0; + + public CanceledPostReadListener(AsyncContext ac, CountDownLatch partialReadLatch, AtomicBoolean testFailed) { + this.ac = ac; + this.partialReadLatch = partialReadLatch; + this.testFailed = testFailed; + } + + @Override + public void onDataAvailable() throws IOException { + ServletInputStream sis = ac.getRequest().getInputStream(); + boolean isReady; + + byte[] buffer = new byte[32]; + do { + if (partialReadLatch.getCount() == 0) { + System.out.println("debug"); + } + int bytesRead = sis.read(buffer); + + if (bytesRead == -1) { + return; + } + totalRead += bytesRead; + isReady = sis.isReady(); + System.out.println("Read [" + bytesRead + + "], buffer [" + new String(buffer, 0, bytesRead, StandardCharsets.UTF_8) + + "], total read [" + totalRead + + "], isReady [" + isReady + "]"); + } while (isReady); + if (totalRead == 16) { + partialReadLatch.countDown(); + } + } + + @Override + public void onAllDataRead() throws IOException { + ac.complete(); + } + + @Override + public void onError(Throwable throwable) { + throwable.printStackTrace(); + // This is the expected behaviour so clear the failed flag. + testFailed.set(false); + ac.complete(); + } + } } diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index 4f3fa1c0aeb..6a2b77e9d2d 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -144,6 +144,12 @@ Avoid NullPointerException when a secure channel is closed before the SSL engine was initialized. (remm) + + Ensure that the ReadListener's onError() event + is triggered if the client closes the connection before sending the + entire request body and the server is ready the request body using + non-blocking I/O. (markt) +