Skip to content

Commit

Permalink
Refactor WebSocket close for suspend/resume
Browse files Browse the repository at this point in the history
Ensure that WebSocket connection closure completes if the connection is
closed when the server side has used the proprietary suspend/resume
feature to suspend the connection.
  • Loading branch information
markt-asf committed Feb 13, 2024
1 parent 2070b43 commit b0e3b1b
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 8 deletions.
6 changes: 6 additions & 0 deletions java/org/apache/tomcat/websocket/Constants.java
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

import jakarta.websocket.Extension;

Expand Down Expand Up @@ -94,6 +95,11 @@ public class Constants {
// Milliseconds so this is 20 seconds
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;

// Configuration for session close timeout
public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
// Default is 30 seconds - setting is in milliseconds
public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);

// Configuration for read idle timeout on WebSocket session
public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";

Expand Down
67 changes: 61 additions & 6 deletions java/org/apache/tomcat/websocket/WsSession.java
Expand Up @@ -27,6 +27,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -115,6 +116,7 @@ public class WsSession implements Session {
private volatile long lastActiveRead = System.currentTimeMillis();
private volatile long lastActiveWrite = System.currentTimeMillis();
private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
private volatile Long sessionCloseTimeoutExpiry;


/**
Expand Down Expand Up @@ -593,7 +595,14 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal
*/
state.set(State.CLOSED);
// ... and close the network connection.
wsRemoteEndpoint.close();
closeConnection();
} else {
/*
* Set close timeout. If the client fails to send a close message response within the timeout, the session
* and the connection will be closed when the timeout expires.
*/
sessionCloseTimeoutExpiry =
Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
}

// Fail any uncompleted messages.
Expand Down Expand Up @@ -632,7 +641,7 @@ public void onClose(CloseReason closeReason) {
state.set(State.CLOSED);

// Close the network connection.
wsRemoteEndpoint.close();
closeConnection();
} else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
/*
* The local endpoint sent a close message the the same time as the remote endpoint. The local close is
Expand All @@ -644,12 +653,55 @@ public void onClose(CloseReason closeReason) {
* The local endpoint sent the first close message. The remote endpoint has now responded with its own close
* message so mark the session as fully closed and close the network connection.
*/
wsRemoteEndpoint.close();
closeConnection();
}
// CLOSING and CLOSED are NO-OPs
}


private void closeConnection() {
/*
* Close the network connection.
*/
wsRemoteEndpoint.close();
/*
* Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
* tracking the session close timeout.
*/
webSocketContainer.unregisterSession(getSessionMapKey(), this);
}


/*
* Returns the session close timeout in milliseconds
*/
protected long getSessionCloseTimeout() {
long result = 0;
Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
if (obj instanceof Long) {
result = ((Long) obj).intValue();
}
if (result <= 0) {
result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
}
return result;
}


protected void checkCloseTimeout() {
// Skip the check if no session close timeout has been set.
if (sessionCloseTimeoutExpiry != null) {
// Check if the timeout has expired.
if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
// Check if the session has been closed in another thread while the timeout was being processed.
if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
closeConnection();
}
}
}
}


private void fireEndpointOnClose(CloseReason closeReason) {

// Fire the onClose event
Expand Down Expand Up @@ -722,16 +774,14 @@ private void sendCloseMessage(CloseReason closeReason) {
if (log.isDebugEnabled()) {
log.debug(sm.getString("wsSession.sendCloseFail", id), e);
}
wsRemoteEndpoint.close();
closeConnection();
// Failure to send a close message is not unexpected in the case of
// an abnormal closure (usually triggered by a failure to read/write
// from/to the client. In this case do not trigger the endpoint's
// error handling
if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
localEndpoint.onError(this, e);
}
} finally {
webSocketContainer.unregisterSession(getSessionMapKey(), this);
}
}

