diff --git a/java/org/apache/tomcat/websocket/Constants.java b/java/org/apache/tomcat/websocket/Constants.java index 7ec14131bda5..f619c596428a 100644 --- a/java/org/apache/tomcat/websocket/Constants.java +++ b/java/org/apache/tomcat/websocket/Constants.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import jakarta.websocket.Extension; @@ -94,6 +95,11 @@ public class Constants { // Milliseconds so this is 20 seconds public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + // Configuration for session close timeout + public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT"; + // Default is 30 seconds - setting is in milliseconds + public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + // Configuration for read idle timeout on WebSocket session public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS"; diff --git a/java/org/apache/tomcat/websocket/WsSession.java b/java/org/apache/tomcat/websocket/WsSession.java index f85c9c21ac35..be16756bf4cf 100644 --- a/java/org/apache/tomcat/websocket/WsSession.java +++ b/java/org/apache/tomcat/websocket/WsSession.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -115,6 +116,7 @@ public class WsSession implements Session { private volatile long lastActiveRead = System.currentTimeMillis(); private volatile long lastActiveWrite = System.currentTimeMillis(); private Map futures = new ConcurrentHashMap<>(); + private volatile Long sessionCloseTimeoutExpiry; /** @@ -593,7 +595,14 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal */ state.set(State.CLOSED); // ... and close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); + } else { + /* + * Set close timeout. If the client fails to send a close message response within the timeout, the session + * and the connection will be closed when the timeout expires. + */ + sessionCloseTimeoutExpiry = + Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout())); } // Fail any uncompleted messages. @@ -632,7 +641,7 @@ public void onClose(CloseReason closeReason) { state.set(State.CLOSED); // Close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) { /* * The local endpoint sent a close message the the same time as the remote endpoint. The local close is @@ -644,12 +653,55 @@ public void onClose(CloseReason closeReason) { * The local endpoint sent the first close message. The remote endpoint has now responded with its own close * message so mark the session as fully closed and close the network connection. */ - wsRemoteEndpoint.close(); + closeConnection(); } // CLOSING and CLOSED are NO-OPs } + private void closeConnection() { + /* + * Close the network connection. + */ + wsRemoteEndpoint.close(); + /* + * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for + * tracking the session close timeout. + */ + webSocketContainer.unregisterSession(getSessionMapKey(), this); + } + + + /* + * Returns the session close timeout in milliseconds + */ + protected long getSessionCloseTimeout() { + long result = 0; + Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY); + if (obj instanceof Long) { + result = ((Long) obj).intValue(); + } + if (result <= 0) { + result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT; + } + return result; + } + + + protected void checkCloseTimeout() { + // Skip the check if no session close timeout has been set. + if (sessionCloseTimeoutExpiry != null) { + // Check if the timeout has expired. + if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) { + // Check if the session has been closed in another thread while the timeout was being processed. + if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + closeConnection(); + } + } + } + } + + private void fireEndpointOnClose(CloseReason closeReason) { // Fire the onClose event @@ -722,7 +774,7 @@ private void sendCloseMessage(CloseReason closeReason) { if (log.isDebugEnabled()) { log.debug(sm.getString("wsSession.sendCloseFail", id), e); } - wsRemoteEndpoint.close(); + closeConnection(); // Failure to send a close message is not unexpected in the case of // an abnormal closure (usually triggered by a failure to read/write // from/to the client. In this case do not trigger the endpoint's @@ -730,8 +782,6 @@ private void sendCloseMessage(CloseReason closeReason) { if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { localEndpoint.onError(this, e); } - } finally { - webSocketContainer.unregisterSession(getSessionMapKey(), this); } } @@ -864,6 +914,11 @@ public String getQueryString() { @Override public Principal getUserPrincipal() { checkState(); + return getUserPrincipalInternal(); + } + + + public Principal getUserPrincipalInternal() { return userPrincipal; } diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index 034e30d2a056..e6376ce4b210 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -604,7 +604,12 @@ Set getOpenSessions(Object key) { synchronized (endPointSessionMapLock) { Set sessions = endpointSessionMap.get(key); if (sessions != null) { - result.addAll(sessions); + // Some sessions may be in the process of closing + for (WsSession session : sessions) { + if (session.isOpen()) { + result.add(session); + } + } } } return result; @@ -1019,8 +1024,10 @@ public void backgroundProcess() { if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; + // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); + wsSession.checkCloseTimeout(); } } diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/java/org/apache/tomcat/websocket/server/WsServerContainer.java index 8fb4eb967ca0..b3b37ca45640 100644 --- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java +++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java @@ -349,7 +349,7 @@ protected void registerSession(Object key, WsSession wsSession) { */ @Override protected void unregisterSession(Object key, WsSession wsSession) { - if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { + if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) { unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } super.unregisterSession(key, wsSession); diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java index cb54821662a8..f624f5c87c95 100644 --- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java +++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java @@ -23,6 +23,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import jakarta.servlet.ServletContextEvent; +import jakarta.servlet.ServletContextListener; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.CloseReason; import jakarta.websocket.ContainerProvider; @@ -39,7 +41,9 @@ import org.apache.catalina.servlets.DefaultServlet; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint; +import org.apache.tomcat.websocket.server.Constants; import org.apache.tomcat.websocket.server.TesterEndpointConfig; +import org.apache.tomcat.websocket.server.WsServerContainer; public class TestWsSessionSuspendResume extends WebSocketBaseTest { @@ -141,4 +145,99 @@ void addMessage(String message) { } } } + + + @Test + public void testSuspendThenClose() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(SuspendCloseConfig.class.getName()); + ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName()); + + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); + + wsSession.getBasicRemote().sendText("start test"); + + // Wait for the client response to be received by the server + int count = 0; + while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) { + Thread.sleep(100); + count ++; + } + Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed()); + } + + + public static final class SuspendCloseConfig extends TesterEndpointConfig { + private static final String PATH = "/echo"; + + @Override + protected Class getEndpointClass() { + return SuspendCloseEndpoint.class; + } + + @Override + protected ServerEndpointConfig getServerEndpointConfig() { + return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build(); + } + } + + + public static final class SuspendCloseEndpoint extends Endpoint { + + // Yes, a static variable is a hack. + private static WsSession serverSession; + + @Override + public void onOpen(Session session, EndpointConfig epc) { + serverSession = (WsSession) session; + // Set a short session close timeout (milliseconds) + serverSession.getUserProperties().put( + org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000)); + // Any message will trigger the suspend then close + serverSession.addMessageHandler(String.class, message -> { + try { + serverSession.getBasicRemote().sendText("server session open"); + serverSession.getBasicRemote().sendText("suspending server session"); + serverSession.suspend(); + serverSession.getBasicRemote().sendText("closing server session"); + serverSession.close(); + } catch (IOException ioe) { + ioe.printStackTrace(); + // Attempt to make the failure more obvious + throw new RuntimeException(ioe); + } + }); + } + + @Override + public void onError(Session session, Throwable t) { + t.printStackTrace(); + } + + public static boolean isServerSessionFullyClosed() { + return serverSession.isClosed(); + } + } + + + public static class WebSocketFastServerTimeout implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + container.setProcessPeriod(0); + } + } } \ No newline at end of file diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index 377e08e7a3f3..227c4cb3dc54 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -220,6 +220,11 @@ Review usage of debug logging and downgrade trace or data dumping operations from debug level to trace. (remm) + + Ensure that WebSocket connection closure completes if the connection is + closed when the server side has used the proprietary suspend/resume + feature to suspend the connection. (markt) + diff --git a/webapps/docs/web-socket-howto.xml b/webapps/docs/web-socket-howto.xml index 49d155bd25db..60231694cbf0 100644 --- a/webapps/docs/web-socket-howto.xml +++ b/webapps/docs/web-socket-howto.xml @@ -64,6 +64,13 @@ the timeout to use in milliseconds. For an infinite timeout, use -1.

+

The session close timeout defaults to 30000 milliseconds (30 seconds). This + may be changed by setting the property + org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT in the user + properties collection attached to the WebSocket session. The value assigned + to this property should be a Long and represents the timeout to + use in milliseconds. Values less than or equal to zero will be ignored.

+

In addition to the Session.setMaxIdleTimeout(long) method which is part of the Jakarta WebSocket API, Tomcat provides greater control of the timing out the session due to lack of activity. Setting the property