Expand Down Expand Up @@ -864,6 +914,11 @@ public String getQueryString() {
@Override
public Principal getUserPrincipal() {
checkState();
return getUserPrincipalInternal();
}


public Principal getUserPrincipalInternal() {
return userPrincipal;
}

Expand Down
9 changes: 8 additions & 1 deletion java/org/apache/tomcat/websocket/WsWebSocketContainer.java
Expand Up @@ -604,7 +604,12 @@ Set<Session> getOpenSessions(Object key) {
synchronized (endPointSessionMapLock) {
Set<WsSession> sessions = endpointSessionMap.get(key);
if (sessions != null) {
result.addAll(sessions);
// Some sessions may be in the process of closing
for (WsSession session : sessions) {
if (session.isOpen()) {
result.add(session);
}
}
}
}
return result;
Expand Down Expand Up @@ -1019,8 +1024,10 @@ public void backgroundProcess() {
if (backgroundProcessCount >= processPeriod) {
backgroundProcessCount = 0;

// Check all registered sessions.
for (WsSession wsSession : sessions.keySet()) {
wsSession.checkExpiration();
wsSession.checkCloseTimeout();
}
}

Expand Down
Expand Up @@ -349,7 +349,7 @@ protected void registerSession(Object key, WsSession wsSession) {
*/
@Override
protected void unregisterSession(Object key, WsSession wsSession) {
if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) {
if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
}
super.unregisterSession(key, wsSession);
Expand Down
99 changes: 99 additions & 0 deletions test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
Expand Up @@ -23,6 +23,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.servlet.ServletContextEvent;
import jakarta.servlet.ServletContextListener;
import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.CloseReason;
import jakarta.websocket.ContainerProvider;
Expand All @@ -39,7 +41,9 @@
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
import org.apache.tomcat.websocket.server.WsServerContainer;

public class TestWsSessionSuspendResume extends WebSocketBaseTest {

Expand Down Expand Up @@ -141,4 +145,99 @@ void addMessage(String message) {
}
}
}


@Test
public void testSuspendThenClose() throws Exception {
Tomcat tomcat = getTomcatInstance();

Context ctx = getProgrammaticRootContext();
ctx.addApplicationListener(SuspendCloseConfig.class.getName());
ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());

Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");

tomcat.start();

WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));

wsSession.getBasicRemote().sendText("start test");

// Wait for the client response to be received by the server
int count = 0;
while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
Thread.sleep(100);
count ++;
}
Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
}


public static final class SuspendCloseConfig extends TesterEndpointConfig {
private static final String PATH = "/echo";

@Override
protected Class<?> getEndpointClass() {
return SuspendCloseEndpoint.class;
}

@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
}
}


public static final class SuspendCloseEndpoint extends Endpoint {

// Yes, a static variable is a hack.
private static WsSession serverSession;

@Override
public void onOpen(Session session, EndpointConfig epc) {
serverSession = (WsSession) session;
// Set a short session close timeout (milliseconds)
serverSession.getUserProperties().put(
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
// Any message will trigger the suspend then close
serverSession.addMessageHandler(String.class, message -> {
try {
serverSession.getBasicRemote().sendText("server session open");
serverSession.getBasicRemote().sendText("suspending server session");
serverSession.suspend();
serverSession.getBasicRemote().sendText("closing server session");
serverSession.close();
} catch (IOException ioe) {
ioe.printStackTrace();
// Attempt to make the failure more obvious
throw new RuntimeException(ioe);
}
});
}

@Override
public void onError(Session session, Throwable t) {
t.printStackTrace();
}

public static boolean isServerSessionFullyClosed() {
return serverSession.isClosed();
}
}


public static class WebSocketFastServerTimeout implements ServletContextListener {

@Override
public void contextInitialized(ServletContextEvent sce) {
WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
container.setProcessPeriod(0);
}
}
}
5 changes: 5 additions & 0 deletions webapps/docs/changelog.xml
Expand Up @@ -220,6 +220,11 @@
Review usage of debug logging and downgrade trace or data dumping
operations from debug level to trace. (remm)
</fix>
<fix>
Ensure that WebSocket connection closure completes if the connection is
closed when the server side has used the proprietary suspend/resume
feature to suspend the connection. (markt)
</fix>
</changelog>
</subsection>
<subsection name="Web applications">
Expand Down
7 changes: 7 additions & 0 deletions webapps/docs/web-socket-howto.xml
Expand Up @@ -64,6 +64,13 @@
the timeout to use in milliseconds. For an infinite timeout, use
<code>-1</code>.</p>

<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This
may be changed by setting the property
<code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user
properties collection attached to the WebSocket session. The value assigned
to this property should be a <code>Long</code> and represents the timeout to
use in milliseconds. Values less than or equal to zero will be ignored.</p>

<p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which
is part of the Jakarta WebSocket API, Tomcat provides greater control of the
timing out the session due to lack of activity. Setting the property
Expand Down

0 comments on commit b0e3b1b

Please sign in to comment.